diff --git a/.bazelrc b/.bazelrc index 590a87f5732..7a32ca68e40 100644 --- a/.bazelrc +++ b/.bazelrc @@ -30,6 +30,10 @@ build:monolithic --define framework_shared_object=false # opts in to modular op registration support by default. build --define framework_shared_object=true +# Flags for open source build, always set to be true. +build --define open_source_build=true +test --define open_source_build=true + # Please note that MKL on MacOS or windows is still not supported. # If you would like to use a local MKL instead of downloading, please set the # environment variable "TF_MKL_ROOT" every time before build. @@ -108,6 +112,10 @@ build --spawn_strategy=standalone build --strategy=Genrule=standalone build -c opt +# By default, build TF in C++ 14 mode. +build --cxxopt=-std=c++14 +build --host_cxxopt=-std=c++14 + # Make Bazel print out all options from rc files. build --announce_rc diff --git a/CODEOWNERS b/CODEOWNERS index 2828cf3baf8..25ff318d2d8 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,13 +1,14 @@ # Where component owners are known, add them here. -/tensorflow/c/eager @jaingurav @alextp +/tensorflow/c/eager @jaingaurav @alextp /tensorflow/core/common_runtime/eager @jaingaurav @alextp /tenosrflow/core/debug @caisq /tensorflow/core/nccl/ @azaks2 @chsigg /tensorflow/core/platform/windows/ @mrry /tensorflow/core/platform/s3 @yongtang +/tensorflow/python/autograph/ @mdanatg @kkimdev /tensorflow/python/debug @caisq -/tensorflow/python/eager @jaingurav @alextp +/tensorflow/python/eager @jaingaurav @alextp /tensorflow/python/tools/api/generator/ @annarev /tensorflow/tensorboard/ @jart /tensorflow/tools/docs/ @markdaoust @@ -15,6 +16,7 @@ # contrib # NEED OWNER: /tensorflow/contrib/all_reduce +/tensorflow/contrib/autograph/ @mdanatg @kkimdev /tensorflow/contrib/batching/ @alextp @chrisolston /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva @@ -26,11 +28,10 @@ /tensorflow/contrib/data/ @mrry /tensorflow/tensorflow/contrib/distribute @joshl @priyag @sourabhbajaj @frankchn /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi -/tensorflow/contrib/eager @jaingurav @alextp +/tensorflow/contrib/eager @jaingaurav @alextp /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo /tensorflow/contrib/ffmpeg/ @fredbertsch /tensorflow/contrib/framework/ @ebrevdo -/tensorflow/contrib/gan/ @joel-shor /tensorflow/contrib/graph_editor/ @purpledog # NEED OWNER: /tensorflow/contrib/grid_rnn/ /tensorflow/contrib/hadoop @yongtang diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index a4647020ff7..72304bee694 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -60,7 +60,13 @@ If you are experiencing or witnessing conflict, we ask you to use the following ## Reporting Violations -Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Edd Wilder-James (ewj@google.com) and Sarah Novotny (sarahnovotny@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. +Violations of the Code of Conduct can be reported to TensorFlow’s Project +Stewards, Edd Wilder-James (ewj@google.com) and Thea Lamkin +(thealamkin@google.com). The Project Steward will determine whether the Code of +Conduct was violated, and will issue an appropriate sanction, possibly including +a written warning or expulsion from the project, project sponsored spaces, or +project forums. We ask that you make a good-faith effort to resolve your +conflict via the conflict resolution policy before submitting a report. Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4ed8a8bf2b2..2b285cd91d7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,7 +29,8 @@ Follow either of the two links above to access the appropriate CLA and instructi ### Contributing code If you have improvements to TensorFlow, send us your pull requests! For those -just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). +just getting started, Github has a +[how to](https://help.github.com/articles/using-pull-requests/). TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, a TensorFlow diff --git a/README.md b/README.md index 5a66b9bb03a..1eb06225176 100644 --- a/README.md +++ b/README.md @@ -2,61 +2,58 @@ ------------------ - - | **`Documentation`** | |-----------------| | [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | -**TensorFlow** is an open source software library for numerical computation -using data flow graphs. The graph nodes represent mathematical operations, while -the graph edges represent the multidimensional data arrays (tensors) that flow -between them. This flexible architecture enables you to deploy computation to -one or more CPUs or GPUs in a desktop, server, or mobile device without -rewriting code. TensorFlow also includes -[TensorBoard](https://github.com/tensorflow/tensorboard), a data visualization -toolkit. +[TensorFlow](https://www.tensorflow.org/) is an end-to-end open source platform +for machine learning. It has a comprehensive, flexible ecosystem of +[tools](https://www.tensorflow.org/resources/tools), +[libraries](https://www.tensorflow.org/resources/libraries-extensions), and +[community](https://www.tensorflow.org/community) resources that lets +researchers push the state-of-the-art in ML and developers easily build and +deploy ML powered applications. -TensorFlow was originally developed by researchers and engineers -working on the Google Brain team within Google's Machine Intelligence Research -organization for the purposes of conducting machine learning and deep neural -networks research. The system is general enough to be applicable in a wide -variety of other domains, as well. +TensorFlow was originally developed by researchers and engineers working on the +Google Brain team within Google's Machine Intelligence Research organization for +the purposes of conducting machine learning and deep neural networks research. +The system is general enough to be applicable in a wide variety of other +domains, as well. -TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards -compatible API's for C++, Go, Java, JavaScript, and Swift. +TensorFlow provides stable [Python](https://www.tensorflow.org/api_docs/python) +and [C++](https://www.tensorflow.org/api_docs/cc) APIs, as well as +non-guaranteed backwards compatible API for +[other languages](https://www.tensorflow.org/api_docs). -Keep up to date with release announcements and security updates by -subscribing to +Keep up-to-date with release announcements and security updates by subscribing +to [announce@tensorflow.org](https://groups.google.com/a/tensorflow.org/forum/#!forum/announce). +See all the [mailing lists](https://www.tensorflow.org/community/forums). -## Installation +## Install + +See the [TensorFlow install guide](https://www.tensorflow.org/install) for the +[pip package](https://www.tensorflow.org/install/pip), to +[enable GPU support](https://www.tensorflow.org/install/gpu), use a +[Docker container](https://www.tensorflow.org/install/docker), and +[build from source](https://www.tensorflow.org/install/source). To install the current release for CPU-only: ``` -pip install tensorflow +$ pip install tensorflow ``` -Use the GPU package for CUDA-enabled GPU cards: +Use the GPU package for +[CUDA-enabled GPU cards](https://www.tensorflow.org/install/gpu): ``` -pip install tensorflow-gpu +$ pip install tensorflow-gpu ``` -*See [Installing TensorFlow](https://www.tensorflow.org/install) for detailed -instructions, and how to build from source.* - -People who are a little more adventurous can also try our nightly binaries: - -**Nightly pip packages** * We are pleased to announce that TensorFlow now offers -nightly pip packages under the +*Nightly binaries are available for testing using the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and -[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on PyPi. -Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean -environment to install the nightly TensorFlow build. We support CPU and GPU -packages on Linux, Mac, and Windows. +[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) packages on PyPi.* #### *Try your first TensorFlow program* @@ -74,8 +71,8 @@ $ python 'Hello, TensorFlow!' ``` -Learn more examples about how to do specific tasks in TensorFlow at the -[tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). +For more examples, see the +[TensorFlow tutorials](https://www.tensorflow.org/tutorials/). ## Contribution guidelines @@ -116,6 +113,8 @@ The TensorFlow project strives to abide by generally accepted best practices in Build Type | Status | Artifacts --------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/) +**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | [Release](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) **Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) **Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) **Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) @@ -126,20 +125,23 @@ Build Type **Linux CPU with Intel® MKL-DNN**
**Supports Python 2.7, 3.4, 3.5, and 3.6** | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/lastStableBuild) | [1.13.1 pypi](https://pypi.org/project/intel-tensorflow/) **Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 pypi](https://tensorflow.pypi.thoth-station.ninja/index/) -## For more information +## Resources -* [TensorFlow Website](https://www.tensorflow.org) -* [TensorFlow Tutorials](https://www.tensorflow.org/tutorials/) -* [TensorFlow Model Zoo](https://github.com/tensorflow/models) +* [TensorFlow.org](https://www.tensorflow.org) +* [TensorFlow tutorials](https://www.tensorflow.org/tutorials/) +* [TensorFlow official models](https://github.com/tensorflow/models/tree/master/official) +* [TensorFlow examples](https://github.com/tensorflow/examples) +* [TensorFlow in Practice from Coursera](https://www.coursera.org/specializations/tensorflow-in-practice) +* [TensorFlow blog](https://blog.tensorflow.org) * [TensorFlow Twitter](https://twitter.com/tensorflow) -* [TensorFlow Blog](https://blog.tensorflow.org) -* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) -* [TensorFlow Roadmap](https://www.tensorflow.org/community/roadmap) -* [TensorFlow White Papers](https://www.tensorflow.org/about/bib) -* [TensorFlow YouTube Channel](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) -* [TensorFlow Visualization Toolkit](https://github.com/tensorflow/tensorboard) +* [TensorFlow YouTube](https://www.youtube.com/channel/UC0rqucBdTuFTjJiefW5t-IQ) +* [TensorFlow roadmap](https://www.tensorflow.org/community/roadmap) +* [TensorFlow white papers](https://www.tensorflow.org/about/bib) +* [TensorBoard visualization toolkit](https://github.com/tensorflow/tensorboard) -Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. +Learn more about the +[TensorFlow community](https://www.tensorflow.org/community) and how to +[contribute](https://www.tensorflow.org/community/contribute). ## License diff --git a/RELEASE.md b/RELEASE.md index 6a4c2d6486d..801b9c8a2c8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -43,6 +43,11 @@ * Transitive dependencies on :pooling_ops were removed. Some users may need to add explicit dependencies on :pooling_ops if they reference the operators from that library. +* tf.keras.optimizers default learning rate changes: + * Adadelta: 1.000 to 0.001 + * Adagrad: 0.01 to 0.001 + * Adamax: 0.002 to 0.001 + * NAdam: 0.002 to 0.001 ## Bug Fixes and Other Changes @@ -746,7 +751,7 @@ Ag Ramesh, Alex Wiltschko, Alexander Pantyukhin, Amogh Mannekote, An Jiaoyang, A and [programmers guide page](http://tensorflow.org/versions/r1.9/programmers_guide/keras). * Update `tf.keras` to the Keras 2.1.6 API. * Added [`tf.keras.layers.CuDNNGRU`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNGRU) and [`tf.keras.layers.CuDNNLSTM`](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/keras/layers/CuDNNLSTM) layers. [Try it](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb?linkId=53292082). -* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/boosted_trees). +* Adding support of core [feature columns](https://www.tensorflow.org/get_started/feature_columns) and [losses](https://www.tensorflow.org/api_docs/python/tf/losses) to [gradient boosted trees estimators](https://github.com/tensorflow/models/tree/master/official/r1/boosted_trees). * The [python interface](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/lite) for the [TFLite Optimizing Converter](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/toco/README.md) has been expanded, and the command line interface (AKA: `toco`, `tflite_convert`) is once again diff --git a/WORKSPACE b/WORKSPACE index 43312f350d6..74ea14d0fd7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,7 +7,7 @@ http_archive( sha256 = "5b00383d08dd71f28503736db0500b6fb4dda47489ff5fc6bed42557c07c6ba9", strip_prefix = "rules_closure-308b05b2419edb5c8ee0471b67a40403df940149", urls = [ - "http://mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", "https://github.com/bazelbuild/rules_closure/archive/308b05b2419edb5c8ee0471b67a40403df940149.tar.gz", # 2019-06-13 ], ) @@ -49,9 +49,14 @@ remote_config_workspace() # Apple and Swift rules. http_archive( name = "build_bazel_rules_apple", - sha256 = "23792cd999f97fc97284d1c44cb1324bfdd0bc54aa68ad513fa3705aca3b1f9e", - urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.15.0/rules_apple.0.15.0.tar.gz"], + sha256 = "6efdde60c91724a2be7f89b0c0a64f01138a45e63ba5add2dca2645d981d23a1", + urls = ["https://github.com/bazelbuild/rules_apple/releases/download/0.17.2/rules_apple.0.17.2.tar.gz"], ) # https://github.com/bazelbuild/rules_apple/releases +http_archive( + name = "build_bazel_rules_swift", + sha256 = "96a86afcbdab215f8363e65a10cf023b752e90b23abf02272c4fc668fcb70311", + urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.11.1/rules_swift.0.11.1.tar.gz"], +) # https://github.com/bazelbuild/rules_swift/releases http_archive( name = "build_bazel_apple_support", sha256 = "7356dbd44dea71570a929d1d4731e870622151a5f27164d966dda97305f33471", @@ -62,11 +67,6 @@ http_archive( sha256 = "2ef429f5d7ce7111263289644d233707dba35e39696377ebab8b0bc701f7818e", urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/0.8.0/bazel-skylib.0.8.0.tar.gz"], ) # https://github.com/bazelbuild/bazel-skylib/releases -http_archive( - name = "build_bazel_rules_swift", - sha256 = "9efe9699e9765e6b4a5e063e4a08f6b163cccaf0443f775d935baf5c3cd6ed0e", - urls = ["https://github.com/bazelbuild/rules_swift/releases/download/0.9.0/rules_swift.0.9.0.tar.gz"], -) # https://github.com/bazelbuild/rules_swift/releases http_archive( name = "com_github_apple_swift_swift_protobuf", type = "zip", @@ -104,8 +104,7 @@ http_archive( build_file = "//:models.BUILD", sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", - "http://download.tensorflow.org/models/inception_v1.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", ], ) @@ -114,8 +113,7 @@ http_archive( build_file = "//:models.BUILD", sha256 = "bddd81ea5c80a97adfac1c9f770e6f55cbafd7cce4d3bbe15fbeb041e6b8f3e8", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", - "http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_android_export.zip", ], ) @@ -124,8 +122,7 @@ http_archive( build_file = "//:models.BUILD", sha256 = "859edcddf84dddb974c36c36cfc1f74555148e9c9213dedacf1d6b613ad52b96", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", - "http://download.tensorflow.org/models/mobile_multibox_v1a.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/mobile_multibox_v1a.zip", ], ) @@ -134,8 +131,7 @@ http_archive( build_file = "//:models.BUILD", sha256 = "3d374a730aef330424a356a8d4f04d8a54277c425e274ecb7d9c83aa912c6bfa", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", - "http://download.tensorflow.org/models/stylize_v1.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip", ], ) @@ -144,7 +140,6 @@ http_archive( build_file = "//:models.BUILD", sha256 = "c3ec4fea3158eb111f1d932336351edfe8bd515bb6e87aad4f25dbad0a600d0c", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", - "http://download.tensorflow.org/models/speech_commands_v0.01.zip", + "https://storage.googleapis.com/download.tensorflow.org/models/speech_commands_v0.01.zip", ], ) diff --git a/configure.py b/configure.py index 64022101e97..a01d952bb1e 100644 --- a/configure.py +++ b/configure.py @@ -1145,78 +1145,6 @@ def set_trisycl_include_dir(environ_cp): write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) -def set_mpi_home(environ_cp): - """Set MPI_HOME.""" - - default_mpi_home = which('mpirun') or which('mpiexec') or '' - default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) - - def valid_mpi_path(mpi_home): - exists = ( - os.path.exists(os.path.join(mpi_home, 'include')) and - (os.path.exists(os.path.join(mpi_home, 'lib')) or - os.path.exists(os.path.join(mpi_home, 'lib64')) or - os.path.exists(os.path.join(mpi_home, 'lib32')))) - if not exists: - print( - 'Invalid path to the MPI Toolkit. %s or %s or %s or %s cannot be found' - % (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')), - os.path.exists(os.path.join(mpi_home, 'lib64')), - os.path.exists(os.path.join(mpi_home, 'lib32')))) - return exists - - _ = prompt_loop_or_load_from_env( - environ_cp, - var_name='MPI_HOME', - var_default=default_mpi_home, - ask_for_var='Please specify the MPI toolkit folder.', - check_success=valid_mpi_path, - error_msg='', - suppress_default_error=True) - - -def set_other_mpi_vars(environ_cp): - """Set other MPI related variables.""" - # Link the MPI header files - mpi_home = environ_cp.get('MPI_HOME') - symlink_force('%s/include/mpi.h' % mpi_home, 'third_party/mpi/mpi.h') - - # Determine if we use OpenMPI or MVAPICH, these require different header files - # to be included here to make bazel dependency checker happy - if os.path.exists(os.path.join(mpi_home, 'include/mpi_portable_platform.h')): - symlink_force( - os.path.join(mpi_home, 'include/mpi_portable_platform.h'), - 'third_party/mpi/mpi_portable_platform.h') - # TODO(gunan): avoid editing files in configure - sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = False', - 'MPI_LIB_IS_OPENMPI = True') - else: - # MVAPICH / MPICH - symlink_force( - os.path.join(mpi_home, 'include/mpio.h'), 'third_party/mpi/mpio.h') - symlink_force( - os.path.join(mpi_home, 'include/mpicxx.h'), 'third_party/mpi/mpicxx.h') - # TODO(gunan): avoid editing files in configure - sed_in_place('third_party/mpi/mpi.bzl', 'MPI_LIB_IS_OPENMPI = True', - 'MPI_LIB_IS_OPENMPI = False') - - if os.path.exists(os.path.join(mpi_home, 'lib/libmpi.so')): - symlink_force( - os.path.join(mpi_home, 'lib/libmpi.so'), 'third_party/mpi/libmpi.so') - elif os.path.exists(os.path.join(mpi_home, 'lib64/libmpi.so')): - symlink_force( - os.path.join(mpi_home, 'lib64/libmpi.so'), 'third_party/mpi/libmpi.so') - elif os.path.exists(os.path.join(mpi_home, 'lib32/libmpi.so')): - symlink_force( - os.path.join(mpi_home, 'lib32/libmpi.so'), 'third_party/mpi/libmpi.so') - - else: - raise ValueError( - 'Cannot find the MPI library file in %s/lib or %s/lib64 or %s/lib32' % - (mpi_home, mpi_home, mpi_home)) - - def system_specific_test_config(env): """Add default build and test flags required for TF tests to bazelrc.""" write_to_bazelrc('test --flaky_test_attempts=3') @@ -1549,11 +1477,6 @@ def main(): raise UserInputError('SYCL / CUDA / ROCm are mututally exclusive. ' 'At most 1 GPU platform can be configured.') - set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) - if environ_cp.get('TF_NEED_MPI') == '1': - set_mpi_home(environ_cp) - set_other_mpi_vars(environ_cp) - set_cc_opt_flags(environ_cp) set_system_libs_flag(environ_cp) if is_windows(): diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 61539c5e586..4d34f9849b7 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -7,7 +7,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library_additional_deps_impl") load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_additional_binary_deps", ) load( @@ -356,6 +356,15 @@ config_setting( }, ) +# Flag to indicate open source build, .bazelrc always has it set to be true +config_setting( + name = "oss", + define_values = { + "open_source_build": "true", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "using_cuda_clang_with_dynamic_build", define_values = { @@ -364,11 +373,20 @@ config_setting( }, ) +config_setting( + name = "build_oss_using_cuda_clang", + define_values = { + "using_cuda_clang": "true", + "open_source_build": "true", + }, +) + # Setting to use when loading kernels dynamically config_setting( name = "dynamic_loaded_kernels", define_values = { "dynamic_loaded_kernels": "true", + "framework_shared_object": "true", }, visibility = ["//visibility:public"], ) @@ -389,16 +407,18 @@ config_setting( ) config_setting( - name = "using_rocm_hipcc", + name = "build_oss_using_cuda_nvcc", define_values = { - "using_rocm_hipcc": "true", + "using_cuda_nvcc": "true", + "open_source_build": "true", }, ) config_setting( - name = "with_mpi_support", - values = {"define": "with_mpi_support=true"}, - visibility = ["//visibility:public"], + name = "using_rocm_hipcc", + define_values = { + "using_rocm_hipcc": "true", + }, ) config_setting( @@ -444,6 +464,7 @@ config_setting( package_group( name = "internal", packages = [ + "//perftools/accelerators/xprof/api/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", @@ -607,6 +628,7 @@ tf_cc_shared_object( "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", "//tensorflow/core:tensorflow", + "//tensorflow/core/distributed_runtime/rpc:grpc_session", ], ) @@ -750,8 +772,8 @@ genrule( mkdir $@ for f in $(SRCS); do d="$${f%/*}" - d="$${d#bazel-out*genfiles/}" - d="$${d#*external/eigen_archive/}" + d="$${d#bazel-out/*/genfiles/}" + d="$${d#bazel-out/*/bin/}" if [[ $${d} == *local_config_* ]]; then continue @@ -763,6 +785,9 @@ genrule( if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then continue fi + + d="$${d#*external/farmhash_archive/src}" + d="$${d#*external/$${extname}/}" fi mkdir -p "$@/$${d}" diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 6d1c40a2428..2962a7a60e2 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -27,11 +27,27 @@ import sys as _sys # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.tools import module_util as _module_util +from tensorflow.python.platform import tf_logging as _logging # API IMPORTS PLACEHOLDER # WRAPPER_PLACEHOLDER +if "dev" in __version__: # pylint: disable=undefined-variable + _logging.warning(""" + + TensorFlow's `tf-nightly` package will soon be updated to TensorFlow 2.0. + + Please upgrade your code to TensorFlow 2.0: + * https://www.tensorflow.org/beta/guide/migration_guide + + Or install the latest stable TensorFlow 1.X release: + * `pip install -U "tensorflow==1.*"` + + Otherwise your code may be broken by the change. + + """) + # Make sure directory containing top level submodules is in # the __path__ so that "from tensorflow.foo import bar" works. # We're using bitwise, but there's nothing special about that. diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index dd5a3a08765..ffc457de4aa 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -73,7 +73,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_platform", + "//tensorflow/core/platform:platform", "//tensorflow/core:op_gen_lib", "//tensorflow/core/distributed_runtime:server_lib", ], @@ -264,10 +264,10 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:lib_platform", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/platform", "@com_google_absl//absl/strings", ], ) @@ -355,6 +355,7 @@ tf_cuda_library( deps = [ ":tf_status", ":tf_status_helper", + ":tf_tensor_internal", ] + select({ "//tensorflow:android": [ ":c_api_internal", @@ -467,7 +468,6 @@ tf_cuda_cc_test( "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:no_op_op_lib", - "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:sendrecv_ops_op_lib", "//tensorflow/core:spectral_ops_op_lib", @@ -503,6 +503,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/types:optional", ], ) @@ -579,7 +580,7 @@ tf_cuda_cc_test( "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], }), - tags = ["noasan"], + tags = ["no_cuda_on_cpu_tap"], # We must ensure that the dependencies can be dynamically linked since # the shared library must be able to use core:framework. # linkstatic = tf_kernel_tests_linkstatic(), @@ -588,10 +589,11 @@ tf_cuda_cc_test( ":kernels", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/kernels:ops_testutil", + "//third_party/eigen3", ], ) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 62b2504a26d..ed4f10e0f77 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -1024,7 +1024,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, desc->colocation_constraints.insert(location); } } else { - desc->node_builder.Attr(attr_name, attr_value); + desc->node_builder.Attr(attr_name, std::move(attr_value)); } status->status = Status::OK(); @@ -1045,7 +1045,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, std::vector(desc->colocation_constraints.begin(), desc->colocation_constraints.end())); } - status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); + status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret, + /*consume=*/true); if (TF_GetCode(status) == TF_OK) { // Run shape inference function for newly added node. diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index ad0c4068d45..f04f0175696 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -24,6 +24,8 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -596,7 +598,10 @@ struct TF_CheckpointReader : public tensorflow::checkpoint::CheckpointReader { TF_CheckpointReader* TF_NewCheckpointReader(const char* filename, TF_Status* status) { TF_CheckpointReader* reader = new TF_CheckpointReader(filename, status); - if (!status->status.ok()) return nullptr; + if (!status->status.ok()) { + TF_DeleteCheckpointReader(reader); + return nullptr; + } const auto& m = reader->GetVariableToDataTypeMap(); for (auto it = m.begin(); it != m.end(); ++it) reader->variable_list.push_back(it->first); @@ -995,3 +1000,170 @@ TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext( << handle->DebugString(); return ret; } + +TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { + TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; + result->num_items = num_items; + result->items = (num_items == 0) ? nullptr : new TF_ShapeAndType[num_items](); + return result; +} + +void TF_ShapeAndTypeListSetShape(TF_ShapeAndTypeList* shape_list, int index, + const int64_t* dims, int num_dims) { + DCHECK(index >= 0 && index < shape_list->num_items); + TF_ShapeAndType& shape = shape_list->items[index]; + DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!"; + DCHECK(num_dims >= 0) << "Number of dimensions cannot be negative!"; + shape.num_dims = num_dims; + shape.dims = new int64_t[num_dims]; + memcpy(shape.dims, dims, sizeof(int64_t) * num_dims); +} + +void TF_ShapeAndTypeListSetUnknownShape(TF_ShapeAndTypeList* shape_list, + int index) { + DCHECK(index >= 0 && index < shape_list->num_items); + TF_ShapeAndType& shape = shape_list->items[index]; + DCHECK(shape.dims == nullptr) << "Shape at " << index << " is already set!"; + shape.num_dims = -1; + shape.dims = nullptr; +} + +void TF_ShapeAndTypeListSetDtype(TF_ShapeAndTypeList* shape_list, int index, + TF_DataType dtype) { + DCHECK(index >= 0 && index < shape_list->num_items); + TF_ShapeAndType& shape_and_type = shape_list->items[index]; + shape_and_type.dtype = dtype; +} + +void TF_DeleteShapeAndTypeList(TF_ShapeAndTypeList* shape_list) { + if (shape_list == nullptr) return; + for (size_t i = 0; i < shape_list->num_items; ++i) { + delete[] shape_list->items[i].dims; + } + delete[] shape_list->items; + delete shape_list; +} + +void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array, + int num_items) { + if (shape_list_array == nullptr) return; + for (int i = 0; i < num_items; ++i) { + TF_DeleteShapeAndTypeList(shape_list_array[i]); + } + delete[] shape_list_array; +} + +namespace tensorflow { +Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); +} // namespace tensorflow + +void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, + TF_Tensor** input_tensors, + TF_ShapeAndTypeList* input_tensors_as_shapes, + TF_ShapeAndTypeList** input_resource_shapes_and_types, + TF_ShapeAndTypeList** output_shapes, + TF_ShapeAndTypeList*** output_resource_shapes_and_types, + TF_Status* status) { + using tensorflow::NodeDef; + using tensorflow::OpRegistrationData; + using tensorflow::Tensor; + using tensorflow::shape_inference::DimensionHandle; + using tensorflow::shape_inference::InferenceContext; + using tensorflow::shape_inference::ShapeAndType; + using tensorflow::shape_inference::ShapeHandle; + + const int num_inputs = input_shapes->num_items; + NodeDef node_def; + node_def.set_name(tfe_op->operation.Name()); + node_def.set_op(tfe_op->operation.Name()); + for (int i = 0; i < num_inputs; ++i) { + node_def.add_input("dummy_input"); + } + tfe_op->operation.Attrs().FillAttrValueMap(node_def.mutable_attr()); + + const tensorflow::OpRegistrationData* op_reg_data; + status->status = + tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); + if (!status->status.ok()) return; + + // Initialize a input_tensor vector with `nullptr` values. + std::vector input_tensors_vector(num_inputs, nullptr); + // A vector to keep track of newly created `tf::Tensor` objects. + std::vector all_input_tensors; + // Update the vector with information from `input_tensors` if provided. + if (input_tensors != nullptr) { + // Note that we take the address of the elements in `all_input_tensors` + // below. Allocate enough space so that no reallocation happens, which will + // make the pointers invalid. + all_input_tensors.reserve(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + if (input_tensors[i] == nullptr) continue; + all_input_tensors.emplace_back(); + Tensor& input_tensor = all_input_tensors.back(); + status->status = TF_TensorToTensor(input_tensors[i], &input_tensor); + if (!status->status.ok()) return; + input_tensors_vector[i] = &input_tensor; + } + } + + // Create an inference context with dummy values, which will be updated later. + InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def, + std::vector(num_inputs), input_tensors_vector, + {}, + std::vector>>()); + + // Set input_shapes. + for (int i = 0; i < num_inputs; ++i) { + std::vector dims; + const TF_ShapeAndType& input_shape = input_shapes->items[i]; + if (input_shape.num_dims == InferenceContext::kUnknownRank) { + c.SetInput(i, c.UnknownShape()); + continue; + } + for (int j = 0; j < input_shape.num_dims; ++j) { + dims.push_back(c.MakeDim(input_shape.dims[j])); + } + c.SetInput(i, c.MakeShape(dims)); + } + + // TODO(bgogul): Handle input_tensors_as_shapes. + // TODO(bgogul): Handle input_resource_shapes_and_types. + + status->status = c.construction_status(); + if (!status->status.ok()) return; + + if (op_reg_data->shape_inference_fn == nullptr) { + status->status = + InvalidArgument("No shape inference function exists for op '", + node_def.op(), "', did you forget to define it?"); + return; + } + + status->status = c.Run(op_reg_data->shape_inference_fn); + if (!status->status.ok()) return; + + // Set output_shapes. + TF_ShapeAndTypeList* output_shapes_result = + TF_NewShapeAndTypeList(c.num_outputs()); + for (int i = 0; i < c.num_outputs(); ++i) { + ShapeHandle shape_handle = c.output(i); + TF_ShapeAndType& shape = output_shapes_result->items[i]; + shape.num_dims = c.Rank(shape_handle); + if (shape.num_dims == InferenceContext::kUnknownRank) { + shape.dims = nullptr; + continue; + } + shape.dims = new int64_t[shape.num_dims]; + for (size_t j = 0; j < shape.num_dims; ++j) { + shape.dims[j] = c.Value(c.Dim(shape_handle, j)); + } + } + if (output_shapes != nullptr) *output_shapes = output_shapes_result; + + // TODO(bgogul): Set output_resource_shapes_and_types. +} + +void TF_ImportGraphDefOptionsSetValidateColocationConstraints( + TF_ImportGraphDefOptions* opts, unsigned char enable) { + opts->opts.validate_colocation_constraints = enable; +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index d91f3ab8b05..126db2640f6 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -343,6 +343,65 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_ConsumeInputConcreteTensorFromTraceContext(TFE_TraceContext* trace_ctx, unsigned int idx); +// Information about the shape of a Tensor and its type. +struct TF_ShapeAndType { + // Number of dimensions. -1 indicates unknown rank. + int num_dims; + // Array of dimensions. -1 indicates unknown dim. + int64_t* dims; + // The data type. May be 0 to denote unknown type. + TF_DataType dtype; +}; + +typedef struct TF_ShapeAndType TF_ShapeAndType; + +// A list of TF_ShapeAndType elements.. +struct TF_ShapeAndTypeList { + int num_items; + TF_ShapeAndType* items; +}; +typedef struct TF_ShapeAndTypeList TF_ShapeAndTypeList; + +// API for manipulating TF_ShapeAndTypeList objects. +// +TF_CAPI_EXPORT extern TF_ShapeAndTypeList* TF_NewShapeAndTypeList( + int num_shapes); +TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetShape( + TF_ShapeAndTypeList* shape_list, int index, const int64_t* dims, + int num_dims); +TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetUnknownShape( + TF_ShapeAndTypeList* shape_list, int index); +TF_CAPI_EXPORT extern void TF_ShapeAndTypeListSetDtype( + TF_ShapeAndTypeList* shape_list, int index, TF_DataType dtype); +TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList( + TF_ShapeAndTypeList* shape_list); +TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray( + TF_ShapeAndTypeList** shape_list_array, int num_items); + +// Infer shapes for the given `op`. The arguments mimic the arguments of the +// `shape_inference::InferenceContext` constructor. Note the following: +// - The inputs of the `op` are not used for shape inference. So, it is +// OK to not have the inputs properly set in `op`. See `input_tensors` +// if you want shape inference to consider the input tensors of the +// op for shape inference. +// - The types need not be set in `input_shapes` as it is not used. +// - The number of `input_tensors` should be the same as the number of items +// in `input_shapes`. +// +// The results are returned in `output_shapes` and +// `output_resource_shapes_and_types`. The caller is responsible for freeing the +// memory in these buffers by calling `TF_DeleteShapeAndTypeList`. +TF_CAPI_EXPORT extern void TFE_InferShapes( + TFE_Op* op, TF_ShapeAndTypeList* input_shapes, TF_Tensor** input_tensors, + TF_ShapeAndTypeList* input_tensor_as_shapes, + TF_ShapeAndTypeList** input_resource_shapes_and_types, + TF_ShapeAndTypeList** output_shapes, + TF_ShapeAndTypeList*** output_resource_shapes_and_types, TF_Status* status); + +TF_CAPI_EXPORT extern void +TF_ImportGraphDefOptionsSetValidateColocationConstraints( + TF_ImportGraphDefOptions* opts, unsigned char enable); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 55f3a8599fd..ed0ab7c26f8 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_experimental.h" + +#include "absl/types/optional.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" @@ -431,5 +433,155 @@ TEST_F(AddEagerOpToGraphTest, TFE_DeleteTensorHandle(matrix); } +class ShapeInferenceTest : public ::testing::Test { + protected: + ShapeInferenceTest() + : status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) { + tfe_context_ = TFE_NewContext(tfe_context_options_, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + } + + ~ShapeInferenceTest() override { + TFE_DeleteContextOptions(tfe_context_options_); + TFE_DeleteContext(tfe_context_); + TF_DeleteStatus(status_); + } + + // Checks the expected result of shape inference for the given `op`. + void CheckOutputShapes( + TFE_Op* op, + const std::vector>>& input_shapes_vec, + const std::vector& input_tensors, + const absl::optional>& expected_shape) { + // Create input_shapes. + TF_ShapeAndTypeList* input_shapes = + TF_NewShapeAndTypeList(input_shapes_vec.size()); + for (size_t i = 0; i < input_shapes_vec.size(); ++i) { + const auto& input_shape = input_shapes_vec[i]; + if (input_shape.has_value()) { + TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(), + input_shape->size()); + } else { + TF_ShapeAndTypeListSetUnknownShape(input_shapes, i); + } + } + TF_ShapeAndTypeList* output_shapes; + TFE_InferShapes(op, input_shapes, + input_tensors.empty() + ? nullptr + : const_cast(input_tensors.data()), + /*input_tensors_as_shapes*/ nullptr, + /*input_resource_shapes_and_types*/ nullptr, &output_shapes, + /*output_resource_shapes_and_types*/ nullptr, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CHECK_EQ(output_shapes->num_items, 1); + + int num_dims = output_shapes->items[0].num_dims; + int64_t* dims = output_shapes->items[0].dims; + + if (!expected_shape.has_value()) { + EXPECT_EQ(num_dims, -1); + EXPECT_EQ(dims, nullptr); + return; + } + + EXPECT_EQ(num_dims, expected_shape->size()); + for (size_t i = 0; i < num_dims; ++i) { + EXPECT_EQ(dims[i], (*expected_shape)[i]); + } + TF_DeleteShapeAndTypeList(input_shapes); + TF_DeleteShapeAndTypeList(output_shapes); + } + + absl::optional> make_shape( + std::vector&& dims) const { + return absl::make_optional(dims); + } + + absl::optional> unknown_shape() const { + return absl::nullopt; + } + + static constexpr int64_t kUnknownDim = + shape_inference::InferenceContext::kUnknownDim; + TF_Status* status_; + TFE_ContextOptions* tfe_context_options_; + TFE_Context* tfe_context_; +}; + +TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) { + TFE_Op* matmul_op; + matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + // Infer shape when everything is known. + CheckOutputShapes(matmul_op, + /*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})}, + /*input_tensors*/ {}, + /*expected_shape*/ make_shape({3, 4})); + + // Infer shape when second operand has unknown shape. + CheckOutputShapes(matmul_op, + /*input_shapes*/ {make_shape({3, 2}), unknown_shape()}, + /*input_tensors*/ {}, + /*expected_shape*/ make_shape({3, kUnknownDim})); + + // Infer shape when some dimensions are unknown. + CheckOutputShapes( + matmul_op, + /*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})}, + /*input_tensors*/ {}, + /*expected_shape*/ make_shape({kUnknownDim, 4})); + + // Infer shape when everything is unknown. + CheckOutputShapes(matmul_op, + /*input_shapes*/ {unknown_shape(), unknown_shape()}, + /*input_tensors*/ {}, + /*expected_shape*/ make_shape({kUnknownDim, kUnknownDim})); + + TFE_DeleteOp(matmul_op); + // TODO(bgogul): Add some death tests where status is not OK. +} + +TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) { + // Prepare some tensors for shape. + TF_Tensor* tensor_1X6 = Int32Tensor({1, 6}); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6}); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT); + TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32); + CheckOutputShapes(reshape_op, + /* input_shapes*/ {unknown_shape(), unknown_shape()}, + /* input_tensors*/ {nullptr, tensor_1X6}, + /*expected_shape*/ make_shape({1, 6})); + TFE_DeleteOp(reshape_op); + reshape_op = nullptr; + + TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(fill_op, "T", TF_FLOAT); + TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32); + + float five = 5.0; + TFE_TensorHandle* scalar = TestScalarTensorHandle(five); + TF_Tensor* scalarTensor = TFE_TensorHandleResolve(scalar, status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + CheckOutputShapes(fill_op, + /* input_shapes*/ {unknown_shape(), unknown_shape()}, + /* input_tensors*/ {tensor_1X1X6, scalarTensor}, + /*expected_shape*/ make_shape({1, 1, 6})); + TFE_DeleteOp(fill_op); + fill_op = nullptr; + + TFE_DeleteTensorHandle(scalar); + TF_DeleteTensor(scalarTensor); + TF_DeleteTensor(tensor_1X1X6); + TF_DeleteTensor(tensor_1X6); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 20815813d06..bb2be3db087 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -41,6 +41,7 @@ namespace { // node names, so if necessary we add a suffix to make // names unique. If we have an input named "A" and a node in the function // body named "a", they will be renamed to "a" and "a_0". +// TODO(b/139886381) Unify this and the one in graph_to_functiondef.cc class NodeNameMapping { public: NodeNameMapping() = default; @@ -64,14 +65,14 @@ class NodeNameMapping { string Lookup(const string& name) const; private: - string UniquifyHelper(const string& name) const; + string UniquifyHelper(const string& name); static string Normalize(string name); // The normalized/uniquified names already used as // input names (in signature), output names (in signature), and node names // (in node_def). // This is a superset of values in name_mapping_. - std::unordered_set used_names_; + std::unordered_map used_names_; // Mapping from original node name from the graph to the normalized // and uniquified version of it. std::unordered_map name_mapping_; @@ -102,13 +103,16 @@ string NodeNameMapping::Normalize(string name) { return i == n ? "unknown" : name.substr(i); } -string NodeNameMapping::UniquifyHelper(const string& name) const { +string NodeNameMapping::UniquifyHelper(const string& name) { + auto it = used_names_.emplace(name, 0); // If the name hasn't been used yet, use it as-is. - if (used_names_.find(name) == used_names_.end()) return name; + if (it.second) return name; + // Add a suffix to name to make it unique. - for (int i = 0;; ++i) { - const string candidate = strings::StrCat(name, "_", i); - if (used_names_.find(candidate) == used_names_.end()) return candidate; + while (true) { + const string candidate = strings::StrCat(name, "_", it.first->second); + it.first->second++; + if (used_names_.emplace(candidate, 0).second) return candidate; } } @@ -120,16 +124,13 @@ string NodeNameMapping::GetInputName(const string& name) { string NodeNameMapping::GetOutputName(const string& name) { const string& input_name = UniquifyHelper(Normalize(name)); - // Record that we used this name, but don't add it to name_mapping_ - // since this name is not for a node. - used_names_.insert(input_name); + // Don't add it to name_mapping_ since this name is not for a node. return input_name; } string NodeNameMapping::Uniquify(const string& name) { const string uniqued = UniquifyHelper(name); name_mapping_[name] = uniqued; - used_names_.insert(uniqued); return uniqued; } @@ -139,7 +140,7 @@ Status NodeNameMapping::UseOutputName(const string& name) { return InvalidArgument("Cannot have duplicate output names. Name '", name, "' appears more than once in 'output_names' array."); } - used_names_.insert(iter, name); + used_names_.emplace(name, 0); return Status::OK(); } diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 49076039fa7..c97fa93e3a5 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -22,15 +22,16 @@ limitations under the License. #include #include "tensorflow/c/c_test_util.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/graph.pb_text.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" -#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -233,7 +234,7 @@ void TestEncodeDecode(int line, const std::vector& data) { // Create C++ Tensor Tensor src(tensorflow::DT_STRING, TensorShape(dims)); for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { - src.flat()(i) = data[i]; + src.flat()(i) = data[i]; } TF_Tensor* dst = TF_TensorFromTensor(src, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -243,7 +244,7 @@ void TestEncodeDecode(int line, const std::vector& data) { ASSERT_EQ(Status::OK(), TF_TensorToTensor(dst, &output)) << line; ASSERT_EQ(src.NumElements(), output.NumElements()) << line; for (tensorflow::int64 i = 0; i < src.NumElements(); ++i) { - ASSERT_EQ(data[i], output.flat()(i)) << line; + ASSERT_EQ(data[i], output.flat()(i)) << line; } TF_DeleteTensor(dst); @@ -556,7 +557,7 @@ TEST(CAPI, Graph) { EXPECT_FALSE(found_add); found_add = true; } else { - ADD_FAILURE() << "Unexpected NodeDef: " << ProtoDebugString(n); + ADD_FAILURE() << "Unexpected NodeDef: " << n.DebugString(); } } EXPECT_TRUE(found_placeholder); @@ -581,20 +582,20 @@ TEST(CAPI, Graph) { // Compare with first GraphDef + added NodeDef. NodeDef* added_node = graph_def.add_node(); *added_node = node_def; - EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2)); + EXPECT_EQ(graph_def.DebugString(), graph_def2.DebugString()); // Look up some nodes by name. TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg"); EXPECT_TRUE(neg == neg2); NodeDef node_def2; ASSERT_TRUE(GetNodeDef(neg2, &node_def2)); - EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); + EXPECT_EQ(node_def.DebugString(), node_def2.DebugString()); TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed"); EXPECT_TRUE(feed == feed2); ASSERT_TRUE(GetNodeDef(feed, &node_def)); ASSERT_TRUE(GetNodeDef(feed2, &node_def2)); - EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); + EXPECT_EQ(node_def.DebugString(), node_def2.DebugString()); // Test iterating through the nodes of a graph. found_placeholder = false; @@ -618,7 +619,7 @@ TEST(CAPI, Graph) { found_neg = true; } else { ASSERT_TRUE(GetNodeDef(oper, &node_def)); - ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def); + ADD_FAILURE() << "Unexpected Node: " << node_def.DebugString(); } } EXPECT_TRUE(found_placeholder); @@ -1385,7 +1386,7 @@ TEST(CAPI, SavedModel) { tensorflow::Example example; auto* feature_map = example.mutable_features()->mutable_feature(); (*feature_map)["x"].mutable_float_list()->add_value(i); - input.flat()(i) = example.SerializeAsString(); + input.flat()(i) = example.SerializeAsString(); } const tensorflow::string input_op_name( @@ -2498,6 +2499,38 @@ TEST(TestKernel, TestGetRegisteredKernelsForOpNoKernels) { #undef EXPECT_TF_META +TEST(CAPI, TestTensorAligned) { + int64_t dim = 7; + size_t tensor_size_bytes = dim * TF_DataTypeSize(TF_FLOAT); + TF_Tensor* a = TF_AllocateTensor( + /*dtype=*/TF_FLOAT, /*dims=*/&dim, /*num_dims=*/1, + /*len=*/tensor_size_bytes); + float* data = reinterpret_cast(TF_TensorData(a)); + for (int i = 0; i < dim; ++i) { + data[i] = 0; + } + if (EIGEN_MAX_ALIGN_BYTES > 0) { + EXPECT_TRUE(TF_TensorIsAligned(a)); + } + TF_DeleteTensor(a); +} + +TEST(CAPI, TestTensorIsNotAligned) { + // Test unaligned access via a Slice. + Tensor x(DT_FLOAT, TensorShape({30})); + x.flat().setConstant(0.0); + + // Take an unaligned slice. + Tensor y = x.Slice(1, 13); + TF_Status* status = TF_NewStatus(); + TF_Tensor* a = TF_TensorFromTensor(y, status); + if (EIGEN_MAX_ALIGN_BYTES > 0) { + EXPECT_FALSE(TF_TensorIsAligned(a)); + } + TF_DeleteStatus(status); + TF_DeleteTensor(a); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 7eddc17a8e5..5c42e508f71 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -8,12 +8,12 @@ load( "tfe_xla_copts", ) load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_additional_device_tracer_test_flags", "tf_kernel_tests_linkstatic", ) load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) @@ -156,6 +156,7 @@ tf_cuda_cc_test( ], deps = [ ":c_api", + ":c_api_experimental", ":c_api_internal", ":c_api_test_util", "//tensorflow/c:c_test_util", @@ -235,9 +236,11 @@ tf_cuda_cc_test( ], args = ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), + extra_copts = tfe_xla_copts(), linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + ["nomac"], deps = [ + ":c_api", ":c_api_experimental", ":c_api_test_util", "//tensorflow/c:c_test_util", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 22c1f219f38..b70f40cc46a 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -202,9 +202,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( "Currently, TFE_NewContext only supports tensorflow::GrpcServer.")); } - LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - - tensorflow::uint64 context_id = tensorflow::random::New64(); + tensorflow::uint64 context_id = tensorflow::EagerContext::NewContextId(); + // Make master eager context accessible by local eager service, which might + // receive send tensor requests from remote workers. + LOG_AND_RETURN_IF_ERROR(grpc_server->AddMasterEagerContextToEagerService( + context_id, ctx->context)); std::vector remote_workers; grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); @@ -240,9 +242,11 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( &remote_eager_workers)); // Initialize remote eager workers. - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, context_id, keep_alive_secs, server_def, - remote_eager_workers.get(), ctx->context->Async(), base_request)); + // TODO(b/138847548) Create remote eager contexts in async mode by default. + LOG_AND_RETURN_IF_ERROR( + CreateRemoteContexts(remote_workers, context_id, keep_alive_secs, + server_def, remote_eager_workers.get(), + ctx->context->Executor()->Async(), base_request)); tensorflow::RemoteRendezvous* r = grpc_server->worker_env()->rendezvous_mgr->Find(context_id); @@ -261,15 +265,21 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); auto* device_mgr = grpc_server->worker_env()->device_mgr; - auto remote_mgr = - absl::make_unique(/*is_master=*/true); + auto remote_mgr = absl::make_unique( + /*is_master=*/true, ctx->context); - return ctx->context->InitializeRemoteMaster( + LOG_AND_RETURN_IF_ERROR(ctx->context->InitializeRemoteMaster( std::move(server), grpc_server->worker_env(), worker_session, std::move(remote_eager_workers), std::move(remote_device_mgr), remote_workers, context_id, r, device_mgr, keep_alive_secs, - worker_session->cluster_flr.get(), std::move(remote_mgr)); + worker_session->cluster_flr.get(), std::move(remote_mgr))); + + // NOTE: We start the server after all other initialization, because the + // GrpcServer cannot be destroyed after it is started. + LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); #undef LOG_AND_RETURN_IF_ERROR + + return tensorflow::Status::OK(); } #endif // !IS_MOBILE_PLATFORM @@ -365,12 +375,6 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->device_placement_policy = policy; } -TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, - unsigned char enable, - TF_Status* status) { - status->status = ctx->context->SetAsyncForThread(enable); -} - void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { @@ -455,18 +459,6 @@ extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( ctx->context->GetDevicePlacementPolicy()); } -void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context->AsyncWait(); -} - -void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) { - status->status = ctx->context->GetStatus(); -} - -void TFE_ContextAsyncClearError(TFE_Context* ctx) { - ctx->context->ClearAsyncError(); -} - TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); @@ -571,7 +563,8 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { const tensorflow::Tensor* t = nullptr; tensorflow::TensorHandle* h_cpu = nullptr; status->status = EagerCopyToDevice( - handle, handle->Context(), handle->Context()->HostCPU(), false, &h_cpu); + handle, handle->Context(), handle->Context()->Executor(), + handle->Context()->HostCPU(), false, &h_cpu); if (!status->status.ok()) { return nullptr; } @@ -671,7 +664,7 @@ void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { - TF_AttrType ret; + TF_AttrType ret = TF_ATTR_INT; status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), attr_name, &ret, is_list); return ret; @@ -683,10 +676,11 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, TF_Status* status) { TF_AttrType ret; TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); - if (!status->status.ok()) { - return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. + if (status->status.ok()) { + ret = TFE_OpGetAttrType(op, attr_name, is_list, status); + } else { + ret = TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. } - ret = TFE_OpGetAttrType(op, attr_name, is_list, status); TFE_DeleteOp(op); return ret; } @@ -922,6 +916,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, return nullptr; } status->status = tensorflow::EagerCopyToDevice(h->handle, ctx->context, + ctx->context->Executor(), device, false, &handle); if (status->status.ok()) { return new TFE_TensorHandle(handle); @@ -957,12 +952,10 @@ unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) { void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { ctx->context->SetShouldStoreGraphs(true); - ctx->context->SetShouldStoreStepStats(true); } void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { ctx->context->SetShouldStoreGraphs(false); - ctx->context->SetShouldStoreStepStats(false); } } // extern "C" @@ -974,7 +967,7 @@ TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t, void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, TF_Status* status) { - TFE_ContextAsyncWait(ctx, status); + status->status = ctx->context->Executor()->WaitForAllPendingNodes(); if (!status->status.ok()) return; tensorflow::mutex_lock ml(*ctx->context->MetadataMu()); status->status = MessageToBuffer(*ctx->context->RunMetadataProto(), buf); diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h old mode 100755 new mode 100644 index f6850118b89..d29e66dc1b8 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -77,7 +77,7 @@ typedef enum TFE_ContextDevicePlacementPolicy { // LINT.ThenChange(//tensorflow/core/common_runtime/eager/context.h) // Sets the default execution mode (sync/async). Note that this can be -// overridden per thread using TFE_ContextSetAsyncForThread. +// overridden per thread using TFE_ContextSetExecutorForThread. TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, unsigned char enable); @@ -89,6 +89,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); // "Context" under which operations/functions are executed. It encapsulates // things like the available devices, resource manager etc. +// TFE_Context must outlive all tensor handles created using it. In other +// words, TFE_DeleteContext() must be called after all tensor handles have +// been deleted (with TFE_DeleteTensorHandle). // // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; @@ -115,11 +118,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(TFE_Context* ctx); -// Overrides the execution mode (sync/async) for the current thread. -TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, - unsigned char enable, - TF_Status* status); - // A tensorflow.ServerDef specifies remote workers (in addition to the current // workers name). Operations created on this context can then be executed on // any of these remote workers by setting an appropriate device. @@ -132,25 +130,6 @@ TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx, size_t proto_len, TF_Status* status); -// Causes the calling thread to block till all ops dispatched in async mode -// have been executed. Note that "execution" here refers to kernel execution / -// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee -// that lower level device queues (like GPU streams) have been flushed. -// -// This call may not block for execution of ops enqueued concurrently with this -// call. -TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context*, - TF_Status* status); - -// When an error happens, any pending operations are discarded and newly issued -// ops return an error. This call clears the error state and re-enables -// execution of newly issued ops. -// -// Note that outputs of discarded ops remain in a corrupt state and should not -// be used for future calls. -// TODO(agarwal): mark the affected handles and raise errors if they are used. -TF_CAPI_EXPORT extern void TFE_ContextAsyncClearError(TFE_Context*); - // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 32f28a0712c..a9ad77198e7 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -32,9 +32,7 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { op->operation.ConsumeInput(h->handle); } -TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx) { - return new TFE_Profiler(ctx); -} +TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); } bool TFE_ProfilerIsOk(TFE_Profiler* profiler) { return profiler->profiler->Status().ok(); @@ -55,23 +53,10 @@ void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf, }; } -TFE_ProfilerContext* TFE_NewProfilerContext() { - return new TFE_ProfilerContext; -} - -void TFE_ProfilerContextSetEagerContext(TFE_ProfilerContext* profiler_context, - TFE_Context* eager_context) { - profiler_context->profiler_context.eager_context = eager_context->context; -} - -void TFE_DeleteProfilerContext(TFE_ProfilerContext* profiler_context) { - delete profiler_context; -} - -void TFE_StartProfilerServer(TFE_ProfilerContext* context, int port) { - // Release child thread intentionally. The child thread can be terminate by +void TFE_StartProfilerServer(int port) { + // Release child thread intentionally. The child thread can be terminated by // terminating the main thread. - tensorflow::StartProfilerServer(&context->profiler_context, port).release(); + tensorflow::StartProfilerServer(port).release(); } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { @@ -587,3 +572,30 @@ void TFE_OpSetCancellationManager(TFE_Op* op, op->operation.SetCancellationManager( &cancellation_manager->cancellation_manager); } + +TFE_Executor* TFE_NewExecutor(bool is_async) { + return new TFE_Executor(is_async); +} + +void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; } + +bool TFE_ExecutorIsAsync(TFE_Executor* executor) { + return executor->executor()->Async(); +} + +void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor, + TF_Status* status) { + status->status = executor->executor()->WaitForAllPendingNodes(); +} + +void TFE_ExecutorClearError(TFE_Executor* executor) { + executor->executor()->ClearError(); +} + +void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { + ctx->context->SetExecutorForThread(executor->executor()); +} + +TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { + return new TFE_Executor(ctx->context->Executor()); +} diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index cdf1492c0bc..e5a9459faff 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -25,8 +25,6 @@ extern "C" { TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); -typedef struct TFE_ProfilerContext TFE_ProfilerContext; - // A profiler which will start profiling when creating the object and will stop // when the object is destroyed. It will profile all operations run under the // given TFE_Context. Multiple instance of it can be created, but at most one @@ -34,7 +32,7 @@ typedef struct TFE_ProfilerContext TFE_ProfilerContext; // Thread-safety: TFE_Profiler is thread-safe. typedef struct TFE_Profiler TFE_Profiler; -TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(TFE_ProfilerContext* ctx); +TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler(); TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler); TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler); @@ -44,27 +42,14 @@ TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf, TF_Status* status); -// Return a new profiler context object. -TF_CAPI_EXPORT extern TFE_ProfilerContext* TFE_NewProfilerContext(void); - -// Set the eager context in TFE_ProfilerServerOptions -TF_CAPI_EXPORT extern void TFE_ProfilerContextSetEagerContext( - TFE_ProfilerContext* profiler_context, TFE_Context* eager_context); - -// Destroy a profiler context object. -TF_CAPI_EXPORT extern void TFE_DeleteProfilerContext( - TFE_ProfilerContext* profiler_context); - // Start a profiler grpc server which listens to specified port. It will start // the server on its own thread. It can be shutdown by terminating tensorflow. // It can be used in both Eager mode and graph mode. Creating multiple profiler // server is allowed. The service defined in // tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use -// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable -// file following -// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. -TF_CAPI_EXPORT extern void TFE_StartProfilerServer(TFE_ProfilerContext* context, - int port); +// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file +// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace. +TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port); // Enables only graph collection in RunMetadata on the functions executed from // this context. @@ -367,6 +352,51 @@ TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager( TFE_Op* op, TFE_CancellationManager* cancellation_manager, TF_Status* status); +// ----------------------------------------------------------------------------- +// Eager Executor APIs. +typedef struct TFE_Executor TFE_Executor; + +// Creates a new eager Executor. Nodes in one executor are guaranteed to be +// executed in sequence. Assigning nodes to different executors allows executing +// nodes in parallel. +TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(bool is_async); + +// Deletes the eager Executor without waiting for enqueued nodes. Please call +// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to +// make sure all nodes are finished. +TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*); + +// Returns true if the executor is in async mode. +TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*); + +// Causes the calling thread to block till all ops dispatched in this executor +// have been executed. Note that "execution" here refers to kernel execution / +// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee +// that lower level device queues (like GPU streams) have been flushed. +// +// This call may not block for execution of ops enqueued concurrently with this +// call. +TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes( + TFE_Executor*, TF_Status* status); + +// When an error happens, any pending operations are discarded and newly issued +// ops return an error. This call clears the error state and re-enables +// execution of newly issued ops. +// +// Note that outputs of discarded ops remain in a corrupt state and should not +// be used for future calls. +// TODO(agarwal): mark the affected handles and raise errors if they are used. +TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*); + +// Sets a custom Executor for current thread. All nodes created by this thread +// will be added to this Executor. It will override current executor. +TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*, + TFE_Executor*); + +// Returns the Executor for current thread. +TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread( + TFE_Context*); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_experimental_test.cc b/tensorflow/c/eager/c_api_experimental_test.cc index 249d6c8960b..ab76ad10adc 100644 --- a/tensorflow/c/eager/c_api_experimental_test.cc +++ b/tensorflow/c/eager/c_api_experimental_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/cc/profiler/profiler.h" #include "tensorflow/core/lib/monitoring/collection_registry.h" @@ -43,12 +44,9 @@ void ExecuteWithProfiling(bool async) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetAsync(opts, static_cast(async)); TFE_Context* ctx = TFE_NewContext(opts, status); - TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); - TFE_ProfilerContextSetEagerContext(profiler_context, ctx); - TFE_Profiler* profiler = TFE_NewProfiler(profiler_context); + TFE_Profiler* profiler = TFE_NewProfiler(); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); - TFE_DeleteProfilerContext(profiler_context); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -71,8 +69,10 @@ void ExecuteWithProfiling(bool async) { ASSERT_EQ(1, num_retvals); TF_Buffer* profiler_result = TF_NewBuffer(); if (async) { - TFE_ContextAsyncWait(ctx, status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); } TFE_ProfilerSerializeToString(profiler, profiler_result, status); TFE_DeleteProfiler(profiler); @@ -85,7 +85,10 @@ void ExecuteWithProfiling(bool async) { if (!gpu_device_name.empty()) { EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0")); // device name with "stream:all" is collected by Device Tracer. +#ifndef TENSORFLOW_USE_ROCM + // ROCm platform does not yet support stream level tracing EXPECT_TRUE(HasSubstr(profile_proto_str, "stream:all")); +#endif } // "/host:CPU" is collected by TraceMe EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU")); @@ -110,27 +113,14 @@ TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); } TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); } TEST(CAPI, MultipleProfilerSession) { - TF_Status* status = TF_NewStatus(); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetAsync(opts, static_cast(false)); - TFE_Context* ctx = TFE_NewContext(opts, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContextOptions(opts); - - TFE_ProfilerContext* profiler_context = TFE_NewProfilerContext(); - TFE_ProfilerContextSetEagerContext(profiler_context, ctx); - - TFE_Profiler* profiler1 = TFE_NewProfiler(profiler_context); + TFE_Profiler* profiler1 = TFE_NewProfiler(); EXPECT_TRUE(TFE_ProfilerIsOk(profiler1)); - TFE_Profiler* profiler2 = TFE_NewProfiler(profiler_context); + TFE_Profiler* profiler2 = TFE_NewProfiler(); EXPECT_FALSE(TFE_ProfilerIsOk(profiler2)); TFE_DeleteProfiler(profiler1); TFE_DeleteProfiler(profiler2); - TFE_DeleteProfilerContext(profiler_context); - TFE_DeleteContext(ctx); - TF_DeleteStatus(status); } TEST(CAPI, MonitoringCounter0) { @@ -307,5 +297,205 @@ TEST(CAPI, CancellationManager) { TFE_DeleteCancellationManager(c_mgr); } +TEST(CAPI, Function_ident_CPU) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + for (bool async : {false, true, false}) { + TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); + TFE_Executor* executor = TFE_NewExecutor(async); + TFE_ContextSetExecutorForThread(ctx, executor); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_ContextSetExecutorForThread(ctx, old_executor); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + TFE_DeleteExecutor(old_executor); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } + TFE_ContextRemoveFunction(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContext(ctx); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} + +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Function_ident_XLA_CPU) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + for (bool async : {false, true, false}) { + TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); + TFE_Executor* executor = TFE_NewExecutor(async); + TFE_ContextSetExecutorForThread(ctx, executor); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK); + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_ContextSetExecutorForThread(ctx, old_executor); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + TFE_DeleteExecutor(old_executor); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + } + TFE_ContextRemoveFunction(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContext(ctx); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} +#endif // TENSORFLOW_EAGER_USE_XLA + +void Executor_MatMul_CPU(bool async) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_Executor* old_executor = TFE_ContextGetExecutorForThread(ctx); + TFE_Executor* executor = TFE_NewExecutor(async); + TFE_ContextSetExecutorForThread(ctx, executor); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[2] = {nullptr, nullptr}; + int num_retvals = 2; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(1, num_retvals); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retvals[0]); + TFE_ContextSetExecutorForThread(ctx, old_executor); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + TFE_DeleteExecutor(old_executor); + TFE_DeleteContext(ctx); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} +TEST(CAPI, Executor_MatMul_CPU) { Executor_MatMul_CPU(false); } +TEST(CAPI, Executor_MatMul_CPUAsync) { Executor_MatMul_CPU(true); } + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index fe0c952dacb..5efed2ca76d 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -76,7 +76,14 @@ struct TFE_Context { async, device_mgr, device_mgr_owned, rendezvous, custom_kernel_creator)) {} - ~TFE_Context() { context->Unref(); } + ~TFE_Context() { + // TODO(iga): Add a separate API method to shutdown TFE_Context so that we + // don't send RPCs and block in destructor. + context->WaitForAndCloseRemoteContexts(); + // context->RefCountIsOne() should be true here. + // TODO(iga): Remove EagerContext refcounting. + context->Unref(); + } tensorflow::EagerContext* context; }; @@ -130,14 +137,8 @@ struct TFE_Op { std::unique_ptr inference_ctx; }; -struct TFE_ProfilerContext { - tensorflow::ProfilerContext profiler_context; -}; - struct TFE_Profiler { - explicit TFE_Profiler(TFE_ProfilerContext* ctx) { - profiler = tensorflow::ProfilerSession::Create(&ctx->profiler_context); - } + explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); } std::unique_ptr profiler; }; @@ -291,4 +292,19 @@ struct TFE_CancellationManager { tensorflow::CancellationManager cancellation_manager; }; +struct TFE_Executor { + explicit TFE_Executor(bool async) + : owned_executor(new tensorflow::EagerExecutor(async)) {} + + explicit TFE_Executor(tensorflow::EagerExecutor* executor) + : owned_executor(nullptr), unowned_executor(executor) {} + + tensorflow::EagerExecutor* executor() { + return owned_executor == nullptr ? unowned_executor : owned_executor.get(); + } + + std::unique_ptr owned_executor; + tensorflow::EagerExecutor* unowned_executor; +}; + #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index e80620c9a64..d3b755fee6e 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" @@ -78,7 +79,10 @@ void BM_Execute(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } if (async) { - TFE_ContextAsyncWait(ctx, status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); } tensorflow::testing::StopTiming(); TFE_DeleteOp(matmul); @@ -89,6 +93,41 @@ void BM_Execute(int iters, int async) { } BENCHMARK(BM_Execute)->Arg(0)->Arg(1); +void BM_Execute_Identity(int iters, int async) { + tensorflow::testing::StopTiming(); + tensorflow::testing::SetLabel(async ? "ExecuteIdentityAsync" + : "ExecuteIdentity"); + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetAsync(opts, static_cast(async)); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* identity = IdentityOp(ctx, m); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TFE_Execute(identity, &retvals[0], &num_retvals, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + } + if (async) { + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); + } + tensorflow::testing::StopTiming(); + TFE_DeleteOp(identity); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} +BENCHMARK(BM_Execute_Identity)->Arg(0)->Arg(1); + TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -196,8 +235,10 @@ void TestRemoteExecute(bool async) { TFE_DeleteOp(matmul); - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); TFE_DeleteContext(ctx); TF_DeleteStatus(status); @@ -282,9 +323,11 @@ void TestRemoteExecuteSilentCopies(bool async) { TFE_DeleteOp(matmul); - TFE_ContextAsyncWait(ctx, status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); TFE_DeleteContext(ctx); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TF_DeleteStatus(status); @@ -298,7 +341,7 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsync) { TestRemoteExecuteSilentCopies(true); } -void TestRemoteExecuteDeleteTensorAfterContext(bool async) { +void TestRemoteExecuteDeleteContextWithOutstandingRPC(bool async) { tensorflow::ServerDef server_def = GetServerDef(2); // This server def has the task index set to 0. @@ -324,33 +367,49 @@ void TestRemoteExecuteDeleteTensorAfterContext(bool async) { TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + // Use large matrices so that RPCs don't return before we get a chance + // to call TFE_DeleteContext. + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle100x100(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle100x100(); const char remote_device_name[] = "/job:localhost/replica:0/task:1/device:CPU:0"; auto* h0_task1 = TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); TFE_DeleteTensorHandle(h0_task0); - - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_DeleteContext(ctx); - - // Delete tensors after context is deleted. + TFE_DeleteTensorHandle(h1_task0); TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); - TF_DeleteStatus(status); + TFE_DeleteOp(matmul); + + TFE_DeleteContext(ctx); // TODO(b/136478427): Figure out how to correctly shut the server down. worker_server.release(); } -TEST(CAPI, RemoteExecuteDeleteTensorAfterContext) { - TestRemoteExecuteDeleteTensorAfterContext(false); +TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPC) { + TestRemoteExecuteDeleteContextWithOutstandingRPC(false); } -TEST(CAPI, RemoteExecuteDeleteTensorAfterContextAsync) { - TestRemoteExecuteDeleteTensorAfterContext(true); + +TEST(CAPI, RemoteExecuteDeleteContextWithOutstandingRPCAsync) { + TestRemoteExecuteDeleteContextWithOutstandingRPC(true); } void CheckTFE_TensorHandleHasFloats(TFE_TensorHandle* handle, @@ -397,8 +456,10 @@ void CheckRemoteMatMulExecutesOK(TFE_Context* ctx, TFE_DeleteOp(matmul); - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); TF_DeleteStatus(status); } @@ -433,8 +494,9 @@ void TestRemoteExecuteChangeServerDef(bool async) { "/job:localhost/replica:0/task:0/device:CPU:0"; CheckRemoteMatMulExecutesOK(ctx, remote_device_name, local_device_name); - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); // TODO(b/136478427): Figure out how to correctly shut the server down. worker_server.release(); @@ -476,8 +538,9 @@ void TestRemoteExecuteChangeServerDef(bool async) { CheckRemoteMatMulExecutesOK(ctx, new_remote_device_name, new_local_device_name); - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); TF_DeleteStatus(status); @@ -610,8 +673,11 @@ void TensorHandleCopyBetweenDevicesError(bool async) { TFE_TensorHandle* hcopy = TFE_TensorHandleCopyToDevice(hcpu, ctx, kCPUDevice, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - TFE_ContextAsyncWait(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())); + + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); TFE_DeleteTensorHandle(hcopy); TFE_DeleteTensorHandle(hcpu); if (hdevice != nullptr) TFE_DeleteTensorHandle(hdevice); @@ -740,8 +806,10 @@ void TensorHandleSilentCopy(bool async) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_ContextAsyncWait(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); TFE_DeleteContext(ctx); } @@ -786,8 +854,10 @@ void TensorHandleSilentCopyLocal(bool async) { TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); - TFE_ContextAsyncWait(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); TFE_DeleteContext(ctx); } TEST(CAPI, TensorHandleSilentCopyLocal) { TensorHandleSilentCopyLocal(false); } @@ -921,8 +991,10 @@ TEST(CAPI, TensorHandleDevices) { } TFE_DeleteTensorHandle(hcpu); - TFE_ContextAsyncWait(ctx, status.get()); - EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteExecutor(executor); TFE_DeleteContext(ctx); } @@ -1000,9 +1072,11 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { retvals[0] = nullptr; TFE_Execute(matmul2, &retvals[0], &num_retvals, status); EXPECT_NE(TF_OK, TF_GetCode(status)); - TFE_ContextAsyncClearError(ctx); - TFE_ContextAsyncWait(ctx, status); - EXPECT_EQ(TF_OK, TF_GetCode(status)); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorClearError(executor); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); } // Following works in async mode since TFE_ContextAsyncClearError was called. TF_SetStatus(status, TF_OK, ""); @@ -1220,147 +1294,6 @@ void ExecuteWithTracing(bool async) { TEST(CAPI, ExecuteWithTracing) { ExecuteWithTracing(false); } TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithTracing(true); } -TEST(CAPI, Function_ident_CPU) { - // First create a simple identity function. - TF_Graph* function_graph = TF_NewGraph(); - TF_OperationDescription* arg_descr = - TF_NewOperation(function_graph, "Placeholder", "arg"); - TF_SetAttrType(arg_descr, "dtype", TF_INT32); - TF_Status* status = TF_NewStatus(); - TF_Operation* arg = TF_FinishOperation(arg_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_OperationDescription* id_descr = - TF_NewOperation(function_graph, "Identity", "id"); - TF_SetAttrType(id_descr, "T", TF_INT32); - TF_AddInput(id_descr, {arg, 0}); - TF_Operation* id = TF_FinishOperation(id_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_Output input{arg, 0}; - TF_Output output{id, 0}; - TF_Function* fn = - TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, - &output, nullptr, nullptr, "test", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteGraph(function_graph); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContextOptions(opts); - TFE_ContextAddFunction(ctx, fn, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteFunction(fn); - - for (bool async : {false, true, false}) { - TFE_ContextSetAsyncForThread(ctx, static_cast(async), - status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); - - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); - - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); - } - TFE_ContextRemoveFunction(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContext(ctx); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteStatus(status); -} - -#ifdef TENSORFLOW_EAGER_USE_XLA -TEST(CAPI, Function_ident_XLA_CPU) { - // First create a simple identity function. - TF_Graph* function_graph = TF_NewGraph(); - TF_OperationDescription* arg_descr = - TF_NewOperation(function_graph, "Placeholder", "arg"); - TF_SetAttrType(arg_descr, "dtype", TF_INT32); - TF_Status* status = TF_NewStatus(); - TF_Operation* arg = TF_FinishOperation(arg_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_OperationDescription* id_descr = - TF_NewOperation(function_graph, "Identity", "id"); - TF_SetAttrType(id_descr, "T", TF_INT32); - TF_AddInput(id_descr, {arg, 0}); - TF_Operation* id = TF_FinishOperation(id_descr, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_Output input{arg, 0}; - TF_Output output{id, 0}; - TF_Function* fn = - TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, - &output, nullptr, nullptr, "test", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteGraph(function_graph); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_Context* ctx = TFE_NewContext(opts, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContextOptions(opts); - TFE_ContextAddFunction(ctx, fn, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteFunction(fn); - - for (bool async : {false, true, false}) { - TFE_ContextSetAsyncForThread(ctx, static_cast(async), - status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK); - TF_Tensor* t = - TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); - *reinterpret_cast(TF_TensorData(t)) = 42; - TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteTensor(t); - - TFE_Op* op = TFE_NewOp(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_OpAddInput(op, h, status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - - // Now run it via XLA. - TFE_OpSetXLACompilation(op, true); - - std::vector result; - result.push_back(nullptr); - int num_retvals = 1; - TFE_Execute(op, result.data(), &num_retvals, status); - TFE_DeleteOp(op); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - ASSERT_EQ(num_retvals, 1); - - TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); - TFE_DeleteTensorHandle(h); - TF_DeleteTensor(r); - TFE_DeleteTensorHandle(result[0]); - } - TFE_ContextRemoveFunction(ctx, "ident", status); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TFE_DeleteContext(ctx); - ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); - TF_DeleteStatus(status); -} -#endif // TENSORFLOW_EAGER_USE_XLA - string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( @@ -1474,7 +1407,10 @@ void BM_ExecuteFunction(int iters, int async) { CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); } if (async) { - TFE_ContextAsyncWait(ctx, status); + TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteExecutor(executor); } tensorflow::testing::StopTiming(); TFE_DeleteTensorHandle(m); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 17d17c0b7f7..51566b35a9f 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -85,6 +85,24 @@ TFE_TensorHandle* TestMatrixTensorHandle() { return th; } +TFE_TensorHandle* TestMatrixTensorHandle100x100() { + constexpr int64_t dims[] = {100, 100}; + constexpr int num_elements = dims[0] * dims[1]; + float data[num_elements]; + for (int i = 0; i < num_elements; ++i) { + data[i] = 1.0f; + } + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { int64_t dims[] = {3, 2}; double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; @@ -128,6 +146,19 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Identity", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a) { TF_Status* status = TF_NewStatus(); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index 4ff3ff4301f..28062222cf0 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -16,7 +16,6 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ #include "tensorflow/c/eager/c_api.h" - #include "tensorflow/core/platform/types.h" // Return a tensor handle containing a float scalar @@ -34,6 +33,9 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle(); // Return a tensor handle containing a 2x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle(); +// Return a tensor handle containing a 100x100 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle100x100(); + // Return a tensor handle containing a 3x2 matrix of doubles TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); @@ -43,6 +45,9 @@ TFE_TensorHandle* TestMatrixTensorHandle3X2(); // Return a matmul op multiplying `a` by `b`. TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); +// Return an identity op. +TFE_Op* IdentityOp(TFE_Context* ctx, TFE_TensorHandle* a); + // Return a shape op fetching the shape of `a`. TFE_Op* ShapeOp(TFE_Context* ctx, TFE_TensorHandle* a); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 0545e3f7ce0..edb2733ab32 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -18,6 +18,7 @@ limitations under the License. // Language-agnostic gradient tape. Does not perform backpropagation, just // maintains the data structures required to do so. +#include #include #include "tensorflow/core/framework/tensor_shape.h" @@ -209,7 +210,9 @@ class ForwardAccumulator { // ForwardAccumulator. explicit ForwardAccumulator( const VSpace& vspace) - : vspace_(vspace), backward_tape_(nullptr), accumulating_(false) {} + : vspace_(vspace) { + call_state_.emplace(nullptr, false); + } virtual ~ForwardAccumulator() { for (auto accumulated : accumulated_gradients_) { @@ -262,6 +265,12 @@ class ForwardAccumulator { const std::function& backward_function_getter, const std::function& backward_function_deleter); + // Returns true if `Accumulate` is active somewhere above on the stack and + // there isn't an intervening PushState. This is useful for ordering + // ForwardAccumulators, where more deeply nested accumulators should not see + // computations from less deeply nested accumulators. + bool BusyAccumulating() const { return call_state_.top().accumulating; } + // Fetches the current Jacobian-vector product associated with `tensor_id`, or // a nullptr if none is available. // @@ -276,6 +285,15 @@ class ForwardAccumulator { bool ShouldRecord(gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes); + // Temporarily push or pop transient state for this accumulator. + // + // Allows an accumulator which is currently processing an operation to + // temporarily reset its state. Without pushing and poping, accumulators + // ignore operations executed as a direct result of their own jvp + // computations. + void PushState() { call_state_.emplace(nullptr, false); } + void PopState() { call_state_.pop(); } + private: // Helper for Accumulate: uses a GradientTape to compute forward gradients // from a backward gradient function. Fills `out_grads` corresponding to @@ -283,7 +301,7 @@ class ForwardAccumulator { // // Executes the backward function in order to trace its gradient, which will // waste computation if executing eagerly (when graph building the unneeded - // computation is pruned). Temporarily sets `backward_tape_` so that + // computation is pruned). Temporarily sets `backward_tape` so that // Accumulate will forward op executions to the tape while the backward // function is running; this effectively adds the backward tape to the active // set (but does not require complicated callbacks to the language bindings). @@ -299,16 +317,26 @@ class ForwardAccumulator { // Not owned; provides operations on Tensors which are currently only // available in language bindings (e.g. Python). const VSpace& vspace_; - // Set temporarily while in the Accumulate method; if backward_tape_ is not - // nullptr then we forward op executions to it so Accumulate can compute a - // backward pass on its backward function. - // - // Not owned by the ForwardAccumulator. The method which sets `backward_tape_` - // keeps ownership. - GradientTape* backward_tape_; - // While the Accumulate method is running (accumulating_ is True), any op - // executions not forwarded to backward_tape_ should be ignored. - bool accumulating_; + + struct AccumulatorCallState { + AccumulatorCallState( + GradientTape* backward_tape, + bool accumulating) + : backward_tape(backward_tape), accumulating(accumulating) {} + // Set temporarily while in the Accumulate method; if backward_tape is not + // nullptr then we forward op executions to it so Accumulate can compute a + // backward pass on its backward function. + // + // Not owned by the ForwardAccumulator. The method which sets + // `backward_tape` keeps ownership. + GradientTape* backward_tape; + // While the Accumulate method is running (accumulating is True), any op + // executions not forwarded to backward_tape should be ignored. + bool accumulating; + }; + // A deque-backed stack, whose element references are not invalidated by + // pushes and pops at the back. + std::stack call_state_; }; // Template instantiations here @@ -841,12 +869,12 @@ template bool ForwardAccumulator::ShouldRecord( gtl::ArraySlice tensor_ids, gtl::ArraySlice dtypes) { - if (backward_tape_ != nullptr) { - // If we're forwarding Accumulate calls to backward_tape_'s RecordOperation, + if (call_state_.top().backward_tape != nullptr) { + // If we're forwarding Accumulate calls to backward_tape's RecordOperation, // we should also delegate ShouldRecord. - return backward_tape_->ShouldRecord(tensor_ids, dtypes); + return call_state_.top().backward_tape->ShouldRecord(tensor_ids, dtypes); } - if (accumulating_) { + if (call_state_.top().accumulating) { return false; } for (int i = 0; i < tensor_ids.size(); ++i) { @@ -878,9 +906,10 @@ ForwardAccumulator::ForwardpropFromTape( */ std::unique_ptr> tape( new GradientTape(false)); - backward_tape_ = tape.get(); + AccumulatorCallState& call_state = call_state_.top(); + call_state.backward_tape = tape.get(); auto pop_backward_tape = - gtl::MakeCleanup([this] { this->backward_tape_ = nullptr; }); + gtl::MakeCleanup([&call_state] { call_state.backward_tape = nullptr; }); std::vector forwardprop_aids; std::vector sources; std::unordered_set sources_set; @@ -955,10 +984,10 @@ Status ForwardAccumulator::Accumulate( const ForwardFunction* forward_function, const std::function& backward_function_getter, const std::function& backward_function_deleter) { - if (backward_tape_ != nullptr) { - // If backward_tape_ is not null, then this call to Accumulate is the result + if (call_state_.top().backward_tape != nullptr) { + // If backward_tape is not null, then this call to Accumulate is the result // of a still-active call to Accumulate which is running operations. We - // forward these operations to backward_tape_ so the outer Accumulate call + // forward these operations to backward_tape so the outer Accumulate call // can do its work. // // Rather than re-entering and delegating Accumulate like this, we could @@ -966,9 +995,9 @@ Status ForwardAccumulator::Accumulate( // (so it can deactivate itself and activate its GradientTape). Currently // that is managed by the language binding and would require relatively // messy callbacks. - backward_tape_->RecordOperation(op_type, output_tensors, input_tensor_id, - input_dtypes, backward_function_getter, - backward_function_deleter); + call_state_.top().backward_tape->RecordOperation( + op_type, output_tensors, input_tensor_id, input_dtypes, + backward_function_getter, backward_function_deleter); return Status::OK(); } if (!ShouldRecord(input_tensor_id, input_dtypes)) { @@ -1006,9 +1035,8 @@ Status ForwardAccumulator::Accumulate( // Avoid infinite recursion. Whichever forward function we run, it'll end up // executing ops, and we don't want to watch those with this accumulator. - accumulating_ = true; - auto reset_accumulating = - gtl::MakeCleanup([this] { this->accumulating_ = false; }); + call_state_.emplace(nullptr, true); + auto pop_call_state = gtl::MakeCleanup([this] { this->call_state_.pop(); }); std::vector forward_grads; if (forward_function == nullptr) { diff --git a/tensorflow/c/experimental/rendezvous.cc b/tensorflow/c/experimental/rendezvous.cc index 0ee4907b7a4..7a90bde8fe4 100644 --- a/tensorflow/c/experimental/rendezvous.cc +++ b/tensorflow/c/experimental/rendezvous.cc @@ -45,6 +45,9 @@ CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id, void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& args, DoneCallback done) { + if (args.cancellation_manager != nullptr) { + VLOG(1) << "WARNING: CRemoteRendezvous does not support cancellation."; + } TF_ParsedKey key; key.src_device = parsed.src_device.data(); key.src_device_len = parsed.src_device.size(); diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 7184ad68fb7..a4d51a1b3b2 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -63,12 +63,26 @@ cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} libdir=\${exec_prefix}/${LIBDIR} -includedir=\${prefix}/include +includedir=\${prefix}/include/tensorflow Name: TensorFlow Version: ${TF_VERSION} Description: Library for computation using data flow graphs for scalable machine learning Requires: -Libs: -L\${libdir} -ltensorflow +Libs: -L\${libdir} -ltensorflow -ltensorflow_framework +Cflags: -I\${includedir} +EOF + +cat << EOF > tensorflow_cc.pc +prefix=${TF_PREFIX} +exec_prefix=\${prefix} +libdir=\${exec_prefix}/${LIBDIR} +includedir=\${prefix}/include/tensorflow + +Name: TensorFlow +Version: ${TF_VERSION} +Description: Library for computation using data flow graphs for scalable machine learning +Requires: +Libs: -L\${libdir} -ltensorflow_cc -ltensorflow_framework Cflags: -I\${includedir} EOF diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 94685c8ffaf..b067176f3be 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -189,8 +190,8 @@ void TF_GetInput(TF_OpKernelContext* ctx, int i, TF_Tensor** tensor, void TF_SetOutput(TF_OpKernelContext* ctx, int i, const TF_Tensor* tensor, TF_Status* status) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); - if (i < 0 || i >= cc_ctx->num_inputs()) { - TF_SetStatus(status, TF_OUT_OF_RANGE, "input index out of range"); + if (i < 0 || i >= cc_ctx->num_outputs()) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "output index out of range"); return; } ::tensorflow::Tensor cc_tensor; @@ -240,3 +241,14 @@ TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { int64_t TF_StepId(TF_OpKernelContext* ctx) { return reinterpret_cast<::tensorflow::OpKernelContext*>(ctx)->step_id(); } + +TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, + TF_DataType dtype, int64_t* dims, int num_dims, + size_t len) { + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); + tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index); + auto* allocator = cc_ctx->get_allocator(attr); + void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator); + return TF_NewTensor(dtype, dims, num_dims, data, len, + tensorflow::deallocate_buffer, allocator); +} diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index a192437a52f..8d0518ae170 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -180,6 +180,16 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_Status* status); +// Allocates Tensor for output at given index. Caller takes ownership of +// returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor). +// +// This function should be used to allocate outputs inside kernel +// compute function. +TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, + int index, TF_DataType dtype, + int64_t* dims, int num_dims, + size_t len); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 0e65d18ec81..05277b6c12c 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -12,17 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#define EIGEN_USE_GPU +#endif #include "tensorflow/c/kernels.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/kernel_def.pb.h" -#include "tensorflow/core/framework/node_def.pb_text.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -309,4 +315,144 @@ TEST(TestKernel, TestHostMemory) { TF_DeleteKernelBuilder(builder); ASSERT_TRUE(delete_called); } + +class DeviceKernelOpTest : public OpsTestBase { + protected: + void SetupOp(const char* op_name, const char* kernel_name, + void (*compute_func)(void*, TF_OpKernelContext*)) { + TF_KernelBuilder* builder = TF_NewKernelBuilder( + op_name, device_name_, nullptr, compute_func, nullptr); + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(kernel_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + +#if GOOGLE_CUDA + std::unique_ptr device( + DeviceFactory::NewDevice(device_name_, {}, "/job:a/replica:0/task:0")); + OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); +#endif + TF_ASSERT_OK(NodeDefBuilder(op_name, op_name).Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } + +#if GOOGLE_CUDA + const char* device_name_ = tensorflow::DEVICE_GPU; +#else + const char* device_name_ = tensorflow::DEVICE_CPU; +#endif +}; + +REGISTER_OP("AllocateOutputOp1").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + // Allocate output + int64_t dim = 1; + size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT); + TF_Tensor* output = TF_AllocateOutput( + /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim, + /*num_dims=*/1, /*len=*/tensor_size_bytes); + EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); + EXPECT_EQ(1, TF_NumDims(output)); + EXPECT_EQ(1, TF_Dim(output, 0)); + + // Set output to 3 + float* data = reinterpret_cast(TF_TensorData(output)); + float value = 3.0f; +#if GOOGLE_CUDA + OpKernelContext* cc_ctx = reinterpret_cast(ctx); + cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value, + tensor_size_bytes); +#else + *data = value; +#endif + + TF_Status* s = TF_NewStatus(); + TF_SetOutput(ctx, 0, output, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + SetupOp("AllocateOutputOp1", "AllocateOutput1", my_compute_func); + + TF_ASSERT_OK(RunOpKernel()); + Tensor* output = GetOutput(0); + EXPECT_EQ("Tensor", + output->DebugString(100)); +} + +REGISTER_OP("AllocateOutputOp0").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + // Allocate empty output + int64_t dim = 0; + TF_Tensor* output = TF_AllocateOutput( + /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim, + /*num_dims=*/1, /*len=*/0); + + EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); + EXPECT_EQ(1, TF_NumDims(output)); + EXPECT_EQ(0, TF_Dim(output, 0)); + + TF_Status* s = TF_NewStatus(); + TF_SetOutput(ctx, 0, output, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + SetupOp("AllocateOutputOp0", "AllocateOutput0", my_compute_func); + + TF_ASSERT_OK(RunOpKernel()); + Tensor* output = GetOutput(0); + EXPECT_EQ("Tensor", + output->DebugString(100)); +} + +REGISTER_OP("AllocateOutputOp2x3").Output("output1: float"); + +TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) { + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + // Allocate 2x3 output + int64_t dim[2] = {2, 3}; + size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT); + TF_Tensor* output = TF_AllocateOutput( + /*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim, + /*num_dims=*/2, /*len=*/tensor_size_bytes); + EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); + EXPECT_EQ(2, TF_NumDims(output)); + EXPECT_EQ(2, TF_Dim(output, 0)); + EXPECT_EQ(3, TF_Dim(output, 1)); + + // Set output to [1 2 3 4 5 6] + void* data = TF_TensorData(output); + float value[6] = {1, 2, 3, 4, 5, 6}; +#if GOOGLE_CUDA + OpKernelContext* cc_ctx = reinterpret_cast(ctx); + cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value, + tensor_size_bytes); +#else + memcpy(data, value, tensor_size_bytes); +#endif + + TF_Status* s = TF_NewStatus(); + TF_SetOutput(ctx, 0, output, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + SetupOp("AllocateOutputOp2x3", "AllocateOutput2x3", my_compute_func); + + TF_ASSERT_OK(RunOpKernel()); + Tensor* output = GetOutput(0); + EXPECT_EQ("Tensor", + output->DebugString(100)); +} } // namespace tensorflow diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index deb36166a47..2ad778d6057 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -31,6 +31,37 @@ using tensorflow::TensorBuffer; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; +namespace tensorflow { +void* allocate_tensor(const char* operation, size_t len, Allocator* allocator) { + void* data = allocator->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); + if (LogMemory::IsEnabled() && data != nullptr) { + LogMemory::RecordRawAllocation( + operation, LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len, data, + allocator); + } + return data; +} + +void* allocate_tensor(const char* operation, size_t len) { + return allocate_tensor(operation, len, cpu_allocator()); +} + +void deallocate_buffer(void* data, size_t len, void* arg) { + Allocator* allocator = nullptr; + if (arg == nullptr) { + allocator = cpu_allocator(); + } else { + allocator = reinterpret_cast(arg); + } + if (LogMemory::IsEnabled() && data != nullptr) { + LogMemory::RecordRawDeallocation( + "TensorFlow C Api", LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data, + allocator, false); + } + allocator->DeallocateRaw(data); +} +} // namespace tensorflow + namespace { class TF_ManagedBuffer : public TensorBuffer { public: @@ -63,36 +94,15 @@ class TF_ManagedBuffer : public TensorBuffer { bool OwnsMemory() const override { return false; } }; -void* allocate_tensor(const char* operation, size_t len) { - void* data = - tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len); - if (tensorflow::LogMemory::IsEnabled() && data != nullptr) { - tensorflow::LogMemory::RecordRawAllocation( - operation, tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, - len, data, tensorflow::cpu_allocator()); - } - return data; -} - -void deallocate_buffer(void* data, size_t len, void* arg) { - if (tensorflow::LogMemory::IsEnabled() && data != nullptr) { - tensorflow::LogMemory::RecordRawDeallocation( - "TensorFlow C Api", - tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data, - tensorflow::cpu_allocator(), false); - } - tensorflow::cpu_allocator()->DeallocateRaw(data); -} - } // namespace -TF_Tensor::~TF_Tensor() { buffer->Unref(); } - TF_Tensor* TF_AllocateTensor(TF_DataType dtype, const int64_t* dims, int num_dims, size_t len) { - void* data = allocate_tensor("TF_AllocateTensor", len); - return TF_NewTensor(dtype, dims, num_dims, data, len, deallocate_buffer, - nullptr); + void* data = tensorflow::allocate_tensor("TF_AllocateTensor", len, + tensorflow::cpu_allocator()); + return TF_NewTensor(dtype, dims, num_dims, data, len, + tensorflow::deallocate_buffer, + tensorflow::cpu_allocator()); } TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, @@ -117,8 +127,8 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, // // Other types have the same representation, so copy only if it is safe to // do so. - buf = new TF_ManagedBuffer(allocate_tensor("TF_NewTensor", len), len, - deallocate_buffer, nullptr); + buf = new TF_ManagedBuffer(tensorflow::allocate_tensor("TF_NewTensor", len), + len, tensorflow::deallocate_buffer, nullptr); std::memcpy(buf->data(), data, len); // Free the original buffer. deallocator(data, len, deallocator_arg); @@ -126,9 +136,12 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf = new TF_ManagedBuffer(data, len, deallocator, deallocator_arg); } - TF_Tensor* ret = new TF_Tensor{dtype, tensorflow::TensorShape(dimvec), buf}; + TF_Tensor* ret = + new TF_Tensor{Tensor(static_cast(dtype), + tensorflow::TensorShape(dimvec), buf)}; + buf->Unref(); size_t elem_size = TF_DataTypeSize(dtype); - if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { + if (elem_size > 0 && len < (elem_size * ret->tensor.NumElements())) { delete ret; return nullptr; } @@ -139,7 +152,7 @@ TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { // It is safe to move the Tensor if and only if we own the unique reference to // it. In that case, we might as well not delete and reallocate, but a future // implementation might need to do so. - TensorBuffer* buf = tensor->buffer; + TensorBuffer* buf = tensorflow::TensorCApi::Buffer(tensor->tensor); if (buf->RefCountIsOne() && buf->root_buffer()->RefCountIsOne() && buf->OwnsMemory()) { return tensor; @@ -149,13 +162,23 @@ TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { void TF_DeleteTensor(TF_Tensor* t) { delete t; } -TF_DataType TF_TensorType(const TF_Tensor* t) { return t->dtype; } -int TF_NumDims(const TF_Tensor* t) { return t->shape.dims(); } -int64_t TF_Dim(const TF_Tensor* t, int dim_index) { - return static_cast(t->shape.dim_size(dim_index)); +TF_DataType TF_TensorType(const TF_Tensor* t) { + return static_cast(t->tensor.dtype()); +} + +int TF_NumDims(const TF_Tensor* t) { return t->tensor.dims(); } + +int64_t TF_Dim(const TF_Tensor* t, int dim_index) { + return static_cast(t->tensor.dim_size(dim_index)); +} + +size_t TF_TensorByteSize(const TF_Tensor* t) { + return tensorflow::TensorCApi::Buffer(t->tensor)->size(); +} + +void* TF_TensorData(const TF_Tensor* t) { + return tensorflow::TensorCApi::Buffer(t->tensor)->data(); } -size_t TF_TensorByteSize(const TF_Tensor* t) { return t->buffer->size(); } -void* TF_TensorData(const TF_Tensor* t) { return t->buffer->data(); } int64_t TF_TensorElementCount(const TF_Tensor* t) { int64_t result = 1; @@ -166,63 +189,17 @@ int64_t TF_TensorElementCount(const TF_Tensor* t) { return result; } -// Returns the number of elements that would be present in a tensor with the -// given shape. -static int64_t ShapeNumElements(const int64_t* dims, int num_dims) { - int64_t result = 1; - for (int dim = 0; dim < num_dims; ++dim) { - result *= dims[dim]; - } - return result; -} - -static void UnrefIfNonNull(::tensorflow::TensorBuffer* buf) { - if (buf != nullptr) { - buf->Unref(); - } -} - -static void RefIfNonNull(::tensorflow::TensorBuffer* buf) { - if (buf != nullptr) { - buf->Ref(); - } -} - void TF_TensorBitcastFrom(const TF_Tensor* from, TF_DataType type, TF_Tensor* to, const int64_t* new_dims, int num_new_dims, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); - size_t in_size = TF_DataTypeSize(TF_TensorType(from)); - if (in_size == 0) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "input tensor has a zero-sized data type"); - return; - } - size_t out_size = TF_DataTypeSize(type); - if (out_size == 0) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "output tensor has a zero-sized data type"); - return; - } - - if (ShapeNumElements(new_dims, num_new_dims) * out_size != - TF_TensorElementCount(from) * in_size) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - "input tensor is not compatible with output shape"); - return; - } - - tensorflow::TensorShapeProto p; + tensorflow::TensorShape s; for (int i = 0; i < num_new_dims; ++i) { - p.add_dim()->set_size(new_dims[i]); - } - to->shape = tensorflow::TensorShape(p); - to->dtype = type; - if (to->buffer != from->buffer) { - UnrefIfNonNull(to->buffer); - to->buffer = from->buffer; - RefIfNonNull(to->buffer); + s.AddDim(new_dims[i]); } + Status cc_status(to->tensor.BitcastFrom( + from->tensor, static_cast(type), s)); + Set_TF_Status_from_Status(status, cc_status); } // -------------------------------------------------------------------------- @@ -332,17 +309,19 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, return t; } if (src.dtype() != tensorflow::DT_STRING) { - TensorBuffer* buf = tensorflow::TensorCApi::Buffer(src); - buf->Ref(); - return new TF_Tensor{static_cast(src.dtype()), src.shape(), - buf}; + auto* result = new TF_Tensor(); + if (!result->tensor.CopyFrom(src, src.shape())) { + delete result; + return nullptr; + } + return result; } // DT_STRING tensors require a copying since TF_Tensor.buffer expects a flatly // encoded sequence of strings. // Compute bytes needed for encoding. size_t size = 0; - const auto& srcarray = src.flat(); + const auto& srcarray = src.flat(); for (int i = 0; i < srcarray.size(); ++i) { const string& s = srcarray(i); // uint64 starting_offset, TF_StringEncode-d string. @@ -393,14 +372,14 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, } Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { - if (src->dtype == TF_RESOURCE) { - if (src->shape.dims() != 0) { + if (src->tensor.dtype() == DT_RESOURCE) { + if (src->tensor.dims() != 0) { return InvalidArgument( "Malformed TF_RESOURCE tensor: expected a scalar, got a tensor with " "shape ", - src->shape.DebugString()); + src->tensor.shape().DebugString()); } - *dst = Tensor(tensorflow::DT_RESOURCE, src->shape); + *dst = Tensor(tensorflow::DT_RESOURCE, src->tensor.shape()); if (!dst->scalar()().ParseFromString( string(static_cast(TF_TensorData(src)), TF_TensorByteSize(src)))) { @@ -409,14 +388,13 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { } return Status::OK(); } - if (src->dtype != TF_STRING) { - *dst = - tensorflow::TensorCApi::MakeTensor(src->dtype, src->shape, src->buffer); + if (src->tensor.dtype() != DT_STRING) { + *dst = src->tensor; return Status::OK(); } // TF_STRING tensors require copying since Tensor class expects a sequence of // string objects. - const tensorflow::int64 num_elements = src->shape.num_elements(); + const tensorflow::int64 num_elements = src->tensor.NumElements(); const char* input = reinterpret_cast(TF_TensorData(src)); const size_t src_size = TF_TensorByteSize(src); if (static_cast(src_size / sizeof(tensorflow::uint64)) < @@ -427,8 +405,8 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { const char* data_start = input + sizeof(tensorflow::uint64) * num_elements; const char* limit = input + src_size; - *dst = Tensor(static_cast(src->dtype), src->shape); - auto dstarray = dst->flat(); + *dst = Tensor(src->tensor.dtype(), src->tensor.shape()); + auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = reinterpret_cast(input)[i]; @@ -447,3 +425,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { } } // namespace tensorflow + +bool TF_TensorIsAligned(const TF_Tensor* tensor) { + return tensor->tensor.IsAligned(); +} diff --git a/tensorflow/c/tf_tensor.h b/tensorflow/c/tf_tensor.h index 5d4f70c1b6b..462fdc8b497 100644 --- a/tensorflow/c/tf_tensor.h +++ b/tensorflow/c/tf_tensor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_C_TF_TENSOR_H_ #define TENSORFLOW_C_TF_TENSOR_H_ +#include #include #include "tensorflow/c/tf_datatype.h" @@ -175,6 +176,9 @@ TF_CAPI_EXPORT extern size_t TF_StringDecode(const char* src, size_t src_len, // TF_STRING tensor. TF_CAPI_EXPORT extern size_t TF_StringEncodedSize(size_t len); +// Returns bool iff this tensor is aligned. +TF_CAPI_EXPORT extern bool TF_TensorIsAligned(const TF_Tensor*); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h index 6def66c9412..ea7d49b5966 100644 --- a/tensorflow/c/tf_tensor_internal.h +++ b/tensorflow/c/tf_tensor_internal.h @@ -23,13 +23,12 @@ limitations under the License. // Internal structures used by the C API. These are likely to change and should // not be depended on. -struct TF_Tensor { - ~TF_Tensor(); - - TF_DataType dtype; - tensorflow::TensorShape shape; - tensorflow::TensorBuffer* buffer; -}; +// This struct forms part of the C API's public interface. It must strictly be +// passed to or returned from C functions *by pointer*. Otherwise, changes to +// its internal structure will break the C API's binary interface. +typedef struct TF_Tensor { + ::tensorflow::Tensor tensor; +} TF_Tensor; namespace tensorflow { @@ -42,5 +41,13 @@ class TensorCApi { } }; +// Allocates tensor data buffer using specified allocator. +// `operation` is a name for this operation. +void* allocate_tensor(const char* operation, size_t len, Allocator* allocator); + +// Deallocates tensor data buffer. +// Defaults to deallocating using CPU allocator. You can pass pointer to +// a different Allocator as `arg`. +void deallocate_buffer(void* data, size_t len, void* arg); } // namespace tensorflow #endif // TENSORFLOW_C_TF_TENSOR_INTERNAL_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 07de89f997e..40b182c8acf 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -649,7 +649,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", - "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", ], @@ -667,7 +666,6 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", - "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a0353bf17a6..919e2dfc638 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/framework/types.pb_text.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/hash/hash.h" @@ -193,12 +193,12 @@ string PrintTensor(const TensorProto& tensor_proto) { string ret; for (int64 i = 0; i < num_elts; ++i) { if (i > 0) strings::StrAppend(&ret, " "); - strings::StrAppend(&ret, absl::CEscape(t.flat()(i))); + strings::StrAppend(&ret, absl::CEscape(t.flat()(i))); } return ret; } default: { - LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype()); + LOG(FATAL) << "Not handling type " << DataType_Name(t.dtype()); return string(); } } @@ -223,7 +223,7 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { case AttrValue::kB: return attr_value.b() ? "true" : "false"; case AttrValue::kType: - return EnumName_DataType(attr_value.type()); + return DataType_Name(attr_value.type()); case AttrValue::kShape: return PrintTensorShape(attr_value.shape()); case AttrValue::kTensor: @@ -254,8 +254,7 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { } else if (attr_value.list().type_size() > 0) { for (int i = 0; i < attr_value.list().type_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); - strings::StrAppend(&ret, - EnumName_DataType(attr_value.list().type(i))); + strings::StrAppend(&ret, DataType_Name(attr_value.list().type(i))); } } else if (attr_value.list().shape_size() > 0) { for (int i = 0; i < attr_value.list().shape_size(); ++i) { diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index ac05e3cf95b..178b4da972a 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -200,10 +200,10 @@ TEST(CCOpTest, TemplatedConst) { test::ExpectTensorEqual( out, test::AsTensor({3.f, 2.f, -1.f, 0.f}, {2, 2})); - auto c2 = ops::Const(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); + auto c2 = ops::Const(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); test::GetTensor(root, c2, &out); - test::ExpectTensorEqual( - out, test::AsTensor({"this", "is", "a", "constant"}, {4, 1})); + test::ExpectTensorEqual( + out, test::AsTensor({"this", "is", "a", "constant"}, {4, 1})); } TEST(CCOpTest, EmptyConst) { diff --git a/tensorflow/cc/framework/ops.cc b/tensorflow/cc/framework/ops.cc index 920a8e79556..8516dfd7a29 100644 --- a/tensorflow/cc/framework/ops.cc +++ b/tensorflow/cc/framework/ops.cc @@ -97,7 +97,7 @@ Input::Initializer::Initializer( Tensor elem = e.tensor; if (first.tensor.dtype() == DT_STRING) { for (int i = 0; i < elem.NumElements(); ++i) { - t.flat()(offset + i) = elem.flat()(i); + t.flat()(offset + i) = elem.flat()(i); } offset += elem.NumElements(); } else { diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 0717e7dd4b3..1414e861002 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -111,7 +111,7 @@ class Input { Initializer(const T& v) { // NOLINT(runtime/explicit) typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), TensorShape()); - t.flat()(0) = RealT(v); + t.flat()(0) = RealT(v); tensor = t; } @@ -125,7 +125,7 @@ class Input { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); for (int64 i = 0; i < t.NumElements(); ++i) { - t.flat()(i) = RealT(v); + t.flat()(i) = RealT(v); } tensor = t; } @@ -170,7 +170,7 @@ class Input { // START_SKIP_DOXYGEN template ::value> struct RealType { - typedef string type; + typedef tstring type; }; template diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index e93ca8633e6..b5cac5fec28 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -272,7 +272,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( std::unordered_set current_constraints(colocation_constraints_); const AttrSlice attrs = colocate_with_op.node()->attrs(); std::vector node_constraints; - if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) { + if (TryGetNodeAttr(attrs, kColocationAttrName, &node_constraints)) { for (const string& entry : node_constraints) { StringPiece s(entry); if (absl::ConsumePrefix(&s, kColocationGroupPrefix)) { @@ -299,7 +299,7 @@ const std::vector& Scope::control_deps() const { return impl()->control_deps_; } -void Scope::UpdateStatus(const Status s) const { +void Scope::UpdateStatus(const Status& s) const { impl()->status_->Update(s); if (impl()->exit_on_error_ && !ok()) { LOG(FATAL) << *impl()->status_; @@ -318,7 +318,7 @@ Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const { if (ok()) { GraphDef graph_def; graph()->ToGraphDef(&graph_def); - UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g)); + UpdateStatus(ConvertGraphDefToGraph(opts, std::move(graph_def), g)); } return *impl()->status_; } diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index ef2daff1357..63a555b7217 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -177,7 +177,7 @@ class Scope { /// Note: The status object is shared between all children of this scope. /// If the resulting status is not Status::OK() and exit_on_error_ is set on /// this scope, this function exits by calling LOG(FATAL). - void UpdateStatus(const Status s) const; + void UpdateStatus(const Status& s) const; // START_SKIP_DOXYGEN diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 69b5d7fd47c..345cd23b9ec 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -97,7 +97,7 @@ TEST(ConstOpTest, WithExplicitShape) { auto d = ops::Const(root, {"1", "2", "3", "4", "5", "6"}, {2, 3}); TF_CHECK_OK(root.status()); EXPECT_EQ(d.op().output_type(0), DT_STRING); - ExpectNodeEqual(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3}); + ExpectNodeEqual(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3}); } TEST(ConstOpTest, FromProto) { @@ -144,7 +144,7 @@ TEST(ConstOpTest, TemplatedConst) { auto c1 = ops::Const(root, {1, 2}); ExpectTypeAndShape(c1.node(), DT_INT32, {2}); - auto c2 = ops::Const(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); + auto c2 = ops::Const(root, {{"this"}, {"is"}, {"a"}, {"constant"}}); ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1}); } diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD index d18a0bcab0c..5b4a105eb28 100644 --- a/tensorflow/cc/profiler/BUILD +++ b/tensorflow/cc/profiler/BUILD @@ -9,6 +9,7 @@ tf_cuda_cc_test( name = "profiler_test", srcs = ["profiler_test.cc"], tags = [ + "no_rocm", # stream level tracing not supported on ROCm "nogpu", # b/77649654 ], deps = [ diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 01752b65f2f..39b84922d13 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -10,7 +10,7 @@ load( "tf_cc_test", ) load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "if_static", "if_static_and_not_mobile", ) diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index dfc7ccd9542..a3b80fbdba5 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -75,7 +75,7 @@ Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, Tensor CreateStringTensor(const string& value) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = value; + tensor.scalar()() = value; return tensor; } @@ -219,7 +219,7 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, // Add variables to the graph. Tensor variables_path_tensor(DT_STRING, TensorShape({})); - variables_path_tensor.scalar()() = variables_path; + variables_path_tensor.scalar()() = variables_path; std::vector> inputs = { {string(variable_filename_const_op_name), variables_path_tensor}}; diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 422994ba07c..aa2031d17d2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -63,8 +63,8 @@ class LoaderTest : public ::testing::Test { bundle.session->Run({}, {"filename_tensor:0"}, {}, &path_outputs)); ASSERT_EQ(1, path_outputs.size()); - test::ExpectTensorEqual( - test::AsTensor({"foo.txt"}, TensorShape({})), path_outputs[0]); + test::ExpectTensorEqual( + test::AsTensor({"foo.txt"}, TensorShape({})), path_outputs[0]); } void CheckSavedModelBundle(const string& export_dir, @@ -78,14 +78,14 @@ class LoaderTest : public ::testing::Test { const string output_name = signature_def.outputs().at(kRegressOutputs).name(); - std::vector serialized_examples; + std::vector serialized_examples; for (float x : {0, 1, 2, 3}) { serialized_examples.push_back(MakeSerializedExample(x)); } // Validate the half plus two behavior. Tensor input = - test::AsTensor(serialized_examples, TensorShape({4})); + test::AsTensor(serialized_examples, TensorShape({4})); std::vector outputs; TF_ASSERT_OK(bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs)); diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD index fca45c869fd..b1440655c72 100644 --- a/tensorflow/cc/saved_model/python/BUILD +++ b/tensorflow/cc/saved_model/python/BUILD @@ -1,7 +1,7 @@ # Description: # CLIF wrappers for TensorFlow SavedModels. -load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_py_clif_cc") package( default_visibility = ["//visibility:public"], diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index 799856f7fd4..d6d99229372 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -48,12 +48,12 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { export_dir); } -Status FindMetaGraphDef(const SavedModel& saved_model_proto, - const std::unordered_set& tags, +Status FindMetaGraphDef(const std::unordered_set& tags, + SavedModel* saved_model_proto, MetaGraphDef* meta_graph_def) { LOG(INFO) << "Reading meta graph with tags { " << absl::StrJoin(tags, " ") << " }"; - for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) { + for (MetaGraphDef& graph_def : *saved_model_proto->mutable_meta_graphs()) { // Get tags from the graph_def. std::unordered_set graph_tags; for (const string& tag : graph_def.meta_info_def().tags()) { @@ -61,7 +61,7 @@ Status FindMetaGraphDef(const SavedModel& saved_model_proto, } // Match with the set of tags provided. if (graph_tags == tags) { - *meta_graph_def = graph_def; + *meta_graph_def = std::move(graph_def); return Status::OK(); } } @@ -81,7 +81,8 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, MetaGraphDef* const meta_graph_def) { SavedModel saved_model_proto; TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); - TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def)); + TF_RETURN_IF_ERROR( + FindMetaGraphDef(tags, &saved_model_proto, meta_graph_def)); return Status::OK(); } diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index eeb91017890..0ec48ec9357 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -42,6 +42,10 @@ void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info, tensor_names->insert(coo_sparse.values_tensor_name()); tensor_names->insert(coo_sparse.indices_tensor_name()); tensor_names->insert(coo_sparse.dense_shape_tensor_name()); + } else if (tensor_info.has_composite_tensor()) { + for (const auto& component : tensor_info.composite_tensor().components()) { + tensor_names->insert(component.name()); + } } else { tensor_names->insert(tensor_info.name()); } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index 979b23c3fc5..274a1630a05 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -425,5 +425,63 @@ TEST_F(FreezeTest, GraphDefWithAndWithoutDependentResourceVariables) { TestFreezeGraphWithAndWithoutDependentVariables(true); } +TEST_F(FreezeTest, InputsAndOutputsCompositeTensorSignatureDef) { + // Test that inputs and outputs get correctly populated for a + // SignatureDef containing composite tensor inputs and outputs. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def; + + TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"]; + in.mutable_composite_tensor()->add_components()->set_name("input1:0"); + in.mutable_composite_tensor()->add_components()->set_name("input2:0"); + + TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"]; + out.mutable_composite_tensor()->add_components()->set_name("output2:0"); + out.mutable_composite_tensor()->add_components()->set_name("output1:0"); + + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set expected_inputs = {"input1:0", "input2:0"}; + std::unordered_set expected_outputs = {"output1:0", "output2:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, InputsAndOutputsSparseCooSignatureDef) { + // Test that inputs and outputs get correctly populated for a + // SignatureDef containing composite tensor inputs and outputs. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def; + + TensorInfo& in = (*signature_def.mutable_inputs())["input_arg"]; + in.mutable_coo_sparse()->set_values_tensor_name("input1:0"); + in.mutable_coo_sparse()->set_indices_tensor_name("input2:0"); + in.mutable_coo_sparse()->set_dense_shape_tensor_name("input3:0"); + + TensorInfo& out = (*signature_def.mutable_outputs())["output_arg"]; + out.mutable_coo_sparse()->set_values_tensor_name("output1:0"); + out.mutable_coo_sparse()->set_indices_tensor_name("output2:0"); + out.mutable_coo_sparse()->set_dense_shape_tensor_name("output3:0"); + + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set expected_inputs = {"input1:0", "input2:0", + "input3:0"}; + std::unordered_set expected_outputs = {"output1:0", "output2:0", + "output3:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 88b00cb2eea..bff56bdda89 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,7 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library", "tf_jit_compilation_passes_extra_deps") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library") package( default_visibility = [ @@ -144,8 +144,57 @@ cc_library( ], ) +XLA_DEVICE_DEPS = [ + ":common", + ":xla_launch_util", + ":xla_tensor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:optional", + "//tensorflow/compiler/jit/ops:xla_ops", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:stream_pool", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:control_flow_ops_op_lib", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:functional_ops_op_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:no_op_op_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:sendrecv_ops_op_lib", + "//tensorflow/core:state_ops_op_lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:fifo_queue", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:identity_op", + "//tensorflow/core/kernels:resource_variable_ops", + "//tensorflow/core/kernels:shape_ops", + "//tensorflow/core/kernels:variable_ops", + "//tensorflow/core/kernels/data:generator_dataset_op", + "//tensorflow/core/kernels/data:iterator_ops", + "//tensorflow/core/kernels/data:optional_ops", + "//tensorflow/core/kernels/data:prefetch_dataset_op", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor/platform", +] + cc_library( - name = "xla_device", + name = "xla_device_no_jit_rewrite_registration", srcs = [ "xla_compile_on_demand_op.cc", "xla_device.cc", @@ -158,56 +207,22 @@ cc_library( "xla_device_context.h", "xla_device_ops.h", ], + deps = XLA_DEVICE_DEPS, +) + +cc_library( + name = "xla_device", + hdrs = [ + "xla_compile_on_demand_op.h", + "xla_device.h", + "xla_device_context.h", + "xla_device_ops.h", + ], # Public visibility is needed for external TF/XLA backends. visibility = ["//visibility:public"], - deps = [ - ":common", + deps = XLA_DEVICE_DEPS + [ ":jit_compilation_passes", - ":xla_launch_util", - ":xla_tensor", - "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:stream_pool", - "//tensorflow/core:array_ops_op_lib", - "//tensorflow/core:control_flow_ops_op_lib", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:functional_ops_op_lib", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:math_ops_op_lib", - "//tensorflow/core:nn_ops_op_lib", - "//tensorflow/core:no_op_op_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:resource_variable_ops_op_lib", - "//tensorflow/core:sendrecv_ops_op_lib", - "//tensorflow/core:state_ops_op_lib", - "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core/kernels:constant_op", - "//tensorflow/core/kernels:fifo_queue", - "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:identity_op", - "//tensorflow/core/kernels:resource_variable_ops", - "//tensorflow/core/kernels:shape_ops", - "//tensorflow/core/kernels:variable_ops", - "//tensorflow/core/kernels/data:generator_dataset_op", - "//tensorflow/core/kernels/data:iterator_ops", - "//tensorflow/core/kernels/data:optional_ops", - "//tensorflow/core/kernels/data:prefetch_dataset_op", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/stream_executor/platform", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:optional", + ":xla_device_no_jit_rewrite_registration", ], ) @@ -281,6 +296,7 @@ cc_library( hdrs = ["xla_compilation_cache.h"], deps = [ ":xla_activity_listener", + ":xla_activity_proto_cc", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", @@ -292,6 +308,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:logging", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -324,17 +342,21 @@ cc_library( alwayslink = 1, ) +# Linked by tensorflow core, without registration of jit compilation passes +# which is not necessary to create and run a XlaLocalLaunchBase kernel. +# Linking jit compilation passes could cause programs stuck right now (b/140069592). cc_library( - name = "xla_kernel_creator", + name = "xla_kernel_creator_util", srcs = [ - "xla_kernel_creator.cc", - "xla_kernel_creator.h", + "xla_kernel_creator_util.cc", ], + hdrs = ["xla_kernel_creator_util.h"], + visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"], deps = [ ":common", ":compilability_check_util", ":compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_ops", + "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -347,6 +369,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_kernel_creator", + srcs = [ + "xla_kernel_creator.cc", + "xla_kernel_creator.h", + ], + deps = [ + ":jit_compilation_passes", + ":xla_kernel_creator_util", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + tf_cc_test( name = "xla_kernel_creator_test", srcs = [ @@ -498,6 +537,7 @@ cc_library( srcs = [ "build_xla_ops_pass.cc", "clone_constants_for_better_clustering.cc", + "cluster_scoping_pass.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", @@ -513,6 +553,7 @@ cc_library( hdrs = [ "build_xla_ops_pass.h", "clone_constants_for_better_clustering.h", + "cluster_scoping_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "encapsulate_xla_computations_pass.h", @@ -677,6 +718,7 @@ tf_cc_test( srcs = [ "build_xla_ops_pass_test.cc", "clone_constants_for_better_clustering_test.cc", + "cluster_scoping_pass_test.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "extract_outside_compilation_pass_test.cc", @@ -800,6 +842,8 @@ cc_library( ":flags", ":resource_operation_safety_analysis", ":union_find", + ":xla_activity_listener", + ":xla_activity_proto_cc", ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", @@ -837,6 +881,7 @@ tf_cc_test( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:ops", + "//tensorflow/core:protos_all_proto_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/memory", @@ -901,6 +946,7 @@ cc_library( srcs = ["xla_activity_logging_listener.cc"], deps = [ ":xla_activity_listener", + ":xla_activity_proto_cc", "//tensorflow/core:logger", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 1265ff9138a..61695d532d1 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -48,6 +48,19 @@ limitations under the License. namespace tensorflow { namespace { +struct DebuggingOpts { + // If true, insert Print nodes to print every output from an XLA cluster. + bool print_outputs; + + // If true, insert CheckNumerics nodes for every floating point typed input to + // an XLA cluster. + bool check_input_numerics; + + // If true, insert CheckNumerics nodes for every floating point typed output + // from an XLA cluster. + bool check_output_numerics; +}; + void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { std::vector out_edges(old_node->out_edges().begin(), old_node->out_edges().end()); @@ -78,7 +91,8 @@ Operation DataToControl(const Scope& scope, Output data) { // Replaces each outgoing edge from `old_node` with a merge node that merges in // the corresponding output from `new_node`. void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, - bool insert_print_nodes) { + absl::string_view cluster_name, + const DebuggingOpts& debugging_opts) { if (!s.status().ok()) { return; } @@ -93,23 +107,36 @@ void MergeOutgoingDataEdges(const Scope& s, Node* old_node, Node* new_node, int oidx = e->src_output(); Output merged_output = merged_outputs[oidx]; if (merged_output.node() == nullptr) { - ops::Merge merge_op(s.WithOpName(absl::StrCat("merge_oidx_", oidx)), - {Output(old_node, oidx), Output(new_node, oidx)}); - if (insert_print_nodes) { + Output new_output(new_node, oidx); + if (debugging_opts.print_outputs) { string cpu_device = "/job:localhost/replica:0/task:0/device:CPU:0"; - ops::Print print_op(s.WithOpName(absl::StrCat("print_", oidx)) + ops::Print print_op(s.WithOpName("print_", oidx) .WithDevice(cpu_device) .WithAssignedDevice(cpu_device), - merge_op.output, {merge_op.output}, + new_output, {new_output}, ops::Print::Attrs{} .Message(absl::StrCat("output ", oidx, " from ", old_node->name(), " is ")) .FirstN(1000) .Summarize(-1)); - merged_output = merged_outputs[oidx] = print_op; - } else { - merged_output = merged_outputs[oidx] = merge_op.output; + new_output = print_op; } + + if (debugging_opts.check_output_numerics && + DataTypeIsFloating(new_output.type())) { + ops::CheckNumerics check_numerics_op( + s.WithOpName("check_output_", oidx) + .WithDevice(new_node->requested_device()) + .WithAssignedDevice(new_node->assigned_device_name()), + new_output, + absl::StrCat("CheckNumerics failed for output ", oidx, "(", + new_output.name(), ") from cluster ", cluster_name)); + new_output = check_numerics_op; + } + + ops::Merge merge_op(s.WithOpName("merge_oidx_", oidx), + {Output(old_node, oidx), new_output}); + merged_output = merged_outputs[oidx] = merge_op.output; } Node* dst = e->dst(); @@ -324,11 +351,34 @@ xla::StatusOr InferDeviceForCluster( return result; } +std::vector GetXlaRunArgs(const Scope& s, + const XlaClusterInfo& cluster_info, + const DebuggingOpts& debugging_opts) { + std::vector xla_run_args; + xla_run_args.reserve(cluster_info.non_constant_inputs.size() + + cluster_info.resource_inputs.size()); + int input_idx = 0; + for (const Output& o : cluster_info.non_constant_inputs) { + if (debugging_opts.check_input_numerics && DataTypeIsFloating(o.type())) { + ops::CheckNumerics check_numerics_op( + s.WithOpName("check_input_", input_idx), o, + absl::StrCat("CheckNumerics failed for input ", input_idx, "(", + o.name(), ") into ", cluster_info.function.name())); + xla_run_args.push_back(check_numerics_op); + } else { + xla_run_args.push_back(o); + } + input_idx++; + } + absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); + return xla_run_args; +} + Status ReplaceNodeWithXlaCompileAndXlaRun( jit::DeviceInfoCache* device_info_cache, const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, - bool insert_print_nodes, Graph* g, Node* n) { + const DebuggingOpts& debugging_opts, Graph* g, Node* n) { XlaClusterInfo cluster_info; TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); @@ -361,12 +411,12 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( TF_RETURN_IF_ERROR( CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); + std::vector xla_run_args = + GetXlaRunArgs(root, cluster_info, debugging_opts); + if (requires_compilation) { // "Strict" compilation: every _XlaCompile invocation must compile the // cluster. - std::vector xla_run_args = cluster_info.non_constant_inputs; - absl::c_copy(cluster_info.resource_inputs, - std::back_inserter(xla_run_args)); ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, xla_compile.key, n->output_types()); @@ -391,9 +441,6 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( Output predicated_compilation_key = s.output_true; Output inverse_predicated_compilation_key = s.output_false; - std::vector xla_run_args = cluster_info.non_constant_inputs; - absl::c_copy(cluster_info.resource_inputs, - std::back_inserter(xla_run_args)); ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, predicated_compilation_key, n->output_types()); @@ -402,7 +449,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( MergeOutgoingDataEdges(root, /*old_node=*/n, /*new_node=*/xla_run.operation.node(), - insert_print_nodes); + cluster_info.function.name(), debugging_opts); TF_RETURN_IF_ERROR(root.status()); @@ -443,15 +490,25 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { enable_lazy_compilation_ ? *enable_lazy_compilation_ : GetBuildXlaOpsPassFlags()->tf_xla_enable_lazy_compilation; - bool insert_print_nodes = - GetBuildXlaOpsPassFlags()->tf_xla_print_cluster_outputs; jit::DeviceInfoCache device_info_cache; + const BuildXlaOpsPassFlags& flags = *GetBuildXlaOpsPassFlags(); + + DebuggingOpts debugging_opts; + debugging_opts.print_outputs = flags.tf_xla_print_cluster_outputs; + debugging_opts.check_input_numerics = + flags.tf_xla_check_cluster_input_numerics; + debugging_opts.check_output_numerics = + flags.tf_xla_check_cluster_output_numerics; + + VLOG(1) << "print_outputs = " << debugging_opts.print_outputs; + VLOG(1) << "check_input_numerics = " << debugging_opts.check_input_numerics; + VLOG(1) << "check_output_numerics = " << debugging_opts.check_output_numerics; for (Node* n : xla_compiled_kernels) { TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun( &device_info_cache, options, *options.flib_def, - lazy_compilation_enabled, insert_print_nodes, graph, n)); + lazy_compilation_enabled, debugging_opts, graph, n)); } if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc new file mode 100644 index 00000000000..f4b9f93c616 --- /dev/null +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -0,0 +1,163 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/cluster_scoping_pass.h" + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { +namespace { + +class ClusterScopingPassImpl { + public: + ClusterScopingPassImpl(Graph* graph, + OptimizerOptions::GlobalJitLevel global_jit_level) + : graph_(graph), + global_jit_level_(global_jit_level), + unique_scope_id_(0) {} + + Status Run(); + + private: + Status ScopingForPipelineStages(); + + size_t GetUniqueScopeId() { return unique_scope_id_++; } + + void AddScopeToAllTransitivePredecessors(Node* start); + + void AddScopeToAllTransitiveSuccessors(Node* start); + + private: + Graph* graph_; + OptimizerOptions::GlobalJitLevel global_jit_level_; + size_t unique_scope_id_; +}; + +absl::optional GetXlaInternalScope(Node* node) { + string scope; + if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) { + return scope; + } + + return absl::nullopt; +} + +void SetXlaInternalScope(Node* node, StringPiece scope) { + node->AddAttr(kXlaInternalScopeAttr, scope); +} + +// NB! We append a new scope as suffix to the _XlaInternalScope attribute +// instead of overriding the old value. In other words, appending scope B to +// scope A creates the conjunction of the scopes A and B (i.e, A & B) and, +// in effect, the node gets both the old and new scopes. As a unique scope +// disallows a node being merged with nodes in other scopes, the scope +// conjunction preserves the semantic of the old scope (i.e., the node still +// cannot be merged with the previously incompatible nodes.) +// +// For example, the below case should be rare in practice but can serve for the +// purpose of discussion. After adding scopes for both Stage and Unstage, +// Node_Y will receive both scopes "unstage" and "stage", while Node_X receives +// only scope "stage". The semantic of scope "unstage" is preserved although +// scope "stage" is later appended. As a result, Node_X and Node_Y will be put +// into different clusters. +// +// Unstage -> Node_Y (scope "unstage & stage") +// | +// V +// Node_X (scope "stage") -> Stage +// +void AddOrAppendXlaInternalScope(Node* node, absl::string_view suffix) { + string updated_scope; + absl::optional cur_scope = GetXlaInternalScope(node); + if (cur_scope == absl::nullopt) { + updated_scope = std::string(suffix); + } else { + updated_scope = absl::StrCat(cur_scope.value(), "&", suffix); + } + SetXlaInternalScope(node, updated_scope); +} + +void ClusterScopingPassImpl::AddScopeToAllTransitivePredecessors(Node* start) { + const string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); + + std::vector starts; + starts.push_back(start); + auto enter = [&](Node* n) { AddOrAppendXlaInternalScope(n, unique_suffix); }; + ReverseDFSFrom(*graph_, starts, enter, /*leave=*/nullptr, + /*stable_comparator=*/NodeComparatorName()); +} + +void ClusterScopingPassImpl::AddScopeToAllTransitiveSuccessors(Node* start) { + const string unique_suffix = absl::StrCat("_", GetUniqueScopeId()); + + std::vector starts; + starts.push_back(start); + auto enter = [&](Node* n) { AddOrAppendXlaInternalScope(n, unique_suffix); }; + DFSFrom(*graph_, starts, enter, /*leave=*/nullptr, + /*stable_comparator=*/NodeComparatorName(), + // Do not filter any edges to better capture the semantics of + // transitive closure of successors. We may revisit this when + // we see more cases needing cluster scoping in the future. + /*edge_filter=*/nullptr); +} + +// This preserves the parallelism between pipeline stages. For example, below +// is a typical pattern of input pipelining in Tensorflow and this heuristic +// ensures Node_X and Node_Y are put into different clusters. Without the +// heuristic, they may be put into the same cluster and it can introduce +// artificial dependencies and incur great performance loss. In this example, +// Node_Y becomes dependent on IteratorGetNext and the latencies add up if +// Node_X and Node_Y are in the same cluster. +// +// IteratorGetNext -> Node_X -> Stage +// +// Unstage -> Node_Y +// +Status ClusterScopingPassImpl::ScopingForPipelineStages() { + for (Node* n : graph_->nodes()) { + DCHECK(n); + if (n->type_string() == "Unstage") { + AddScopeToAllTransitiveSuccessors(n); + } + if (n->type_string() == "Stage") { + AddScopeToAllTransitivePredecessors(n); + } + } + + return Status::OK(); +} + +Status ClusterScopingPassImpl::Run() { + if (global_jit_level_ == OptimizerOptions::OFF) { + return Status::OK(); + } + + return ScopingForPipelineStages(); +} +} // namespace + +Status ClusterScopingPass::Run(const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + + return ClusterScopingPassImpl{graph, GetGlobalJitLevelForGraph(options)} + .Run(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.h b/tensorflow/compiler/jit/cluster_scoping_pass.h new file mode 100644 index 00000000000..9651c3f878c --- /dev/null +++ b/tensorflow/compiler/jit/cluster_scoping_pass.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// This pass adds scopes to nodes in the _XlaInternalScope attribute to guide +// the later clustering passes. A major reason to do this is to prevent the +// clustering from losing critical parallelism in the Tensorflow graph, which +// can incur great performance degradation. +// +// This pass must be run before MarkForCompilationPass, as it stores the +// scoping information that MarkForCompilationPass will need to respect for +// clustering decision. +class ClusterScopingPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CLUSTER_SCOPING_PASS_H_ diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc new file mode 100644 index 00000000000..b3e63b8c298 --- /dev/null +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/cluster_scoping_pass.h" + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +Status ClusterScoping(std::unique_ptr* graph) { + FixupSourceAndSinkEdges(graph->get()); + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + FunctionDefLibrary fdef_lib; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + opt_options.flib_def = &flib_def; + SessionOptions session_options; + session_options.env = Env::Default(); + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(OptimizerOptions::ON_2); + opt_options.session_options = &session_options; + + ClusterScopingPass pass; + return pass.Run(opt_options); +} + +absl::flat_hash_map GetXlaInternalScopes(const Graph& graph) { + absl::flat_hash_map scopes; + for (Node* node : graph.nodes()) { + string scope; + if (GetNodeAttr(node->attrs(), kXlaInternalScopeAttr, &scope).ok()) { + scopes[node->name()] = scope; + } + } + + if (VLOG_IS_ON(2)) { + VLOG(2) << "_XlaInternalScopes:"; + for (const auto& p : scopes) { + VLOG(2) << " " << p.first << " -> " << p.second; + } + } + return scopes; +} + +Node* BuildStageNode(GraphDefBuilder& builder, string name, + std::initializer_list dtypes, + absl::Span values) { + auto opts = builder.opts() + .WithName(std::move(name)) + .WithAttr("dtypes", std::move(dtypes)); + if (opts.HaveError()) { + return nullptr; + } + + NodeBuilder node_builder(name, "Stage", opts.op_registry()); + node_builder.Input(values); + return opts.FinalizeBuilder(&node_builder); +} + +TEST(XlaCompilationTest, StagePipelinePreserved) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + // Graph: + // b + // | + // v + // a -> add0 (ClusterX) -> relu0 (ClusterX) -> stage + // + // b + // | + // v + // unstage -> add1 (ClusterY) -> relu1 (ClusterY) + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("a") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::SourceOp("Const", builder.opts() + .WithName("b") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* unstage = ops::SourceOp( + "Unstage", + builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT})); + + Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0")); + Node* add1 = + ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1")); + Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0")); + ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1")); + BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0}); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(ClusterScoping(&graph)); + + auto scopes = GetXlaInternalScopes(*graph); + EXPECT_NE(scopes["add0"], scopes["add1"]); + EXPECT_EQ(scopes["add0"], scopes["relu0"]); + EXPECT_EQ(scopes["add1"], scopes["relu1"]); +} + +TEST(XlaCompilationTest, StagePipelinePreservedAndInitialScopesRespected) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + // Graph: + // b + // | + // v + // a -> add0 (ClusterA) -> relu0 (ClusterB) -> stage + // + // b + // | + // v + // unstage -> add1 (ClusterC) -> relu1 (ClusterD) + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("a") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::SourceOp("Const", builder.opts() + .WithName("b") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* unstage = ops::SourceOp( + "Unstage", + builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT})); + + // Intentionally give add0 and add1 the same initial scope but they should + // be separated by the ClusterScopingPass. + Node* add0 = ops::BinaryOp("Add", a, b, + builder.opts().WithName("add0").WithAttr( + kXlaInternalScopeAttr, "ClusterA")); + Node* add1 = ops::BinaryOp("Add", unstage, b, + builder.opts().WithName("add1").WithAttr( + kXlaInternalScopeAttr, "ClusterA")); + Node* relu0 = ops::UnaryOp("Relu", add0, + builder.opts().WithName("relu0").WithAttr( + kXlaInternalScopeAttr, "ClusterB")); + ops::UnaryOp("Relu", add1, + builder.opts().WithName("relu1").WithAttr( + kXlaInternalScopeAttr, "ClusterD")); + BuildStageNode(builder, "stage", {DT_FLOAT}, {relu0}); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(ClusterScoping(&graph)); + + auto scopes = GetXlaInternalScopes(*graph); + EXPECT_NE(scopes["add0"], scopes["add1"]); + EXPECT_NE(scopes["add0"], scopes["relu0"]); + EXPECT_NE(scopes["add1"], scopes["relu1"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 5e3b93d30e5..6498436fbd9 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" @@ -44,6 +46,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/memory_types.h" @@ -83,7 +86,7 @@ Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, } // anonymous namespace -std::vector +RecursiveCompilabilityChecker::UncompilableNodesMap RecursiveCompilabilityChecker::FindUncompilableNodes( const Node& node, FunctionLibraryRuntime* lib_runtime, const std::vector* @@ -98,12 +101,14 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( } } stack_trace.emplace_back(StackFrameView{node.name(), ""}); - std::vector uncompilable_nodes; - IsCompilableNode(node, lib_runtime, &stack_trace, &uncompilable_nodes); + + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes; + IsCompilableNode(node, lib_runtime, &stack_trace, + /*encapsulating_function=*/nullptr, &uncompilable_nodes); return uncompilable_nodes; } -std::vector +RecursiveCompilabilityChecker::UncompilableNodesMap RecursiveCompilabilityChecker::FindUncompilableNodes( const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, const std::vector* @@ -118,8 +123,10 @@ RecursiveCompilabilityChecker::FindUncompilableNodes( } } stack_trace.emplace_back(StackFrameView{call_def.name(), ""}); - std::vector uncompilable_nodes; - IsCompilableCall(call_def, lib_runtime, &stack_trace, &uncompilable_nodes); + + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes; + IsCompilableCall(call_def, lib_runtime, &stack_trace, + /*encapsulating_function=*/nullptr, &uncompilable_nodes); return uncompilable_nodes; } @@ -154,16 +161,18 @@ bool RecursiveCompilabilityChecker::HasXLAKernel(const Node& node) const { bool RecursiveCompilabilityChecker::IsCompilableIf( const Node& if_node, FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes) const { + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { bool is_compilable = true; is_compilable &= ExtractNodeDefAndCheckCompilability( - if_node, "then_branch", "if_then", lib_runtime, stack_trace, - uncompilable_nodes); + if_node, "then_branch", "if_then", encapsulating_function, lib_runtime, + stack_trace, uncompilable_nodes); if (!uncompilable_nodes && !is_compilable) return is_compilable; is_compilable &= ExtractNodeDefAndCheckCompilability( - if_node, "else_branch", "if_else", lib_runtime, stack_trace, - uncompilable_nodes); + if_node, "else_branch", "if_else", encapsulating_function, lib_runtime, + stack_trace, uncompilable_nodes); return is_compilable; } @@ -174,37 +183,43 @@ bool RecursiveCompilabilityChecker::IsCompilableIf( bool RecursiveCompilabilityChecker::IsCompilableWhile( const Node& while_node, FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes) const { + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { bool is_compilable = true; is_compilable &= ExtractNodeDefAndCheckCompilability( - while_node, "cond", "while_cond", lib_runtime, stack_trace, - uncompilable_nodes); + while_node, "cond", "while_cond", encapsulating_function, lib_runtime, + stack_trace, uncompilable_nodes); + if (!uncompilable_nodes && !is_compilable) return is_compilable; is_compilable &= ExtractNodeDefAndCheckCompilability( - while_node, "body", "while_body", lib_runtime, stack_trace, - uncompilable_nodes); + while_node, "body", "while_body", encapsulating_function, lib_runtime, + stack_trace, uncompilable_nodes); return is_compilable; } bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability( const Node& node, const std::string& attr_name, - const std::string& call_name, FunctionLibraryRuntime* lib_runtime, + const std::string& call_name, NameAttrList* encapsulating_function, + FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes) const { + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { NodeDef call; call.set_name(call_name); if (!MakeCallNodeFromAttribute(node, attr_name, &call).ok()) { const auto uncompilable_reason = absl::StrCat( "missing '", attr_name, "' attribute from node", node.name()); MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); VLOG(2) << "Rejecting node " << node.name() << ": " << uncompilable_reason << "."; return false; } - if (!IsCompilableCall(call, lib_runtime, stack_trace, uncompilable_nodes)) { + if (!IsCompilableCall(call, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes)) { VLOG(2) << "Rejecting node " << node.name() << ": can't compile : " << call.op(); return false; @@ -218,24 +233,33 @@ bool RecursiveCompilabilityChecker::ExtractNodeDefAndCheckCompilability( bool RecursiveCompilabilityChecker::IsCompilableCall( const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes) const { + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { if (stack_trace->size() > kMaxRecursionDepth) { std::string uncompilable_reason = "function depth limit exceeded"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); VLOG(2) << "Rejecting " << call_def.op() << ": " << uncompilable_reason << "."; return false; } FunctionLibraryRuntime::Handle handle; - Status status = InstantiateFunctionCall(call_def, lib_runtime, &handle); - if (!status.ok()) { + Status s; + NameAttrList function; + s = NameAndAttrsFromFunctionCall(call_def, &function); + if (s.ok()) { + s = lib_runtime->Instantiate(function.name(), AttrSlice(&function.attr()), + &handle); + } + + if (!s.ok()) { std::string uncompilable_reason = "could not instantiate call"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); VLOG(2) << "Rejecting " << call_def.DebugString() << ": " - << uncompilable_reason << " : " << status; + << uncompilable_reason << " : " << s; return false; } @@ -244,9 +268,9 @@ bool RecursiveCompilabilityChecker::IsCompilableCall( const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); bool is_compilable = true; for (const Node* node : fbody->graph->op_nodes()) { - stack_trace->emplace_back(StackFrameView{node->name(), call_def.op()}); - is_compilable &= - IsCompilableNode(*node, lib_runtime, stack_trace, uncompilable_nodes); + stack_trace->emplace_back(StackFrameView{node->name(), function.name()}); + is_compilable &= IsCompilableNode(*node, lib_runtime, stack_trace, + &function, uncompilable_nodes); stack_trace->pop_back(); if (!uncompilable_nodes && !is_compilable) return is_compilable; } @@ -263,20 +287,28 @@ bool RecursiveCompilabilityChecker::OpIsInaccurate(const Node& node) const { bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const { // b/128001705: SelfAdjointEigV2 and Svd performance issues. // b/135640736: MatrixInverse performance issues. + // https://github.com/tensorflow/tensorflow/pull/31012: + // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes + // create convolutions too large for CuDNN to handle. return node.type_string() == "SelfAdjointEigV2" || node.type_string() == "Svd" || node.type_string() == "Qr" || - node.type_string() == "MatrixInverse"; + node.type_string() == "MatrixInverse" || + node.type_string() == "ResizeNearestNeighbor" || + node.type_string() == "ResizeBilinear" || + node.type_string() == "ResizeBilinearGrad"; } bool RecursiveCompilabilityChecker::IsCompilableNode( const Node& node, FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes) const { + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) + const { auto stack_depth = stack_trace->size(); if (node.IsSource() || node.IsSink()) { absl::string_view uncompilable_reason = "source or sink node"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -287,7 +319,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( (node.type_string() == "_Arg" || node.type_string() == "_Retval")) { absl::string_view uncompilable_reason = "top level _Arg or _Retval"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -299,33 +331,35 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( absl::string_view uncompilable_reason = "_scoped_allocator or _forward_from attribute"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), node)) { if (!IsCompilableCall(node.def(), lib_runtime, stack_trace, - uncompilable_nodes)) { + encapsulating_function, uncompilable_nodes)) { LogNotCompilable(node, "unsupported function"); return false; } } else if (!HasXLAKernel(node)) { absl::string_view uncompilable_reason = "unsupported op"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } - if (node.type_string() == "While" && - !IsCompilableWhile(node, lib_runtime, stack_trace, uncompilable_nodes)) { + if (node.IsWhileNode() && + !IsCompilableWhile(node, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes)) { LogNotCompilable(node, "unsupported while"); return false; } if (node.IsIfNode() && - !IsCompilableIf(node, lib_runtime, stack_trace, uncompilable_nodes)) { + !IsCompilableIf(node, lib_runtime, stack_trace, encapsulating_function, + uncompilable_nodes)) { LogNotCompilable(node, "unsupported if"); return false; } @@ -334,7 +368,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( IsStatefulRandomOp(node.type_string())) { absl::string_view uncompilable_reason = "stateful random op"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -342,7 +376,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( if (!op_filter_.allow_control_trigger && node.IsControlTrigger()) { absl::string_view uncompilable_reason = "not allowed control trigger"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -351,7 +385,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( IsAssertOrCheckNumerics(node.type_string())) { absl::string_view uncompilable_reason = "Assert or CheckNumerics"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -360,7 +394,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( OpProducesOrConsumesVariant(node)) { absl::string_view uncompilable_reason = "DT_VARIANT producer/consumer"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -368,7 +402,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( if (!op_filter_.allow_stack_ops && IsStackOp(node)) { absl::string_view uncompilable_reason = "Stack op"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -376,7 +410,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( if (!op_filter_.allow_tensor_array_ops && IsTensorArrayOp(node)) { absl::string_view uncompilable_reason = "TensorArray op"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -386,7 +420,7 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( absl::string_view uncompilable_reason = "resource variable op in called function"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -394,16 +428,22 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( if (!op_filter_.allow_inaccurate_ops && OpIsInaccurate(node)) { absl::string_view uncompilable_reason = "operation with numerical accuracy issues"; + BroadcastOptimizationRemark(XlaOptimizationRemark::INACCURATE_OPERATION, + node.DebugString()) + .IgnoreError(); MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } if (!op_filter_.allow_slow_ops && OpIsSlow(node)) { absl::string_view uncompilable_reason = "slow operation"; + BroadcastOptimizationRemark(XlaOptimizationRemark::SLOW_OPERATION, + node.DebugString()) + .IgnoreError(); MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, - uncompilable_nodes); + encapsulating_function, uncompilable_nodes); LogNotCompilable(node, uncompilable_reason); return false; } @@ -432,8 +472,9 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( /*static*/ void RecursiveCompilabilityChecker::MaybeMarkUncompilableNode( const absl::string_view reason, const std::vector& stack_trace, - std::vector* uncompilable_node_list) { - if (!uncompilable_node_list) return; + NameAttrList* encapsulating_function, + RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_nodes) { + if (!uncompilable_nodes) return; UncompilableNodeInfo node_info; node_info.uncompilable_reason = std::string(reason); @@ -445,7 +486,20 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( }); node_info.name = std::string(stack_trace.back().name); - (*uncompilable_node_list).push_back(std::move(node_info)); + auto function = + encapsulating_function ? *encapsulating_function : NameAttrList(); + auto function_identifier = function.ShortDebugString(); + + auto it = uncompilable_nodes->find(function_identifier); + if (it == uncompilable_nodes->end()) { + std::vector + uncompileable_node_info{std::move(node_info)}; + uncompilable_nodes->emplace( + std::move(function_identifier), + std::make_pair(function, std::move(uncompileable_node_info))); + } else { + it->second.second.emplace_back(std::move(node_info)); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 2ad3496bb7c..04639df14a1 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -129,19 +130,35 @@ class RecursiveCompilabilityChecker { const DeviceType* jit_device_type) : op_filter_(*op_filter), jit_device_type_(*jit_device_type) {} - // Returns a list of uncompilable nodes. When `node` is inside a function - // body, users can set `node_stack_trace` to provide an additional - // context for `node`'s placement within the outer most graph. - std::vector FindUncompilableNodes( + using UncompilableNodesMap = + std::map>>; + + // Returns a map where the key is the function identifier(short debug + // string) of the function encapsulating the uncompilable nodes, and the + // value is a pair of NameAttrList of the function and a vector of + // uncompilable node info. When uncompilable node is not inside any + // function call nodes, then key is a ShortDebugString() of an empty + // NameAttrList. + // + // Also, when `node` is inside a function body, users can set + // `node_stack_trace` to provide an additional context for `node`'s + // placement within the outer most graph. + UncompilableNodesMap FindUncompilableNodes( const Node& node, FunctionLibraryRuntime* lib_runtime, const std::vector* node_stack_trace = nullptr) const; - // Returns a list of uncompilable nodes in `call_def` that cannot be - // compiled by XLA. It is assumed that `call_def` is a call operation. - // When `node` is inside a function body, users can set + // Returns a map where the key is the function identifier(short debug + // string) of the function encapsulating the uncompilable nodes, and the + // value is a pair of NameAttrList of the function and a vector of + // uncompilable node info. When uncompilable node is not inside any + // function call nodes, then key is a ShortDebugString() of an empty + // NameAttrList. + // + // Also, when `node` is inside a function body, users can set // `node_stack_trace` to provide an additional context for `node`'s // placement within the outer most graph. - std::vector FindUncompilableNodes( + UncompilableNodesMap FindUncompilableNodes( const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, const std::vector* node_stack_trace = nullptr) const; @@ -176,27 +193,31 @@ class RecursiveCompilabilityChecker { bool IsCompilableNode( const Node& node, FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes = nullptr) const; + NameAttrList* encapsulating_function = nullptr, + UncompilableNodesMap* uncompilable_nodes = nullptr) const; bool IsCompilableCall( const NodeDef& call_def, FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes = nullptr) const; - bool IsCompilableIf( - const Node& if_node, FunctionLibraryRuntime* lib_runtime, - std::vector* stack_trace, - std::vector* uncompilable_nodes) const; - bool IsCompilableWhile( - const Node& while_node, FunctionLibraryRuntime* lib_runtime, - std::vector* stack_trace, - std::vector* uncompilable_nodes) const; + NameAttrList* encapsulating_function = nullptr, + UncompilableNodesMap* uncompilable_nodes = nullptr) const; + bool IsCompilableIf(const Node& if_node, FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; + bool IsCompilableWhile(const Node& while_node, + FunctionLibraryRuntime* lib_runtime, + std::vector* stack_trace, + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes) const; // Returns compilability of node def retrieved from `node`'s attribute with // name `attr_name`. bool ExtractNodeDefAndCheckCompilability( const Node& node, const std::string& attr_name, - const std::string& call_name, FunctionLibraryRuntime* lib_runtime, + const std::string& call_name, NameAttrList* encapsulating_function, + FunctionLibraryRuntime* lib_runtime, std::vector* stack_trace, - std::vector* uncompilable_nodes) const; + UncompilableNodesMap* uncompilable_nodes) const; bool IsStackOp(const Node& node) const { const XlaResourceOpInfo* op_info = @@ -231,7 +252,8 @@ class RecursiveCompilabilityChecker { static void MaybeMarkUncompilableNode( const absl::string_view reason, const std::vector& stack_trace, - std::vector* uncompilable_node_list); + NameAttrList* encapsulating_function, + UncompilableNodesMap* uncompilable_nodes_map); // Make sure we don't recurse infinitely on recursive functions. const int kMaxRecursionDepth = 10; diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 90d69680514..0dd3b8141c9 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -117,10 +119,15 @@ TEST_F(CompilabilityCheckUtilTest, CheckNonFunctionalNodes) { const auto uncompilable_nodes = checker_->FindUncompilableNodes(*uncompilable_op, flib_runtime); ASSERT_EQ(1, uncompilable_nodes.size()); - const auto& node_info = uncompilable_nodes.at(0); - EXPECT_EQ("unsupported op", node_info.uncompilable_reason); - ASSERT_EQ(1, node_info.stack_trace.size()); - ASSERT_EQ("", node_info.stack_trace.at(0).function_name); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_EQ("unsupported op", uncompilable_node_info.uncompilable_reason); + ASSERT_EQ(1, uncompilable_node_info.stack_trace.size()); + ASSERT_EQ("", uncompilable_node_info.stack_trace.at(0).function_name); } TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) { @@ -147,12 +154,18 @@ TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) { checker_->FindUncompilableNodes(*functional_node, flib_runtime); EXPECT_EQ(1, uncompilable_nodes.size()); - const auto& node_info = uncompilable_nodes.at(0); + NameAttrList function; + function.set_name(kUncompilableFunctionName); + const auto node_info_it = + uncompilable_nodes.find(function.ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + const auto& uncompilable_node_list = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_node_list.size()); + const auto& node_info = uncompilable_node_list.at(0); const auto& node_stack = node_info.stack_trace; ASSERT_EQ(2, node_stack.size()); EXPECT_EQ("D", node_stack.at(0).name); EXPECT_EQ(kUncompilableFunctionNodeName, node_stack.at(1).name); - EXPECT_EQ(kUncompilableFunctionNodeName, node_info.name); EXPECT_EQ("unsupported op", node_info.uncompilable_reason); } @@ -212,7 +225,15 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalWhileNode) { checker_->FindUncompilableNodes(**while_node_it, flib_runtime); ASSERT_EQ(1, uncompilable_nodes.size()); - const auto& node_info = uncompilable_nodes.at(0); + NameAttrList function; + function.set_name(kUncompilableFunctionName); + const auto node_info_it = + uncompilable_nodes.find(function.ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + const auto& uncompilable_node_list = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_node_list.size()); + const auto& node_info = uncompilable_node_list.at(0); + const auto& node_stack = node_info.stack_trace; ASSERT_EQ(2, node_stack.size()); const auto& stacktrace_first_node_info = node_stack.at(0); @@ -280,7 +301,14 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { checker_->FindUncompilableNodes(**if_node_it, flib_runtime); ASSERT_EQ(2, uncompilable_nodes.size()); - const auto& uncompilable_node_one = uncompilable_nodes.at(0); + NameAttrList function_one; + function_one.set_name(kUncompilableFunctionName); + auto it = uncompilable_nodes.find(function_one.ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), it); + + const auto& uncompilable_node_list = it->second.second; + ASSERT_EQ(1, uncompilable_node_list.size()); + const auto& uncompilable_node_one = uncompilable_node_list.at(0); const auto& node_one_stack = uncompilable_node_one.stack_trace; ASSERT_EQ(2, node_one_stack.size()); @@ -296,7 +324,14 @@ TEST_F(CompilabilityCheckUtilTest, CheckFunctionalIfNode) { EXPECT_EQ(kUncompilableFunctionNodeName, uncompilable_node_one.name); EXPECT_EQ("unsupported op", uncompilable_node_one.uncompilable_reason); - const auto& uncompilable_node_two = uncompilable_nodes.at(1); + NameAttrList function_two; + function_two.set_name(kUncompilableFunctionTwoName); + it = uncompilable_nodes.find(function_two.ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), it); + + const auto& uncompilable_node_two_list = it->second.second; + ASSERT_EQ(1, uncompilable_node_two_list.size()); + const auto& uncompilable_node_two = uncompilable_node_two_list.at(0); const auto& node_two_stack = uncompilable_node_two.stack_trace; ASSERT_EQ(2, node_two_stack.size()); const auto& node_two_stacktrace_first_node = node_two_stack.at(0); diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc index f847d66f3c6..b23f6ec35f5 100644 --- a/tensorflow/compiler/jit/defs.cc +++ b/tensorflow/compiler/jit/defs.cc @@ -18,6 +18,12 @@ limitations under the License. namespace tensorflow { const char* const kXlaCompileAttr = "_XlaCompile"; + +// User-provided through jit_scope APIs. Effective only when auto_jit is OFF. const char* const kXlaScopeAttr = "_XlaScope"; +// Automatically inserted by auto_jit to guide clustering results. Effective +// only when auto_jit is ON. +const char* const kXlaInternalScopeAttr = "_XlaInternalScope"; + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h index a3aabc949db..bf8009344df 100644 --- a/tensorflow/compiler/jit/defs.h +++ b/tensorflow/compiler/jit/defs.h @@ -24,6 +24,7 @@ namespace tensorflow { // Name of attribute used to tag operators for compilation with XLA extern const char* const kXlaCompileAttr; // "_XlaCompile" extern const char* const kXlaScopeAttr; // "_XlaScope" +extern const char* const kXlaInternalScopeAttr; // "_XlaInternalScope" } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 6992a0165d4..e0c0c0b18cc 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1317,7 +1317,7 @@ Status EncapsulateSubgraphsPass::Run( bool IsXlaCompiledKernel(const Node& node) { bool is_compiled = false; bool has_compilation_attr = - GetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled).ok() && + TryGetNodeAttr(node.attrs(), kXlaCompiledKernelAttr, &is_compiled) && is_compiled; return has_compilation_attr ? is_compiled : false; } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 2c2cd094133..b9889988cc0 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -245,8 +245,8 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, // while iterating. std::vector launch_nodes; for (Node* n : graph->nodes()) { - string name; - if (GetNodeAttr(n->attrs(), kXlaClusterAttr, &name).ok()) { + const string& name = GetNodeAttrString(n->attrs(), kXlaClusterAttr); + if (!name.empty()) { launch_nodes.push_back(n); } } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 8935cdfc240..b35e08fb1f0 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -33,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/dump_graph.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -369,7 +369,8 @@ xla::StatusOr BuildXlaHostComputeNodeDef( return new_def; } -Status ValidateOutsideCompilationCallNode(Node* call_node) { +TF_ATTRIBUTE_NOINLINE Status +ValidateOutsideCompilationCallNode(Node* call_node) { // DT_INT64 as input/output for outside compilation is not supported yet: // b/120809951. for (const Edge* e : call_node->in_edges()) { @@ -402,7 +403,7 @@ Status ValidateOutsideCompilationCallNode(Node* call_node) { } // Replace outside compilation function call node with XlaHostCompute node. -xla::StatusOr ReplaceOutsideCompilationCallNode( +TF_ATTRIBUTE_NOINLINE xla::StatusOr ReplaceOutsideCompilationCallNode( Graph* g, Node* call_node, const std::map& host_compute_core, const absl::flat_hash_map>& cluster_deps) { // Build XlaHostCompute NodeDef. @@ -440,7 +441,7 @@ Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { n->ClearAttr(attr_name); n->AddAttr(attr_name, branch_func); } - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { for (const string& attr_name : std::vector{"cond", "body"}) { NameAttrList branch_func; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), attr_name, &branch_func)); @@ -523,16 +524,14 @@ xla::StatusOr> UpdateTypesAttribute( // Add edges from lifted outside compilation argument nodes to `n` in Graph `g`. void AddEdgesFromOutsideCompilationNodes( - const int original_arg_count, const std::vector& data_types, - const std::vector>& - lifted_arg_nodes_and_outside_compilation_nodes, - Graph* g, Node* n) { + const int original_arg_count, const int arg_to_input_edge_offset, + const std::vector& data_types, + const std::vector& outside_compilation_nodes, Graph* g, Node* n) { // Add edges from outside compilation nodes to While node. for (int i = original_arg_count; i < data_types.size(); i++) { Node* outside_compilation_node = - lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count] - .second; - g->AddEdge(outside_compilation_node, 0, n, i); + outside_compilation_nodes[i - original_arg_count]; + g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset); } } @@ -573,14 +572,15 @@ Status AddMatchingRetvalNode(const FunctionBody& function_body, void ReplaceLiftedArgNodePlaceholderWithArg( const FunctionBody& function_body, const int original_arg_count, - const int arg_idx, - const std::vector>& - lifted_arg_nodes_and_outside_compilation_nodes, + const int arg_idx, const std::vector& lifted_arg_nodes, Node* arg_node) { - Node* lifted_arg_node = - lifted_arg_nodes_and_outside_compilation_nodes[arg_idx - - original_arg_count] - .first; + Node* lifted_arg_node = lifted_arg_nodes[arg_idx - original_arg_count]; + // This might happen because lifted_arg_node only exists in one branch of an + // If node, and we are handling the other branch. + if (!lifted_arg_node) { + return; + } + for (const Edge* e : lifted_arg_node->out_edges()) { if (e->IsControlEdge()) { function_body.graph->AddControlEdge(arg_node, e->dst()); @@ -588,7 +588,6 @@ void ReplaceLiftedArgNodePlaceholderWithArg( function_body.graph->AddEdge(arg_node, 0, e->dst(), e->dst_input()); } } - function_body.graph->RemoveNode(lifted_arg_node); } @@ -597,7 +596,7 @@ void ReplaceLiftedArgNodePlaceholderWithArg( Status PostprocessLiftedArgsForWhile( const std::unordered_map& outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { - TF_RET_CHECK(n->type_string() == "While"); + TF_RET_CHECK(n->IsWhileNode()); // Check if there is any lifted args in body function. NameAttrList body_func; @@ -629,12 +628,25 @@ Status PostprocessLiftedArgsForWhile( n)); // Add edges from outside compilation nodes to While node. - AddEdgesFromOutsideCompilationNodes( - original_arg_count, data_types, - lifted_arg_nodes_and_outside_compilation_nodes, g, n); + std::vector outside_compilation_nodes; + std::transform( + lifted_arg_nodes_and_outside_compilation_nodes.begin(), + lifted_arg_nodes_and_outside_compilation_nodes.end(), + std::back_inserter(outside_compilation_nodes), + [](const std::pair& pair) { return pair.second; }); + AddEdgesFromOutsideCompilationNodes(original_arg_count, + /*arg_to_input_edge_offset=*/0, + data_types, outside_compilation_nodes, g, + n); // In body_graph, create new _Arg/_Retval nodes, and replace lifted arg // nodes with the new _Arg nodes. + std::vector lifted_arg_nodes; + std::transform( + lifted_arg_nodes_and_outside_compilation_nodes.begin(), + lifted_arg_nodes_and_outside_compilation_nodes.end(), + std::back_inserter(lifted_arg_nodes), + [](const std::pair& pair) { return pair.first; }); for (int i = original_arg_count; i < data_types.size(); i++) { TF_ASSIGN_OR_RETURN(Node * arg_node, AddOutsideCompilationInputArgToFunctionBody( @@ -644,8 +656,7 @@ Status PostprocessLiftedArgsForWhile( AddMatchingRetvalNode(*body_function_body, i, data_types[i], arg_node)); ReplaceLiftedArgNodePlaceholderWithArg( - *body_function_body, original_arg_count, i, - lifted_arg_nodes_and_outside_compilation_nodes, arg_node); + *body_function_body, original_arg_count, i, lifted_arg_nodes, arg_node); } FunctionDef rewritten_body_function_def; @@ -682,6 +693,219 @@ Status PostprocessLiftedArgsForWhile( return Status::OK(); } +Status PostprocessLiftedArgsForIf( + const std::unordered_map& outside_compilation_attr_to_node, + Graph* g, Node* n, FunctionLibraryDefinition* fld) { + TF_RET_CHECK(n->IsIfNode()); + + NameAttrList then_branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "then_branch", &then_branch_func)); + const FunctionDef* then_branch_function_def = + fld->Find(then_branch_func.name()); + TF_RET_CHECK(then_branch_function_def); + + NameAttrList else_branch_func; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "else_branch", &else_branch_func)); + const FunctionDef* else_branch_function_def = + fld->Find(else_branch_func.name()); + TF_RET_CHECK(else_branch_function_def); + + // Nothing to do if neither branch contains any lifted arguments. + if (!HasLiftedArgs(*then_branch_function_def) && + !HasLiftedArgs(*else_branch_function_def)) { + return Status::OK(); + } + + std::unique_ptr then_branch_function_body; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *then_branch_function_def, AttrSlice(&then_branch_func.attr()), fld, + &then_branch_function_body)); + + std::unique_ptr else_branch_function_body; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( + *else_branch_function_def, AttrSlice(&else_branch_func.attr()), fld, + &else_branch_function_body)); + + // Then and else branches have same argument count and argument data types. + int original_arg_count = then_branch_function_body->arg_nodes.size(); + + TF_ASSIGN_OR_RETURN( + auto then_branch_lifted_arg_nodes_and_outside_compilation_nodes, + LiftedArgsAndOutsideCompilationNodesInFunctionBody( + *then_branch_function_body, outside_compilation_attr_to_node)); + + TF_ASSIGN_OR_RETURN( + auto else_branch_lifted_arg_nodes_and_outside_compilation_nodes, + LiftedArgsAndOutsideCompilationNodesInFunctionBody( + *else_branch_function_body, outside_compilation_attr_to_node)); + + // Merge lifted args from then and else branches. + std::vector outside_compilation_nodes; + std::vector then_branch_lifted_arg_nodes; + for (const auto& pair : + then_branch_lifted_arg_nodes_and_outside_compilation_nodes) { + outside_compilation_nodes.push_back(pair.second); + then_branch_lifted_arg_nodes.push_back(pair.first); + } + for (const auto& pair : + else_branch_lifted_arg_nodes_and_outside_compilation_nodes) { + if (std::find(outside_compilation_nodes.begin(), + outside_compilation_nodes.end(), + pair.second) == outside_compilation_nodes.end()) { + outside_compilation_nodes.push_back(pair.second); + // Then branch does not contain this lifted arg. Add an empty item to + // then_branch_lifted_arg_nodes. + then_branch_lifted_arg_nodes.push_back(nullptr); + } + } + // Reorder else_branch_lifted_arg_nodes_and_outside_compilation_nodes. + std::vector else_branch_lifted_arg_nodes( + outside_compilation_nodes.size()); + for (const auto& pair : + else_branch_lifted_arg_nodes_and_outside_compilation_nodes) { + auto iter = std::find(outside_compilation_nodes.begin(), + outside_compilation_nodes.end(), pair.second); + TF_RET_CHECK(iter != outside_compilation_nodes.end()); + int index = iter - outside_compilation_nodes.begin(); + else_branch_lifted_arg_nodes[index] = pair.first; + } + + // Append lifted args' types to If node's Tin attribute. + std::vector data_types; + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "Tin", &data_types)); + for (Node* n : outside_compilation_nodes) { + data_types.push_back(n->output_type(0)); + } + n->ClearAttr("Tin"); + n->AddAttr("Tin", data_types); + + // Add edges from outside compilation nodes to If node. If node's input #0 + // is predicate input, input #1 maps to _Arg #0 of branch functions, thus + // arg_to_input_edge_offset is set to 1. + AddEdgesFromOutsideCompilationNodes(original_arg_count, + /*arg_to_input_edge_offset=*/1, + data_types, outside_compilation_nodes, g, + n); + + for (int i = original_arg_count; i < data_types.size(); ++i) { + TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node, + AddOutsideCompilationInputArgToFunctionBody( + *then_branch_function_body, i, data_types[i])); + + ReplaceLiftedArgNodePlaceholderWithArg( + *then_branch_function_body, original_arg_count, i, + then_branch_lifted_arg_nodes, then_branch_arg_node); + + TF_ASSIGN_OR_RETURN(Node * else_branch_arg_node, + AddOutsideCompilationInputArgToFunctionBody( + *else_branch_function_body, i, data_types[i])); + + ReplaceLiftedArgNodePlaceholderWithArg( + *else_branch_function_body, original_arg_count, i, + else_branch_lifted_arg_nodes, else_branch_arg_node); + } + + FunctionDef rewritten_then_branch_function_def; + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *then_branch_function_body->graph, then_branch_func.name(), + HostGraphControlRetMapping, &rewritten_then_branch_function_def)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(then_branch_func.name(), + rewritten_then_branch_function_def)); + + FunctionDef rewritten_else_branch_function_def; + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *else_branch_function_body->graph, else_branch_func.name(), + HostGraphControlRetMapping, &rewritten_else_branch_function_def)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(else_branch_func.name(), + rewritten_else_branch_function_def)); + return Status::OK(); +} + +Status PostprocessLiftedArgsForCall( + const std::unordered_map& outside_compilation_attr_to_node, + Graph* g, Node* n, FunctionLibraryDefinition* fld) { + const FunctionDef* fdef = fld->Find(n->type_string()); + TF_RET_CHECK(fdef); + + // Nothing to do if the function does not contain any lifted arguments. + if (!HasLiftedArgs(*fdef)) { + return Status::OK(); + } + + std::unique_ptr fbody; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, n->attrs(), fld, &fbody)); + + int original_arg_count = fbody->arg_nodes.size(); + + TF_ASSIGN_OR_RETURN(auto lifted_arg_nodes_and_outside_compilation_nodes, + LiftedArgsAndOutsideCompilationNodesInFunctionBody( + *fbody, outside_compilation_attr_to_node)); + + // Append lifted args' types to call node's input data types. + std::vector data_types(n->input_types().begin(), + n->input_types().end()); + for (auto pair : lifted_arg_nodes_and_outside_compilation_nodes) { + Node* outside_compilation_node = pair.second; + DataType data_type; + TF_RET_CHECK(outside_compilation_node->IsIdentity() || + outside_compilation_node->type_string() == "Placeholder"); + if (outside_compilation_node->IsIdentity()) { + TF_RETURN_IF_ERROR( + GetNodeAttr(outside_compilation_node->def(), "T", &data_type)); + } else { + TF_RETURN_IF_ERROR( + GetNodeAttr(outside_compilation_node->def(), "dtype", &data_type)); + } + data_types.push_back(data_type); + } + + std::vector lifted_arg_nodes; + std::transform( + lifted_arg_nodes_and_outside_compilation_nodes.begin(), + lifted_arg_nodes_and_outside_compilation_nodes.end(), + std::back_inserter(lifted_arg_nodes), + [](const std::pair& pair) { return pair.first; }); + for (int i = original_arg_count; i < data_types.size(); ++i) { + TF_ASSIGN_OR_RETURN( + Node * arg_node, + AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i])); + + ReplaceLiftedArgNodePlaceholderWithArg(*fbody, original_arg_count, i, + lifted_arg_nodes, arg_node); + } + + FunctionDef rewritten_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*fbody->graph, n->type_string(), + HostGraphControlRetMapping, + &rewritten_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(n->type_string(), rewritten_fdef)); + + // We need to recreate the node. Otherwise TF will not know n->num_inputs() + // has increased. + NodeDef node_def = n->def(); + for (int i = original_arg_count; i < data_types.size(); i++) { + Node* outside_compilation_node = + lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count] + .second; + node_def.add_input(absl::StrCat(outside_compilation_node->name(), ":", 0)); + } + TF_ASSIGN_OR_RETURN(n, ReplaceNode(g, n, node_def)); + + // Add edges from outside compilation nodes to call node. + std::vector outside_compilation_nodes; + std::transform( + lifted_arg_nodes_and_outside_compilation_nodes.begin(), + lifted_arg_nodes_and_outside_compilation_nodes.end(), + std::back_inserter(outside_compilation_nodes), + [](const std::pair& pair) { return pair.second; }); + AddEdgesFromOutsideCompilationNodes(original_arg_count, + /*arg_to_input_edge_offset=*/0, + data_types, outside_compilation_nodes, g, + n); + + return Status::OK(); +} + // Creates a mapping from outside compilation cluster name to lifted argument // placeholder. xla::StatusOr> OutsideCompilationAttrToNode( @@ -690,10 +914,9 @@ xla::StatusOr> OutsideCompilationAttrToNode( for (Node* n : g.op_nodes()) { bool is_lifted_arg; string outside_compilation_attr; - if (GetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg).ok() && - GetNodeAttr(n->def(), "_xla_outside_compilation", - &outside_compilation_attr) - .ok()) { + if (TryGetNodeAttr(n->def(), kXlaIsLiftedArgAttrName, &is_lifted_arg) && + TryGetNodeAttr(n->def(), "_xla_outside_compilation", + &outside_compilation_attr)) { TF_RET_CHECK(is_lifted_arg); TF_RET_CHECK(n->IsIdentity() || n->type_string() == "Placeholder"); outside_compilation_attr_to_node[outside_compilation_attr] = n; @@ -707,15 +930,34 @@ Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node, OutsideCompilationAttrToNode(*g)); + std::vector call_nodes; for (Node* n : g->op_nodes()) { if (!HasNodeAttr(n->def(), kXlaHasHostTransferAttrName)) { continue; } - if (n->type_string() == "While") { + if (n->IsWhileNode()) { TF_RETURN_IF_ERROR(PostprocessLiftedArgsForWhile( outside_compilation_attr_to_node, g, n, fld)); } + + if (n->IsIfNode()) { + TF_RETURN_IF_ERROR(PostprocessLiftedArgsForIf( + outside_compilation_attr_to_node, g, n, fld)); + } + + // Outside compilation host side function call will always be direct + // function call nodes. + // Function call nodes need to be handled separately because we rewrite + // nodes in `PostprocessLiftedArgsForCall`. + if (fld->Contains(n->type_string())) { + call_nodes.push_back(n); + } + } + + for (Node* n : call_nodes) { + TF_RETURN_IF_ERROR(PostprocessLiftedArgsForCall( + outside_compilation_attr_to_node, g, n, fld)); } return Status::OK(); @@ -1065,9 +1307,9 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, } // Builds XlaSendToHost node which sends cond predicate to host. -xla::StatusOr BuildSendIfPredNode(const string& name, - const string& host_transfer_key, - Node* pred_node, Graph* g) { +TF_ATTRIBUTE_NOINLINE xla::StatusOr BuildSendIfPredNode( + const string& name, const string& host_transfer_key, Node* pred_node, + Graph* g) { NodeDefBuilder send_pred_builder(name, "XlaSendToHost"); send_pred_builder.Attr("Tinput", DT_BOOL); send_pred_builder.Attr("key", absl::StrCat(host_transfer_key, "_dtoh_0")); @@ -1130,15 +1372,13 @@ Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, } // Builds host side graph for If node. -Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, - const string& outside_compilation_attr_name, - const string& xla_cluster_name, - const string& if_node_name, - const string& host_transfer_key, - const string& host_graph_func_name, - FunctionLibraryDefinition* fld, - const string& then_branch_host_func_name, - const string& else_branch_host_func_name) { +TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const string& if_node_name, const string& host_transfer_key, + const string& host_graph_func_name, FunctionLibraryDefinition* fld, + const string& then_branch_host_func_name, + const string& else_branch_host_func_name) { Graph host_graph(fld); string outside_compilation_name = absl::StrCat("oc_if_", if_node_name); AttrValue device_ordinal_value; @@ -1215,10 +1455,9 @@ Status BuildHostGraphForIfNode(const string& xla_cluster_attr_name, } // Rewrites loop cond to add a node which sends loop cond to host. -Status AddSendLoopPredToLoopCond(FunctionLibraryDefinition* fld, - const NameAttrList& loop_cond_func, - const string& while_node_name, - const string& host_transfer_key) { +TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( + FunctionLibraryDefinition* fld, const NameAttrList& loop_cond_func, + const string& while_node_name, const string& host_transfer_key) { // Instantiate the loop cond function. std::unique_ptr fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld->Find(loop_cond_func.name()), @@ -1406,7 +1645,7 @@ Status RewriteHostWhileLoopBody( } // Builds host side graph for while node. -Status BuildHostGraphForWhileNode( +TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const string& while_node_name, const string& host_transfer_key, @@ -1503,10 +1742,6 @@ Status BuildHostGraphForFuncCallNode( call_builder.Attr(kXlaHasHostTransferAttrName, true); call_builder.Attr(xla_cluster_attr_name, xla_cluster_name); call_builder.Attr(outside_compilation_attr_name, call_builder.node_name()); - // Make sure control outputs of this function call node will be respected when - // this node is lowered. - call_builder.Attr(LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr, - true); NodeDef call_def; TF_RETURN_IF_ERROR(call_builder.Finalize(&call_def)); Status s; @@ -1529,6 +1764,221 @@ Status BuildHostGraphForFuncCallNode( return Status::OK(); } +TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + bool func_has_outside_compilation = false; + NameAttrList func; + if (fld->Contains(n->type_string())) { + func.set_name(n->type_string()); + typedef protobuf::Map AttrMap; + *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); + } else if (n->IsPartitionedCall()) { + TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "f", &func)); + } else { + TF_RET_CHECK(n->type_string() == FunctionLibraryDefinition::kGradientOp); + func.set_name(FunctionLibraryDefinition::kGradientOp); + *func.mutable_attr() = n->def().attr(); + } + string new_func_name = absl::StrCat(n->name(), "_oc"); + string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + func, new_func_name, host_func_name, host_compute_core, flr, fld, + shape_inference_graphs, &func_has_outside_compilation)); + + // If the function call does not have outside compilation, nothing to do. + if (!func_has_outside_compilation) { + return Status::OK(); + } + + *has_outside_compilation = true; + + // Change `n` to call the new function directly. + auto replace_builder = + absl::make_unique(n->name(), new_func_name, fld); + std::vector inputs(n->num_inputs()); + for (const Edge* e : n->in_edges()) { + if (e->IsControlEdge()) { + continue; + } + + TF_RET_CHECK(e->dst_input() >= 0 && e->dst_input() < inputs.size()); + inputs[e->dst_input()] = + NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(), + e->src()->output_type(e->src_output())}; + } + for (const auto& input : inputs) { + replace_builder->Input(input); + } + for (const auto& attr : n->attrs()) { + replace_builder->Attr(attr.first, attr.second); + } + auto replace_def = absl::make_unique(); + TF_RETURN_IF_ERROR(replace_builder->Finalize(replace_def.get())); + TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, *replace_def)); + replace->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the function call. + string oc_host_graph_name = + absl::StrCat("oc_func_host_graph_", replace->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode( + xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, + replace->name(), host_func_name, oc_host_graph_name, fld)); + + // Record the host graph. + host_graphs->push_back(oc_host_graph_name); + + return Status::OK(); +} + +Status ExtractOutsideCompilationForIfNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + // Instantiate "then_branch" and "else_branch". + NameAttrList then_branch, else_branch; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); + + // Extract outside compilation for then_branch and else_branch. + bool then_branch_has_outside_compilation = false; + bool else_branch_has_outside_compilation = false; + string then_branch_host_func_name = + absl::StrCat("oc_then_branch_host_if_", n->name()), + else_branch_host_func_name = + absl::StrCat("oc_else_branch_host_if_", n->name()); + string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), + else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + then_branch, then_branch_xla_func_name, then_branch_host_func_name, + host_compute_core, flr, fld, shape_inference_graphs, + &then_branch_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + else_branch, else_branch_xla_func_name, else_branch_host_func_name, + host_compute_core, flr, fld, shape_inference_graphs, + &else_branch_has_outside_compilation)); + + // If then/else branch do not have outside compilation, nothing to do. + if (!then_branch_has_outside_compilation && + !else_branch_has_outside_compilation) { + return Status::OK(); + } + + *has_outside_compilation = true; + + // Change If node to call the new functions. + then_branch.set_name(then_branch_xla_func_name); + n->ClearAttr("then_branch"); + n->AddAttr("then_branch", then_branch); + else_branch.set_name(else_branch_xla_func_name); + n->ClearAttr("else_branch"); + n->AddAttr("else_branch", else_branch); + + string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); + + // XLA computation: add a SendToHost node to send cond predicate. + Node* pred_node; + TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); + TF_ASSIGN_OR_RETURN( + Node * send_pred_node, + BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), + host_transfer_key, pred_node, g)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{send_pred_node->name()}); + + // Add a control edge from `send_pred_node` to If node, so XlaCompiler will + // visit If node after `send_pred_node`, thus the token output for + // `send_pred_node` has been generated. + g->AddControlEdge(send_pred_node, n); + + // Build host side graph for the "If" node. + string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + then_branch_host_func_name, else_branch_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + + return Status::OK(); +} + +Status ExtractOutsideCompilationForWhileNode( + const string& xla_cluster_attr_name, + const string& outside_compilation_attr_name, const string& xla_cluster_name, + const std::map& host_compute_core, Graph* g, Node* n, + FunctionLibraryRuntime* flr, FunctionLibraryDefinition* fld, + std::vector* host_graphs, + std::vector* shape_inference_graphs, + bool* has_outside_compilation) { + // Instantiate "cond" and "body". + NameAttrList cond, body; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); + + // Extract outside compilation for cond and body. + bool cond_has_outside_compilation = false; + bool body_has_outside_compilation = false; + string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), + body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); + string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), + body_xla_func_name = absl::StrCat(body.name(), "_oc"); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &cond_has_outside_compilation)); + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + body, body_xla_func_name, body_host_func_name, host_compute_core, flr, + fld, shape_inference_graphs, &body_has_outside_compilation)); + + // If cond/body do not have outside compilation, nothing to do. + if (!cond_has_outside_compilation && !body_has_outside_compilation) { + return Status::OK(); + } + + *has_outside_compilation = true; + + // Change While node to call the new functions. + cond.set_name(cond_xla_func_name); + n->ClearAttr("cond"); + n->AddAttr("cond", cond); + body.set_name(body_xla_func_name); + n->ClearAttr("body"); + n->AddAttr("body", body); + + string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); + + // XLA computation: rewrite cond function to add a SendToHost node to send + // loop predicate. + TF_RETURN_IF_ERROR( + AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); + n->AddAttr(kXlaTokenInputNodesAttrName, + std::vector{kXlaTokenArgNodeName}); + + // Build host side graph for the "While" node. + string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); + TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( + xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, + n->name(), host_transfer_key, oc_host_graph_name, fld, + cond_host_func_name, body_host_func_name)); + host_graphs->push_back(oc_host_graph_name); + + return Status::OK(); +} + Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( Graph* g, const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, @@ -1540,193 +1990,32 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( for (Node* n : g->nodes()) { if (n->IsIfNode()) { if_nodes.push_back(n); - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { while_nodes.push_back(n); - } else if (fld->Contains(n->type_string())) { + } else if (IsFunctionCall(*fld, *n)) { func_call_nodes.push_back(n); - } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { - // Only gradient for user-defined function should be considered as - // function call node. - NameAttrList original_func; - TF_RETURN_IF_ERROR(GetNodeAttr( - n->def(), FunctionLibraryDefinition::kFuncAttr, &original_func)); - if (fld->Contains(original_func.name())) { - func_call_nodes.push_back(n); - } } } for (Node* n : func_call_nodes) { - // Extract outside compilation for the function call. - bool func_has_outside_compilation = false; - NameAttrList func; - func.set_name(n->type_string()); - typedef protobuf::Map AttrMap; - *func.mutable_attr() = AttrMap(n->attrs().begin(), n->attrs().end()); - string new_func_name = absl::StrCat(n->name(), "_oc"); - string host_func_name = absl::StrCat("oc_func_call_host_", n->name()); - TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFuncCallNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - func, new_func_name, host_func_name, host_compute_core, flr, fld, - shape_inference_graphs, &func_has_outside_compilation)); - - // If the function call does not have outside compilation, nothing to do. - if (!func_has_outside_compilation) { - continue; - } - - *has_outside_compilation = true; - - // Change `n` to call the new function directly. - NodeDefBuilder replace_builder(n->name(), new_func_name, fld); - for (const Edge* e : n->in_edges()) { - if (e->IsControlEdge()) { - continue; - } - replace_builder.Input(e->src()->name(), e->src_output(), - e->src()->output_type(e->src_output())); - } - for (const auto& attr : n->attrs()) { - replace_builder.Attr(attr.first, attr.second); - } - NodeDef replace_def; - TF_RETURN_IF_ERROR(replace_builder.Finalize(&replace_def)); - TF_ASSIGN_OR_RETURN(Node * replace, ReplaceNode(g, n, replace_def)); - replace->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); - - // Build host side graph for the function call. - string oc_host_graph_name = - absl::StrCat("oc_func_host_graph_", replace->name()); - TF_RETURN_IF_ERROR(BuildHostGraphForFuncCallNode( - xla_cluster_attr_name, xla_cluster_name, outside_compilation_attr_name, - replace->name(), host_func_name, oc_host_graph_name, fld)); - - // Record the host graph. - host_graphs->push_back(oc_host_graph_name); + host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs, + has_outside_compilation)); } for (Node* n : if_nodes) { - // Instantiate "then_branch" and "else_branch". - NameAttrList then_branch, else_branch; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "then_branch", &then_branch)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "else_branch", &else_branch)); - - // Extract outside compilation for then_branch and else_branch. - bool then_branch_has_outside_compilation = false; - bool else_branch_has_outside_compilation = false; - string then_branch_host_func_name = - absl::StrCat("oc_then_branch_host_if_", n->name()), - else_branch_host_func_name = - absl::StrCat("oc_else_branch_host_if_", n->name()); - string then_branch_xla_func_name = absl::StrCat(then_branch.name(), "_oc"), - else_branch_xla_func_name = absl::StrCat(else_branch.name(), "_oc"); - TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForIfNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - then_branch, then_branch_xla_func_name, then_branch_host_func_name, - host_compute_core, flr, fld, shape_inference_graphs, - &then_branch_has_outside_compilation)); - TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( - xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - else_branch, else_branch_xla_func_name, else_branch_host_func_name, - host_compute_core, flr, fld, shape_inference_graphs, - &else_branch_has_outside_compilation)); - - // If then/else branch do not have outside compilation, nothing to do. - if (!then_branch_has_outside_compilation && - !else_branch_has_outside_compilation) { - continue; - } - - *has_outside_compilation = true; - - // Change If node to call the new functions. - then_branch.set_name(then_branch_xla_func_name); - n->ClearAttr("then_branch"); - n->AddAttr("then_branch", then_branch); - else_branch.set_name(else_branch_xla_func_name); - n->ClearAttr("else_branch"); - n->AddAttr("else_branch", else_branch); - - string host_transfer_key = absl::StrCat("oc_if_pred_", n->name()); - - // XLA computation: add a SendToHost node to send cond predicate. - Node* pred_node; - TF_RETURN_IF_ERROR(n->input_node(0, &pred_node)); - TF_ASSIGN_OR_RETURN( - Node * send_pred_node, - BuildSendIfPredNode(absl::StrCat("send_oc_if_pred_", n->name()), - host_transfer_key, pred_node, g)); - n->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{send_pred_node->name()}); - - // Add a control edge from `send_pred_node` to If node, so XlaCompiler will - // visit If node after `send_pred_node`, thus the token output for - // `send_pred_node` has been generated. - g->AddControlEdge(send_pred_node, n); - - // Build host side graph for the "If" node. - string oc_host_graph_name = absl::StrCat("oc_if_host_graph_", n->name()); - TF_RETURN_IF_ERROR(BuildHostGraphForIfNode( - xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - n->name(), host_transfer_key, oc_host_graph_name, fld, - then_branch_host_func_name, else_branch_host_func_name)); - host_graphs->push_back(oc_host_graph_name); + host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs, + has_outside_compilation)); } for (Node* n : while_nodes) { - // Instantiate "cond" and "body". - NameAttrList cond, body; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "cond", &cond)); - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "body", &body)); - - // Extract outside compilation for cond and body. - bool cond_has_outside_compilation = false; - bool body_has_outside_compilation = false; - string cond_host_func_name = absl::StrCat("oc_cond_host_while_", n->name()), - body_host_func_name = absl::StrCat("oc_body_host_while_", n->name()); - string cond_xla_func_name = absl::StrCat(cond.name(), "_oc"), - body_xla_func_name = absl::StrCat(body.name(), "_oc"); - TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( + TF_RETURN_IF_ERROR(ExtractOutsideCompilationForWhileNode( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - cond, cond_xla_func_name, cond_host_func_name, host_compute_core, flr, - fld, shape_inference_graphs, &cond_has_outside_compilation)); - TF_RETURN_IF_ERROR(ExtractOutsideCompilationForFunction( - xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - body, body_xla_func_name, body_host_func_name, host_compute_core, flr, - fld, shape_inference_graphs, &body_has_outside_compilation)); - - // If cond/body do not have outside compilation, nothing to do. - if (!cond_has_outside_compilation && !body_has_outside_compilation) { - continue; - } - - *has_outside_compilation = true; - - // Change While node to call the new functions. - cond.set_name(cond_xla_func_name); - n->ClearAttr("cond"); - n->AddAttr("cond", cond); - body.set_name(body_xla_func_name); - n->ClearAttr("body"); - n->AddAttr("body", body); - - string host_transfer_key = absl::StrCat("oc_while_pred_", n->name()); - - // XLA computation: rewrite cond function to add a SendToHost node to send - // loop predicate. - TF_RETURN_IF_ERROR( - AddSendLoopPredToLoopCond(fld, cond, n->name(), host_transfer_key)); - n->AddAttr(kXlaTokenInputNodesAttrName, - std::vector{kXlaTokenArgNodeName}); - - // Build host side graph for the "While" node. - string oc_host_graph_name = absl::StrCat("oc_while_host_graph_", n->name()); - TF_RETURN_IF_ERROR(BuildHostGraphForWhileNode( - xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, - n->name(), host_transfer_key, oc_host_graph_name, fld, - cond_host_func_name, body_host_func_name)); - host_graphs->push_back(oc_host_graph_name); + host_compute_core, g, n, flr, fld, host_graphs, shape_inference_graphs, + has_outside_compilation)); } return Status::OK(); @@ -1889,11 +2178,11 @@ Status ExtractOutsideCompilationForFunction( // Encapsulate outside_compilation cluster into function call node. std::unique_ptr graph_out; - RewriteOutsideCompilationSubgraphFn rewrite_fn( + auto rewrite_fn = absl::make_unique( xla_cluster_attr_name, outside_compilation_attr_name, xla_cluster_name, new_func_name); TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - outside_compilation_attr_name, *fbody->graph, rewrite_fn, + outside_compilation_attr_name, *fbody->graph, *rewrite_fn, /*reuse_existing_functions=*/true, &graph_out, fld)); // Replace outside_compilation function nodes with HostCompute ops. @@ -1908,26 +2197,26 @@ Status ExtractOutsideCompilationForFunction( // If we could not infer shapes for XlaSendFromHost inputs statically, we // will set the "shape_inference_graph" attribute. In that case, copy // outside compilation subgraph as shape inference graph in `fld`. - NameAttrList shape_inference_graph; + auto shape_inference_graph = absl::make_unique(); TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "shape_inference_graph", - &shape_inference_graph)); - if (!shape_inference_graph.name().empty()) { - shape_inference_graphs->push_back(shape_inference_graph.name()); + shape_inference_graph.get())); + if (!shape_inference_graph->name().empty()) { + shape_inference_graphs->push_back(shape_inference_graph->name()); shape_inference_graphs_to_rewrite.push_back( - shape_inference_graph.name()); + shape_inference_graph->name()); const FunctionDef* xla_fdef = fld->Find(n->name()); if (!xla_fdef) { return errors::Internal("Cannot find XLA function ", n->name()); } - FunctionDef shape_inference_fdef = *xla_fdef; - shape_inference_fdef.mutable_signature()->set_name( - shape_inference_graph.name()); - if (fld->Find(shape_inference_graph.name())) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph.name(), - shape_inference_fdef)); + auto shape_inference_fdef = absl::make_unique(*xla_fdef); + shape_inference_fdef->mutable_signature()->set_name( + shape_inference_graph->name()); + if (fld->Find(shape_inference_graph->name())) { + TF_RETURN_IF_ERROR(fld->ReplaceFunction(shape_inference_graph->name(), + *shape_inference_fdef)); } else { - TF_RETURN_IF_ERROR(fld->AddFunctionDef(shape_inference_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(*shape_inference_fdef)); } } } @@ -1972,15 +2261,15 @@ Status ExtractOutsideCompilationForFunction( TF_RETURN_IF_ERROR( ConstructHostGraph(xla_cluster_name, outside_compilation_attr_name, outside_compilation_host_graphs, fld, &host_graph)); - FunctionDef host_graph_fdef; + auto host_graph_fdef = absl::make_unique(); TF_RETURN_IF_ERROR(GraphToFunctionDef(*host_graph, host_graph_func_name, HostGraphControlRetMapping, - &host_graph_fdef)); + host_graph_fdef.get())); if (fld->Find(host_graph_func_name)) { TF_RETURN_IF_ERROR( - fld->ReplaceFunction(host_graph_func_name, host_graph_fdef)); + fld->ReplaceFunction(host_graph_func_name, *host_graph_fdef)); } else { - TF_RETURN_IF_ERROR(fld->AddFunctionDef(host_graph_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(*host_graph_fdef)); } // Shape inference graphs might contain Placeholder nodes for outside @@ -1999,19 +2288,19 @@ Status ExtractOutsideCompilationForFunction( } // Replace original function. - FunctionDef updated_fdef; + auto updated_fdef = absl::make_unique(); TF_RETURN_IF_ERROR( - GraphToFunctionDef(*graph_out, new_func_name, &updated_fdef)); + GraphToFunctionDef(*graph_out, new_func_name, updated_fdef.get())); const FunctionDef* original_fdef = fld->Find(func_name); if (original_fdef) { for (const auto& attr : original_fdef->attr()) { - (*updated_fdef.mutable_attr())[attr.first] = attr.second; + (*updated_fdef->mutable_attr())[attr.first] = attr.second; } } if (fld->Find(new_func_name)) { - TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, updated_fdef)); + TF_RETURN_IF_ERROR(fld->ReplaceFunction(new_func_name, *updated_fdef)); } else { - TF_RETURN_IF_ERROR(fld->AddFunctionDef(updated_fdef)); + TF_RETURN_IF_ERROR(fld->AddFunctionDef(*updated_fdef)); } if (VLOG_IS_ON(4)) { DumpGraphToFile( diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index f69a28b71b3..53f9b70c876 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -105,6 +105,8 @@ void AllocateAndParseFlags() { build_ops_flags = new BuildXlaOpsPassFlags; build_ops_flags->tf_xla_enable_lazy_compilation = true; build_ops_flags->tf_xla_print_cluster_outputs = false; + build_ops_flags->tf_xla_check_cluster_input_numerics = false; + build_ops_flags->tf_xla_check_cluster_output_numerics = false; build_ops_flags->tf_xla_disable_constant_folding = false; mark_for_compilation_flags = new MarkForCompilationPassFlags; @@ -144,6 +146,14 @@ void AllocateAndParseFlags() { &build_ops_flags->tf_xla_print_cluster_outputs, "If true then insert Print nodes to print out values produced by " "XLA clusters."), + Flag("tf_xla_check_cluster_input_numerics", + &build_ops_flags->tf_xla_check_cluster_input_numerics, + "If true then insert CheckNumerics nodes to to check all cluster " + "inputs."), + Flag("tf_xla_check_cluster_output_numerics", + &build_ops_flags->tf_xla_check_cluster_output_numerics, + "If true then insert CheckNumerics nodes to to check all cluster " + "outputs."), Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand, "Switch a device into 'on-demand' mode, where instead of " diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 91e93f30d11..9307874133c 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -103,6 +103,14 @@ struct BuildXlaOpsPassFlags { // clusters. Useful for debugging. bool tf_xla_print_cluster_outputs; + // If true, insert CheckNumerics nodes for every floating point typed input to + // an XLA cluster. + bool tf_xla_check_cluster_input_numerics; + + // If true, insert CheckNumerics nodes for every floating point typed output + // from an XLA cluster. + bool tf_xla_check_cluster_output_numerics; + // Disables all constant folding. The primary use for this is for testing to // guarantee that tests are run on XLA and not on TF's CPU implementation. bool tf_xla_disable_constant_folding; diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 127f0d4a82e..4773e8dc562 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/clone_constants_for_better_clustering.h" +#include "tensorflow/compiler/jit/cluster_scoping_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" @@ -50,6 +51,9 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25, REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 5, CloneConstantsForBetterClusteringPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 9, + ClusterScopingPass); + REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 49b8731ca0b..e09dfd2b49c 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -5,33 +5,48 @@ package( licenses = ["notice"], # Apache 2.0 ) +XLA_OPS_DEPS = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", + "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:xla_activity_listener", + "//tensorflow/compiler/jit:xla_activity_proto_cc", + "//tensorflow/compiler/jit:xla_compilation_cache", + "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", + "//tensorflow/compiler/jit:xla_launch_util", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:tf2xla_util", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:state_ops_op_lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:tf_allocator_adapter", +] + +# Linked by tensorflow core, without registration of jit compilation passes. cc_library( - name = "xla_ops", + name = "xla_ops_no_jit_rewrite_registration", srcs = ["xla_ops.cc"], hdrs = ["xla_ops.h"], - deps = [ - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/jit:xla_compilation_cache", - "//tensorflow/compiler/jit:xla_device", - "//tensorflow/compiler/jit:xla_launch_util", - "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:state_ops_op_lib", - "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/stream_executor:tf_allocator_adapter", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/memory", + deps = XLA_OPS_DEPS, + alwayslink = 1, +) + +cc_library( + name = "xla_ops", + hdrs = ["xla_ops.h"], + deps = XLA_OPS_DEPS + [ + ":xla_ops_no_jit_rewrite_registration", + "//tensorflow/compiler/jit:jit_compilation_passes", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 788e90ffe99..fabd0374013 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -62,8 +63,7 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { DeviceType device_type = ctx->device_type(); se::Platform::Id platform_id = nullptr; const XlaDevice::Metadata* xla_device_metadata = nullptr; - std::unique_ptr xla_allocator; - se::DeviceMemoryAllocator* device_allocator = nullptr; + se::DeviceMemoryAllocator* custom_allocator = nullptr; if (ctx->device_type() == DeviceType(DEVICE_CPU)) { platform_id = se::host::kHostPlatformId; @@ -83,23 +83,13 @@ XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { // (which xla_allocator above uses) as on an XlaDevice, this is a dummy // allocator that returns XlaTensor objects. The XlaCompiler needs a real // allocator to allocate real buffers. - platform_id = xla_device_metadata->platform()->id(); - device_allocator = + custom_allocator = xla_device_metadata->client()->backend().memory_allocator(); } - if (!device_allocator) { - xla::StatusOr maybe_platform = - se::MultiPlatformManager::PlatformWithId(platform_id); - OP_REQUIRES_OK_RETURN(ctx, XlaPlatformInfo(), maybe_platform.status()); - - xla_allocator = absl::make_unique( - maybe_platform.ValueOrDie(), ctx->device()->GetAllocator({})); - } - return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - std::move(xla_allocator), device_allocator); + custom_allocator); } // A closure describing how to run a compiled version of a TensorFlow function. @@ -184,6 +174,31 @@ class XlaExecutableClosureStore { TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); }; +// Return allocator from platform info if non-null, or populate and return a +// pointer to the allocator adapter with allocator from context. +// +// This is necessary because for XLA devices the underlying TF allocator returns +// dummy tensors. +se::DeviceMemoryAllocator* GetAllocator( + absl::optional* tf_allocator_adapter, + OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { + if (platform_info.custom_allocator()) { + return platform_info.custom_allocator(); + } + if (!ctx->op_device_context()) { + // Stream is not set for the host platform. + se::Platform* platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) + .ValueOrDie(); + tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); + return &tf_allocator_adapter->value(); + } + // platform_info. + tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), + ctx->op_device_context()->stream()); + return &tf_allocator_adapter->value(); +} + } // namespace XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, @@ -280,6 +295,7 @@ static Status CompileToLocalExecutable( TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables)); *client = static_cast(cache->client()); + absl::optional tf_allocator_adapter; XlaCompiler::Options options; options.client = *client; if (ctx->op_device_context() != nullptr) { @@ -291,7 +307,8 @@ static Status CompileToLocalExecutable( options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_info.platform_id() == se::host::kHostPlatformId); - options.device_allocator = platform_info.allocator(); + options.device_allocator = + GetAllocator(&tf_allocator_adapter, ctx, platform_info); if (platform_info.xla_device_metadata()) { options.shape_representation_fn = platform_info.xla_device_metadata()->shape_representation_fn(); @@ -349,8 +366,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(1) << "Executing XLA Computation..."; + absl::optional tf_allocator_adapter; + se::DeviceMemoryAllocator* allocator = + GetAllocator(&tf_allocator_adapter, ctx, platform_info_); XlaComputationLaunchContext launch_context( - client, platform_info_.allocator(), + client, allocator, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), platform_info_.UseMultipleStreams()); launch_context.PopulateInputs(ctx, kernel, variables, @@ -360,21 +380,28 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); - run_options.set_allocator(platform_info_.allocator()); + run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); Env* env = Env::Default(); auto start_time = env->NowMicros(); - auto run_result = executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result; + if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { + run_result = executable->Run(launch_context.arguments(), run_options); + } else { + run_result = executable->RunAsync(launch_context.arguments(), run_options); + } OP_REQUIRES(ctx, run_result.ok(), run_result.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( ctx, kernel, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0)); + /*missing_ctx_input_prefix=*/0, input_output_alias)); VLOG(1) << "Done"; } @@ -467,6 +494,10 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { if (status.code() == error::UNIMPLEMENTED) { LOG(WARNING) << "Compilation failed:" << status.ToString() << ". Falling back to TF function call."; + + BroadcastOptimizationRemark( + XlaOptimizationRemark::UNIMPLEMENTED_OPERATION, status.ToString()) + .IgnoreError(); executable = nullptr; mutex_lock guard(cannot_compile_cluster_mu_); cannot_compile_cluster_ = true; @@ -498,7 +529,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { client, executable, kernel, std::move(variables), constants_.size())); Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); - compilation_key.flat()(0) = key; + compilation_key.flat()(0) = key; Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); compilation_successful.flat()(0) = true; @@ -513,13 +544,16 @@ XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); - const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); + const XlaExecutableClosureStore::KeyT& key = key_tensor.flat()(0); XlaExecutableClosure closure = XlaExecutableClosureStore::Global()->Consume(key); + absl::optional tf_allocator_adapter; + se::DeviceMemoryAllocator* allocator = + GetAllocator(&tf_allocator_adapter, ctx, platform_info_); XlaComputationLaunchContext launch_context( - closure.client(), platform_info_.allocator(), + closure.client(), allocator, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); @@ -544,19 +578,28 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); - run_options.set_allocator(platform_info_.allocator()); + run_options.set_allocator(allocator); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); Env* env = Env::Default(); auto start_time = env->NowMicros(); - auto run_result = - closure.executable()->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result; + if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { + run_result = + closure.executable()->Run(launch_context.arguments(), run_options); + } else { + run_result = + closure.executable()->RunAsync(launch_context.arguments(), run_options); + } OP_REQUIRES(ctx, run_result.ok(), run_result.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; + const xla::HloInputOutputAliasConfig& input_output_alias = + closure.executable()->executable()->module().input_output_alias_config(); + tensorflow::profiler::TraceMe hlo_module_activity( [&] { return absl::StrCat("Populate Outputs (", ctx->num_outputs(), ")"); @@ -567,7 +610,8 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { ctx, launch_context.PopulateOutputs( ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/closure.num_constant_args())); + /*missing_ctx_input_prefix=*/closure.num_constant_args(), + input_output_alias)); } REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 3a1009ec8a7..bc6829a6c77 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -37,18 +37,14 @@ class XlaPlatformInfo { public: XlaPlatformInfo() : device_type_("") {} XlaPlatformInfo(XlaPlatformInfo&&) = default; - explicit XlaPlatformInfo( - const DeviceType device_type, se::Platform::Id platform_id, - const XlaDevice::Metadata* xla_device_metadata, - std::unique_ptr xla_allocator, - se::DeviceMemoryAllocator* device_allocator) + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + se::DeviceMemoryAllocator* device_allocator) : device_type_(device_type), platform_id_(platform_id), xla_device_metadata_(xla_device_metadata), - xla_allocator_(std::move(xla_allocator)), - device_allocator_(device_allocator) { - CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr)); - } + device_allocator_(device_allocator) {} XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; @@ -56,9 +52,11 @@ class XlaPlatformInfo { return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); } - se::DeviceMemoryAllocator* allocator() const { - return device_allocator_ ? device_allocator_ : xla_allocator_.get(); + // Non-null only when run on an XLA device. + se::DeviceMemoryAllocator* custom_allocator() const { + return device_allocator_; } + DeviceType device_type() const { return device_type_; } // This is equal to xla_device_metadata()->platform()->id() if @@ -82,11 +80,8 @@ class XlaPlatformInfo { const XlaDevice::Metadata* xla_device_metadata_; // If the op associated with this XlaPlatformInfo is placed on an XLA device - // then device_allocator_ is the xla::Backend's memory allocator and - // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device - // then device_allocator_ is null and xla_allocator_ points to an appropriate - // se::TfAllocatorAdapter instance. - std::unique_ptr xla_allocator_; + // then device_allocator_ is the xla::Backend's memory allocator. If the op + // is placed on a regular CPU or GPU device then device_allocator_ is null. se::DeviceMemoryAllocator* device_allocator_; TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index b819998bdc7..90755a1cb70 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -677,8 +677,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation( } DataType dtype; - if (!GetNodeAttr(n->def(), "dtype", &dtype).ok() || - !DataTypeIsInteger(dtype)) { + if (!TryGetNodeAttr(n->def(), "dtype", &dtype) || !DataTypeIsInteger(dtype)) { return false; } @@ -695,7 +694,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation( } const TensorProto* proto = nullptr; - if (!GetNodeAttr(const_input->def(), "value", &proto).ok()) { + if (!TryGetNodeAttr(const_input->def(), "value", &proto)) { return false; } @@ -924,20 +923,35 @@ MarkForCompilationPassImpl::ClusteringWillIntroduceInterDeviceDependency( } absl::optional MarkForCompilationPassImpl::GetXlaScope(Node* node) { - // Look for an _XlaScope on both nodes. If both nodes have a scope and the - // scopes do not match, do not cluster along this edge. This restriction is - // overridden if the global_jit_level_ is ON. If even one of the nodes lacks - // an _XlaScope attribute, then it is treated as a "bridge" and a cluster may - // be created along it. We may want to restrict this behavior to require all - // nodes marked with _XlaCompile=true to also have a _XlaScope property set - // (and raise an error otherwise); but for now we don't do this. - if (global_jit_level_ != OptimizerOptions::OFF) { - return absl::nullopt; - } + // Look for either _XlaScope or _XlaInternalScope on both nodes to guide + // clustering. If both nodes have a scope and the scopes do not match, do + // not cluster along this edge. If even one of the nodes lacks a scope + // attribute, then it is treated as a "bridge" and a cluster may be created + // along it. + // + // The difference between _XlaScope and _XlaInternalScope is that _XlaScope is + // provided by users through jit_scope APIs, while _XlaInternalScope is + // automatically generated by the ClusterScopingPass when auto_jit is on. As + // such, we respect _XlaScope only when auto_jit is off, while respecting + // _XlaInternalScope only when auto_jit is on. + // + // We may want to restrict the _XlaScope behavior to require all nodes marked + // with _XlaCompile=true to also have a _XlaScope property set (and raise an + // error otherwise); but for now we don't do this. - string scope; - if (GetNodeAttr(node->attrs(), kXlaScopeAttr, &scope).ok()) { - return scope; + if (global_jit_level_ != OptimizerOptions::OFF) { + // If global_jit_level_ is ON, respect only _XlaInternalScope. + const string& scope = + GetNodeAttrString(node->attrs(), kXlaInternalScopeAttr); + if (!scope.empty()) { + return scope; + } + } else { + // If global_jit_level_ is OFF, respect only _XlaScope. + const string& scope = GetNodeAttrString(node->attrs(), kXlaScopeAttr); + if (!scope.empty()) { + return scope; + } } return absl::nullopt; @@ -970,8 +984,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { int effective_cluster_size = (node->IsIdentity() || node->IsConstant()) ? 0 : 1; - bool has_functional_control_flow = - node->type_string() == "While" || node->IsIfNode(); + bool has_functional_control_flow = node->IsWhileNode() || node->IsIfNode(); absl::optional deadness_predicate; if (deadness_analysis_) { @@ -1000,7 +1013,7 @@ Status MarkForCompilationPassImpl::BuildInitialClusterSet() { bool is_xla_compile_attr_true = false; bool xla_compile_attr; - if (GetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr).ok()) { + if (TryGetNodeAttr(node->attrs(), kXlaCompileAttr, &xla_compile_attr)) { is_xla_compile_attr_true |= xla_compile_attr; } @@ -1549,9 +1562,7 @@ StatusOr MarkForCompilationPassImpl::ShouldCompileClusterImpl( XlaOpRegistry::AutoclusteringPolicy::kIfEnabledGlobally && global_jit_level_ != OptimizerOptions::OFF); - if (!should_compile && - registration->autoclustering_policy == - XlaOpRegistry::AutoclusteringPolicy::kIfExplicitlyRequested && + if (!should_compile && global_jit_level_ != OptimizerOptions::OFF && device_type.type_string() == DEVICE_CPU) { static std::once_flag once; std::call_once(once, [] { @@ -1628,10 +1639,9 @@ std::atomic* GetPointerToFuel(int64 initial_value) { } } // anonymous namespace -bool IsCompilable( - FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::vector* - uncompilable_node_info) { +bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, + RecursiveCompilabilityChecker::UncompilableNodesMap* + uncompilable_node_info) { Device* device = flr->device(); const XlaOpRegistry::DeviceRegistration* registration; CHECK(XlaOpRegistry::GetCompilationDevice(device->device_type(), @@ -1657,8 +1667,8 @@ bool IsCompilable( return checker.IsCompilableCall(ndef, flr); } - std::vector - uncompilable_node_result = checker.FindUncompilableNodes(ndef, flr); + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_node_result = + checker.FindUncompilableNodes(ndef, flr); uncompilable_node_info->swap(uncompilable_node_result); return uncompilable_node_info->empty(); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e186763b5e4..7adfc1419bf 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -52,10 +52,9 @@ class MarkForCompilationPass : public GraphOptimizationPass { // function is compilable iff every operator in the function body is // compilable. If 'ndef' is not compilable and 'uncompilable_node_info' is not // null, we will populate 'uncompilable_node_info' with uncompilable node info. -bool IsCompilable( - FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::vector* - uncompilable_node_info = nullptr); +bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, + RecursiveCompilabilityChecker::UncompilableNodesMap* + uncompilable_node_info = nullptr); namespace testing { // DO NOT USE IN PRODUCTION. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index cbe60b05eef..f10b4d0b4cb 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -52,7 +52,7 @@ std::unordered_map GetClusters(const Graph& graph) { std::unordered_map ids; for (Node* node : graph.nodes()) { string cluster; - if (GetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster).ok()) { + if (TryGetNodeAttr(node->attrs(), kXlaClusterAttr, &cluster)) { CHECK(!cluster.empty()); ids[node->name()] = cluster; } @@ -1718,5 +1718,91 @@ TEST(XlaCompilationTest, UnsupportedEnterExitPattern) { EXPECT_EQ(0, clusters.size()); } +namespace { +Node* MakeStageNode(GraphDefBuilder& builder, string name, + std::initializer_list dtypes, + absl::Span values) { + auto opts = builder.opts() + .WithName(std::move(name)) + .WithAttr("dtypes", std::move(dtypes)); + if (opts.HaveError()) { + return nullptr; + } + + NodeBuilder node_builder(name, "Stage", opts.op_registry()); + node_builder.Input(values); + return opts.FinalizeBuilder(&node_builder); +} +} // namespace + +TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { + auto build_staged_graph = [](std::unique_ptr* graph) -> Status { + // Construct a graph as below with two pipeline stages and test that nodes + // in different stages will not be merged if ClusterScopingPass is on. + // + // b + // | + // v + // a -> add0 -> relu0 -> stage + // + // b + // | + // v + // unstage -> add1 -> relu1 + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("a") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::SourceOp("Const", builder.opts() + .WithName("b") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* unstage = ops::SourceOp( + "Unstage", + builder.opts().WithName("unstage").WithAttr("dtypes", {DT_FLOAT})); + + Node* add0 = ops::BinaryOp("Add", a, b, builder.opts().WithName("add0")); + Node* add1 = + ops::BinaryOp("Add", unstage, b, builder.opts().WithName("add1")); + Node* relu0 = ops::UnaryOp("Relu", add0, builder.opts().WithName("relu0")); + ops::UnaryOp("Relu", add1, builder.opts().WithName("relu1")); + MakeStageNode(builder, "stage", {DT_FLOAT}, {relu0}); + + return GraphDefBuilderToGraph(builder, graph->get()); + }; + + // All nodes go into the same cluster if ClusterScopingPass is off. + { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(build_staged_graph(&graph)); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation( + &graph, + MarkForCompilationPassTestHelper::Options().WithNoClusterScoping())); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_EQ(clusters["add0"], clusters["add1"]); + EXPECT_EQ(clusters["add0"], clusters["relu1"]); + EXPECT_EQ(clusters["relu0"], clusters["add1"]); + EXPECT_EQ(clusters["relu0"], clusters["relu1"]); + } + + // By default, ClusterScopingPass is on and different pipeline stages should + // not be merged. + { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(build_staged_graph(&graph)); + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + EXPECT_NE(clusters["add0"], clusters["add1"]); + EXPECT_NE(clusters["add0"], clusters["relu1"]); + EXPECT_NE(clusters["relu0"], clusters["add1"]); + EXPECT_NE(clusters["relu0"], clusters["relu1"]); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index fa5abdfe508..44bd7b47d54 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" + +#include "tensorflow/compiler/jit/cluster_scoping_pass.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" @@ -48,8 +50,14 @@ namespace tensorflow { opt_options.graph = graph; opt_options.session_options = &session_options; opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunForTest( + + if (options.enable_cluster_scoping) { + ClusterScopingPass cluster_scoping_pass; + TF_RETURN_IF_ERROR(cluster_scoping_pass.Run(opt_options)); + } + + MarkForCompilationPass mark_for_compilation_pass; + return mark_for_compilation_pass.RunForTest( opt_options, /*disable_deadness_analysis=*/options.disable_deadness_analysis); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h index b81fca43c80..f482a80f5b5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -24,8 +24,12 @@ class MarkForCompilationPassTestHelper { struct Options { bool enable_global_jit; bool disable_deadness_analysis; + bool enable_cluster_scoping; - Options() : enable_global_jit(true), disable_deadness_analysis(true) {} + Options() + : enable_global_jit(true), + disable_deadness_analysis(true), + enable_cluster_scoping(true) {} Options WithNoGlobalJit() { Options copy = *this; @@ -38,6 +42,12 @@ class MarkForCompilationPassTestHelper { copy.disable_deadness_analysis = false; return copy; } + + Options WithNoClusterScoping() { + Options copy = *this; + copy.enable_cluster_scoping = false; + return copy; + } }; // Runs the MarkForCompilation pass on `graph` after assigning all nodes in diff --git a/tensorflow/compiler/jit/node_matchers.cc b/tensorflow/compiler/jit/node_matchers.cc index b878f05e1df..932e0769813 100644 --- a/tensorflow/compiler/jit/node_matchers.cc +++ b/tensorflow/compiler/jit/node_matchers.cc @@ -135,7 +135,7 @@ struct NodeMatcher : public ::testing::MatcherInterface { if (constant_value) { const TensorProto* proto = nullptr; - if (!GetNodeAttr(node->def(), "value", &proto).ok()) { + if (!TryGetNodeAttr(node->def(), "value", &proto)) { if (listener->IsInterested()) { *listener << "\ncould not find \"value\" attribute in node"; } diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test.cc b/tensorflow/compiler/jit/tests/auto_clustering_test.cc index 2154e371e83..c4db4b082ad 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test.cc @@ -45,8 +45,8 @@ class AutoClusteringTestImpl : public AutoClusteringTest { TEST_F(AutoClusteringTestImpl, KerasImagenetMain) { // Generated from // - // bazel run -c opt --config=cuda \ - // tensorflow_models/official/resnet/keras:keras_imagenet_main \ + // TARGET_PATH=tensorflow_models/official/vision/image_classification \ + // bazel run -c opt --config=cuda ${TARGET_PATH}:resnet_imagenet_main \ // -- --skip_eval --num_gpus=1 --dtype=fp16 --batch_size=192 \ // --train_steps=210 --enable_xla --enable_eager=true // @@ -57,8 +57,8 @@ TEST_F(AutoClusteringTestImpl, KerasImagenetMain) { TEST_F(AutoClusteringTestImpl, KerasImagenetMainGraphMode) { // Generated from // - // bazel run -c opt --config=cuda \ - // tensorflow_models/official/resnet/keras:keras_imagenet_main \ + // TARGET_PATH=tensorflow_models/official/vision/image_classification \ + // bazel run -c opt --config=cuda ${TARGET_PATH}:resnet_imagenet_main \ // -- --use_synthetic_data --num_gpus=1 --batch_size=117 --train_steps=600 \ // --skip_eval=True --logtostderr --enable_xla TF_ASSERT_OK( diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index faeb3883b48..726f7f0b068 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -186,7 +186,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( /*input_buffer_bytes=*/k_buffer_size, /*output_buffer_bytes=*/k_buffer_size, io::ZlibCompressionOptions::GZIP()); - string decompressed_pbtxt_string; + tstring decompressed_pbtxt_string; Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string); if (!s.ok() && !errors::IsOutOfRange(s)) { // OutOfRange is fine since we set the number of read bytes to INT_MAX. diff --git a/tensorflow/compiler/jit/xla_activity.proto b/tensorflow/compiler/jit/xla_activity.proto index 1edde32cc46..50bfb297fa1 100644 --- a/tensorflow/compiler/jit/xla_activity.proto +++ b/tensorflow/compiler/jit/xla_activity.proto @@ -94,3 +94,27 @@ message XlaJitCompilationActivity { // Total microseconds spent in (re-)compiling this cluster so far. int64 cumulative_compile_time_us = 4; } + +// LINT.IfChange +// +// Used for logging situations seen in Tensorflow models being optimized that +// are known to not perform well with XLA. +// +// Next ID: 3 +message XlaOptimizationRemark { + // Next ID: 6 + enum Warning { + NONE = 0; + INACCURATE_OPERATION = 1; + SLOW_OPERATION = 2; + UNIMPLEMENTED_OPERATION = 3; + SLOW_IMAGE_RESIZE_DIMENSIONS = 4; + MEGAMORPHIC_FUNCTION = 5; + } + + Warning warning = 1; + + // Information such as which node was the problem. + string debug_information = 2; +} +// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/compiler/jit/xla_activity_listener.h) diff --git a/tensorflow/compiler/jit/xla_activity_listener.cc b/tensorflow/compiler/jit/xla_activity_listener.cc index 1f14cc90527..a1ea6a6bf8e 100644 --- a/tensorflow/compiler/jit/xla_activity_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_listener.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -71,6 +72,21 @@ Status BroadcastXlaActivity( }); } +Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark) { + VLOG(2) << "OptimizationRemark: " << optimization_remark.DebugString(); + return ForEachListener([&](XlaActivityListener* listener) { + return listener->Listen(optimization_remark); + }); +} + +Status BroadcastOptimizationRemark( + XlaOptimizationRemark::Warning optimization_warning, + string debug_information) { + XlaOptimizationRemark remark; + remark.set_warning(optimization_warning); + remark.set_debug_information(std::move(debug_information)); + return BroadcastOptimizationRemark(std::move(remark)); +} void RegisterXlaActivityListener( std::unique_ptr listener) { XlaActivityListenerList* listener_list = GetXlaActivityListenerList(); diff --git a/tensorflow/compiler/jit/xla_activity_listener.h b/tensorflow/compiler/jit/xla_activity_listener.h index 547181d6010..05328c896d3 100644 --- a/tensorflow/compiler/jit/xla_activity_listener.h +++ b/tensorflow/compiler/jit/xla_activity_listener.h @@ -27,6 +27,18 @@ Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity); // Broadcast `jit_compilation_activity` to all the registered listeners. Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity); +// Broadcast `jit_compilation_activity` to all the registered listeners. +Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark); + +// LINT.IfChange +// Called after TensorFlow realizes possible lost performance. The parameters in +// this should match all of the values in the XlaOptimizationRemark proto. +Status BroadcastOptimizationRemark( + XlaOptimizationRemark::Warning optimization_warning, + string debug_information); + +// LINT.ThenChange(//tensorflow/compiler/jit/xla_activity.proto) + // Various components of the system can subclass XlaActivityListener to // notifications on auto-clustering and JIT compilation events. // @@ -41,6 +53,9 @@ class XlaActivityListener { virtual Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) = 0; + // Called after TensorFlow realizes possible lost performance. + virtual Status Listen(const XlaOptimizationRemark& optimization_remark) = 0; + // Called at program exit in best-effort manner to give listeners a chance to // flush their state. // diff --git a/tensorflow/compiler/jit/xla_activity_listener_test.cc b/tensorflow/compiler/jit/xla_activity_listener_test.cc index 4d087e2caac..034adbf44fe 100644 --- a/tensorflow/compiler/jit/xla_activity_listener_test.cc +++ b/tensorflow/compiler/jit/xla_activity_listener_test.cc @@ -43,6 +43,10 @@ class TestListener : public XlaActivityListener { return Status::OK(); } + Status Listen(const XlaOptimizationRemark& optimization_remark) override { + return Status::OK(); + } + ~TestListener() override {} const XlaAutoClusteringActivity& auto_clustering_activity() const { diff --git a/tensorflow/compiler/jit/xla_activity_logging_listener.cc b/tensorflow/compiler/jit/xla_activity_logging_listener.cc index a36bd3bd707..87e39a5481f 100644 --- a/tensorflow/compiler/jit/xla_activity_logging_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_logging_listener.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/core/platform/logger.h" @@ -59,6 +60,23 @@ class XlaActivityLoggingListener final : public XlaActivityListener { return Status::OK(); } + Status Listen(const XlaOptimizationRemark& optimization_remark) override { + if (!IsEnabled()) { + VLOG(3) << "Logging XlaJitCompilationActivity disabled"; + return Status::OK(); + } + + if (Logger* logger = Logger::GetSingletonAsync()) { + VLOG(2) << "Logging XlaJitCompilationActivity"; + VLOG(3) << optimization_remark.DebugString(); + logger->LogProto(optimization_remark); + } else { + VLOG(2) << "Not logging: logger not ready yet."; + } + + return Status::OK(); + } + private: bool IsEnabled() { static bool result = ComputeIsEnabled(); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 035a50e1852..1e440031570 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -17,8 +17,10 @@ limitations under the License. #include +#include "absl/base/call_once.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -27,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/common_runtime/metrics.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -224,6 +227,20 @@ Status XlaCompilationCache::CompileSingleOp( out_compilation_result, out_executable); } +namespace { +// Print something that users can search for to definitively ascertain that XLA +// was used for their TF model. +// +// Prints only once to avoid spamming LOG(INFO). +void LogOnceXlaCompiledFirstCluster() { + static absl::once_flag log_once; + absl::call_once(log_once, [] { + LOG(INFO) << "Compiled cluster using XLA! This line is logged at most " + "once for the lifetime of the process."; + }); +} +} // namespace + Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, absl::Span args, @@ -301,6 +318,9 @@ Status XlaCompilationCache::CompileImpl( } if (is_megamorphic) { + BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION, + function.name()) + .IgnoreError(); VLOG(3) << "Not compiling cluster " << function.name() << " because it is megamorphic."; return false; @@ -346,11 +366,13 @@ Status XlaCompilationCache::CompileImpl( const uint64 compile_end_us = env->NowMicros(); const uint64 compile_time_us = compile_end_us - compile_start_us; + metrics::UpdateXlaCompilationTime(compile_time_us); { mutex_lock lock(cluster_compile_stats_mu_); auto it = cluster_compile_stats_.find(function.name()); it->second.compile_count++; it->second.cumulative_compile_time_us += compile_time_us; + LogOnceXlaCompiledFirstCluster(); VLOG(1) << "compiled " << function.name() << " " << it->second.compile_count << " times, compile time: " << compile_time_us diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 24d29f4c808..3dc8379ebaa 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -83,9 +83,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( ctx, result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0)); + /*missing_ctx_input_prefix=*/0, input_output_alias)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index fbfda449ebd..85c09a027d3 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -98,10 +98,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 1d8b4beb8bd..be2038a7a8a 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -203,6 +203,8 @@ XlaDevice::XlaDevice(const SessionOptions& session_options, device_ordinal_(options.device_ordinal), jit_device_name_(options.compilation_device_name), platform_(options.platform), + intra_op_parallelism_threads_( + session_options.config.intra_op_parallelism_threads()), use_multiple_streams_(options.use_multiple_streams), shape_representation_fn_(options.shape_representation_fn), allowed_devices_(options.allowed_devices) { @@ -233,10 +235,13 @@ xla::LocalClient* XlaDevice::client() const { // don't want to do it until we get a chance to hook the platform up // to a simulator. + xla::LocalClientOptions options; + options.set_platform(platform_) + .set_allowed_devices(allowed_devices_) + .set_intra_op_parallelism_threads(intra_op_parallelism_threads_); // TODO(b/78468222): This can fail, at least when the backend is GPU and // there is no GPU on the host. - return xla::ClientLibrary::GetOrCreateLocalClient(platform_, allowed_devices_) - .ValueOrDie(); + return xla::ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); } Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 51910c6fabc..877580e73f9 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -202,6 +202,8 @@ class XlaDevice : public LocalDevice { const DeviceType jit_device_name_; // The platform for this device. se::Platform* const platform_; // Not owned. + // Intra-op threads to spawn (from SessionOptions). + const int intra_op_parallelism_threads_; // Memory allocator associated with this device. Allocator* xla_allocator_ GUARDED_BY(mu_) = nullptr; // Not owned. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ea784e72137..5e4c6340f42 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -90,8 +90,9 @@ XlaDeviceContext::XlaDeviceContext( CHECK(host_to_device_stream_ != nullptr); CHECK(stream_ != nullptr); if (!shape_representation_fn_) { - shape_representation_fn_ = [](const TensorShape& shape, - DataType dtype) -> xla::StatusOr { + shape_representation_fn_ = + [](const TensorShape& shape, DataType dtype, + bool use_fast_memory) -> xla::StatusOr { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); return xla_shape; @@ -130,9 +131,10 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, CHECK(xla_tensor); Status status = [&]() -> Status { - TF_ASSIGN_OR_RETURN(xla::Shape shape, - shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype())); + TF_ASSIGN_OR_RETURN( + xla::Shape shape, + shape_representation_fn_(device_tensor->shape(), device_tensor->dtype(), + /*use_fast_memory=*/false)); // The device tensor should always be fresh. TF_RET_CHECK(!xla_tensor->has_shaped_buffer()); diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 2c8203b1c5d..99e95314f64 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -212,11 +212,11 @@ class XlaAssignVariableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ - .TypeConstraint("T"), \ + .TypeConstraint("T"), \ ArgOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp) \ .Device(DEVICE) \ - .TypeConstraint("T") \ + .TypeConstraint("T") \ .HostMemory("input"), \ RetvalOp); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 8934b52d686..cead23d816e 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -147,10 +147,10 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, - DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index c138fd1ff39..e3706a09278 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -14,243 +14,20 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_kernel_creator.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/jit/compilability_check_util.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { -namespace { - -// Utility which searches for values in a sorted list by scanning over it once. -// No matter how many times ScanForValue is called, the list is scanned at most -// once. However, if a call to ScanForValue skips over a value, that value is -// not revisited in future calls to ScanForValue, so callers must take -// care to order their calls. -// -// Useful for merging multiple sorted lists in O(n) time. -class SinglePassSearch { - public: - // Creates a SinglePassSearch object that can be used to search in `values`. - // Does not take ownership of `values`. `values` must outlive this. - // `values` must be sorted. - explicit SinglePassSearch(const std::vector* values) - : current_index_(0), values_(values) {} - - // Scans forward in the vector looking for "value", updating the internal - // position in to the vector. - // Returns true iff the vector contains the given value at or after current - // position. - // Not thread-safe. - bool ScanForValue(int value) { - while (current_index_ < values_->size() && - (*values_)[current_index_] <= value) { - if ((*values_)[current_index_] == value) { - current_index_++; - return true; - } - current_index_++; - } - return false; - } - - private: - int current_index_; - const std::vector* values_; -}; -} // namespace bool XlaKernelCreator::CanCreateKernel(const FunctionLibraryRuntime& flr, const NodeDef& node_def) const { - const FunctionDef* function_def = - flr.GetFunctionLibraryDefinition()->Find(node_def.name()); - if (function_def == nullptr) { - // The node def is not calling a function. Individual ops can be - // run directly using on-demand mode, no need to create XlaLaunch - // kernel for them. - return false; - } - - // If kXlaCompileAttr is set on the node_def, use its value. - const auto& it = node_def.attr().find(kXlaCompileAttr); - if (it != node_def.attr().end()) { - return it->second.b(); - } - - // kXlaCompileAttr is not set on node_def, check if it is set on - // FunctionDef. - bool xla_compile = false; - Status status = flr.GetFunctionLibraryDefinition()->GetAttr( - node_def, kXlaCompileAttr, &xla_compile); - if (!status.ok() || !xla_compile) { - if (VLOG_IS_ON(3)) { - if (!status.ok()) { - VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " - << node_def.op() << ". status=" << status.ToString(); - } else { - VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; - } - } - return false; - } - return true; -} - -// Given a FunctionLibraryRuntime and a NodeDef calling a function in the -// runtime, returns this function's body in `fbody` as well as the indices -// of its constant and resource arguments. -// `fbody` is owned by `flr`. -// `constant_arg_indices` and `resource_arg_indices` should be empty vector. -// They are sorted in ascending order on this function's return. -Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NodeDef& node_def, - const FunctionBody** fbody, - std::vector* constant_arg_indices, - std::vector* resource_arg_indices) { - FunctionLibraryRuntime::Handle handle; - // If node_def is not instantiable, e.g., the function does not exist, - // simply bail out. - TF_RETURN_IF_ERROR( - flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); - *fbody = flr->GetFunctionBody(handle); - CHECK(*fbody); // Can't be nullptr since we just instantiated it. - const DataTypeVector& arg_types = (*fbody)->arg_types; - std::vector const_args(arg_types.size()); - // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*((*fbody)->graph), &const_args, - /*compile_time_const_nodes=*/nullptr, flr)); - - for (int i = 0; i < const_args.size(); ++i) { - if (const_args[i]) { - constant_arg_indices->push_back(i); - } - } - - // There can be hundreds of resource variables. Reserve the space for them. - // We don't reserve for constants above as they are usually few. - resource_arg_indices->reserve(arg_types.size()); - for (int i = 0; i < arg_types.size(); ++i) { - if (arg_types[i] == DT_RESOURCE) { - resource_arg_indices->push_back(i); - } - } - - return Status::OK(); + return CanCreateXlaKernel(flr, node_def); } Status XlaKernelCreator::CreateKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, std::unique_ptr* kernel) const { - if (!CanCreateKernel(*flr, node_def)) { - return errors::Internal("Invalid node: ", node_def.ShortDebugString()); - } - - VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); - - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - std::vector - uncompilable_node_info; - if (!IsCompilable(flr, node_def, &uncompilable_node_info)) { - string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - node_def.ShortDebugString(), ".\n"); - absl::StrAppend(&message, "Uncompilable nodes:\n"); - for (const auto& node_info : uncompilable_node_info) { - string node_message = - absl::StrCat("\t", node_info.name, ": ", - node_info.uncompilable_reason, "\n", "\tStacktrace:\n"); - for (const auto& stack_frame : node_info.stack_trace) { - absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", - stack_frame.name, stack_frame.function_name); - } - absl::StrAppend(&message, node_message); - } - VLOG(1) << message; - // node_def is calling a function that XLA can't compile. - return errors::InvalidArgument(message); - } - - // Get function body, constant args, and resource args. - const FunctionBody* fbody = nullptr; - std::vector constant_arg_indices; - std::vector resource_arg_indices; - TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); - - // Set input and output memory types. - MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); - // These indices are used only for optimization purposes. They allow us - // to loop over constant_arg_indices and resource_arg_indices only once - // while iterating over all the function arguments checking if it is a - // resource or a constant. - // The reason we optimized this code is because functions can have a lot of - // captured arguments. For example, the backward pass of ResNet50 takes in all - // 214 variables and a similar number of activations. - SinglePassSearch constants_search(&constant_arg_indices); - SinglePassSearch resources_search(&resource_arg_indices); - for (int i = 0; i < fbody->arg_types.size(); ++i) { - if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { - // Compile-time constants and resource handles are expected to be in - // host memory. - input_memory_types[i] = HOST_MEMORY; - } - } - // One might wonder, about the case where a compile-time constant argument - // (which must be in host memory) is also used as an input into an op, - // e.g. Add, that expects its inputs in device memory. Here is how it - // works now. - // First, what do we mean by "op expects an input in XYZ memory"? - // There are two types of "ops" here: the tf2xla kernel and the HLO - // computation it builds. The tf2xla kernel needs to retrieve the actual - // numeric value of the compile-time constant tensors, so it really expects - // them to be on in host memory. However, for other inputs, it refers to them - // using xla::ComputationDataHandle, which is just a symbolic handle that - // xla::ComputationBuilder assigns. How does this handle gets assigned for - // constant arguments? Even constant arguments get an _Arg node in the graph - // instatiated for Function compilation. The tf2xla kernel for constant _Arg - // nodes takes the constant value, converts it to XlaLiteral, and feeds it - // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This - // constant XlaLiteral is included in the HLO graph, and subsequently, in - // the actual executable, which is copied to the device before being - // executed. Thus, when this executable runs, the constant is available in - // device memory. - - // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory except for resources. - MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); - for (int i = 0; i < fbody->ret_types.size(); ++i) { - if (fbody->ret_types[i] == DT_RESOURCE) { - output_memory_types[i] = HOST_MEMORY; - } - } - - // Create the kernel. - NameAttrList function; - function.set_name(node_def.op()); - *(function.mutable_attr()) = node_def.attr(); - - Device* dev = flr->device(); - Status s; - OpKernelConstruction construction( - DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &node_def, - &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, - fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - - *kernel = absl::make_unique( - &construction, constant_arg_indices, resource_arg_indices, function); - return s; + return CreateXlaKernel(flr, node_def, kernel); } namespace { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.h b/tensorflow/compiler/jit/xla_kernel_creator.h index 739cf02d877..8815ee49ce5 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -39,4 +39,4 @@ class XlaKernelCreator : public CustomKernelCreator { } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_H_ diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc new file mode 100644 index 00000000000..96bde65003f --- /dev/null +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.cc @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. +// +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; + } + + private: + int current_index_; + const std::vector* values_; +}; +} // namespace + +bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { + const FunctionDef* function_def = + flr.GetFunctionLibraryDefinition()->Find(node_def.name()); + if (function_def == nullptr) { + // The node def is not calling a function. Individual ops can be + // run directly using on-demand mode, no need to create XlaLaunch + // kernel for them. + return false; + } + + // If kXlaCompileAttr is set on the node_def, use its value. + const auto& it = node_def.attr().find(kXlaCompileAttr); + if (it != node_def.attr().end()) { + return it->second.b(); + } + + // kXlaCompileAttr is not set on node_def, check if it is set on + // FunctionDef. + bool xla_compile = false; + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return false; + } + return true; +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector* constant_arg_indices, + std::vector* resource_arg_indices) { + FunctionLibraryRuntime::Handle handle; + // If node_def is not instantiable, e.g., the function does not exist, + // simply bail out. + TF_RETURN_IF_ERROR( + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector const_args(arg_types.size()); + // If we can't analyze the const args. Bail out. + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(*((*fbody)->graph), &const_args, + /*compile_time_const_nodes=*/nullptr, flr)); + + for (int i = 0; i < const_args.size(); ++i) { + if (const_args[i]) { + constant_arg_indices->push_back(i); + } + } + + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); + } + } + + return Status::OK(); +} + +Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel) { + if (!CanCreateXlaKernel(*flr, node_def)) { + return errors::Internal("Invalid node: ", node_def.ShortDebugString()); + } + + VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; + if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { + std::vector + uncompilable_node_info; + for (const auto& it : uncompilable_nodes_map) { + for (const auto& info : it.second.second) { + uncompilable_node_info.emplace_back(info); + } + } + string message = absl::StrCat( + "Function invoked by the following node is not compilable: ", + node_def.ShortDebugString(), ".\n"); + absl::StrAppend(&message, "Uncompilable nodes:\n"); + for (const auto& node_info : uncompilable_node_info) { + string node_message = + absl::StrCat("\t", node_info.name, ": ", + node_info.uncompilable_reason, "\n", "\tStacktrace:\n"); + for (const auto& stack_frame : node_info.stack_trace) { + absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", + stack_frame.name, stack_frame.function_name); + } + absl::StrAppend(&message, node_message); + } + VLOG(1) << message; + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument(message); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. + MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory except for resources. + MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (int i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } + + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + + Device* dev = flr->device(); + Status s; + OpKernelConstruction construction( + DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), &node_def, + &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, + fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); + + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function); + return s; +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.h b/tensorflow/compiler/jit/xla_kernel_creator_util.h new file mode 100644 index 00000000000..71398c334fc --- /dev/null +++ b/tensorflow/compiler/jit/xla_kernel_creator_util.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + + // Given a NodeDef 'node_def' and the function library runtime 'flr', returns + // true if 'node_def' is a call to a compilable function defined in 'flr', + // with the kXlaCompileAttr set. +bool CanCreateXlaKernel(const FunctionLibraryRuntime& flr, + const NodeDef& node_def); + +// Given a supported NodeDef, returns a XlaLaunchOp that computes the node. +Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index e9c4eb6e8ee..176c39aeb4c 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -42,6 +42,13 @@ namespace tensorflow { namespace { using xla::ScopedShapedBuffer; using xla::ShapedBuffer; + +const char kPossibleNonVariableResourceHintMessage[] = + "If the error is similar to `Trying to access resource using the wrong " + "type`, this is likely because XLA only accepts Resource Variables as " + "inputs by snapshotting their values. Other TensorFlow resource types like " + "TensorList/TensorArray/Stack are not supported. Try removing non-variable " + "resource inputs to XLA."; } // anonymous namespace VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {} @@ -88,7 +95,12 @@ static Status GetVariableInfosFromCtxInputs( [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); std::vector> variables; - TF_RETURN_IF_ERROR(LookupResources(ctx, resource_handles, &variables)); + + Status s = LookupResources(ctx, resource_handles, &variables); + if (!s.ok()) { + errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage); + return s; + } result->clear(); result->reserve(variable_indices.size()); @@ -235,9 +247,32 @@ void XlaComputationLaunchContext::PopulateInputs( } } +namespace { + +bool MustAliasOutput(const xla::HloInputOutputAliasConfig& input_output_alias, + int output_num) { + xla::ShapeIndex output_index; + if (input_output_alias.shape().IsTuple()) { + output_index = {output_num}; + } else { + DCHECK_EQ(output_num, 0) + << "output_num must be 0 for non-tuple shapes but is " << output_num; + output_index = {}; + } + if (input_output_alias.shape().tuple_shapes_size() == 0) { + return false; + } + return input_output_alias.OutputHasAlias(output_index) && + input_output_alias.GetAliasedParameter(output_index).value().kind == + xla::HloInputOutputAliasConfig::kUserAlias; +} + +} // namespace + Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - ScopedShapedBuffer output, int missing_ctx_input_prefix) { + ScopedShapedBuffer output, int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -331,8 +366,16 @@ Status XlaComputationLaunchContext::PopulateOutputs( << "Invalid input for outputs " << i << ": " << input_index; ctx->set_output(i, ctx->input(input_index)); } else { + if (MustAliasOutput(input_output_alias, output_num)) { + DCHECK(output.buffer({output_num}).is_null()) + << "Expected output buffer to be aliased, but it is not nil."; + } se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { + if (MustAliasOutput(input_output_alias, output_num)) { + return errors::Unimplemented( + "Aliasing is not yet supported for allocate_xla_tensors_."); + } Tensor* output_tensor; TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); @@ -347,8 +390,18 @@ Status XlaComputationLaunchContext::PopulateOutputs( CHECK_EQ(output_tensor->TotalBytes(), 0); } } else { + bool is_aliased = false; + if (MustAliasOutput(input_output_alias, output_num)) { + int xla_param = input_output_alias.GetAliasedParameter({output_num}) + .value() + .parameter_number; + DCHECK(arg_ptrs_[xla_param] != nullptr); + buffer = arg_ptrs_[xla_param]->buffer({}); + is_aliased = true; + } Tensor output_tensor = XlaTensorBuffer::MakeTensor( - ctx->expected_output_dtype(i), shape, buffer, allocator); + ctx->expected_output_dtype(i), shape, + /*unref_buffer=*/!is_aliased, buffer, allocator); output.set_buffer(se::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } @@ -412,7 +465,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); output.set_buffer(se::OwningDeviceMemory(), {output_num}); Tensor output_tensor = XlaTensorBuffer::MakeTensor( - write.type, write.shape, buffer, allocator); + write.type, write.shape, /*unref_buffer=*/true, buffer, allocator); *variable_infos[i].var()->tensor() = output_tensor; } ++output_num; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 429ff0a065c..3df36e25daa 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -149,10 +149,10 @@ class XlaComputationLaunchContext { // // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are // missing and adjusts input indices accordingly. - Status PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output, - int missing_ctx_input_prefix); + Status PopulateOutputs( + OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias); // Return the argument list. Only valid after PopulateInputs() has been // called. @@ -193,12 +193,15 @@ class XlaTensorBuffer : public TensorBuffer { } static Tensor MakeTensor(DataType dtype, const TensorShape& shape, - se::DeviceMemoryBase buffer, Allocator* allocator) { + bool unref_buffer, se::DeviceMemoryBase buffer, + Allocator* allocator) { size_t expected_size = shape.num_elements() * DataTypeSize(dtype); auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size, buffer.size(), allocator); Tensor t(dtype, shape, tensor_buffer); - tensor_buffer->Unref(); + if (unref_buffer) { + tensor_buffer->Unref(); + } return t; } diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 247bb83e7f7..1e556822f4b 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -19,10 +19,23 @@ filegroup( srcs = glob(["**/*.td"]), ) +cc_library( + name = "op_name_mapper", + srcs = ["op_name_mapper.cc"], + hdrs = ["op_name_mapper.h"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@llvm//:support", + "@local_config_mlir//:IR", + ], +) + cc_library( name = "tf_mlir_opt_main", srcs = ["tf_mlir_opt_main.cc"], + copts = ["-std=c++14"], deps = [ + ":init_mlir", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", @@ -31,12 +44,14 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", - "//tensorflow/compiler/mlir/xla", + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:lhlo", "//tensorflow/compiler/mlir/xla:xla_dialect_registration", "//tensorflow/compiler/mlir/xla:xla_legalize_control_flow", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_to_standard", "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", "@llvm//:support", "@local_config_mlir//:AffineDialectRegistration", "@local_config_mlir//:MlirOptLib", @@ -49,16 +64,29 @@ cc_library( ], ) +cc_library( + name = "init_mlir", + srcs = ["init_mlir.cc"], + hdrs = ["init_mlir.h"], + deps = [ + "//tensorflow/core:lib", + "@llvm//:support", + ], +) + tf_cc_binary( name = "tf-opt", deps = [ ":tf_mlir_opt_main", + "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", ], ) tf_cc_binary( name = "tf-mlir-translate", + srcs = ["tf_mlir_translate_main.cc"], deps = [ + ":init_mlir", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", @@ -66,12 +94,14 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", "//tensorflow/compiler/mlir/xla:xla_mlir_translate", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_proto_cc", "//tensorflow/stream_executor/lib", "@llvm//:support", "@local_config_mlir//:IR", + "@local_config_mlir//:Support", + "@local_config_mlir//:TranslateClParser", "@local_config_mlir//:Translation", - "@local_config_mlir//:tools/mlir-translate/mlir-translate", ], ) diff --git a/tensorflow/compiler/mlir/init_mlir.cc b/tensorflow/compiler/mlir/init_mlir.cc new file mode 100644 index 00000000000..54f8a57d8a6 --- /dev/null +++ b/tensorflow/compiler/mlir/init_mlir.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/init_mlir.h" + +#include "tensorflow/core/platform/init_main.h" + +namespace tensorflow { + +InitMlir::InitMlir(int *argc, char ***argv) : init_llvm_(*argc, *argv) { + constexpr char kSeparator[] = "--"; + + // Find index of separator between two sets of flags. + int pass_remainder = 1; + bool split = false; + for (int i = 0; i < *argc; ++i) { + if (llvm::StringRef((*argv)[i]) == kSeparator) { + pass_remainder = i; + *argc -= (i + 1); + split = true; + break; + } + } + + tensorflow::port::InitMain((*argv)[0], &pass_remainder, argv); + if (split) { + *argc += pass_remainder; + (*argv)[1] = (*argv)[0]; + ++*argv; + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/init_mlir.h b/tensorflow/compiler/mlir/init_mlir.h new file mode 100644 index 00000000000..91020c1758b --- /dev/null +++ b/tensorflow/compiler/mlir/init_mlir.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" + +namespace tensorflow { + +// Initializer to perform both InitLLVM and TF"s InitMain initialization. +// InitMain also performs flag parsing and '--' is used to separate flags passed +// to it: Flags before the first '--' are parsed by InitMain and argc and argv +// progressed to the flags post. If there is no separator, then no flags are +// parsed by InitMain and argc/argv left unadjusted. +// TODO(jpienaar): The way help flag is handled could be improved. +class InitMlir { + public: + InitMlir(int *argc, char ***argv); + + private: + llvm::InitLLVM init_llvm_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_INIT_MLIR_H_ diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 7846716e9dd..663740bf692 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -26,8 +26,8 @@ filegroup( name = "tensorflow_lite_ops_td_files", srcs = [ "ir/tfl_ops.td", + "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@local_config_mlir//:OpBaseTdFiles", - "@local_config_mlir//:QuantizationOpsTdFiles", ], ) @@ -146,6 +146,7 @@ cc_library( hdrs = [ "utils/validators.h", ], + copts = ["-std=c++14"], deps = [ "@local_config_mlir//:Dialect", "@local_config_mlir//:IR", @@ -166,8 +167,9 @@ cc_library( "ir/tfl_traits.h", "transforms/passes.h", "utils/attribute_utils.h", - "utils/quantization_utils.h", + "//tensorflow/compiler/mlir/lite/quantization:quantization_traits.h", ], + copts = ["-std=c++14"], deps = [ ":tensorflow_lite_ops_inc_gen", ":validators", @@ -181,51 +183,36 @@ cc_library( "@local_config_mlir//:QuantOps", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", - "@local_config_mlir//:TypeUtilities", ], alwayslink = 1, ) -cc_library( - name = "tensorflow_lite_quantization_utils", - srcs = [ - "utils/generated_op_quant_spec_getters.inc", - "utils/quantization_driver.cc", - "utils/quantization_utils.cc", - ], - hdrs = [ - "utils/quantization_utils.h", - ], - deps = [ - ":tensorflow_lite", - "//tensorflow/core:lib_proto_parsing", - "@com_google_absl//absl/memory", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:QuantOps", - "@local_config_mlir//:StandardOps", - "@local_config_mlir//:Support", - ], -) - cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ + "transforms/extract_ophint.cc", "transforms/generated_legalize_tf.inc", "transforms/generated_lower_static_tensor_list.inc", "transforms/generated_prepare_tf.inc", + "transforms/legalize_ophint_func_op.cc", "transforms/legalize_tf.cc", "transforms/lower_static_tensor_list.cc", + "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", + "transforms/trim_functions_tf.cc", ], hdrs = [ "transforms/passes.h", ], + copts = ["-std=c++14"], deps = [ + ":common", ":tensorflow_lite", - ":tensorflow_lite_quantization_utils", ":validators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@llvm//:support", "@local_config_mlir//:Analysis", @@ -234,7 +221,6 @@ cc_library( "@local_config_mlir//:QuantOps", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", - "@local_config_mlir//:TypeUtilities", ], alwayslink = 1, ) @@ -248,13 +234,16 @@ cc_library( hdrs = [ "transforms/passes.h", ], + copts = ["-std=c++14"], deps = [ ":tensorflow_lite", ":validators", + "//tensorflow/compiler/mlir/tensorflow", "@llvm//:support", "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", ], alwayslink = 1, @@ -267,14 +256,16 @@ cc_library( "transforms/post_quantize.cc", "transforms/prepare_quantize.cc", "transforms/quantize.cc", + "utils/generated_op_quant_spec_getters.inc", ], hdrs = [ "transforms/passes.h", ], + copts = ["-std=c++14"], deps = [ ":tensorflow_lite", - ":tensorflow_lite_quantization_utils", ":validators", + "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "@com_google_absl//absl/memory", "@llvm//:support", "@local_config_mlir//:Analysis", @@ -287,32 +278,26 @@ cc_library( alwayslink = 1, ) -tf_native_cc_binary( - name = "op_quant_spec_getters_gen", +filegroup( + name = "generated_op_quant_spec_getters", srcs = [ - "tools/op_quant_spec_getters_gen.cc", - ], - deps = [ - "@llvm//:support", - "@llvm//:tablegen", - "@local_config_mlir//:TableGen", + "utils/generated_op_quant_spec_getters.inc", ], ) genrule( name = "op_quant_spec_getters_inc", srcs = [ - "@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td", - "@local_config_mlir//:include/mlir/IR/OpBase.td", - ":ir/tfl_ops.td", + "ir/tfl_ops.td", + "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", ], outs = [ "utils/generated_op_quant_spec_getters.inc", ], - cmd = ("$(location :op_quant_spec_getters_gen) " + + cmd = ("$(location //tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen) " + "-I external/local_config_mlir/include " + "$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"), - tools = [":op_quant_spec_getters_gen"], + tools = ["//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen"], ) # Library with tensorflow Lite dialect static initialization. @@ -321,6 +306,7 @@ cc_library( srcs = [ "ir/dialect_registration.cc", ], + copts = ["-std=c++14"], deps = [ ":tensorflow_lite", "@local_config_mlir//:IR", @@ -329,9 +315,9 @@ cc_library( ) tf_native_cc_binary( - name = "operator-writer-gen", + name = "operator-converter-gen", srcs = [ - "operator_writer_gen.cc", + "operator_converter_gen.cc", ], deps = [ "@llvm//:support", @@ -341,30 +327,30 @@ tf_native_cc_binary( ) genrule( - name = "operator_writer_inc", + name = "operator_converter_inc", srcs = [ - "@local_config_mlir//:include/mlir/Dialect/QuantOps/QuantPredicates.td", - "@local_config_mlir//:include/mlir/IR/OpBase.td", - ":ir/tfl_ops.td", + "ir/tfl_ops.td", + "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", ], outs = [ - "operator_writers.inc", + "operator_converters.inc", ], - cmd = ("$(location :operator-writer-gen) " + + cmd = ("$(location :operator-converter-gen) " + "-I external/local_config_mlir/include " + "$(location //tensorflow/compiler/mlir/lite:ir/tfl_ops.td) " + " -o $@"), - tools = [":operator-writer-gen"], + tools = [":operator-converter-gen"], ) cc_library( name = "flatbuffer_tflite_operator_lib", srcs = [ "flatbuffer_operator.cc", - "operator_writers.inc", + "operator_converters.inc", ], hdrs = [ "flatbuffer_operator.h", ], + copts = ["-std=c++14"], deps = [ ":tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", @@ -394,6 +380,7 @@ cc_library( hdrs = [ "emit_error_reporter.h", ], + copts = ["-std=c++14"], deps = [ "//tensorflow/lite/core/api", "@local_config_mlir//:IR", @@ -405,18 +392,23 @@ cc_library( srcs = [ "flatbuffer_import.cc", "flatbuffer_translate.cc", + "utils/convert_type.cc", ], hdrs = [ "flatbuffer_import.h", "flatbuffer_translate.h", + "utils/convert_type.h", ], + copts = ["-std=c++14"], deps = [ ":flatbuffer_tflite_operator_lib", ":tensorflow_lite", ":tensorflow_lite_dialect_registration", + "//tensorflow/compiler/mlir:op_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", @@ -426,6 +418,7 @@ cc_library( "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -459,12 +452,24 @@ cc_library( hdrs = [ "tf_tfl_translate_cl.h", ], + copts = ["-std=c++14"], deps = [ "@llvm//:support", ], alwayslink = 1, ) +cc_library( + name = "common", + hdrs = [ + "common/tfl_pass_config.h", + ], + copts = ["-std=c++14"], + deps = [ + "@llvm//:support", + ], +) + filegroup( name = "tf_tfl_translate_main", srcs = [ @@ -476,10 +481,13 @@ tf_cc_binary( name = "tf_tfl_translate", srcs = [":tf_tfl_translate_main"], deps = [ + ":common", ":flatbuffer_translate_lib", ":tensorflow_lite", + ":tf_tfl_passes", ":tf_tfl_translate_cl_options", ":tf_to_tfl_flatbuffer", + "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -497,6 +505,7 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_lib", "//tensorflow/core:lib", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform/default/build_config:base", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", @@ -510,12 +519,42 @@ tf_cc_binary( ], ) +cc_library( + name = "tf_tfl_passes", + srcs = ["tf_tfl_passes.cc"], + hdrs = [ + "tf_tfl_passes.h", + ], + copts = ["-std=c++14"], + deps = [ + ":common", + ":tensorflow_lite_legalize_tf", + ":tensorflow_lite_optimize", + ":tensorflow_lite_quantize", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", + "//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "@llvm//:support", + "@local_config_mlir//:Analysis", + "@local_config_mlir//:IR", + "@local_config_mlir//:Parser", + "@local_config_mlir//:Pass", + "@local_config_mlir//:QuantOps", + "@local_config_mlir//:QuantOpsDialectRegistration", + "@local_config_mlir//:Support", + "@local_config_mlir//:Transforms", + ], +) + cc_library( name = "tf_to_tfl_flatbuffer", srcs = ["tf_to_tfl_flatbuffer.cc"], hdrs = [ "tf_to_tfl_flatbuffer.h", ], + copts = ["-std=c++14"], deps = [ ":flatbuffer_translate_lib", ":tensorflow_lite", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h new file mode 100644 index 00000000000..3b3ba4dc686 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace TFL { + +// A config that controls which passes get run as part TFLite converter. +struct PassConfig { + PassConfig() + : emit_builtin_tflite_ops(true), + run_quantize(false), + emit_quant_adaptor_ops(false), + lower_tensor_list_ops(false), + trim_functions_whitelist({}) {} + + // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be + // added, which produces TF Lite ops. + bool emit_builtin_tflite_ops; + // If run_quantize is true, quantization passes will be added. + bool run_quantize; + // If `emit_quant_adaptor_ops` is true, Quantize and + // Dequantize ops are added as part of running quantization passes. + bool emit_quant_adaptor_ops; + // If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic + // TF ops before legalization to TF Lite dialect. + bool lower_tensor_list_ops; + // The whitelist of functions that would be preserved after trimming. + llvm::ArrayRef trim_functions_whitelist; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_COMMON_TFL_PASS_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 5256013bbce..74cecd6fbb6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -15,13 +15,29 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include +#include #include +#include #include +#include +#include "absl/base/casts.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MemoryBuffer.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir @@ -31,78 +47,590 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir #include "mlir/Translation.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" +using llvm::ArrayRef; using mlir::Builder; +using mlir::DenseElementsAttr; using mlir::FuncOp; using mlir::Location; using mlir::MLIRContext; using mlir::OpBuilder; +using mlir::Operation; +using mlir::OperationState; using mlir::OwningModuleRef; +using mlir::Value; +using mlir::quant::QuantizedType; using tflite::TensorT; using xla::StatusOr; namespace errors = tensorflow::errors; +namespace tfl = mlir::TFL; namespace { bool IsScalar(const TensorT& tensor) { - // TODO(krzysd): We can't distinguish scalars and unranked tensors + // TODO(b/138222071) We can't distinguish scalars and unranked tensors // Work out a way to handle this and stub out the code until then return tensor.shape.empty() && false; } -StatusOr GetTensorElementType(const TensorT& tensor, - Builder builder) { - switch (tensor.type) { - case tflite::TensorType_FLOAT32: - return builder.getF32Type(); - case tflite::TensorType_FLOAT16: - return builder.getF16Type(); - case tflite::TensorType_INT32: - return builder.getIntegerType(32); - case tflite::TensorType_UINT8: - return builder.getIntegerType(8); - case tflite::TensorType_INT64: - return builder.getIntegerType(64); - case tflite::TensorType_STRING: - return errors::InvalidArgument("String tensors are not supported"); - case tflite::TensorType_BOOL: - return builder.getI1Type(); - case tflite::TensorType_INT16: - return builder.getIntegerType(16); - case tflite::TensorType_COMPLEX64: - return mlir::ComplexType::get(builder.getF32Type()); - case tflite::TensorType_INT8: - return builder.getIntegerType(8); - } - return errors::OutOfRange("Unknown tensor type"); +bool IsQuantized(const TensorT& tensor) { + return (tensor.quantization != nullptr) && + !tensor.quantization->zero_point.empty(); } -StatusOr GetTensorType(const TensorT& tensor, Builder builder) { - TF_ASSIGN_OR_RETURN(auto elem_type, GetTensorElementType(tensor, builder)); - if (IsScalar(tensor)) { +// Create the MLIR NamedLoc location corresponding to a given tensor +Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { + if (tensor.name.empty()) { + return base; + } + return mlir::NameLoc::get(builder.getIdentifier(tensor.name), base); +} + +// Returns the correct type for a quantized tensor +// We have a special case for constants since they have a higher minimum value. +StatusOr GetQuantizedType(const TensorT& tensor, Builder builder, + bool is_constant = false) { + tflite::QuantizationParametersT& quant_params = *tensor.quantization; + if (quant_params.details.AsCustomQuantization()) { + return errors::Unimplemented("Cannot handle experimental quantization"); + } + + bool is_signed = true; + mlir::IntegerType storage_type; + if (tensor.type == tflite::TensorType_UINT8) { + is_signed = false; + storage_type = builder.getIntegerType(8); + } else { + auto raw_elem_type = ConvertElementType(tensor.type, builder); + if (!raw_elem_type.isa()) { + return errors::InvalidArgument( + "Quantized tensors must be stored as integers"); + } + storage_type = raw_elem_type.cast(); + } + + // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights. + // Since we don't know which ones are weights, we represent this optimization + // as a change in the storage bounds for the type for all constants of this + // type. + bool is_weight_buffer = is_constant && (storage_type.getWidth() == 8); + + int64_t storage_min = QuantizedType::getDefaultMininumForInteger( + is_signed, storage_type.getWidth()) + + is_weight_buffer; + int64_t storage_max = QuantizedType::getDefaultMaxinumForInteger( + is_signed, storage_type.getWidth()); + uint32_t flags = + is_signed ? mlir::quant::QuantizationFlags::FlagValue::Signed : 0; + + if (0 != quant_params.quantized_dimension) { + llvm::SmallVector scales(quant_params.scale.begin(), + quant_params.scale.end()); + return mlir::quant::UniformQuantizedPerAxisType::get( + flags, storage_type, builder.getF32Type(), scales, + quant_params.zero_point, quant_params.quantized_dimension, storage_min, + storage_max); + } + return mlir::quant::UniformQuantizedType::get( + flags, storage_type, builder.getF32Type(), quant_params.scale.at(0), + quant_params.zero_point.at(0), storage_min, storage_max); +} + +// TODO(b/138222071) Remove shapeless_are_scalars once we can reliably +// make that distinction and don't have to rely on context +// (input to main and constants must have static shape) +StatusOr GetTensorType(const TensorT& tensor, Builder builder, + bool shapeless_are_scalars = false, + bool is_constant = false) { + mlir::Type elem_type = ConvertElementType(tensor.type, builder); + // TODO(b/139554398) Store min/max (even for non-quantized tensors) somewhere + // if it's set + if (IsQuantized(tensor)) { + TF_ASSIGN_OR_RETURN(elem_type, + GetQuantizedType(tensor, builder, is_constant)); + } + + if (IsScalar(tensor) || (shapeless_are_scalars && tensor.shape.empty())) { return builder.getTensorType({}, elem_type); } if (!tensor.shape.empty()) { - llvm::SmallVector shape; - for (int32_t i : tensor.shape) { - shape.push_back(int64_t{i}); - } + llvm::SmallVector shape(tensor.shape.begin(), + tensor.shape.end()); return builder.getTensorType(shape, elem_type); } return builder.getTensorType(elem_type); } +StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { + // TODO(krzysd) Support custom ops + if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { + return errors::Unimplemented("unsupported custom operation: ", + opcode.custom_code); + } + if (opcode.builtin_code == tflite::BuiltinOperator_IF) { + return std::string("tf.If"); + } + if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) { + return std::string("tf.While"); + } + + const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code); + std::string lowered_name = llvm::StringRef(op_name).lower(); + return llvm::Twine("tfl.", lowered_name).str(); +} + +// The buffers in TFLite flatbuffers have their contents stored as a vector of +// bytes that represent little-endian values. +// The read_size parameter is present to allow reading both float16 and float32s +// without a case split. +template +std::vector ReadAsLittleEndian(ArrayRef bytes) { + std::vector ret; + size_t read_size = sizeof(T); + int bytes_len = bytes.size(); + assert(bytes_len % read_size == 0); + + size_t elem_count = bytes_len / read_size; + ret.reserve(elem_count); + + const char* data_ptr = reinterpret_cast(bytes.data()); + for (int i = 0; i < elem_count; i++) { + ret.push_back( + llvm::support::endian::readNext(data_ptr)); + } + return ret; +} + +tensorflow::TensorProto ConvertTfliteConstTensor( + const tflite::TensorT& tensor, const std::vector& buffer) { + tensorflow::TensorProto ret; + ret.set_dtype(TflTypeToTfType(tensor.type)); + + tensorflow::TensorShapeProto* shape = ret.mutable_tensor_shape(); + shape->set_unknown_rank(false); + for (auto dim : tensor.shape) { + shape->add_dim()->set_size(int64_t{dim}); + } + std::string content; + content.assign(reinterpret_cast(buffer.data()), buffer.size()); + ret.set_tensor_content(content); + return ret; +} + +StatusOr ConvertFloatBuffer( + mlir::RankedTensorType shaped_type, mlir::FloatType elem_type, + const std::vector& buffer) { + size_t bytes_len = buffer.size(); + + // The bytes of floats are stored little-endian. + switch (elem_type.getWidth()) { + case 16: { + assert(bytes_len % 2 == 0); + size_t elem_count = bytes_len / 2; + std::vector values; + values.reserve(elem_count); + + const char* data = reinterpret_cast(buffer.data()); + auto& semantics = elem_type.getFloatSemantics(); + + for (int i = 0; i < elem_count; i++) { + uint16_t bit_repr = + llvm::support::endian::readNext(data); + llvm::APInt int_repr(16, bit_repr); + values.emplace_back(semantics, int_repr); + } + + return DenseElementsAttr::get(shaped_type, values); + } + case 32: { + assert(bytes_len % 4 == 0); + size_t elem_count = bytes_len / 4; + std::vector values; + values.reserve(elem_count); + + const char* data = reinterpret_cast(buffer.data()); + + for (int i = 0; i < elem_count; i++) { + uint32_t bit_repr = + llvm::support::endian::readNext(data); + values.push_back(absl::bit_cast(bit_repr)); + } + return DenseElementsAttr::get(shaped_type, ArrayRef(values)); + } + } + return errors::InvalidArgument("unsupported bit width", elem_type.getWidth()); +} + +StatusOr ConvertIntBuffer( + mlir::RankedTensorType shaped_type, mlir::Type elem_type, + const std::vector& buffer) { + unsigned bit_width; + mlir::RankedTensorType buffer_type; + if (auto itype = elem_type.dyn_cast()) { + bit_width = itype.getWidth(); + } else if (auto qtype = elem_type.dyn_cast()) { + bit_width = qtype.getStorageTypeIntegralWidth(); + shaped_type = mlir::RankedTensorType::get(shaped_type.getShape(), + qtype.getStorageType()); + } else { + return errors::InvalidArgument("unsupported integer constant type"); + } + + switch (bit_width) { + case 1: { + // vector doesn't convert to an ArrayRef + llvm::SmallVector values; + values.reserve(buffer.size()); + for (auto b : buffer) { + values.emplace_back(b != 0); + } + return DenseElementsAttr::get(shaped_type, ArrayRef(values)); + } + case 8: { + return DenseElementsAttr::get(shaped_type, ArrayRef(buffer)); + } + case 16: { + auto values = ReadAsLittleEndian(buffer); + return DenseElementsAttr::get(shaped_type, ArrayRef(values)); + } + case 32: { + auto values = ReadAsLittleEndian(buffer); + return DenseElementsAttr::get(shaped_type, ArrayRef(values)); + } + case 64: { + auto values = ReadAsLittleEndian(buffer); + return DenseElementsAttr::get(shaped_type, ArrayRef(values)); + } + default: + return errors::Unimplemented("Cannot handle bit width ", bit_width); + } +} + +StatusOr BuildConstOp(const tflite::TensorT& tensor, + const std::vector& buffer, + OpBuilder builder, Location loc) { + TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder, + /*shapeless_are_scalars=*/true, + /*is_constant=*/true)); + auto shaped_type = type.dyn_cast(); + if (!shaped_type) { + return errors::Internal("Constant doesn't have a shape"); + } + + auto elem_type = shaped_type.getElementType(); + + mlir::ElementsAttr value; + if (auto float_type = elem_type.dyn_cast()) { + TF_ASSIGN_OR_RETURN(value, + ConvertFloatBuffer(shaped_type, float_type, buffer)); + } else if (elem_type.isa() || + elem_type.isa()) { + TF_ASSIGN_OR_RETURN(value, + ConvertIntBuffer(shaped_type, elem_type, buffer)); + } else if (elem_type.isa()) { + auto& dialect = elem_type.getDialect(); + tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); + std::string mangled = tensorflow::mangling_util::MangleTensor(repr); + + value = builder.getOpaqueElementsAttr(&dialect, shaped_type, mangled); + } else { + return errors::Unimplemented("Constant of unsupported type"); + } + + if (IsQuantized(tensor)) { + auto op = builder.create( + loc, builder.getTypeAttr(shaped_type), value); + return op.getOperation(); + } + auto op = builder.create(loc, value); + return op.getOperation(); +} + +llvm::SmallVector ConvertSubgraphIdxsToFunctionAttrs( + tflite::BuiltinOptionsUnion options, + const std::vector& func_names, Builder builder) { + if (auto* opts = options.AsIfOptions()) { + uint32_t then_idx = opts->then_subgraph_index; + auto then_attr = builder.getSymbolRefAttr(func_names.at(then_idx)); + uint32_t else_idx = opts->else_subgraph_index; + auto else_attr = builder.getSymbolRefAttr(func_names.at(else_idx)); + + return {builder.getNamedAttr("then_branch", then_attr), + builder.getNamedAttr("else_branch", else_attr), + // TODO(b/139667752): Analyze statelessness correctly + builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))}; + } + if (auto* opts = options.AsWhileOptions()) { + uint32_t cond_idx = opts->cond_subgraph_index; + auto cond_attr = builder.getSymbolRefAttr(func_names.at(cond_idx)); + uint32_t body_idx = opts->body_subgraph_index; + auto body_attr = builder.getSymbolRefAttr(func_names.at(body_idx)); + + return {builder.getNamedAttr("cond", cond_attr), + builder.getNamedAttr("body", body_attr), + // TODO(b/139667752): Analyze statelessness correctly + builder.getNamedAttr("is_stateless", builder.getBoolAttr(false))}; + } + return {}; +} + +// TODO(krzysd) Handle function calls +StatusOr ConvertOp( + const tflite::OperatorT& op, const std::vector vals_map, + Value* optional_arg_marker, const std::vector& op_names, + const std::vector& func_names, + const std::vector>& tensors, Location loc, + OpBuilder builder) { + llvm::SmallVector operands; + llvm::SmallVector outputTypes; + + if (op.outputs.empty()) { + auto err = errors::InvalidArgument("operator with no outputs"); + return emitError(loc, err.ToString()), err; + } + + const std::string& op_name = op_names.at(op.opcode_index); + OperationState op_state(loc, op_name); + + for (auto input_num : op.inputs) { + if (input_num == -1) { + assert(optional_arg_marker != nullptr); + op_state.addOperands({optional_arg_marker}); + } else { + op_state.addOperands({vals_map.at(input_num)}); + } + } + + for (auto output_num : op.outputs) { + auto& tensor = *tensors.at(output_num); + auto type_or_err = GetTensorType(tensor, builder); + if (!type_or_err.ok()) { + return emitError(loc, type_or_err.status().ToString()), + type_or_err.status(); + } + auto type = type_or_err.ConsumeValueOrDie(); + + // Special case for reshape, which stores its return shape in an option + // that we need to extract from + // Note: UniqueOp is handled by the typing information on its output tensor + if (auto* opts = op.builtin_options.AsReshapeOptions()) { + llvm::SmallVector shape(opts->new_shape.begin(), + opts->new_shape.end()); + type = builder.getTensorType(ArrayRef(shape), + type.getElementType()); + } + + // Special case for quantize: return type must also be in qtype attribute + if (op_name == "tfl.quantize") { + op_state.addAttribute("qtype", builder.getTypeAttr(type)); + } + + op_state.addTypes({type}); + } + + llvm::SmallVector attrs; + mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs); + op_state.addAttributes(attrs); + + // Handle the conversion from subgraph index to functions for If and While + auto function_ref_attrs = ConvertSubgraphIdxsToFunctionAttrs( + op.builtin_options, func_names, builder); + op_state.addAttributes(function_ref_attrs); + + return builder.createOperation(op_state); +} + +// Build a FuncOp from a tflite SubGraph +// The op_names are a mapping from indexes into the TFLite operators array to +// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken +// from the deserialized flatbuffer as we do not have the type information to +// interpret them until this point. The base_loc parameter is the location of +// the flatbuffer as a whole (usually a file). The add_pseudo_input_ops flag +// controls whether we create the dummy ops for input that the TFLite dialect +// has in the main function (and only the main function). +StatusOr ConvertSubgraph( + const tflite::SubGraphT& subgraph, llvm::StringRef name, + const std::vector& op_names, + const std::vector& func_names, + const std::vector>& buffers, + Location base_loc, Builder builder, bool add_pseudo_input_ops = false) { + llvm::SmallVector ret_types; + llvm::SmallVector input_types; + + auto func_loc = mlir::NameLoc::get(builder.getIdentifier(name), base_loc); + + // Construct function type + for (auto input : subgraph.inputs) { + auto& tensor = *subgraph.tensors.at(input); + // TODO(b/138222071) Graph inputs must have static shape per the exporter, + // but we cannot differentiate scalars from unranked tensors. + // Here we reverse the default assumption that shape = [] means unranked. + // when processing main() + auto type_or_err = + GetTensorType(tensor, builder, + /*shapeless_are_scalars=*/add_pseudo_input_ops, + /*is_constant=*/false); + if (!type_or_err.ok()) { + emitError(func_loc, "error reading argument types") + << type_or_err.status().ToString(); + return type_or_err.status(); + } + auto type = type_or_err.ConsumeValueOrDie(); + input_types.push_back(type); + } + + llvm::SmallVector is_op_output(subgraph.tensors.size(), false); + for (auto& op : subgraph.operators) { + for (auto output : op->outputs) { + is_op_output[output] = true; + } + } + + for (auto output : subgraph.outputs) { + bool is_constant = !is_op_output[output]; + auto type_or_err = GetTensorType(*subgraph.tensors.at(output), builder, + /*shapeless_are_scalars=*/is_constant, + /*is_constant=*/is_constant); + if (!type_or_err.ok()) { + emitError(func_loc, "error reading return types") + << type_or_err.status().ToString(); + return type_or_err.status(); + } + auto type = type_or_err.ConsumeValueOrDie(); + ret_types.push_back(type); + } + auto func_type = builder.getFunctionType(input_types, ret_types); + + // Construct function object + auto func = FuncOp::create(func_loc, name, func_type, /* attrs= */ {}); + func.addEntryBlock(); + auto& body = func.getBody(); + OpBuilder op_builder{body}; + + std::vector vals_map(subgraph.tensors.size(), nullptr); + Value* maybe_optional_arg_marker = nullptr; + + // Get or construct MLIR values for each input + for (int i = 0, e = subgraph.inputs.size(); i < e; i++) { + auto input_tensor = subgraph.inputs[i]; + const auto& tensor = *subgraph.tensors.at(input_tensor); + auto loc = TensorLoc(tensor, builder, base_loc); + if (nullptr != vals_map[input_tensor]) { + auto err = errors::FailedPrecondition("duplicate input arguments"); + return emitError(loc, err.ToString()), err; + } + if (add_pseudo_input_ops) { + auto* input = func.getArgument(i); + auto op = op_builder.create(loc, input); + vals_map[input_tensor] = op.output(); + } else { + vals_map[input_tensor] = func.getArgument(i); + } + } + + // Construct MLIR operators from TFLite operators + for (auto& op : subgraph.operators) { + for (auto input_num : op->inputs) { + // The operators in a graph are topologically sorted + // and so if no previous operation has produced a tensor + // it must be a constant. + if (input_num == -1) { + if (maybe_optional_arg_marker == nullptr) { + maybe_optional_arg_marker = + op_builder + .create(base_loc, builder.getNoneType(), + builder.getUnitAttr()) + .getResult(); + } + } else if (nullptr == vals_map.at(input_num)) { + auto& const_tensor = *subgraph.tensors[input_num]; + auto const_loc = TensorLoc(const_tensor, builder, base_loc); + auto op_or_err = + BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data, + op_builder, const_loc); + if (!op_or_err.ok()) { + return emitError(const_loc, op_or_err.status().ToString()), + op_or_err.status(); + } + vals_map[input_num] = op_or_err.ValueOrDie()->getResult(0); + } + } + + // The NameLoc corresponding to the name of the first output tensor + auto op_loc = + op->outputs.empty() + ? base_loc + : TensorLoc(*subgraph.tensors[op->outputs[0]], builder, base_loc); + // If there's an optional argument, maybe_optional_arg_marker has been set + // to a valid Value* + TF_ASSIGN_OR_RETURN( + auto* mlir_op, + ConvertOp(*op, vals_map, maybe_optional_arg_marker, op_names, + func_names, subgraph.tensors, op_loc, op_builder)); + for (auto pair : llvm::enumerate(mlir_op->getResults())) { + vals_map[op->outputs[pair.index()]] = pair.value(); + } + } + + // Construct return values + llvm::SmallVector return_operands; + for (auto index : subgraph.outputs) { + if (nullptr == vals_map.at(index)) { + auto& const_tensor = *subgraph.tensors[index]; + auto const_loc = TensorLoc(const_tensor, builder, base_loc); + auto op_or_err = + BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data, + op_builder, const_loc); + if (!op_or_err.ok()) { + return emitError(const_loc, op_or_err.status().ToString()), + op_or_err.status(); + } + vals_map[index] = op_or_err.ValueOrDie()->getResult(0); + } + return_operands.push_back(vals_map[index]); + } + + op_builder.create(base_loc, return_operands); + + return func; +} + +// TFLite subgraphs do not necessarily have names, though MLIR functions must +// have them, so we generate a name for subgraphs that are missing one here. +// Note: in TFLite, the first subgraph is the entry point, and in MLIR that +// represents TFLite, this entry point must be called "main" +// TODO(b/131175224,b/132239787) Support multiple entry points +std::string SubgraphName(unsigned index, const tflite::SubGraphT& subgraph) { + if (subgraph.name.empty()) { + if (index == 0) { + return "main"; + } else { + return llvm::formatv("fn_{0}", index).str(); + } + } else { + return subgraph.name; + } +} } // namespace OwningModuleRef tflite::FlatBufferToMlir(absl::string_view buffer, @@ -117,39 +645,51 @@ OwningModuleRef tflite::FlatBufferToMlir(absl::string_view buffer, std::unique_ptr model(model_ptr->GetModel()->UnPack()); auto builder = Builder(context); - auto module = mlir::ModuleOp::create(base_loc); - // TODO(krzysd): Actually account for the FlatBuffer schema version + std::vector operator_names; + operator_names.reserve(model->operator_codes.size()); + + for (auto& opcode : model->operator_codes) { + auto operator_name_or_error = OpNameForOpCode(*opcode); + if (!operator_name_or_error.ok()) { + return emitError(base_loc, operator_name_or_error.status().ToString()), + nullptr; + } + operator_names.push_back(operator_name_or_error.ConsumeValueOrDie()); + } + + std::vector func_names; + for (auto& subgraph : model->subgraphs) { + func_names.push_back(subgraph->name); + } + + auto module = mlir::ModuleOp::create(base_loc); + // We currently don't use this to make decisions, but we could + // use it in exports or if there are breaking changes module.setAttr("tfl.schema_version", builder.getI32IntegerAttr(model->version)); - - for (auto& subgraph : model->subgraphs) { - llvm::SmallVector ret_types; - llvm::SmallVector input_types; - - for (auto input : subgraph->inputs) { - auto type_or_err = GetTensorType(*subgraph->tensors[input], builder); - if (!type_or_err.ok()) { - return emitError(base_loc, type_or_err.status().ToString()), nullptr; - } - input_types.push_back(type_or_err.ConsumeValueOrDie()); - } - - auto func_type = builder.getFunctionType(input_types, ret_types); - auto func_loc = mlir::NameLoc::get(builder.getIdentifier(subgraph->name), - base_loc, context); - auto func = - FuncOp::create(func_loc, subgraph->name, func_type, /* attrs= */ {}); - func.addEntryBlock(); - - // TODO(krzysd): convert TFLite ops to MLIR ops - // Note: EnumNamesBuiltinOperator has the names of the builtin ops in - // uppercase. We will want them in lowercase with a tfl. prefix for MLIR - OpBuilder op_builder{func.getBody()}; - op_builder.create(base_loc); - module.push_back(func); + if (!model->description.empty()) { + module.setAttr("tfl.description", + builder.getStringAttr(model->description)); } + for (auto e : llvm::enumerate(model->subgraphs)) { + auto& subgraph = e.value(); + std::string name = SubgraphName(e.index(), *subgraph); + auto func_or_error = ConvertSubgraph( + *subgraph, name, operator_names, func_names, model->buffers, base_loc, + // Only the entry point needs pseudo_input_ops + // TODO(b/131175224,b/132239787) Support multiple entry points + builder, /* add_pseudo_input_ops = */ e.index() == 0); + if (!func_or_error.ok()) { + return emitError(base_loc, "could not translate function ") + << subgraph->name, + nullptr; + } + module.push_back(func_or_error.ConsumeValueOrDie()); + } + // TFLite subgraphs do not necessarily have names, + return OwningModuleRef(module); } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 6d85f6f19e2..a18e54ac5bb 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -17,7 +17,10 @@ limitations under the License. #include +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSwitch.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -98,6 +101,11 @@ static int ConvertI32AttrForOptionWriter( return i.getSExtValue(); } +static int ConvertPositiveI32AttrForOptionWriter( + llvm::APInt i, flatbuffers::FlatBufferBuilder* builder) { + return ConvertI32AttrForOptionWriter(i, builder); +} + static flatbuffers::Offset> ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray, flatbuffers::FlatBufferBuilder* builder) { @@ -144,5 +152,59 @@ static tflite::LSTMKernelType ConvertTFL_LSTMKernelTypeAttrForOptionWriter( .Case("BASIC", tflite::LSTMKernelType_BASIC); } +static mlir::Attribute BuildBoolAttr(bool value, mlir::Builder builder) { + return builder.getBoolAttr(value); +} + +static mlir::Attribute BuildF32Attr(float value, mlir::Builder builder) { + return builder.getF32FloatAttr(value); +} + +static mlir::Attribute BuildI32Attr(int32_t value, mlir::Builder builder) { + return builder.getI32IntegerAttr(value); +} + +static mlir::Attribute BuildI64ArrayAttr(std::vector value, + mlir::Builder builder) { + std::vector typecast(value.begin(), value.end()); + return builder.getI64ArrayAttr(typecast); +} + +static mlir::Attribute BuildPositiveI32Attr(int32_t value, + mlir::Builder builder) { + return builder.getI32IntegerAttr(value); +} + +static mlir::Attribute BuildTFL_AFAttr(tflite::ActivationFunctionType value, + mlir::Builder builder) { + const char* option_name = tflite::EnumNameActivationFunctionType(value); + return builder.getStringAttr(option_name); +} + +static mlir::Attribute BuildTFL_FullyConnectedOptionsWeightFormatAttr( + tflite::FullyConnectedOptionsWeightsFormat value, mlir::Builder builder) { + const char* option_name = + tflite::EnumNameFullyConnectedOptionsWeightsFormat(value); + return builder.getStringAttr(option_name); +} + +static mlir::Attribute BuildTFL_LSTMKernelTypeAttr(tflite::LSTMKernelType value, + mlir::Builder builder) { + const char* option_name = tflite::EnumNameLSTMKernelType(value); + return builder.getStringAttr(option_name); +} + +static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value, + mlir::Builder builder) { + const char* option_name = tflite::EnumNameMirrorPadMode(value); + return builder.getStringAttr(option_name); +} + +static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value, + mlir::Builder builder) { + const char* option_name = tflite::EnumNamePadding(value); + return builder.getStringAttr(option_name); +} + // Pull in FlatBuffer writers for TFLite generated using TableGen -#include "tensorflow/compiler/mlir/lite/operator_writers.inc" +#include "tensorflow/compiler/mlir/lite/operator_converters.inc" diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index e35780b11ec..35293c1b812 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -25,6 +25,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "flatbuffers/flatbuffers.h" // TF:flatbuffers #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "tensorflow/lite/schema/schema_generated.h" @@ -42,6 +45,14 @@ llvm::Optional> CreateFlatBufferOperator( const std::vector &operands, const std::vector &results, flatbuffers::FlatBufferBuilder *fbb); +// Populate the array of mlir::NamedAttributes corresponding to the given +// tflite::FlatbufferOptionsUnion. +// We use an out parameter per LLVM convention +void BuiltinOptionsToAttributes( + tflite::BuiltinOptionsUnion op_union, mlir::Builder builder, + // NOLINTNEXTLINE + llvm::SmallVectorImpl &attributes); + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_OPERATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index fca80f836aa..aa57ff7f751 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir @@ -48,14 +49,14 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir #include "mlir/Translation.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/op_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" -#include "tensorflow/compiler/mlir/tensorflow/utils//convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -89,6 +90,8 @@ using mlir::TranslateFromMLIRRegistration; using mlir::Type; using mlir::UnknownLoc; using mlir::Value; +using tensorflow::OpLocNameMapper; +using tensorflow::OpNameMapper; using tensorflow::Status; using tflite::flex::IsWhitelistedFlexOp; using xla::StatusOr; @@ -105,9 +108,10 @@ using llvm::cl::opt; // These command line flags enable control of the translation implementation. bool emit_builtin_tflite_ops; -bool emit_select_tf_ops; bool emit_custom_ops; +bool emit_select_tf_ops; bool lower_tensor_list_ops; +bool strip_debug_info; // NOLINTNEXTLINE static opt emit_builtin_tflite_ops_flag( @@ -117,7 +121,7 @@ static opt emit_builtin_tflite_ops_flag( llvm::cl::location(emit_builtin_tflite_ops), llvm::cl::init(true)); // NOLINTNEXTLINE -static opt emit_select_tf_Ops_flag( +static opt emit_select_tf_ops_flag( "emit-select-tf-ops", llvm::cl::desc( "Emit Select TF operations (Flex ops) in the generated TFLite model"), @@ -135,6 +139,11 @@ static opt lower_tensor_list_ops_flag( llvm::cl::desc("Lower the TensorList ops within the TFLite dialect"), llvm::cl::location(lower_tensor_list_ops), llvm::cl::init(false)); +// NOLINTNEXTLINE +static opt strip_debug_info_flag( + "strip-debug-info", llvm::cl::desc("Strip debug info during export"), + llvm::cl::location(strip_debug_info), llvm::cl::init(false)); + ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; // Use initial buffer size in flatbuffer builder to be same as the initial size @@ -188,6 +197,10 @@ static StatusOr GetTFLiteType(Type type, auto qtype = type.cast(); return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); } + case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { + auto qtype = type.cast(); + return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + } default: // TFLite export fills FLOAT32 for unknown data types. Returning an error // for now for safety and this could be revisited when required. @@ -200,11 +213,13 @@ static bool IsInput(Operation* op) { op->getName().getStringRef() == "tf.Placeholder.input"; } -static bool IsConstOrInput(Operation* op) { - return (isa(op) || isa(op) || - isa(op) || isa(op) || IsInput(op)); +static bool IsConst(Operation* op) { + return isa(op) || isa(op) || + isa(op) || isa(op); } +static bool IsConstOrInput(Operation* op) { return IsConst(op) || IsInput(op); } + template static bool HasValidTFLiteType(Value* value, T& error_handler) { // None type is allowed to represent unspecified operands. @@ -222,7 +237,7 @@ static bool HasValidTFLiteType(Value* value, T& error_handler) { return false; } if (auto* inst = value->getDefiningOp()) { - if (IsConstOrInput(inst) && !type.hasStaticShape()) { + if (IsInput(inst) && !type.hasStaticShape()) { return error_handler.emitError("should have static shape, got ") << type.getShape(), false; @@ -306,8 +321,8 @@ static std::unique_ptr<::tensorflow::NodeDef> getTensorFlowNodeDef( // We pass empty string for the original node_def name since Flex runtime // does not care about this being set correctly on node_def. There is no // "easy" (see b/120948529) way yet to get this from MLIR inst. - auto status_or_node_def = - tensorflow::ConvertTFDialectOpToNodeDef(inst, /*name=*/""); + auto status_or_node_def = tensorflow::ConvertTFDialectOpToNodeDef( + inst, /*name=*/"", /*ignore_unregistered_attrs=*/true); if (!status_or_node_def.ok()) { inst->emitOpError( Twine("failed to obtain TensorFlow nodedef with status: " + @@ -328,13 +343,17 @@ class Translator { static Optional Translate(ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops); + bool emit_custom_ops, + OpNameMapper* op_name_mapper); private: enum class OpType : char { kTfliteBuiltin, kSelectTf, kCustomOp }; explicit Translator(ModuleOp module, bool emit_builtin_tflite_ops, - bool emit_select_tf_ops, bool emit_custom_ops) - : module_(module), builder_(kInitialBufferSize) { + bool emit_select_tf_ops, bool emit_custom_ops, + OpNameMapper* op_name_mapper) + : module_(module), + name_mapper_(*op_name_mapper), + builder_(kInitialBufferSize) { // The first buffer must be empty according to the schema definition. empty_buffer_ = tflite::CreateBuffer(builder_); buffers_.push_back(empty_buffer_); @@ -353,10 +372,6 @@ class Translator { Optional TranslateInternal(); - // Returns name that should be used by tensors for values generated by this - // operation. - std::string GetName(Operation* inst); - // Returns TFLite buffer populated with constant value if the operation is // TFLite constant operation. Otherwise, returns an empty buffer. Emits error // and returns llvm::None on failure. @@ -368,9 +383,14 @@ class Translator { const std::string& name, unsigned buffer_idx); - CustomOptionsOffset CreateIfOpCustomOptions(mlir::TF::IfOp op); - - CustomOptionsOffset CreateWhileOpCustomOptions(mlir::TF::WhileOp op); + // TODO(b/137395003): Legalize control flow ops to TFLite dialect, and remove + // these 2 functions here. + BufferOffset BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results); + BufferOffset BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results); Optional CreateFlexOpCustomOptions( const ::tensorflow::NodeDef& node_def, const mlir::Location& loc); @@ -399,14 +419,17 @@ class Translator { // mapping. void InitializeNamesFromAttribute(FuncOp fn); + // Determines if the specified operation op's operand at operand_index + // is marked as a stateful operand. + bool IsStatefulOperand(mlir::Operation* op, int operand_index); + // Returns a unique name for `op`. std::string UniqueName(mlir::Operation* op); - // Returns a unique name starting with a given prefix. - std::string UniqueName(llvm::StringRef prefix); - ModuleOp module_; + tensorflow::OpNameMapper& name_mapper_; + flatbuffers::FlatBufferBuilder builder_; BufferOffset empty_buffer_; @@ -421,55 +444,14 @@ class Translator { absl::flat_hash_map subgraph_index_map_; absl::flat_hash_set enabled_op_types_; - // Maps from op to name. - absl::flat_hash_map op_to_name_; - absl::flat_hash_map name_to_count_; - // Points to TensorFlow and TFLite dialects, respectively. nullptr if the // dialect is not registered. const Dialect* tf_dialect_; const Dialect* tfl_dialect_; - - // Suffix used to generate unique tensor names from operation names. - int name_counter_ = 0; }; -std::string Translator::GetName(Operation* inst) { - if (auto name_loc = inst->getLoc().dyn_cast()) - return name_loc.getName().str(); - - if (auto call_loc = inst->getLoc().dyn_cast()) { - // Return name if CallSiteLoc's callee has a NameLoc (as should be the case - // if imported with DebugInfo), else use the fallback naming scheme below. - if (auto name_loc = call_loc.getCallee().dyn_cast()) - return name_loc.getName().str(); - } - - // If the location is none of the expected types, then simply use name - // generated using the op type. - return inst->getName().getStringRef().str(); -} - -std::string Translator::UniqueName(llvm::StringRef prefix) { - // Keep incrementing the counter until we find a unique name. - std::string name = prefix; - int64_t& prefix_count = name_to_count_[name]; - int64_t val = prefix_count; - while (val != 0) { - name = (prefix + llvm::Twine(prefix_count)).str(); - ++prefix_count; - val = name_to_count_[name]; - } - name_to_count_[name] = 1; - return name; -} - std::string Translator::UniqueName(mlir::Operation* op) { - auto& name = op_to_name_[op]; - if (!name.empty()) return name; - // Update the value in the map with unique name. - name = UniqueName(GetName(op)); - return name; + return name_mapper_.GetUniqueName(op); } Optional> Translator::BuildBuffer( @@ -510,8 +492,18 @@ Optional> Translator::BuildTensor( // However, we output all known shapes for better round-tripping std::vector shape; if (auto* inst = value->getDefiningOp()) { - if (type.hasStaticShape()) { - auto shape_ref = type.getShape(); + if (type.hasStaticShape() || IsConst(inst)) { + // Const op can have a result of dynamic shaped type (e.g. due to constant + // folding), but we can still derive the shape of a constant tensor + // for its attribute type. + llvm::ArrayRef shape_ref; + if (type.hasStaticShape()) { + shape_ref = type.getShape(); + } else { + mlir::Attribute tensor_attr = inst->getAttr("value"); + shape_ref = tensor_attr.getType().cast().getShape(); + } + auto is_out_of_range = [](int64_t dim) { return dim > std::numeric_limits::max(); }; @@ -535,40 +527,65 @@ Optional> Translator::BuildTensor( builder_, /*min=*/0, /*max=*/0, builder_.CreateVector({static_cast(qtype.getScale())}), builder_.CreateVector({qtype.getZeroPoint()})); + } else if (auto qtype = + element_type + .dyn_cast()) { + std::vector scales(qtype.getScales().begin(), + qtype.getScales().end()); + q_params = tflite::CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), + builder_.CreateVector(qtype.getZeroPoints()), + tflite::QuantizationDetails_NONE, /*details=*/0, + qtype.getQuantizedDimension()); } else { q_params = tflite::CreateQuantizationParameters(builder_); } - + // Check if the value's uses includes an op and usage at an operand index + // marked as a stateful. If so, set the tensor's is_variable as true + // This is v1 ref variable semantics in the TFLite runtime. + bool is_variable = false; + for (auto& use : value->getUses()) { + is_variable = IsStatefulOperand(use.getOwner(), use.getOperandNumber()); + if (is_variable) { + break; + } + } return tflite::CreateTensor( - builder_, builder_.CreateVector(shape), tflite_element_type, buffer_idx, - builder_.CreateString(name), q_params, /*is_variable=*/false); + builder_, builder_.CreateVector(shape), tflite_element_type, + (is_variable ? 0 : buffer_idx), builder_.CreateString(name), q_params, + /*is_variable=*/is_variable); } -CustomOptionsOffset Translator::CreateIfOpCustomOptions(mlir::TF::IfOp op) { - int then_subgraph_index = subgraph_index_map_.at(op.getThen().str()); - int else_subgraph_index = subgraph_index_map_.at(op.getElse().str()); - - auto flex_builder = absl::make_unique(); - flex_builder->Map([&]() { - flex_builder->Int("then_subgraph_index", then_subgraph_index); - flex_builder->Int("else_subgraph_index", else_subgraph_index); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); +BufferOffset Translator::BuildIfOperator( + mlir::TF::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + int then_subgraph_index = subgraph_index_map_.at(op.then_branch().str()); + int else_subgraph_index = subgraph_index_map_.at(op.else_branch().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); } -CustomOptionsOffset Translator::CreateWhileOpCustomOptions( - mlir::TF::WhileOp op) { - int cond_subgraph_index = subgraph_index_map_.at(op.getCond().str()); - int body_subgraph_index = subgraph_index_map_.at(op.getBody().str()); - - auto flex_builder = absl::make_unique(); - flex_builder->Map([&]() { - flex_builder->Int("cond_subgraph_index", cond_subgraph_index); - flex_builder->Int("body_subgraph_index", body_subgraph_index); - }); - flex_builder->Finish(); - return builder_.CreateVector(flex_builder->GetBuffer()); +BufferOffset Translator::BuildWhileOperator( + mlir::TF::WhileOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("while", tflite::BuiltinOperator_WHILE); + int cond_subgraph_index = subgraph_index_map_.at(op.cond().str()); + int body_subgraph_index = subgraph_index_map_.at(op.body().str()); + auto builtin_options = tflite::CreateWhileOptions( + builder_, cond_subgraph_index, body_subgraph_index) + .Union(); + auto inputs = builder_.CreateVector(operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_WhileOptions, + builtin_options); } Optional Translator::CreateFlexOpCustomOptions( @@ -712,63 +729,60 @@ Optional> Translator::BuildOperator( if (dialect == tf_dialect_) { std::string op_name; + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } else if (auto whileOp = dyn_cast(inst)) { + return BuildWhileOperator(whileOp, operands, results); + } + CustomOptionsOffset custom_options; - if (auto ifOp = dyn_cast(inst)) { - op_name = "Experimental_If"; - custom_options = CreateIfOpCustomOptions(ifOp); - } else if (auto whileOp = dyn_cast(inst)) { - op_name = "Experimental_While"; - custom_options = CreateWhileOpCustomOptions(whileOp); - } else { - // Ops in TF dialect can either be custom ops or flex ops. - // The reason we go directly from TensorFlow dialect MLIR to tensorflow - // node instead of going to TF table gen'd ops via generated code is that - // we do not want to restrict custom and flex op conversion support to - // only those TF ops that are currently registered in MLIR. The current - // model is of an open op system. - // - // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex - // we emit op as flex. - // if custom is enabled - // we emit the op as custom. - auto node_def = getTensorFlowNodeDef(inst); - if (!node_def) { + // Ops in TF dialect can either be custom ops or flex ops. + // The reason we go directly from TensorFlow dialect MLIR to tensorflow + // node instead of going to TF table gen'd ops via generated code is that + // we do not want to restrict custom and flex op conversion support to + // only those TF ops that are currently registered in MLIR. The current + // model is of an open op system. + // + // The following algorithm is followed: + // if flex is enabled and the op is whitelisted as flex + // we emit op as flex. + // if custom is enabled + // we emit the op as custom. + auto node_def = getTensorFlowNodeDef(inst); + if (!node_def) { + return llvm::None; + } + + // Flex op case + // Eventually, the whitelist will go away and we will rely on some TF op + // trait (e.g. No side effect) to determine if it is a supported "Flex" + // op or not. + if (enabled_op_types_.contains(OpType::kSelectTf) && + IsWhitelistedFlexOp(node_def->op())) { + // Construct ops as flex op encoding TensorFlow node definition + // as custom options. + // Flex ops are named with the kFlexOpNamePrefix prefix to the actual + // TF op name. + op_name = std::string(kFlexOpNamePrefix) + node_def->op(); + if (auto options = CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; + } else { return llvm::None; } - - // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op - // trait (e.g. No side effect) to determine if it is a supported "Flex" - // op or not. - if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { - // Construct ops as flex op encoding TensorFlow node definition - // as custom options. - // Flex ops are named with the kFlexOpNamePrefix prefix to the actual - // TF op name. - op_name = std::string(kFlexOpNamePrefix) + node_def->op(); - if (auto options = - CreateFlexOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } - } else if (enabled_op_types_.contains(OpType::kCustomOp)) { - // Generic case of custom ops - write using flex buffers since that - // is the only custom options supported by TFLite today. - op_name = node_def->op(); - if (auto options = - CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { - custom_options = *options; - } else { - return llvm::None; - } + } else if (enabled_op_types_.contains(OpType::kCustomOp)) { + // Generic case of custom ops - write using flex buffers since that + // is the only custom options supported by TFLite today. + op_name = node_def->op(); + if (auto options = + CreateCustomOpCustomOptions(*node_def, inst->getLoc())) { + custom_options = *options; } else { - return inst->emitOpError("is neither a custom op nor a flex op"), - llvm::None; + return llvm::None; } + } else { + return inst->emitOpError("is neither a custom op nor a flex op"), + llvm::None; } uint32_t opcode_index = @@ -804,8 +818,8 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) { return; } for (auto it : llvm::enumerate(fn.getArguments())) { - op_to_name_[*it.value()->user_begin()] = input_names[it.index()]; - ++name_to_count_[input_names[it.index()].str()]; + name_mapper_.InitOpName(*it.value()->user_begin(), + input_names[it.index()]); } } @@ -825,8 +839,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) { // insert an op so that we can have a buffer named such. This cannot // currently happen due to pseudo_input nodes. if (auto op = it.value()->getDefiningOp()) { - op_to_name_[op] = output_names[it.index()]; - name_to_count_[output_names[it.index()].str()] = 1; + name_mapper_.InitOpName(op, output_names[it.index()]); } else { fn.emitWarning() << "output is not due to an op and '" << output_names[it.index()] @@ -836,6 +849,27 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn) { } } +bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { + std::vector operand_indices; + // TODO(b/138254427): When the bug is addressed, we'll be able to inspect + // for the presence of a specific OpTrait using mlir::Operation, without + // having to cast it to specific ops like below. + // Until then, when a new RNN/LSTM op is added to TFLite and has stateful + // tensors as operands, they will need to be added here as well. + if (auto tfl = llvm::dyn_cast(op)) { + operand_indices = tfl.GetStatefulOperands(); + } else if (auto tfl = + llvm::dyn_cast(op)) { + operand_indices = tfl.GetStatefulOperands(); + } else if (auto tfl = + llvm::dyn_cast(op)) { + operand_indices = tfl.GetStatefulOperands(); + } else if (auto tfl = llvm::dyn_cast(op)) { + operand_indices = tfl.GetStatefulOperands(); + } + return absl::c_find(operand_indices, operand_index) != operand_indices.end(); +} + Optional> Translator::BuildSubGraph(FuncOp fn) { InitializeNamesFromAttribute(fn); std::vector> tensors; @@ -855,6 +889,10 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { if (!tensor_or) return false; tensors.push_back(*tensor_or); + // TODO(ashwinm): Check if for stateful tensors, if it is also needed to + // make the Buffer empty apart from setting the buffer_idx=0 in the Tensor. + // This does not seem to affect runtime behavior for RNN/LSTM, but would be + // good for reducing memory footprint. if (auto* inst = value->getDefiningOp()) { auto buffer_or = BuildBuffer(inst); if (!buffer_or) return false; @@ -942,10 +980,11 @@ Optional> Translator::BuildSubGraph(FuncOp fn) { Optional Translator::Translate(ModuleOp module, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { + bool emit_custom_ops, + OpNameMapper* op_name_mapper) { if (!IsValidTFLiteMlirModule(module)) return llvm::None; Translator translator(module, emit_builtin_tflite_ops, emit_select_tf_ops, - emit_custom_ops); + emit_custom_ops, op_name_mapper); return translator.TranslateInternal(); } @@ -979,8 +1018,14 @@ Optional Translator::TranslateInternal() { subgraphs.push_back(*subgraph_or); } + std::string model_description; + if (auto attr = module_.getAttrOfType("tfl.description")) { + model_description = attr.getValue().str(); + } else { + model_description = "MLIR Converted."; + } // Build the model and finish the model building process. - auto description = builder_.CreateString("MLIR Converted."); + auto description = builder_.CreateString(model_description.data()); auto model = tflite::CreateModel( builder_, TFLITE_SCHEMA_VERSION, builder_.CreateVector(opcodes_), builder_.CreateVector(subgraphs), description, @@ -1005,21 +1050,38 @@ Optional Translator::TranslateInternal() { // bool tflite::MlirToFlatBufferTranslateFunction( ModuleOp module, std::string* serialized_flatbuffer, - bool emit_builtin_tflite_ops, bool emit_select_tf_ops, - bool emit_custom_ops) { - auto maybe_translated = Translator::Translate( - module, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops); + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + OpNameMapper* op_name_mapper) { + auto maybe_translated = + Translator::Translate(module, emit_builtin_tflite_ops, emit_select_tf_ops, + emit_custom_ops, op_name_mapper); if (!maybe_translated) return true; *serialized_flatbuffer = std::move(*maybe_translated); return false; } +bool tflite::MlirToFlatBufferTranslateFunction( + ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, + bool emit_custom_ops) { + OpLocNameMapper op_name_mapper; + return MlirToFlatBufferTranslateFunction( + module, serialized_flatbuffer, emit_builtin_tflite_ops, + emit_select_tf_ops, emit_custom_ops, &op_name_mapper); +} + static mlir::LogicalResult MlirToFlatBufferFileTranslateFunction( ModuleOp module, llvm::StringRef filename) { std::string serialized_flatbuffer; + std::unique_ptr op_name_mapper; + if (strip_debug_info) { + op_name_mapper = std::make_unique(); + } else { + op_name_mapper = std::make_unique(); + } if (tflite::MlirToFlatBufferTranslateFunction( module, &serialized_flatbuffer, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops)) + emit_select_tf_ops, emit_custom_ops, op_name_mapper.get())) return mlir::failure(); auto file = openOutputFile(filename); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h index 820b2697e43..477a477dde6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/op_name_mapper.h" // These flags are used to control the emission or not of different kinds of ops // during the flatbuffer translation. @@ -27,16 +28,25 @@ extern bool emit_select_tf_ops; extern bool emit_custom_ops; // The flag to control whether to lower tensorlist ops into TF ops. extern bool lower_tensor_list_ops; +// The flag to control whether debug info gets stripped on export. +extern bool strip_debug_info; namespace tflite { // Translates the given MLIR `module` into a FlatBuffer and stores the -// serialized flatbuffer into the string. +// serialized flatbuffer into the string. This uses OpLocNameMapper to convert +// location of the op to name in flatbuffer. bool MlirToFlatBufferTranslateFunction(mlir::ModuleOp module, - std::string *serialized_flatbuffer, + std::string* serialized_flatbuffer, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops); + +// Same as the above but with a custom op name mapper. +bool MlirToFlatBufferTranslateFunction( + mlir::ModuleOp module, std::string* serialized_flatbuffer, + bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, + tensorflow::OpNameMapper* op_name_mapper); } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md b/tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md deleted file mode 100755 index 74e4fc47868..00000000000 --- a/tensorflow/compiler/mlir/lite/g3doc/tfl_ops.md +++ /dev/null @@ -1,1606 +0,0 @@ - -# Operation definition -## tfl.abs (TFL::AbsOp) -Absolute value operator - -### Description: - -Given a tensor `x`, this operation returns a tensor containing the absolute -value of each element in `x`. For example, if x is an input element and y is -an output element, this operation computes \\(y = |x|\\). - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.add_n (TFL::AddNOp) -add_n operator - -### Description: - -Adds all input tensors element-wise. - -### Operands: -1. `inputs`: tensor of 32-bit float or 32-bit integer values - -### Attributes: - -### Results: -1. `sum`: tensor of 32-bit float or 32-bit integer values - -## tfl.add (TFL::AddOp) -Addition operator - -### Description: - -Element-wise addition operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.average_pool_2d (TFL::AveragePool2DOp) -Average_pool_2d operator - -### Description: - -Performs average-pooling operation on input. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `filter_height` | `IntegerAttr` | 32-bit integer attribute attribute | -| `filter_width` | `IntegerAttr` | 32-bit integer attribute attribute | -| `padding` | `StringAttr` | padding enum attribute | -| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | -| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.batch_to_space_nd (TFL::BatchToSpaceNdOp) -BatchToSpaceNd operator - -### Description: - -This operation reshapes the "batch" dimension 0 into space dimensions. - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values -1. `block_shape`: tensor of 32-bit integer values -1. `indices`: tensor of 32-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values - -## tfl.ceil (TFL::CeilOp) -Ceil operator - -### Description: - -Returns element-wise ceil value of the input. - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: - -### Results: -1. `y`: tensor of floating-point values - -## tfl.concatenation (TFL::ConcatenationOp) -Concatenation operator - -### Description: - -Concatenates tensors along one dimension - -### Operands: -1. `values`: tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or quantized type with 8 bits storage type values - -## tfl.pseudo_const (TFL::ConstOp) -Constant pseudo op. - -### Description: - -Represents a constant value in TensorFlow Lite dialect. This is not an -actual operation and it will be lowered to buffer instead. - -The op is allowed to have all the same type of attributes as tf.Const does -(e.g., opaque TF attributes are allowed). - -### Operands: - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `value` | `ElementsAttr` | constant vector/tensor attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.conv_2d (TFL::Conv2DOp) -Convolution operator - -### Description: - -Performs convolution operation on inputs. - -Inputs: - `inputs[0]`: required: the input activation tensor - `inputs[1]`: required: the filter weight tensor - `inputs[2]`: optional: the bias tensor - -### Operands: -1. `input`: tensor of any type values -1. `filter`: tensor of any type values -1. `bias`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `dilation_h_factor` | `IntegerAttr` | 32-bit integer attribute attribute | -| `dilation_w_factor` | `IntegerAttr` | 32-bit integer attribute attribute | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | -| `padding` | `StringAttr` | padding enum attribute | -| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | -| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.cos (TFL::CosOp) -Cosine operator - -### Description: - -Computes element-wise Cosine of input - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: - -### Results: -1. `y`: tensor of floating-point values - -## tfl.depthwise_conv_2d (TFL::DepthwiseConv2DOp) -Depthwise-separable convolution operator - -### Description: - -Performs convolution operation on inputs. - -Inputs: - `inputs[0]`: required: the input activation tensor - `inputs[1]`: required: the filter weight tensor - `inputs[2]`: optional: the bias tensor - -### Operands: -1. `input`: tensor of any type values -1. `filter`: tensor of any type values -1. `bias`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `dilation_h_factor` | `IntegerAttr` | 32-bit integer attribute attribute | -| `dilation_w_factor` | `IntegerAttr` | 32-bit integer attribute attribute | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | -| `padding` | `StringAttr` | padding enum attribute | -| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | -| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | -| `depth_multiplier` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.dequantize (TFL::DequantizeOp) -Dequantize operator - -### Description: - -Converts quantized array of integers to floating-points according to the -quantization parameters. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.div (TFL::DivOp) -Division operator - -### Description: - -Element-wise division operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.elu (TFL::EluOp) -Exponential Linear Unit operator - -### Description: - -Computes the exponential linear - f(x) -> exp(x) - 1 for x < 0, x for x >= 0. -element-wise. - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.equal (TFL::EqualOp) -Equal operator - -### Description: - -Returns the truth element of x == y element-wise - -### Operands: -1. `x`: tensor of 1-bit integer or 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values -1. `y`: tensor of 1-bit integer or 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.exp (TFL::ExpOp) -Natural exponentiation operator - -### Description: - -Performs element-wise natural exponentiation operation on input. - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.expand_dims (TFL::ExpandDimsOp) -Inserts a dimension of 1 into a tensor's shape. - -### Description: - -Given a tensor `input`, this operation inserts a dimension of 1 at the -dimension index `axis` of `input`'s shape. The dimension index `axis` starts at -zero; if you specify a negative number for `axis` it is counted backward from -the end. - -This operation is useful if you want to add a batch dimension to a single -element. For example, if you have a single image of shape `[height, width, -channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, -which will make the shape `[1, height, width, channels]`. - -Other examples: - -``` -# 't' is a tensor of shape [2] -shape(expand_dims(t, 0)) ==> [1, 2] -shape(expand_dims(t, 1)) ==> [2, 1] -shape(expand_dims(t, -1)) ==> [2, 1] - -# 't2' is a tensor of shape [2, 3, 5] -shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] -shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] -shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] -``` - -This operation requires that: - -`-1-input.dims() <= dim <= input.dims()` - -This operation is related to `squeeze()`, which removes dimensions of -size 1. - -### Operands: -1. `input`: tensor of any type values -1. `dim`: tensor of any integer type - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.fake_quant (TFL::FakeQuantOp) -FakeQuant operator - -### Description: - -Fake-quantize the 'inputs' tensor of type float via float scalars min and -max to 'outputs' tensor of same shape as inputs. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `minmax` | `ArrayAttr` | min-max range pair attribute | -| `num_bits` | `IntegerAttr` | 32-bit integer attribute attribute | -| `narrow_range` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.fill (TFL::FillOp) -Fill the tensor with given value. - -### Description: - -Fill the tensor with given value. - -### Operands: -1. `dims`: tensor of 32/64-bit integer values -1. `value`: tensor of any type values - -### Attributes: - -### Results: -1. `res`: tensor of any type values - -## tfl.floor_div (TFL::FloorDivOp) -Floor div operator - -### Description: - -Element-wise floor div operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.floor_mod (TFL::FloorModOp) -Division reminder - -### Description: - -Element-wise division reminder operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.floor (TFL::FloorOp) -Floor operator - -### Description: - -Returns element-wise floor value of the input. - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: - -### Results: -1. `y`: tensor of floating-point values - -## tfl.fully_connected (TFL::FullyConnectedOp) -Fully connected op - -### Description: - - -### Operands: -1. `input`: tensor of 32-bit float values -1. `filter`: tensor of 32-bit float values -1. `bias`: tensor of 32-bit float values or none type - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | -| `weights_format` | `StringAttr` | fully connected options weights format attribute | -| `keep_num_dims` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `output`: tensor of 32-bit float values - -## tfl.gather (TFL::GatherOp) -Gather operator - -### Description: - -Gather slices from `params` axis `axis` according to `indices`. - -### Operands: -1. `params`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer or TFLite string type values -1. `indices`: tensor of 32-bit integer or 64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer or TFLite string type values - -## tfl.greater_equal (TFL::GreaterEqualOp) -Greater_equal operator - -### Description: - -Element-wise greater_equal operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.greater (TFL::GreaterOp) -Greater operator - -### Description: - -Element-wise greater operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.pseudo_input (TFL::InputOp) -Input pseudo operator - -### Description: - -Takes one of the function arguments as input and returns it as result. This -is a NOP and is used to attach attributes such as tensor name to function -arguments. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.leaky_relu (TFL::LeakyReluOp) -Leaky Relu operator - -### Description: - -Element-wise Leaky ReLU operator - x -> x >= 0 ? x : (alpha * x) - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `alpha` | `FloatAttr` | 32-bit float attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.less_equal (TFL::LessEqualOp) -Less_equal operator - -### Description: - -Element-wise less_equal operation. - -### Operands: -1. `lhs`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values -1. `rhs`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.less (TFL::LessOp) -Less operator - -### Description: - -Element-wise less operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.log (TFL::LogOp) -Natural logarithm operator - -### Description: - -Performs element-wise natural logarithm operation on input. - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.log_softmax (TFL::LogSoftmaxOp) -Log softmax operator - -### Description: - -Computes element-wise log softmax activations with the following formula - - input - log(reduce_sum(exp(input), dim)) - -### Operands: -1. `input`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.logical_and (TFL::LogicalAndOp) -Logical AND operator - -### Description: - -Element-wise logical AND operation. - -### Operands: -1. `lhs`: tensor of 1-bit integer values -1. `rhs`: tensor of 1-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.logical_not (TFL::LogicalNotOp) -Logical NOT operator - -### Description: - -Element-wise logical NOT operation. - -### Operands: -1. `lhs`: tensor of 1-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.logical_or (TFL::LogicalOrOp) -Logical OR operator - -### Description: - -Element-wise logical OR operation. - -### Operands: -1. `lhs`: tensor of 1-bit integer values -1. `rhs`: tensor of 1-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.logistic (TFL::LogisticOp) -Logistic operator - -### Description: - -Computes element-wise Sigmoid of input - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: - -### Results: -1. `y`: tensor of floating-point values - -## tfl.max_pool_2d (TFL::MaxPool2DOp) -Max Pool 2D op - -### Description: - -Performs max pool 2D on input. - -Inputs: - `inputs[0]`: required: the input tensor - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `padding` | `StringAttr` | padding enum attribute | -| `stride_w` | `IntegerAttr` | 32-bit integer attribute attribute | -| `stride_h` | `IntegerAttr` | 32-bit integer attribute attribute | -| `filter_width` | `IntegerAttr` | 32-bit integer attribute attribute | -| `filter_height` | `IntegerAttr` | 32-bit integer attribute attribute | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.maximum (TFL::MaximumOp) -Max operator - -### Description: - -Element-wise max operation. - -### Operands: -1. `lhs`: tensor of floating-point or 32/64-bit integer values -1. `rhs`: tensor of floating-point or 32/64-bit integer values - -### Attributes: - -### Results: -1. `max`: tensor of floating-point or 32/64-bit integer values - -## tfl.mean (TFL::MeanOp) -Mean operator - -### Description: - -Computes the mean of elements across dimensions of a tensor. -Reduces input_tensor along the dimensions given in axis. -Unless keepdims is true, the rank of the tensor is reduced by 1 for -each entry in axis. If keepdims is true, the reduced dimensions are retained -with length 1. - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values -1. `axis`: tensor of 32-bit integer or 64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `output`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values - -## tfl.minimum (TFL::MinimumOp) -Min operator - -### Description: - -Element-wise min operation. - -### Operands: -1. `lhs`: tensor of floating-point or 32/64-bit integer values -1. `rhs`: tensor of floating-point or 32/64-bit integer values - -### Attributes: - -### Results: -1. `min`: tensor of floating-point or 32/64-bit integer values - -## tfl.mul (TFL::MulOp) -Multiplication operator - -### Description: - -Element-wise multiplication operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.neg (TFL::NegOp) -Negation operator - -### Description: - -Computes element-wise negation of input - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.not_equal (TFL::NotEqualOp) -Not_equal operator - -### Description: - -Element-wise not_equal operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of 1-bit integer values - -## tfl.pack (TFL::PackOp) -Packs a list of tensors along a dimension into one tensor - -### Description: - -Packs a list of `values_count` rank-`R` tensors into one rank-`(R+1)` -tensor. - -Packs the `values_count` tensors in `values` into a tensor with rank one -higher than each tensor in `values`, by packing them along the `axis` -dimension. - -Given a list of tensors of shape `(A, B, C)`; - -if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. -if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. -Etc. - -For example: - -``` -# 'x' is [1, 4] -# 'y' is [2, 5] -# 'z' is [3, 6] -pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] -``` - -This is the opposite of `unpack`. - -### Operands: -1. `values`: tensor of 32-bit float or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `values_count` | `IntegerAttr` | 32-bit integer attribute attribute | -| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `output`: tensor of 32-bit float or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values - -## tfl.pad (TFL::PadOp) -Padding operator - -### Description: - -This operation pads a `input` with zeros according to the `paddings` you -specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is -the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` -indicates how many zeros to add before the contents of `input` in that -dimension, and `paddings[D, 1]` indicates how many zeros to add after the -contents of `input` in that dimension. - -The padded size of each dimension D of the output is: - - `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` - -For example: - -``` -# 't' is [[1, 1], [2, 2]] -# 'paddings' is [[1, 1], [2, 2]] -# rank of 't' is 2 -pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] - [0, 0, 1, 1, 0, 0] - [0, 0, 2, 2, 0, 0] - [0, 0, 0, 0, 0, 0]] - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values -1. `padding`: tensor of 32/64-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values - -## tfl.padv2 (TFL::PadV2Op) -Padding operator v2 - -### Description: - -This operation pads a `input` according to the `paddings` and -`constant_values` you specify. `paddings` is an integer tensor with shape -`[Dn, 2]`, where n is the rank of `input`. For each dimension D of `input`, -`paddings[D, 0]` indicates how many zeros to add before the contents of -`input` in that dimension, and `paddings[D, 1]` indicates how many zeros to -add after the contents of `input` in that dimension. `constant_values` is a -scalar tensor of the same type as `input` that indicates the value to use -for padding `input`. - -The padded size of each dimension D of the output is: - - `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` - -For example: - -``` -# 't' is [[1, 1], [2, 2]] -# 'paddings' is [[1, 1], [2, 2]] -# rank of 't' is 2 -pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] - [0, 0, 1, 1, 0, 0] - [0, 0, 2, 2, 0, 0] - [0, 0, 0, 0, 0, 0]] - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values -1. `padding`: tensor of 32/64-bit integer values -1. `constant_values`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values - -## tfl.pow (TFL::PowOp) -Power operator - -### Description: - -Element-wise power operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.pseudo_qconst (TFL::QConstOp) -Quantized constant pseudo op - -### Description: - -Represents a quantized constant value in TensorFlow Lite dialect. This is -not an actual operation and it will be lowered to buffer instead. The -quantization parameters are stored as a type attribute in this constant. - -### Operands: - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `qtype` | `TypeAttr` | Tensor type attribute attribute | -| `value` | `ElementsAttr` | constant vector/tensor attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.quantize (TFL::QuantizeOp) -Quantize operator - -### Description: - -Converts floating point tensors to quantized integer tensors according to -the quantization parameters defined in the type attribute. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `qtype` | `TypeAttr` | Tensor type attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.range (TFL::RangeOp) -Range operator - -### Description: - -Returns a 1D tensor defined by a sequence from `start` to `limit` with -a given `delta`. - -### Operands: -1. `start`: tensor of any type values -1. `limit`: tensor of any type values -1. `delta`: tensor of any type values - -### Attributes: - -### Results: -1. `result`: tensor of any type values - -## tfl.rank (TFL::RankOp) -Rank operator. - -### Description: - -Returns the rank of a tensor. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any integer type - -## tfl.reduce_max (TFL::ReduceMaxOp) -Max-reduction operator - -### Description: - -Computes the max reduction along the specified axes - -### Operands: -1. `input`: tensor of any type values -1. `axes`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | - -### Results: -1. «unnamed»: tensor of any type values - -## tfl.reduce_min (TFL::ReduceMinOp) -Min-reduction operator - -### Description: - -Computes the min reduction along the specified axes - -### Operands: -1. `input`: tensor of any type values -1. `axes`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | - -### Results: -1. «unnamed»: tensor of any type values - -## tfl.relu6 (TFL::Relu6Op) -Relu6 operator - -### Description: - -Element-wise Relu6 operator - x -> max(0, min(6, x)) - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.relu (TFL::ReluOp) -Relu operator - -### Description: - -Element-wise Relu operator - x -> max(0, x) - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.reshape (TFL::ReshapeOp) -Reshape operator - -### Description: - -Produces a tensor with the same values but different static shape defined -by the output type. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `new_shape` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.resize_bilinear (TFL::ResizeBilinearOp) -ResizeBilinear Op - -### Description: - -Resize `images` to `size` using bilinear interpolation. - -### Operands: -1. `input`: tensor of 32-bit float or 32-bit integer values -1. `size`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `align_corners` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `output`: tensor of 32-bit float values - -## tfl.reverse_v2 (TFL::ReverseV2Op) -ReverseV2 Operator - -### Description: - -Reverses specific dimensions of a tensor. - -Given a tensor, and a int32/int64 tensor axis representing the set -of dimensions of tensor to reverse. -This operation reverses each dimension i for -which there exists j s.t. axis[j] == i. - -Args: - tensor: A Tensor. Must be one of the following types: - int16, int32, int64, float32 Up to 8-D. - - axis: A Tensor. Must be one of the following types: int32, int64. - with only 1 element which is the axis index. - TODO: Add support for multiple elements. - -### Operands: -1. `input`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values -1. `axis`: tensor of 32-bit integer or 64-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer values - -## tfl.rsqrt (TFL::RsqrtOp) -Reciprocal of square root operator - -### Description: - -Computes element-wise reverse square root of input - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.select (TFL::SelectOp) -Select operator - -### Description: - -Select values of 'x' if the corresponding value of 'condition' is true or -the value of 'y' if false. There are valid condition input sizes: - -1. Either the same shape (in which case the select is elementwise), or -2. condition must be Rank 1 and match over the first dimension. - -### Operands: -1. `condition`: tensor of 1-bit integer values -1. `x`: tensor of 32-bit float or 1-bit integer or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values -1. `y`: tensor of 32-bit float or 1-bit integer or 8-bit integer or 16-bit integer or 32-bit integer or 64-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.shape (TFL::ShapeOp) -Shape operator - -### Description: - -Returns the shape of a tensor. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `out_type` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.sin (TFL::SinOp) -Sine operator - -### Description: - -Computes element-wise Sine of input - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: - -### Results: -1. `y`: tensor of floating-point values - -## tfl.softmax (TFL::SoftmaxOp) -Softmax operator - -### Description: - -Computes element-wise softmax activiations with the following formula - - exp(input) / tf.reduce_sum(exp(input * beta), dim) - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `beta` | `FloatAttr` | 32-bit float attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.space_to_batch_nd (TFL::SpaceToBatchNdOp) -SpaceToBatchNd operator - -### Description: - -This operation reshapes space dimensions into the "batch" dimension 0 - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values -1. `block_shape`: tensor of 32-bit integer values -1. `paddings`: tensor of 32-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values - -## tfl.split (TFL::SplitOp) -Splits a tensor into `num_split` tensors along one dimension. - -### Description: - -Splits the `value` tensor along `split_dim` into a number of sub-tensors -with same shape as the original one, except for `split_dim`. Same as -tf.Split. - -### Operands: -1. `split_dim`: tensor of 32-bit integer values -1. `value`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num_splits` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `outputs`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values - -## tfl.split_v (TFL::SplitVOp) -Splits a tensor into `num_split` tensors along one dimension. - -### Description: - -Splits the `value` tensor along `split_dim` into a number of sub-tensors -with same shape as the original one, except for `split_dim`. The grouping -of the resultant sub-tensors is decided by `size-splits`. Same as tf.SplitV. - -### Operands: -1. `value`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values -1. `size_splits`: tensor of 32-bit integer values -1. `split_dim`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num_splits` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `outputs`: tensor of 32-bit float or 16-bit integer or 32-bit integer or 64-bit integer values - -## tfl.sqrt (TFL::SqrtOp) -Square root operator - -### Description: - -Computes element-wise Square root of input - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.square (TFL::SquareOp) -Square operator - -### Description: - -Computes element-wise Square of input - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.squared_difference (TFL::SquaredDifferenceOp) -Squared difference operator - -### Description: - -Element-wise squared difference operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.squeeze (TFL::SqueezeOp) -Removes dimensions of size 1 from the shape of a tensor. - -### Description: - -Given a tensor `input`, this operation returns a tensor of the same type with -all dimensions of size 1 removed. If you don't want to remove all size 1 -dimensions, you can remove specific size 1 dimensions by specifying -`axis`. - -For example: - -``` -# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -shape(squeeze(t)) ==> [2, 3] -``` - -Or, to remove specific size 1 dimensions: - -``` -# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] -``` - -### Operands: -1. `input`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `squeeze_dims` | `ArrayAttr` | 64-bit integer array attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.strided_slice (TFL::StridedSliceOp) -StridedSlice Op - -### Description: - -Return a strided slice from `input`. - -### Operands: -1. `input`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values -1. `begin`: tensor of 32-bit integer values -1. `end`: tensor of 32-bit integer values -1. `strides`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `begin_mask` | `IntegerAttr` | 32-bit integer attribute attribute | -| `end_mask` | `IntegerAttr` | 32-bit integer attribute attribute | -| `ellipsis_mask` | `IntegerAttr` | 32-bit integer attribute attribute | -| `new_axis_mask` | `IntegerAttr` | 32-bit integer attribute attribute | -| `shrink_axis_mask` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `output`: tensor of 32-bit float or 32-bit integer or 64-bit integer or 8-bit integer values - -## tfl.sub (TFL::SubOp) -Subtraction operator - -### Description: - -Element-wise subtraction operation. - -### Operands: -1. `lhs`: tensor of any type values -1. `rhs`: tensor of any type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.sum (TFL::SumOp) -Sum operator - -### Description: - -Computes the sum reduction along the specified axes - -### Operands: -1. `input`: tensor of any type values -1. `axes`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | - -### Results: -1. «unnamed»: tensor of any type values - -## tfl.tanh (TFL::TanhOp) -Hyperbolic tangent operator - -### Description: - -Computes element-wise Hyperbolic tangent of input - -### Operands: -1. `x`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.tile (TFL::TileOp) -Tile operator. - -### Description: - - Constructs a tensor by tiling a given tensor. - -This operation creates a new tensor by replicating input -multiples times. The output tensor's i'th dimension has -input.dims(i) * multiples[i] elements, and the values of input -are replicated multiples[i] times along the 'i'th dimension. -For example, tiling [a b c d] by [2] produces [a b c d a b c d]. - -### Operands: -1. `input`: tensor of any type values -1. `multiples`: tensor of 32/64-bit integer values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - -## tfl.topk_v2 (TFL::TopKV2Op) -TopK operator - -### Description: - -Returns the top `k` largest element along each last dimensional slice of -`input` and the indices of values within the last dimension of the input -tensor. - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer or 64-bit integer values -1. `k`: tensor of 32-bit integer values - -### Attributes: - -### Results: -1. `values`: tensor of any type values -1. `indices`: tensor of 32-bit integer values - -## tfl.transpose (TFL::TransposeOp) -Transpose operator - -### Description: - -Returns the Transpose of x - -### Operands: -1. `x`: tensor of any type values -1. `perm`: tensor of any type values - -### Attributes: - -### Results: -1. `y`: tensor of any type values - -## tfl.unidirectional_sequence_lstm (TFL::UnidirectionalSequenceLSTMOp) -Unidirectional sequence lstm operator - -### Description: - -A recurrent neural network specified by an LSTM cell. This Op supports -unrolling the input along the time or batch dimensions, and -implements the following operation for -each element in the sequence s = 1...sequence_length: - outputs[s] = state = activation(LSTMOp(inputs[s])) - -where LSTMOp is LSTM TF Lite Op and the “activation” is the function passed -as the “fused_activation_function” argument (if not “NONE”). - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer values -1. `input_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type -1. `input_to_forget_weights`: tensor of 32-bit float or 8-bit integer values -1. `input_to_cell_weights`: tensor of 32-bit float or 8-bit integer values -1. `input_to_output_weights`: tensor of 32-bit float or 8-bit integer values -1. `recurrent_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type -1. `recurrent_to_forget_weights`: tensor of 32-bit float or 8-bit integer values -1. `recurrent_to_cell_weights`: tensor of 32-bit float or 8-bit integer values -1. `recurrent_to_output_weights`: tensor of 32-bit float or 8-bit integer values -1. `cell_to_input_weights`: tensor of 32-bit float or 8-bit integer values or none type -1. `cell_to_forget_weights`: tensor of 32-bit float or 8-bit integer values or none type -1. `cell_to_output_weights`: tensor of 32-bit float or 8-bit integer values or none type -1. `input_gate_bias`: tensor of 32-bit float values or none type -1. `forget_gate_bias`: tensor of 32-bit float values -1. `cell_bias`: tensor of 32-bit float values -1. `output_gate_bias`: tensor of 32-bit float values -1. `projection_weights`: tensor of 32-bit float or 8-bit integer values or none type -1. `projection_bias`: tensor of 32-bit float values or none type -1. `input_activation_state`: stateful tensor -1. `input_cell_state`: stateful tensor -1. `input_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type -1. `forget_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type -1. `cell_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type -1. `output_layer_norm_coefficients`: tensor of 32-bit float or 8-bit integer values or none type - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `fused_activation_function` | `StringAttr` | fused activation enum attribute | -| `cell_clip` | `FloatAttr` | 32-bit float attribute attribute | -| `proj_clip` | `FloatAttr` | 32-bit float attribute attribute | -| `time_major` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `output`: tensor of any type values - -## tfl.unpack (TFL::UnpackOp) -Unpacks a tensor along a dimension into multiple tensors - -### Description: - -Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. - -Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. -For example, given a tensor of shape `(A, B, C, D)`; - -If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` - and each tensor in `output` will have shape `(B, C, D)`. (Note that the - dimension unpacked along is gone, unlike `split`). - -If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` - and each tensor in `output` will have shape `(A, C, D)`. -Etc. - -This is the opposite of `pack`. - -### Operands: -1. `input`: tensor of 32-bit float or 8-bit integer or 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num` | `IntegerAttr` | 32-bit integer attribute attribute | -| `axis` | `IntegerAttr` | 32-bit integer attribute attribute | - -### Results: -1. `outputs`: tensor of 32-bit float or 8-bit integer or 32-bit integer values - -## tfl.zeros_like (TFL::ZerosLikeOp) -ZerosLike operator - -### Description: - -Returns a tensor of zeros with the same shape and type as the input tensor. - -### Operands: -1. `input`: tensor of any type values - -### Attributes: - -### Results: -1. `output`: tensor of any type values - diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 6c91470da07..c3dd7f5a398 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -15,13 +15,21 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -54,13 +62,21 @@ inline bool IsTrailingDimensions(ArrayRef a, ArrayRef b) { return std::equal(a.rbegin(), a.rend(), b.rbegin()); } +// Returns true if it is a shaped type of f32 elements. +inline bool IsF32ShapedType(Type t) { + if (auto shaped_type = t.dyn_cast_or_null()) { + return shaped_type.getElementType().isF32(); + } + return false; +} + // Performs const folding `calculate` with broadcast behavior on the two // attributes `operand1` and `operand2` and returns the result if possible. // The two operands are expected to both be scalar values. template > + llvm::function_ref> Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1, Attribute operand2, const CalculationT &calculate) { @@ -75,100 +91,68 @@ Attribute ConstFoldBinaryOpScalarScalar(Type result_type, Attribute operand1, calculate(lhs.getValue(), rhs.getValue())); } -// TODO: We have multiple functions to handle different attriubte kinds in the -// following. Consider add methods to ElementsAttr to unify these functions. - -// Performs const folding `calculate` with broadcast behavior on the two -// attributes `operand1` and `operand2` and returns the result if possible. -// This function assumes that both operands are `AttrElementT` attributes. -template > -Attribute ConstFoldBinaryOpSplatSplat(Type result_type, Attribute operand1, - Attribute operand2, - const CalculationT &calculate) { - auto type = result_type.cast(); - auto elem_type = type.getElementType(); - - auto element_result = ConstFoldBinaryOpScalarScalar( - elem_type, operand1, operand2, calculate); - if (!element_result) return {}; - - return DenseElementsAttr::get(type, element_result); -} - /// Performs const folding `calculate` with broadcast behavior on the two /// attributes `operand1` and `operand2` and returns the result if possible. -/// This function assumes the first operand is a DenseElementsAttr and the -/// second one is a SplatElementsAttr, and both are verified to have value +/// This function assumes the both operands are verified to have value /// attributes of broadcastable types. template > -Attribute ConstFoldBinaryOpDenseSplat(Type result_type, Attribute operand1, - Attribute operand2, + llvm::function_ref> +Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, + DenseElementsAttr rhs, const CalculationT &calculate) { - auto lhs = operand1.cast(); - - // TODO: Support broadcast behavior - if (lhs.getType() != result_type || operand2.getType() != result_type) - return {}; - - auto rhs = operand2.cast().getSplatValue(); auto type = result_type.cast(); - SmallVector new_values; - new_values.reserve(lhs.rawSize()); - - // Add the splat value to each of the values in the dense elements - // attribute. - auto rhs_val = rhs.cast().getValue(); - for (auto old_val : lhs.getValues()) { - new_values.push_back(calculate(old_val, rhs_val)); - } - - return DenseElementsAttr::get(type, new_values); -} - -/// Performs const folding `calculate` with broadcast behavior on the two -/// attributes `operand1` and `operand2` and returns the result if possible. -/// This function assumes the both operands are DenseElementsAttr and verified -/// to have value attributes of broadcastable types. -template > -Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1, - Attribute operand2, - const CalculationT &calculate) { - auto lhs = operand1.cast(); - auto rhs = operand2.cast(); - if (lhs.getType() != rhs.getType()) { // We only support the case that one of the operand's dimensions are // a perfect suffix of the other. // TODO: support the general broadcast behavior. auto lhs_shape = lhs.getType().getShape(); auto rhs_shape = rhs.getType().getShape(); - if (!IsTrailingDimensions(lhs_shape, rhs_shape) && - !IsTrailingDimensions(rhs_shape, lhs_shape)) + if (IsTrailingDimensions(lhs_shape, rhs_shape)) { + if (!type.hasStaticShape()) type = rhs.getType(); + } else if (IsTrailingDimensions(rhs_shape, lhs_shape)) { + if (!type.hasStaticShape()) type = lhs.getType(); + } else { return {}; + } + } else if (!type.hasStaticShape()) { + type = lhs.getType(); + } + + const bool rhs_is_splat = rhs.isSplat(); + const bool lhs_is_splat = lhs.isSplat(); + + // If both of them are splat, compute and return. + if (lhs_is_splat && rhs_is_splat) { + auto element_result = AttrElementT::get( + type.getElementType(), calculate(lhs.getSplatValue(), + rhs.getSplatValue())); + if (!element_result) return {}; + + return DenseElementsAttr::get(type, element_result); } auto lhs_num_elements = lhs.getType().getNumElements(); auto rhs_num_elements = rhs.getType().getNumElements(); - - auto type = result_type.cast(); - auto num_elements = type.getNumElements(); + auto num_elements = std::max(lhs_num_elements, rhs_num_elements); // We assume the arguments have broadcast-compatible types. Make sure again. assert(std::max(lhs_num_elements, rhs_num_elements) == num_elements); assert(num_elements % std::min(lhs_num_elements, rhs_num_elements) == 0); - SmallVector lhs_old_values(lhs.getValues()); - SmallVector rhs_old_values(rhs.getValues()); + SmallVector lhs_old_values; + SmallVector rhs_old_values; + if (lhs_is_splat) + lhs_old_values.push_back(lhs.getSplatValue()); + else + lhs_old_values = llvm::to_vector<16>(lhs.getValues()); + if (rhs_is_splat) + rhs_old_values.push_back(rhs.getSplatValue()); + else + rhs_old_values = llvm::to_vector<16>(rhs.getValues()); + SmallVector new_values; new_values.reserve(num_elements); @@ -186,8 +170,8 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1, // operand with more elements, since the result has the same number of // elements, we are only going over its elements once. The modulo operation // also works for that. - int lhs_index = i % lhs_num_elements; - int rhs_index = i % rhs_num_elements; + int lhs_index = lhs_is_splat ? 0 : (i % lhs_num_elements); + int rhs_index = rhs_is_splat ? 0 : (i % rhs_num_elements); new_values.push_back( calculate(lhs_old_values[lhs_index], rhs_old_values[rhs_index])); @@ -203,7 +187,7 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, Attribute operand1, template > + llvm::function_ref> Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, Attribute operand2, const CalculationT &calculate, bool is_commutative) { @@ -212,30 +196,11 @@ Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, if (operand2.dyn_cast_or_null()) return ConstFoldBinaryOpScalarScalar(result_type, operand1, operand2, calculate); - } else if (auto lhs = operand1.dyn_cast_or_null()) { - // Splat op splat case - if (auto rhs = operand2.dyn_cast_or_null()) - return ConstFoldBinaryOpSplatSplat( - result_type, lhs.getSplatValue(), rhs.getSplatValue(), calculate); - - // Splat op dense case - if (auto rhs = operand2.dyn_cast_or_null()) { - if (is_commutative) { - // Swap the two constant values to fall into the following case - return ConstFoldBinaryOpDenseSplat(result_type, operand2, - operand1, calculate); - } - } - } else if (auto lhs = operand1.dyn_cast_or_null()) { - // Dense op splat case - if (auto rhs = operand2.dyn_cast_or_null()) - return ConstFoldBinaryOpDenseSplat(result_type, operand1, - operand2, calculate); - - // Dense op dense case - if (auto rhs = operand2.dyn_cast_or_null()) - return ConstFoldBinaryOpDenseDense(result_type, operand1, - operand2, calculate); + } else if (operand1.dyn_cast_or_null() && + operand2.dyn_cast_or_null()) { + return ConstFoldBinaryOpDenseDense( + result_type, operand1.cast(), + operand2.cast(), calculate); } // TODO: support other attribute kinds @@ -249,8 +214,9 @@ Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, /// `intCalculate` is chosen to conduct the calculate. Attribute ConstFoldBinaryOp( Type result_type, ArrayRef operands, - std::function float_calculate, - std::function int_calculate, bool is_commutative) { + llvm::function_ref float_calculate, + llvm::function_ref int_calculate, + bool is_commutative) { // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is // represented as tensor. So we are only handling tensor types here. auto type = result_type.dyn_cast(); @@ -269,6 +235,32 @@ Attribute ConstFoldBinaryOp( return {}; } +/// Performs const folding a attributes `operand` and returns the result if +/// possible. +/// The function currently asserts that the `result_type` to be a f32 tensor +/// type. +/// TODO: Extend this function to handle integral tensor for ops like +/// "tfl.logical_not". +Attribute ConstFoldUnaryOp(Type result_type, Attribute operand, + llvm::function_ref calculate) { + assert(IsF32ShapedType(result_type)); + auto result_shape_type = result_type.cast(); + + if (auto dense_elements = operand.dyn_cast_or_null()) { + SmallVector new_values; + const int num_elements = result_shape_type.getNumElements(); + new_values.reserve(num_elements); + + for (APFloat old_value : dense_elements.getValues()) { + new_values.push_back(calculate(old_value)); + } + + return DenseElementsAttr::get(result_shape_type, new_values); + } + + return {}; +} + void buildComparisonBinOp(Builder *builder, OperationState *result, Value *lhs, Value *rhs) { auto result_type = @@ -410,6 +402,23 @@ static LogicalResult Verify(PackOp op) { if (op.getOperation()->getNumOperands() != op.values_count()) return op.emitOpError("input count should match 'values_count' attribute"); + Value *operand0 = op.getOperand(0); + auto input_type = operand0->getType().cast(); + + // Check axis bounds. + int64_t axis_value = op.axis().getSExtValue(); + if (abs(axis_value) > input_type.getRank()) + return op.emitOpError("op attribute 'axis' is out of bounds, got ") + << axis_value; + + // Make sure all inputs have the same shape and element type. + // TODO(rahulsp): Simplify once b/135032064 is fixed. + for (Value *operand : op.getOperands()) { + auto other_type = operand->getType().cast(); + if (input_type != other_type) + return op.emitOpError("operands should be of the same type"); + } + return success(); } @@ -453,12 +462,87 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { // Remove identity reshape. if (getType() == getOperand()->getType()) return getOperand(); + // Constant folding + assert(operands.size() == 1); + if (auto dense_elements = operands[0].dyn_cast_or_null()) { + auto result_shape_type = getType().cast(); + return dense_elements.reshape(result_shape_type); + } + return nullptr; } void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// SliceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SliceOp op) { + auto input_type = op.input()->getType().cast(); + auto begin_type = op.begin()->getType().cast(); + auto size_type = op.size()->getType().cast(); + if (input_type.hasStaticShape() && begin_type.hasStaticShape() && + size_type.hasStaticShape()) { + if (input_type.getRank() != begin_type.getNumElements()) { + return op.emitError( + "begin tensor elements size is not equal to input tensor rank"); + } + + if (input_type.getRank() != size_type.getNumElements()) { + return op.emitError( + "size tensor elements size is not equal to input tensor rank"); + } + } + + DenseIntElementsAttr begin; + if (matchPattern(op.begin(), m_Constant(&begin))) { + int axis = 0; + for (auto begin_i : llvm::enumerate(begin)) { + if (begin_i.value().getSExtValue() < 0) { + return op.emitError( + llvm::formatv("begin[{0}] cannot be negative", axis)); + } + axis++; + } + } + + DenseIntElementsAttr size; + if (matchPattern(op.size(), m_Constant(&size))) { + int axis = 0; + for (auto size_i : llvm::enumerate(size)) { + if (size_i.value().getSExtValue() < -1) { + return op.emitError( + llvm::formatv("size[{0}] cannot be negative other than -1", axis)); + } + axis++; + } + } + + if (begin && size && input_type.hasStaticShape()) { + const int input_rank = begin.getNumElements(); + for (uint64_t i = 0; i < input_rank; i++) { + int begin_i = + begin.getValue({i}).cast().getValue().getSExtValue(); + int size_i = + size.getValue({i}).cast().getValue().getSExtValue(); + int dim_i = input_type.getShape()[i]; + if (begin_i >= dim_i) { + return op.emitOpError(llvm::formatv( + "begin[{0}] cannot exceed dimension length: {1}", i, dim_i)); + } + if (size_i >= 0 && begin_i + size_i > dim_i) { + return op.emitError(llvm::formatv( + "begin[{0}] + size[{0}] cannot exceed dimension length: {1}", i, + dim_i)); + } + } + } + + return success(); } //===----------------------------------------------------------------------===// @@ -486,7 +570,7 @@ static void BuildTopKOp(Builder *builder, OperationState *result, Value *input, if (matchPattern(k, m_Constant(&cst))) // These casts should all be valid due to how Tensor constants are stored. // TODO(jpienaar): This should use a helper function. - const_k = cst.getValue({}).cast().getValue().getSExtValue(); + const_k = cst.getValue({}).getValue().getSExtValue(); auto val_type = input->getType().cast(); // If value is unranked, then so is results. @@ -543,7 +627,7 @@ struct DropFakeQuant : public RewritePattern { void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.push_back(llvm::make_unique(context)); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -562,12 +646,422 @@ static LogicalResult Verify(UnpackOp op) { return success(); } +//===----------------------------------------------------------------------===// +// SplitOp +//===----------------------------------------------------------------------===// + +// Extracts and returns the signed integer constant in a 0-rank integer tensor +// if 'value' is a constant. +static llvm::Optional ExtractConstantIntFromTensor(Value *value) { + ElementsAttr attr; + if (!matchPattern(value, m_Constant(&attr))) return {}; + + IntegerAttr int_attr = attr.getValue(llvm::None).cast(); + return int_attr.getValue().getSExtValue(); +} + +static LogicalResult Verify(SplitOp op) { + int64_t num_splits = op.num_splits().getSExtValue(); + if (op.getOperation()->getNumResults() != num_splits) + return op.emitOpError("output count should match 'num_splits' attribute"); + + // If 'split_dim' is not a constant, there are no other checks. + llvm::Optional split_dim_opt = + ExtractConstantIntFromTensor(op.split_dim()); + if (!split_dim_opt) return success(); + + // If 'input' is not a ranked tensor, there are no other checks. + auto input_type = op.value()->getType().dyn_cast(); + if (!input_type) return success(); + + int64_t split_dim = split_dim_opt.getValue(); + const int64_t rank = input_type.getRank(); + if (split_dim < 0) split_dim += rank; + if (split_dim < 0 || split_dim >= rank) + return op.emitOpError("'split_dim' should be in [-rank, rank)"); + + // If the 'split_dim' dimension of the 'input' tensor has a dynamic size, + // there are no other checks. + const int64_t dim_size = input_type.getDimSize(split_dim); + if (ShapedType::isDynamic(dim_size)) return success(); + + if (dim_size % num_splits != 0) + return op.emitOpError("'num_splits' should evenly divide 'split_dim' axis"); + + // Creates sliced tensor type. + auto slice_shape = input_type.getShape().vec(); + slice_shape[split_dim] = dim_size / num_splits; + RankedTensorType slice_type = + RankedTensorType::get(slice_shape, input_type.getElementType()); + + // Verifies result tensor types. + for (int64_t i = 0; i < num_splits; ++i) { + Value *result = op.getResult(i); + auto result_type = result->getType().dyn_cast(); + if (!result_type || result_type != slice_type) + return op.emitOpError() << "output #" << i << " should be " << slice_type; + } + + return success(); +} + //===----------------------------------------------------------------------===// // MeanOp //===----------------------------------------------------------------------===// // TODO(b/133854225): Implement shape inference to Mean +//===----------------------------------------------------------------------===// +// LSTMOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(LSTMOp op) { + auto operands = op.GetStatefulOperands(); + if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) { + return success(); + } + return op.emitError("LSTMOp expected to have two stateful operands"); +} + +//===----------------------------------------------------------------------===// +// UnidirectionalSequenceLSTMOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) { + auto operands = op.GetStatefulOperands(); + if (operands.size() == 2 && operands[0] == 18 && operands[1] == 19) { + return success(); + } + return op.emitError( + "UnidirectionalSequenceLSTMOp expected to have two stateful operands"); +} + +//===----------------------------------------------------------------------===// +// UnidirectionalSequenceRNNOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(UnidirectionalSequenceRNNOp op) { + auto operands = op.GetStatefulOperands(); + if (operands.size() == 1 && operands[0] == 4) { + return success(); + } + return op.emitError( + "UnidirectionalSequenceRNNOp expected to have one stateful operand"); +} + +//===----------------------------------------------------------------------===// +// SvdfOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SVDFOp op) { + auto operands = op.GetStatefulOperands(); + if (operands.size() == 1 && operands[0] == 4) { + return success(); + } + return op.emitError("SvdfOp expected to have one stateful operand"); +} + +//===----------------------------------------------------------------------===// +// AbsOp +//===----------------------------------------------------------------------===// + +OpFoldResult AbsOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { return llvm::abs(value); }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// SinOp +//===----------------------------------------------------------------------===// + +OpFoldResult SinOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::sin(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// CosOp +//===----------------------------------------------------------------------===// + +OpFoldResult CosOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::cos(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// LogOp +//===----------------------------------------------------------------------===// + +OpFoldResult LogOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::log(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// SqrtOp +//===----------------------------------------------------------------------===// + +OpFoldResult SqrtOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::sqrt(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// RsqrtOp +//===----------------------------------------------------------------------===// + +OpFoldResult RsqrtOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = 1.f / std::sqrt(f); + return APFloat(result); + }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// SquareOp +//===----------------------------------------------------------------------===// + +OpFoldResult SquareOp::fold(ArrayRef operands) { + Type result_type = getType(); + // Only constant fold for tensor of f32 is implemented. + if (!IsF32ShapedType(result_type)) return nullptr; + + auto compute = [](APFloat value) -> APFloat { return value * value; }; + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef operands) { + assert(operands.size() == 1); + auto result_type = getType().cast(); + if (auto elements_attr = operands[0].dyn_cast_or_null()) { + auto rank = static_cast(elements_attr.getType().getRank()); + return DenseElementsAttr::get(result_type, {rank}); + } + + // Also fold if `input` has a known rank. + auto input_type = input()->getType().cast(); + // Do not fold if rank is zero because the TFLite converter doesn't + // distinguish between unranked input and scalar input due to b/138865275. + // TODO(b/138865275): Remove `input_type.getRank() != 0` in the following + // predicate and fold the op when rank is zero. + if (input_type.hasRank() && input_type.getRank() != 0) { + auto rank = static_cast(input_type.getRank()); + return DenseElementsAttr::get(result_type, {rank}); + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +//===----------------------------------------------------------------------===// +// RangeOp +//===----------------------------------------------------------------------===// + +namespace { + +// Compute the length of a range (1-D) tensor given `start`, `limit`, `delta`. +// Template parameter `FloatOrInt` must be standard C integer or floating-point +// types. +template +int GetLengthOfRange(FloatOrInt start, FloatOrInt limit, FloatOrInt delta) { + // Refer to the implementation in + // tensorflow/lite/kernels/range.cc. + return std::is_integral::value + ? ((std::abs(limit - start) + std::abs(delta) - 1) / + std::abs(delta)) + : std::ceil(std::abs((limit - start) / delta)); +} + +// Builds a constant range tensor of `result_elem_type` elements. +// Template parameter `FloatOrIntAtrr` must be mlir::IntegerAttr or +// mlir::FloatAttr. +template +DenseElementsAttr BuildConstRangeTensor(Type result_elem_type, int num_elements, + FloatOrIntAtrr start_attr, + FloatOrIntAtrr delta_attr) { + using ValueType = typename FloatOrIntAtrr::ValueType; // APInt or APFloat + ValueType start = start_attr.getValue(); + ValueType delta = delta_attr.getValue(); + + SmallVector new_values; + new_values.reserve(num_elements); + ValueType new_value = start; + for (int i = 0; i < num_elements; ++i) { + new_values.push_back(new_value); + new_value = new_value + delta; + } + // Result is always a 1-D tensor. + auto new_result_type = + RankedTensorType::get({num_elements}, result_elem_type); + return DenseElementsAttr::get(new_result_type, new_values); +} +} // namespace + +OpFoldResult RangeOp::fold(ArrayRef operands) { + assert(operands.size() == 3); + auto start_tensor = operands[0].dyn_cast_or_null(); + auto limit_tensor = operands[1].dyn_cast_or_null(); + auto delta_tensor = operands[2].dyn_cast_or_null(); + if (start_tensor && limit_tensor && delta_tensor) { + // Operands should all be scalars + assert(start_tensor.getType().getRank() == 0 && + limit_tensor.getType().getRank() == 0 && + delta_tensor.getType().getRank() == 0); + Type elem_type = getType().cast().getElementType(); + if (elem_type.isa()) { + auto start_attr = start_tensor.getValue({}); + auto limit_attr = limit_tensor.getValue({}); + auto delta_attr = delta_tensor.getValue({}); + const int num_elements = GetLengthOfRange( + start_attr.getInt(), limit_attr.getInt(), delta_attr.getInt()); + return BuildConstRangeTensor(elem_type, num_elements, start_attr, + delta_attr); + } else if (elem_type.isa()) { + auto start_attr = start_tensor.getValue({}); + auto limit_attr = limit_tensor.getValue({}); + auto delta_attr = delta_tensor.getValue({}); + const int num_elements = GetLengthOfRange(start_attr.getValueAsDouble(), + limit_attr.getValueAsDouble(), + delta_attr.getValueAsDouble()); + return BuildConstRangeTensor(elem_type, num_elements, start_attr, + delta_attr); + } + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +namespace { + +// Computes the permutation of a constant `input_tensor` according to `perm`. +// The function recursively traverses the dimensions of the output tensor in +// a row-major order and writes the value in the output tensor into +// `new_values`. +void ComputePermutation(ElementsAttr input_tensor, ArrayRef perm, + ArrayRef output_shape, int num_dimensions, + int output_axis, std::vector *input_indices, + std::vector *new_values) { + // Refer to the implementation of `Transpose` function in + // tensorflow/lite/kernels/internal/reference/reference_ops.h + assert(output_axis < num_dimensions); + const int input_axis = perm[output_axis]; + for (int i = 0; i < output_shape[output_axis]; ++i) { + // Update the input indices on `input_axis`. + input_indices->at(input_axis) = i; + // Write the value from `input_tensor` if it is the last axis or + // recurse into the next axis. + const bool is_last_axis = output_axis == num_dimensions - 1; + if (is_last_axis) { + new_values->push_back(input_tensor.getValue(*input_indices)); + } else { + ComputePermutation(input_tensor, perm, output_shape, num_dimensions, + output_axis + 1, input_indices, new_values); + } + } +} + +} // namespace + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto input_tensor = operands[0].dyn_cast_or_null(); + auto perm_tensor = operands[1].dyn_cast_or_null(); + if (!input_tensor || !perm_tensor) return nullptr; + + // Do not try to fold elements attr of a quant type because + // DenseElementsAttr does not support it. + if (!getType().cast().getElementType().isIntOrFloat()) + return nullptr; + + assert(perm_tensor.getType().getRank() == 1); + const int num_dimensions = input_tensor.getType().getRank(); + assert(perm_tensor.getType().getNumElements() == num_dimensions); + + ArrayRef input_shape = input_tensor.getType().getShape(); + auto output_type = getType().cast(); + + SmallVector perm; + SmallVector output_shape; + for (int i = 0; i < num_dimensions; ++i) { + perm.push_back( + perm_tensor.getValue({static_cast(i)}).getInt()); + output_shape.push_back(input_shape[perm[i]]); + + // Check that the derived output shape matches the static shape. + assert(!output_type.hasStaticShape() || + output_type.getShape()[i] == output_shape[i]); + } + + std::vector new_values; + new_values.reserve(input_tensor.getType().getNumElements()); + std::vector input_indices(num_dimensions); + ComputePermutation(input_tensor, perm, output_shape, num_dimensions, + /*output_axis=*/0, &input_indices, &new_values); + auto result_type = + RankedTensorType::get(output_shape, output_type.getElementType()); + return DenseElementsAttr::get(result_type, new_values); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -575,5 +1069,16 @@ static LogicalResult Verify(UnpackOp op) { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" +Operation *TensorFlowLiteDialect::materializeConstant(OpBuilder &builder, + Attribute value, + Type type, Location loc) { + // If this is an opaque elements attribute or the result type doesn't match + // the attribute type, then generate a tfl.pseudo_const. + if (value.isa() || + (value.isa() && value.getType() != type)) + return builder.create(loc, type, value.cast()); + return nullptr; +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 5eac0511ab7..c60a17a24da 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/Functional.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { @@ -36,6 +37,11 @@ namespace TFL { class TensorFlowLiteDialect : public Dialect { public: explicit TensorFlowLiteDialect(MLIRContext *context); + + // Registered hook to materialize a constant operation from a given attribute + // value with the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; }; #define GET_OP_CLASSES diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 8c78f7a9dc8..458ff270e91 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -24,7 +24,7 @@ limitations under the License. include "mlir/IR/OpBase.td" #endif // OP_BASE -include "mlir/Dialect/QuantOps/QuantPredicates.td" +include "tensorflow/compiler/mlir/lite/quantization/quantization.td" def TFL_Dialect : Dialect { let name = "tfl"; @@ -95,49 +95,6 @@ def TFL_MirrorPaddingAttr : StrEnumAttr<"Padding", "Mirror pad enum", [ TFL_MIRRORPAD_Reflect, TFL_MIRRORPAD_Symmetric ]>; -//===----------------------------------------------------------------------===// -// Min-max range pair definitions. -//===----------------------------------------------------------------------===// - -// A pair of floating point values which defines the min and max of a value -// range for quantization. The attribute is allowed to be empty or -// have 2 elements. -def MinMaxAttr : Attr().size() == 0">, - CPred<"$_self.cast().size() == 2">]>, - "min-max range pair"> { - let storageType = [{ ArrayAttr }]; - let returnType = [{ ArrayRef }]; -} - -//===----------------------------------------------------------------------===// -// QuantizedType definitions. -//===----------------------------------------------------------------------===// - -// The base class of a quantized type. -class TFL_QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # - ".getStorageTypeIntegralWidth() == " # !head(params)>]>, - "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { - string name = n; - string asTraitArgsStr = - StrJoinInt.result # !if(signed, ", true", ", false"); -} - -// Uniform quantized types. Two integers "smantissa" and "sexp" are used to -// express the Mantissa and Exponent components of the floating-point scale so -// the scale of the quantized type is "smantissa * 10 ^ sexp". -class TFL_UInt8UniformQuantizedType - : TFL_QuantizedType<"Uniform", - [8, zero_pt, smantissa, sexp, 0, 255], 0>; -class TFL_Int8UniformQuantizedType - : TFL_QuantizedType<"Uniform", - [8, zero_pt, smantissa, sexp, -128, 127], 1>; - -// 8-bits quantized types. The definitions can be used to specify tensor types. -def TFL_QUI8 : TFL_QuantizedType<"Uniform", [8], 0>; -def TFL_QI8 : TFL_QuantizedType<"Uniform", [8], 1>; - //===----------------------------------------------------------------------===// // TensorType attribute definitions. //===----------------------------------------------------------------------===// @@ -163,20 +120,12 @@ def TFL_IntTensor : TypeAlias; // This is used to represent the type of "ref tensors" or tensors that are // used as variables to track state. -// TODO(ashwinm): This is a placeholder until we have first class support -// for variables. def TFL_StatefulTensor : TypeAlias; // Tensor or None type. class TFL_TensorOfOrNone allowedTypes, string description = ""> : AnyTypeOf<[TensorOf, NoneType], description>; -// Type Constraint operand `idx`'s type is NOT `type`. -// TODO(b/131936589): Once this bug is fixed, we should be able to use -// Neg>> and can remove this. -class TFL_TCopIsNot : - NeggetType().isa<" # type # ">()">>; - def TFL_FpOrI32OrI64Tensor : TensorOf<[AnyFloat, TFL_Int32Or64]>; //===----------------------------------------------------------------------===// @@ -258,31 +207,11 @@ def TFL_ComparisonBinaryBuilder : OpBuilder< }]>; //===----------------------------------------------------------------------===// -// TFL native op traits (for quantization). -// -// Ops in this link should have those traits specified: -// https://www.tensorflow.org/lite/performance/quantization_spec -//===----------------------------------------------------------------------===// +// TFL native op trait for stateful operands. -// Specify this trait if the op has a fixed output value range. -class TFL_FixedResultScale : NativeOpTrait::Impl")>; +class StatefulOperands operands> + : ParamNativeOpTrait<"TFL::StatefulOperands", StrJoinInt.result>; -// Specify this trait if the op requires same inputs and outputs quantization -// scales. -def TFL_SameOperandsAndResultsScale : NativeOpTrait< - "TFL::SameOperandsAndResultsScale">; - -// Specify this trait if the b-th input of the op is a bias input, which needs -// a scale based on the scales of op1 and op2. -class TFL_AccumulatorUniformScale : NativeOpTrait< - !strconcat("TFL::AccumulatorUniformScale<", - StrJoinInt<[bias, op1, op2]>.result, - ">::Impl")>; - -// Specify this trait if the op doesn't have quantizable ouput. We shouldn't -// apply quantization on this op. -def TFL_NoQuantizableResult : NativeOpTrait<"TFL::NoQuantizableResult">; //===----------------------------------------------------------------------===// // TFL op base class. @@ -310,7 +239,7 @@ class TFL_Op traits = []> : } class TFL_ConvOp : - TFL_Op]> { + TFL_Op]> { let summary = opSummary # " operator"; let description = [{ @@ -325,7 +254,7 @@ class TFL_ConvOp : let arguments = ( ins AnyTensor:$input, AnyTensor:$filter, - AnyTensor:$bias, + TFL_TensorOfOrNone<[AnyType]>:$bias, I32Attr:$dilation_h_factor, I32Attr:$dilation_w_factor, TFL_AFAttr:$fused_activation_function, @@ -355,6 +284,8 @@ an output element, this operation computes \\(y = |x|\\). let arguments = (ins AnyTensor:$x); let results = (outs AnyTensor:$y); + + let hasFolder = 1; } def TFL_AddOp : TFL_Op<"add", [Broadcastable, NoSideEffect, Commutative]> { @@ -400,8 +331,35 @@ def TFL_AddNOp : TFL_Op<"add_n", [Commutative, NoSideEffect]> { ); } +def TFL_ReduceAnyOp : TFL_Op<"reduce_any", [NoSideEffect]> { + let summary = [{ +Computes the "logical or" of elements across dimensions of a tensor. + }]; + + let description = [{ +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + }]; + + let arguments = (ins + I1Tensor:$input, + I32Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims + ); + + let results = (outs + I1Tensor:$output + ); + + let hasOptions = 1; + let customOption = "ReducerOptions"; +} + def TFL_AveragePool2DOp: - TFL_Op<"average_pool_2d", [NoSideEffect, TFL_SameOperandsAndResultsScale]> { + TFL_Op<"average_pool_2d", [NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Average_pool_2d operator"; let description = [{ @@ -424,6 +382,32 @@ def TFL_AveragePool2DOp: let customOption = "Pool2DOptions"; } +def TFL_ArgMaxOp : TFL_Op<"arg_max", [NoSideEffect]> { + let summary = "ArgMax operator"; + + let description = [{ + Returns the index with the largest value across dimensions of a tensor. + }]; + + let arguments = ( + // TODO: Add support for uint8. + ins TensorOf<[F32, I32, I8]>:$input, + TFL_I32OrI64Tensor:$dim + ); + + let results = (outs + TFL_I32OrI64Tensor:$output + ); + + let hasOptions = 1; + + DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ + return getResult()->getType().cast().getElementType(). + cast().getWidth() > 32 ? tflite::TensorType_INT64 : + tflite::TensorType_INT32; + }]>; +} + def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { let summary = "ArgMin operator"; @@ -443,6 +427,14 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [NoSideEffect]> { let results = (outs TFL_I32OrI64Tensor:$output ); + + let hasOptions = 1; + + DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ + return getResult()->getType().cast().getElementType(). + cast().getWidth() > 32 ? tflite::TensorType_INT64 : + tflite::TensorType_INT32; + }]>; } def TFL_CeilOp: TFL_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { @@ -462,7 +454,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", NoSideEffect, PredOpTrait<"values and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_SameOperandsAndResultsScale + SameOperandsAndResultsScale ]> { let summary = "Concatenation operator"; @@ -472,14 +464,14 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", let arguments = ( ins Variadic>:$values, + [F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>>:$values, I32Attr:$axis, TFL_AFAttr:$fused_activation_function ); let results = (outs TensorOf< - [F32, I64, I32, I16, I8, TFL_QI8, TFL_QUI8, TFL_Uint8]>:$output + [F32, I64, I32, I16, I8, QI8, QUI8, TFL_Uint8]>:$output ); let hasOptions = 1; @@ -500,6 +492,8 @@ def TFL_ConstOp : Op; @@ -514,6 +508,8 @@ def TFL_CosOp: TFL_Op<"cos", [NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins TFL_FpTensor:$x); let results = (outs TFL_FpTensor:$y); + + let hasFolder = 1; } def TFL_DepthwiseConv2DOp : @@ -532,13 +528,14 @@ def TFL_FullyConnectedOptionsWeightFormatAttr : // TODO(jpienaar): Update post discussion on semantics of FC OP. // TODO(jpienaar): Include more shape verification. -def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [NoSideEffect]> { +def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ + NoSideEffect, AccumulatorUniformScale<2, 0, 1>]> { let summary = "Fully connected op"; let arguments = (ins - TensorOf<[F32]>:$input, - TensorOf<[F32]>:$filter, - TFL_TensorOfOrNone<[F32]>:$bias, + TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$input, + TensorOf<[F32, QI8, QUI8, QI16, QUI16]>:$filter, + TFL_TensorOfOrNone<[F32, QI32, QUI32]>:$bias, TFL_AFAttr:$fused_activation_function, TFL_FullyConnectedOptionsWeightFormatAttr:$weights_format, @@ -547,7 +544,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [NoSideEffect]> { // Depending on the weights format, this op can have one or two outputs. let results = (outs - Variadic>:$output + Variadic>:$output ); let hasOptions = 1; @@ -555,6 +552,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [NoSideEffect]> { def TFL_GatherOp : TFL_Op<"gather", [ NoSideEffect, + SameOperandsAndResultsScale, TFL_OperandHasAtleastRank<0, 1>, PredOpTrait<"params and output must have same element type", TCresVTEtIsSameAsOp<0, 0>> @@ -566,7 +564,7 @@ def TFL_GatherOp : TFL_Op<"gather", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, TFL_Str]>:$params, + TensorOf<[F32, I8, I32, I64, TFL_Str, QI8, QUI8]>:$params, TensorOf<[I32, I64]>:$indices, I32Attr:$axis ); @@ -579,7 +577,7 @@ def TFL_GatherOp : TFL_Op<"gather", [ ]; let results = (outs - TensorOf<[F32, I16, I32, I64, TFL_Str]>:$output + TensorOf<[F32, I16, I32, I64, TFL_Str, QI8, QUI8]>:$output ); let hasOptions = 1; @@ -592,19 +590,19 @@ def TFL_GatherNdOp : TFL_Op<"gather_nd", [NoSideEffect]> { Gather slices from `params` into a Tensor with shape specified by `indices`. }]; - // TODO: missing Uint8. let arguments = (ins - TensorOf<[F32, I8, I64, I32]>:$params, + TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$params, TFL_I32OrI64Tensor:$indices ); let results = (outs - TensorOf<[F32, I8, I64, I32]>:$output + TensorOf<[F32, I8, I64, I32, TFL_Uint8]>:$output ); } // Same type check of lhs and rhs is handled by the Broadcastable trait. -def TFL_LessEqualOp : TFL_Op<"less_equal", [Broadcastable, NoSideEffect]> { +def TFL_LessEqualOp : TFL_Op<"less_equal", [ + Broadcastable, NoSideEffect, NoQuantizableResult]> { let summary = "Less_equal operator"; let description = [{ @@ -612,8 +610,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [Broadcastable, NoSideEffect]> { }]; let arguments = ( - ins TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$lhs, - TensorOf<[F32, I32, I64, I8, TFL_Uint8]>:$rhs); + ins TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$lhs, + TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$rhs); let results = (outs TFL_BoolTensor:$output); @@ -645,7 +643,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag }]; let arguments = (ins - TensorOf<[F32]>:$input, + TensorOf<[F32, QI8, QUI8]>:$input, I32Attr:$radius, F32Attr:$bias, F32Attr:$alpha, @@ -653,13 +651,14 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag ); let results = (outs - TensorOf<[F32]>:$output + TensorOf<[F32, QI8, QUI8]>:$output ); let hasOptions = 1; } -def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [Broadcastable, NoSideEffect]> { +def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ + Broadcastable, NoSideEffect, NoQuantizableResult]> { let summary = "Greater_equal operator"; let description = [{ @@ -682,7 +681,7 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [Broadcastable, NoSideEffect]> } def TFL_NotEqualOp : TFL_Op<"not_equal", [ - Broadcastable, Commutative, NoSideEffect, TFL_NoQuantizableResult]> { + Broadcastable, Commutative, NoSideEffect, NoQuantizableResult]> { let summary = "Not_equal operator"; let description = [{ @@ -747,8 +746,28 @@ def TFL_EluOp: TFL_Op<"elu", [NoSideEffect, SameOperandsAndResultType]> { let hasOptions = 0; } +def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", + [NoSideEffect, + PredOpTrait<"value and output must have same element type", + TCresVTEtIsSameAsOp<0, 1>> + ]> { + let summary = "Embedding lookup operator"; + + let description = [{ + Looks up ids in a list of embedding tensors. + }]; + + let arguments = (ins + TensorOf<[I32]>:$lookup, + TensorOf<[F32, I8, TFL_Uint8]>:$value + ); + + let results = (outs TensorOf<[F32, I8, TFL_Uint8]>:$output); +} + def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable, - PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> { + NoQuantizableResult, + PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> { let summary = "Equal operator"; let description = [{ @@ -757,8 +776,8 @@ def TFL_EqualOp: TFL_Op<"equal", [Commutative, Broadcastable, let arguments = ( ins - TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$x, - TensorOf<[I1, F32, I32, I64, I8, TFL_Uint8]>:$y + TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$x, + TensorOf<[I1, F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$y ); let results = (outs TFL_BoolTensor:$output); @@ -773,9 +792,9 @@ def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { Performs element-wise natural exponentiation operation on input. }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TFL_FpTensor:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_FpTensor:$y); let hasOptions = 0b1; } @@ -825,7 +844,7 @@ size 1. } def TFL_SqueezeOp: TFL_Op<"squeeze", [NoSideEffect, - TFL_SameOperandsAndResultsScale]> { + SameOperandsAndResultsScale]> { let summary = "Removes dimensions of size 1 from the shape of a tensor."; let description = [{ @@ -917,17 +936,15 @@ def TFL_FloorModOp : TFL_Op<"floor_mod", [Broadcastable, NoSideEffect]> { }]; let arguments = ( - ins AnyTensor:$lhs, - AnyTensor:$rhs); + ins TensorOf<[I32, I64, F32]>:$lhs, + TensorOf<[I32, I64, F32]>:$rhs); - let results = (outs AnyTensor:$output); + let results = (outs TensorOf<[I32, I64, F32]>:$output); - let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }]; - - let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }]; + let builders = [TFL_BroadcastableBinaryBuilder]; } -def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, TFL_NoQuantizableResult]> { +def TFL_GreaterOp : TFL_Op<"greater", [NoSideEffect, NoQuantizableResult]> { let summary = "Greater operator"; let description = [{ @@ -962,6 +979,25 @@ def TFL_InputOp : Op { let results = (outs AnyTensor:$output); } +def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect]> { + let summary = "L2 Normalize Operator"; + + let description = [{ + L2Normalization Op + }]; + + let arguments = (ins + TensorOf<[F32, QUI8, QI8, I8]>:$input, + TFL_AFAttr:$fused_activation_function + ); + + let results = (outs TensorOf<[F32, QUI8, QI8, I8]>:$output); + + let hasOptions = 1; + + let customOption = "L2NormOptions"; +} + def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Leaky Relu operator"; @@ -983,7 +1019,7 @@ def TFL_LeakyReluOp: TFL_Op<"leaky_relu", [NoSideEffect, SameOperandsAndResultTy let hasOptions = 0b1; } -def TFL_LessOp : TFL_Op<"less", [NoSideEffect, TFL_NoQuantizableResult]> { +def TFL_LessOp : TFL_Op<"less", [NoSideEffect, NoQuantizableResult]> { let summary = "Less operator"; let description = [{ @@ -1051,6 +1087,24 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [NoSideEffect]> { let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }]; } +def TFL_LogisticOp: TFL_Op<"logistic", [ + NoSideEffect, + SameOperandsAndResultShape, + // zero_point = 0 + // scale = 1. / (max_value + 1) + FixedResultScale>, + FixedResultScale>]> { + let summary = "Logistic operator"; + + let description = [{ + Computes element-wise Sigmoid of input + }]; + + let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x); + + let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y); +} + def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Natural logarithm operator"; @@ -1061,6 +1115,8 @@ def TFL_LogOp: TFL_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins AnyTensor:$x); let results = (outs AnyTensor:$y); + + let hasFolder = 1; } // TODO(b/130643170): Adds some constraint for the input/output element types. @@ -1069,8 +1125,8 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ SameOperandsAndResultShape, // zero_point = max_value // scale = -log_softmax_output_min / (max_value + 1) - TFL_FixedResultScale>, - TFL_FixedResultScale>]> { + FixedResultScale>, + FixedResultScale>]> { let summary = "Log softmax operator"; let description = [{ @@ -1096,13 +1152,13 @@ def MaxPoolOperandAndResultConstraints : PredOpTrait<"MaxPool2D operand and " And<[ // The input and output tensors should have the same elemental type // and they should be one of the specified types below. - TCopVTEtIs<0, AnyTypeOf<[F32, TFL_QI8, TFL_QUI8]>>, + TCopVTEtIs<0, AnyTypeOf<[F32, QI8, QUI8]>>, TFL_TCresVTEtIsSameAsOp<0, 0>]>>; def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ NoSideEffect, MaxPoolOperandAndResultConstraints, - TFL_SameOperandsAndResultsScale]> { + SameOperandsAndResultsScale]> { let summary = "Max Pool 2D op"; let description = [{ @@ -1129,25 +1185,28 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ let customOption = "Pool2DOptions"; } -def TFL_MaximumOp : TFL_Op<"maximum", [Broadcastable, NoSideEffect, Commutative]> { +def TFL_MaximumOp : TFL_Op<"maximum", [ + Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale]> { let summary = "Max operator"; let description = [{ Element-wise max operation. }]; let arguments = ( - ins TFL_FpOrI32OrI64Tensor:$lhs, - TFL_FpOrI32OrI64Tensor:$rhs + ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, + TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs ); - let results = (outs TFL_FpOrI32OrI64Tensor:$max); + let results = (outs + TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$max + ); let builders = [TFL_BroadcastableBinaryBuilder]; let hasOptions = 0; } -def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> { +def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Mean operator"; let description = [{ @@ -1159,12 +1218,13 @@ def TFL_MeanOp : TFL_Op<"mean", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input, + TensorOf<[F32, I8, I32, I64, QI8, QUI8, TFL_Uint8]>:$input, TensorOf<[I32, I64]>:$axis, BoolAttr:$keep_dims ); - let results = (outs TensorOf<[F32, I32, I64, I8]>:$output); + let results = (outs + TensorOf<[F32, I32, I64, I8, QI8, QUI8, TFL_Uint8]>:$output); let hasOptions = 1; let customOption = "ReducerOptions"; @@ -1198,7 +1258,24 @@ def TFL_OneHotOp : TFL_Op<"one_hot", [NoSideEffect]> { let hasOptions = 1; } -def TFL_SliceOp : TFL_Op<"slice", [NoSideEffect]> { +def TFL_RoundOp: TFL_Op<"round", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Round operator"; + + let description = [{ +Rounds the values of a tensor to the nearest integer, element-wise. + }]; + + let arguments = (ins + TensorOf<[F32]>:$x + ); + + let results = (outs + TensorOf<[F32]>:$y + ); +} + +def TFL_SliceOp : TFL_Op<"slice", [ + NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Return a slice from 'input'."; let description = [{ @@ -1206,6 +1283,11 @@ The output tensor is a tensor with dimensions described by 'size' whose values are extracted from 'input' starting at the offsets in 'begin'. +`begin` is zero-based; `size` is one-based. If size[i] is -1, all remaining +elements in dimension i are included in the slice. In other words, this is +equivalent to setting: + size[i] = input.dim_size(i) - begin[i] + *Requirements*: 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) }]; @@ -1219,6 +1301,8 @@ whose values are extracted from 'input' starting at the offsets in let results = (outs AnyTensor:$output ); + + let verifier = [{ return Verify(*this); }]; } def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { @@ -1230,7 +1314,7 @@ def TFL_SumOp: TFL_Op<"sum", [NoSideEffect]> { let arguments = (ins AnyTensor:$input, - TFL_I32OrI64Tensor:$axes, + I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1249,7 +1333,7 @@ def TFL_ReduceMinOp: TFL_Op<"reduce_min", [NoSideEffect]> { let arguments = (ins AnyTensor:$input, - TFL_I32OrI64Tensor:$axes, + I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1268,7 +1352,7 @@ def TFL_ReduceMaxOp: TFL_Op<"reduce_max", [NoSideEffect]> { let arguments = (ins AnyTensor:$input, - TFL_I32OrI64Tensor:$axes, + I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1287,7 +1371,7 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { let arguments = (ins TensorOf<[F32, I8, I32, I64]>:$input, - TFL_I32OrI64Tensor:$axes, + I32Tensor:$axes, BoolAttr:$keep_dims ); @@ -1297,18 +1381,21 @@ def TFL_ReduceProdOp: TFL_Op<"reduce_prod", [NoSideEffect]> { let customOption = "ReducerOptions"; } -def TFL_MinimumOp : TFL_Op<"minimum", [Broadcastable, NoSideEffect, Commutative]> { +def TFL_MinimumOp : TFL_Op<"minimum", [ + Broadcastable, NoSideEffect, Commutative, SameOperandsAndResultsScale]> { let summary = "Min operator"; let description = [{ Element-wise min operation. }]; let arguments = ( - ins TFL_FpOrI32OrI64Tensor:$lhs, - TFL_FpOrI32OrI64Tensor:$rhs + ins TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$lhs, + TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$rhs ); - let results = (outs TFL_FpOrI32OrI64Tensor:$min); + let results = (outs + TensorOf<[AnyFloat, TFL_Int32Or64, QI8, QUI8]>:$min + ); let builders = [TFL_BroadcastableBinaryBuilder]; @@ -1402,6 +1489,7 @@ def TFL_PackOp : TFL_Op<"pack", [NoSideEffect]> { def TFL_PadOp : TFL_Op<"pad", [ NoSideEffect, + SameOperandsAndResultsScale, TFL_OperandHasRank<1, 2>, TFL_OperandRankEquals1DimOfOperand<0, 1>]> { let summary = "Padding operator"; @@ -1431,16 +1519,17 @@ def TFL_PadOp : TFL_Op<"pad", [ }]; let arguments = ( - ins TensorOf<[F32, I8, I32, I64]>:$input, + ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$padding); - let results = (outs TensorOf<[F32, I8, I32, I64]>:$output); + let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); let hasOptions = 1; } def TFL_PadV2Op : TFL_Op<"padv2", [ NoSideEffect, + SameOperandsAndResultsScale, TFL_OperandHasRank<1, 2>, TFL_OperandHasRank<2, 0>, TFL_OperandRankEquals1DimOfOperand<0, 1>, @@ -1475,11 +1564,11 @@ def TFL_PadV2Op : TFL_Op<"padv2", [ }]; let arguments = ( - ins TensorOf<[F32, I8, I32, I64]>:$input, + ins TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, TFL_I32OrI64Tensor:$padding, TensorOf<[F32, I8, I32, I64]>:$constant_values); - let results = (outs TensorOf<[F32, I8, I32, I64]>:$output); + let results = (outs TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$output); let hasOptions = 1; } @@ -1511,9 +1600,13 @@ def TFL_RankOp: TFL_Op<"rank", [NoSideEffect]> { let arguments = (ins AnyTensor:$input); let results = (outs TFL_IntTensor:$output); + + let hasFolder = 1; } -def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale]> { let summary = "Relu operator"; let description = [{ @@ -1526,7 +1619,9 @@ def TFL_ReluOp: TFL_Op<"relu", [NoSideEffect, SameOperandsAndResultType]> { let results = (outs AnyTensor:$y); } -def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, SameOperandsAndResultType]> { +def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, + SameOperandsAndResultShape, + SameOperandsAndResultsScale]> { let summary = "Relu6 operator"; let description = [{ @@ -1540,7 +1635,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [NoSideEffect, SameOperandsAndResultType]> { } def TFL_ReshapeOp: TFL_Op<"reshape", [ - NoSideEffect, TFL_SameOperandsAndResultsScale]> { + NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Reshape operator"; let description = [{ @@ -1577,9 +1672,8 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension `seq_dim` reversed. }]; - // Missing Uint8. let arguments = (ins - TensorOf<[F32, I16, I32, I64]>:$input, + TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$input, TFL_I32OrI64Tensor:$seq_lengths, I32Attr:$seq_dim, @@ -1587,7 +1681,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension ); let results = (outs - TensorOf<[F32, I16, I32, I64]>:$output + TensorOf<[F32, I16, I32, I64, TFL_Uint8]>:$output ); let hasOptions = 1; @@ -1603,9 +1697,11 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins AnyTensor:$x); let results = (outs AnyTensor:$y); + + let hasFolder = 1; } -def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, TFL_NoQuantizableResult]> { +def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, NoQuantizableResult]> { let summary = "Shape operator"; let description = [{ @@ -1623,20 +1719,12 @@ def TFL_ShapeOp: TFL_Op<"shape", [NoSideEffect, TFL_NoQuantizableResult]> { let hasOptions = 1; } -def TFL_LogisticOp: TFL_Op<"logistic", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Logistic operator"; - - let description = [{ - Computes element-wise Sigmoid of input - }]; - - let arguments = (ins TFL_FpTensor:$x); - - let results = (outs TFL_FpTensor:$y); -} - // TODO(jpienaar): Flesh this out. -def TFL_RangeOp: TFL_Op<"range", [NoSideEffect]> { +def TFL_RangeOp: TFL_Op<"range", [NoSideEffect, TFL_OperandHasRank<0, 0>, + TFL_OperandHasRank<1, 0>, TFL_OperandHasRank<2, 0>, + PredOpTrait<"operands and output must have same element type", + And<[TCresVTEtIsSameAsOp<0, 0>, TCresVTEtIsSameAsOp<0, 1>, + TCresVTEtIsSameAsOp<0, 2>]>>]> { let summary = "Range operator"; let description = [{ @@ -1650,6 +1738,8 @@ def TFL_RangeOp: TFL_Op<"range", [NoSideEffect]> { AnyTensor:$delta); let results = (outs AnyTensor:$result); + + let hasFolder = 1; } def TFL_ReverseV2Op: TFL_Op<"reverse_v2", @@ -1703,9 +1793,8 @@ def TFL_SelectOp : TFL_Op<"select", [NoSideEffect, let arguments = (ins TFL_BoolTensor:$condition, - // TODO: Missing uint8. - TensorOf<[F32, I1, I8, I16, I32, I64]>:$x, - TensorOf<[F32, I1, I8, I16, I32, I64]>:$y); + TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$x, + TensorOf<[F32, I1, I8, I16, I32, I64, TFL_Uint8]>:$y); let results = (outs AnyTensor:$output); // TODO(jpienaar): autogenerate this. @@ -1730,6 +1819,8 @@ def TFL_SinOp: TFL_Op<"sin", [NoSideEffect, SameOperandsAndResultType]> { let arguments = (ins TFL_FpTensor:$x); let results = (outs TFL_FpTensor:$y); + + let hasFolder = 1; } // TODO(b/130643170): Adds some constraint for the input/output element types. @@ -1738,8 +1829,8 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ SameOperandsAndResultShape, // zero_point = 0 // scale = 1. / (max_value + 1) - TFL_FixedResultScale>, - TFL_FixedResultScale>]> { + FixedResultScale>, + FixedResultScale>]> { let summary = "Softmax operator"; let description = [{ @@ -1765,9 +1856,11 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [NoSideEffect, SameOperandsAndResultType]> { Computes element-wise Square root of input }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TFL_FpTensor:$x); - let results = (outs AnyTensor:$y); + let results = (outs TFL_FpTensor:$y); + + let hasFolder = 1; } def TFL_SquareOp: TFL_Op<"square", [NoSideEffect, SameOperandsAndResultType]> { @@ -1777,11 +1870,13 @@ def TFL_SquareOp: TFL_Op<"square", [NoSideEffect, SameOperandsAndResultType]> { Computes element-wise Square of input }]; - let arguments = (ins AnyTensor:$x); + let arguments = (ins TensorOf<[AnyFloat, QI8, QUI8]>:$x); - let results = (outs AnyTensor:$y); + let results = (outs TensorOf<[AnyFloat, QI8, QUI8]>:$y); let hasOptions = 0b1; + + let hasFolder = 1; } def TFL_SubOp : TFL_Op<"sub", [Broadcastable, NoSideEffect]> { @@ -1833,22 +1928,21 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [Broadcastable, NoSid def TFL_TanhOp: TFL_Op<"tanh", [ NoSideEffect, - SameOperandsAndResultType, + SameOperandsAndResultShape, // central_value = min_value / 2 + (max_value - 1) / 2 + 1 // zero_point = central_value // scale = 1. / (central_value - min_value) - TFL_FixedResultScale>, - TFL_FixedResultScale>]> { + FixedResultScale>, + FixedResultScale>]> { let summary = "Hyperbolic tangent operator"; let description = [{ Computes element-wise Hyperbolic tangent of input }]; - // TODO(haoliang): missing Uint8. - let arguments = (ins TensorOf<[F32, I16, I8]>:$x); + let arguments = (ins TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$x); - let results = (outs TensorOf<[F32, I16, I8]>:$y); + let results = (outs TensorOf<[F32, I16, I8, QI8, QUI8, TFL_Uint8]>:$y); } def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, @@ -1865,9 +1959,11 @@ def TFL_TileOp: TFL_Op<"tile", [NoSideEffect, For example, tiling [a b c d] by [2] produces [a b c d a b c d]. }]; - let arguments = (ins AnyTensor:$input, TFL_I32OrI64Tensor:$multiples); + let arguments = (ins + TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$input, + TFL_I32OrI64Tensor:$multiples); - let results = (outs AnyTensor:$output); + let results = (outs TensorOf<[F32, I1, I32, I64, TFL_Uint8]>:$output); let hasOptions = 0; } @@ -1887,8 +1983,7 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, }]; let arguments = (ins - // TODO: Missing uint8 - TensorOf<[F32, I8, I32, I64]>:$input, + TensorOf<[F32, I8, I32, I64, TFL_Uint8]>:$input, I32Tensor:$k); let results = (outs @@ -1906,11 +2001,13 @@ def TFL_TopKV2Op: TFL_Op<"topk_v2", [NoSideEffect, TFL_OperandHasRank<1,0>, // dimensions. def TFL_TransposeOp : TFL_Op<"transpose", [NoSideEffect, + TFL_OperandHasRank<1,1>, // TODO(jpienaar): these are only true dynamically, change so that it works // with unknowns. - // TFL_OperandHasRank<1,1>, // TFL_OperandRankEquals1DimOfOperand<0, 1>, - TFL_SameOperandsAndResultsScale]> { + PredOpTrait<"input and output must have same element type", + TCresVTEtIsSameAsOp<0, 0>>, + SameOperandsAndResultsScale]> { let summary = "Transpose operator"; let description = [{ @@ -1919,12 +2016,14 @@ def TFL_TransposeOp : TFL_Op<"transpose", let arguments = ( ins AnyTensor:$x, - AnyTensor:$perm + TensorOf<[I32]>:$perm ); let results = (outs AnyTensor:$y ); + + let hasFolder = 1; } def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> { @@ -1948,14 +2047,14 @@ def TFL_UnpackOp : TFL_Op<"unpack", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I8, I32]>:$input, + TensorOf<[F32, I8, I32, QI8, QUI8]>:$input, I32Attr:$num, I32Attr:$axis ); let results = (outs - Variadic>:$outputs + Variadic>:$outputs ); let verifier = [{ return Verify(*this); }]; @@ -1979,6 +2078,7 @@ def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [NoSideEffect]> { def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ NoSideEffect, + SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", TCresVTEtIsSameAsOp<0, 0>> ]> { @@ -1989,18 +2089,19 @@ def TFL_BatchToSpaceNdOp: TFL_Op<"batch_to_space_nd", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64]>:$input, + TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, TensorOf<[I32]>:$block_shape, TensorOf<[I32]>:$indices ); let results = (outs - TensorOf<[F32, I16, I32, I64]>:$output + TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output ); } def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ NoSideEffect, + SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", TCresVTEtIsSameAsOp<0, 0>> ]> { @@ -2011,17 +2112,76 @@ def TFL_SpaceToBatchNdOp: TFL_Op<"space_to_batch_nd", [ }]; let arguments = (ins - TensorOf<[F32, I8, I32, I64]>:$input, + TensorOf<[F32, I8, I32, I64, QI8, QUI8]>:$input, TensorOf<[I32]>:$block_shape, TensorOf<[I32]>:$paddings ); let results = (outs - TensorOf<[F32, I16, I32, I64]>:$output + TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$output ); } -def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> { +def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [ + NoSideEffect, + SameOperandsAndResultsScale, + PredOpTrait<"input and output must have same element type", + TCresVTEtIsSameAsOp<0, 0>> + ]> { + let summary = "SpaceToDepth operator"; + + let description = [{ + Rearranges blocks of spatial data, into depth. More specifically, + this op outputs a copy of the input tensor where values from the `height` + and `width` dimensions are moved to the `depth` dimension. + `block_size` indicates the input block size. + }]; + + let arguments = (ins + TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input, + I32Attr:$block_size + ); + + let results = (outs + TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output + ); + + let hasOptions = 1; +} + +def TFL_DepthToSpaceOp: TFL_Op<"depth_to_space", [ + NoSideEffect, + SameOperandsAndResultsScale, + PredOpTrait<"input and output must have same element type", + TCresVTEtIsSameAsOp<0, 0>> + ]> { + let summary = "DepthToSpace operator"; + + let description = [{ + Rearranges data from depth into blocks of spatial data. + This is the reverse transformation of SpaceToDepth. More specifically, + this op outputs a copy of the input tensor where values from the `depth` + dimension are moved in spatial blocks to the `height` and `width` + dimensions. The attr `block_size` indicates the input block size and how + the data is moved. + }]; + + let arguments = (ins + TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$input, + I32Attr:$block_size + ); + + let results = (outs + TensorOf<[F32, I8, I32, I64, TFL_Uint8, QUI8]>:$output + ); + + let hasOptions = 1; +} + +def Rank0I32Tensor : Type]>, + "tensor">; + +def TFL_SplitOp : TFL_Op<"split", [NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Splits a tensor into `num_split` tensors along one dimension."; let description = [{ @@ -2031,19 +2191,21 @@ def TFL_SplitOp : TFL_Op<"split", [NoSideEffect]> { }]; let arguments = (ins - I32Tensor:$split_dim, - TensorOf<[F32, I16, I32, I64]>:$value, - I32Attr:$num_splits + Rank0I32Tensor:$split_dim, + TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value, + PositiveI32Attr:$num_splits ); let results = (outs - Variadic>:$outputs + Variadic>:$outputs ); + let verifier = [{ return Verify(*this); }]; + let hasOptions = 1; } -def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> { +def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Splits a tensor into `num_split` tensors along one dimension."; let description = [{ @@ -2053,20 +2215,21 @@ def TFL_SplitVOp : TFL_Op<"split_v", [NoSideEffect]> { }]; let arguments = (ins - TensorOf<[F32, I16, I32, I64]>:$value, + TensorOf<[F32, I16, I32, I64, QI8, QUI8]>:$value, I32Tensor:$size_splits, I32Tensor:$split_dim, I32Attr:$num_splits ); let results = (outs - Variadic>:$outputs + Variadic>:$outputs ); let hasOptions = 1; } -def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [NoSideEffect]> { +def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [ + NoSideEffect, SameOperandsAndResultsScale]> { let summary = "ResizeBilinear Op"; let description = [{ @@ -2075,22 +2238,82 @@ def TFL_ResizeBilinearOp: TFL_Op<"resize_bilinear", [NoSideEffect]> { let arguments = (ins // TODO(ycling): Support quantized types. - TensorOf<[F32, I32]>:$input, + TensorOf<[F32, I32, QI8, QUI8]>:$input, TensorOf<[I32]>:$size, BoolAttr:$align_corners); let results = (outs - TensorOf<[F32]>:$output + TensorOf<[F32, QI8, QUI8]>:$output ); let hasOptions = 1; } +def TFL_ResizeNearestNeighborOp : TFL_Op<"resize_nearest_neighbor", + [NoSideEffect, + SameOperandsAndResultsScale]> { + let summary = "ResizeNearestNeighbor Op"; + + let description = [{ + Resize `images` to `size` using nearest neighbor interpolation. + }]; + + let arguments = (ins + TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$input, + TensorOf<[I32]>:$size, + BoolAttr:$align_corners + ); + + let results = (outs + TensorOf<[F32, I8, TFL_Uint8, QUI8, QI8]>:$output + ); + + let hasOptions = 1; +} + +def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [NoSideEffect]> { + let summary = "Converts a sparse representation into a dense tensor."; + + let description = [{ +Builds an array `dense` with shape `output_shape` such that + +``` +# If sparse_indices is scalar +dense[i] = (i == sparse_indices ? sparse_values : default_value) + +# If sparse_indices is a vector, then for each i +dense[sparse_indices[i]] = sparse_values[i] + +# If sparse_indices is an n by d matrix, then for each i in [0, n) +dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +``` + +All other values in `dense` are set to `default_value`. If `sparse_values` is a +scalar, all sparse indices are set to this single value. + +Indices should be sorted in lexicographic order, and indices must not +contain any repeats. If `validate_indices` is true, these properties +are checked during execution. + }]; + + let arguments = (ins + TFL_I32OrI64Tensor:$sparse_indices, + TFL_I32OrI64Tensor:$output_shape, + TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$sparse_values, + TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$default_value + ); + + let results = (outs + TensorOf<[I32, I64, I8, TFL_Uint8, F32]>:$dense + ); +} + def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ NoSideEffect, PredOpTrait<"input and output must have same element type", - TCresVTEtIsSameAsOp<0, 0>> + TFL_TCresVTEtIsSameAsOp<0, 0>>, + SameOperandsAndResultsScale ]> { let summary = "StridedSlice Op"; @@ -2099,7 +2322,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", }]; let arguments = (ins - TensorOf<[F32, I32, I64, I8]>:$input, + TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$input, TensorOf<[I32]>:$begin, TensorOf<[I32]>:$end, TensorOf<[I32]>:$strides, @@ -2112,7 +2335,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", ); let results = (outs - TensorOf<[F32, I32, I64, I8]>:$output + TensorOf<[F32, I32, I64, I8, QI8, QUI8]>:$output ); let hasOptions = 1; @@ -2207,7 +2430,7 @@ in the unique output `y`. In other words: // Quantization ops. //===----------------------------------------------------------------------===// def TFL_DequantizeOp: TFL_Op<"dequantize", [ - NoSideEffect, TFL_NoQuantizableResult]> { + NoSideEffect, NoQuantizableResult]> { let summary = "Dequantize operator"; let description = [{ @@ -2243,7 +2466,7 @@ def TFL_FakeQuantOp : TFL_Op<"fake_quant", [NoSideEffect]> { } def TFL_QConstOp : Op { + NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> { let summary = "Quantized constant pseudo op"; let description = [{ @@ -2261,7 +2484,7 @@ def TFL_QConstOp : Op { + NoSideEffect, FirstAttrDerivedResultType, NoQuantizableResult]> { let summary = "Quantize operator"; let description = [{ @@ -2305,7 +2528,7 @@ def LstmProjectionWeightBiasConstraint : PredOpTrait< "projection bias must not be specified", Or<[ And<[TCopVTEtIs<16, NoneType>, TCopVTEtIs<17, NoneType>]>, - TFL_TCopIsNot<16, NoneType>]>>; + Neg>]>>; // TODO(b/137798843): Need to add two additional constraints for both LSTM and // UnidirectionalSequenceLstm @@ -2327,7 +2550,8 @@ def TFL_LSTMOp : [LstmMandatoryInputsConstraint, LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, - LstmResultConstraint]> { + LstmResultConstraint, + StatefulOperands<[18, 19]>]> { let summary = "The full lstm operator"; let description = [{ @@ -2405,9 +2629,11 @@ Ba et al. “Layer Normalization” let results = (outs AnyTensor:$output); let hasOptions = 1; + + let verifier = [{ return Verify(*this); }]; } -// UnidirectionalSequenceLstm op . +// UnidirectionalSequenceLstm op. // TODO(ashwinm): Add constraint to validate the combination of operands // that are valid for hybrid vs fully quantized vs float only semantics def TFL_UnidirectionalSequenceLSTMOp : @@ -2415,7 +2641,8 @@ def TFL_UnidirectionalSequenceLSTMOp : [LstmMandatoryInputsConstraint, LstmOptionalPeepholeWeightConstraint, LstmProjectionWeightBiasConstraint, - LstmResultConstraint]> { + LstmResultConstraint, + StatefulOperands<[18, 19]>]> { let summary = "Unidirectional sequence lstm operator"; let description = [{ @@ -2482,6 +2709,129 @@ def TFL_UnidirectionalSequenceLSTMOp : let results = (outs AnyTensor:$output); let hasOptions = 1; + + let verifier = [{ return Verify(*this); }]; +} + +def RnnResultConstraint : PredOpTrait< + "the input and result tensor elemental types must be same", + TCresVTEtIsSameAsOp<0, 0>>; + +// UnidirectionalSequenceRNN op. +def TFL_UnidirectionalSequenceRNNOp : + TFL_Op<"unidirectional_sequence_rnn", + [RnnResultConstraint, StatefulOperands<[4]>]> { + + let summary = "Unidirectional sequence rnn operator"; + + let description = [{ + A recurrent neural network specified by an RNN cell. This Op takes in input + in a format {batch_size, seq_len, input_size} or + {seq_len, batch_size, input_size} if it's time-majored. + + It implements the following operation for + each element in the sequence s = 1...sequence_length: + outputs[s] = state = activation(RNNOp(inputs[s])) + + where RNNOp is RNNOp TF Lite Op and the “activation” is the function passed + as the “fused_activation_function” argument (if not “NONE”). + }]; + + let arguments = ( + ins TensorOf<[F32, I8]>:$input, + + // Weights + TensorOf<[F32, I8]>:$input_to_input_weights, + + // Recurrent weights + TensorOf<[F32, I8]>:$recurrent_to_input_weights, + + // Bias + TensorOf<[F32]>:$input_gate_bias, + + // Hidden state. + TFL_StatefulTensor:$hidden_state, + + // Attributes + BoolAttr:$time_major, + TFL_AFAttr:$fused_activation_function + ); + + let results = (outs TensorOf<[F32, I8]>:$output); + + let hasOptions = 1; + + let customOption = "SequenceRNNOptions"; + + let verifier = [{ return Verify(*this); }]; +} + +def TFL_WhereOp : TFL_Op<"where", [NoSideEffect]> { + let summary = "Returns locations of nonzero / true values in a tensor."; + + let description = [{ +This operation returns the coordinates of true elements in `condition`. The +coordinates are returned in a 2-D tensor where the first dimension (rows) +represents the number of true elements, and the second dimension (columns) +represents the coordinates of the true elements. Keep in mind, the shape of +the output tensor can vary depending on how many true values there are in +`condition`. Indices are output in row-major order. + }]; + + let arguments = (ins + I1Tensor:$input + ); + + // TODO(haoliang): TF Lite only support I32 output right now, need to fix + // either here or in the kernel. + let results = (outs + TFL_I32OrI64Tensor:$index + ); +} + +def SVDFResultConstraint: PredOpTrait< + "the input and result tensor elemental types must be same", + TCresVTEtIsSameAsOp<0, 0>>; + +// SVDF op. +def TFL_SVDFOp : + TFL_Op<"svdf", + [SVDFResultConstraint, StatefulOperands<[4]>]> { + + let summary = "Single value decomposition filter operator"; + + let description = [{ + The SVDF op is a decomposition of a densely connected op into low rank + filters. + For details: https://research.google.com/pubs/pub43813.html + https://arxiv.org/abs/1812.02802 + }]; + + let arguments = ( + ins TensorOf<[F32, I8]>:$input, + + // Feature Weights. + TensorOf<[F32, I8]>:$feature_weights, + + // Time weights + TensorOf<[F32, I8]>:$time_weights, + + // Bias + TFL_TensorOfOrNone<[F32]>:$input_gate_bias, + + // Activation state. + TFL_StatefulTensor:$activation_state, + + // Attributes + I32Attr:$rank, + TFL_AFAttr:$fused_activation_function + ); + + let results = (outs TensorOf<[F32, I8]>:$output); + + let hasOptions = 1; + + let verifier = [{ return Verify(*this); }]; } #endif // TFL_OPS diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h index 807c1100b71..af8c707a04e 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_traits.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_traits.h @@ -18,108 +18,32 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_IR_TFL_TRAITS_H_ -#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" // TF:local_config_mlir -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" namespace mlir { namespace OpTrait { namespace TFL { -using QuantizedType = mlir::quant::QuantizedType; -using UniformQuantizedType = mlir::quant::UniformQuantizedType; - -// The base class that all the quantization related OpTrait implements. -template class TraitType> -struct QuantizationSpecTraitBase : public TraitBase { - static bool IsBias(int index) { return false; } - static bool IsQuantizable() { return true; } -}; - -// This class provides the API for TFL ops that requires same input and output -// scale as the quantization results. This is used as a trait like this: -// -// class TransposeOp -// : public Op { -// -template -class SameOperandsAndResultsScale - : public QuantizationSpecTraitBase {}; - -// This class provides the API for TFL ops that has a fixed output value range. +// The trait to specify that the specified operands of the TFL op are stateful. // This is used as a trait like this: // -// class SoftmaxOp -// : public Op::Impl> { +// class LSTMOp +// : public Op::Impl> { // -// TODO(fengliuai): create a better way to epxress floating point scale in the -// template argument list. -template -class FixedResultUniformScale { +template +class StatefulOperands { public: template class Impl - : public QuantizationSpecTraitBase< - ConcreteType, FixedResultUniformScale< - BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, - StorageTypeMin, StorageTypeMax, Sign>::Impl> { + : public TraitBase::Impl> { public: - QuantizedType GetResultQuantizedType(int index) { - auto op = this->getOperation(); - auto result_type = - op->getResult(index)->getType().template cast(); - Builder builder(op->getContext()); - IntegerType storage_type = builder.getIntegerType(BitWidth); - const double scale = static_cast(ScaleMantissa) * - ::pow(10.0, static_cast(ScaleExp)); - return UniformQuantizedType::getChecked( - Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, - StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); - } - }; -}; - -// This class provides the API for TFL ops that has input as bias. This is used -// as a trait like this: -// -// class Conv2DOp -// : public Op::Impl> { -// -// TODO(fengliuai): supports a configurable accumulator bit width. -template -class AccumulatorUniformScale { - public: - template - class Impl - : public QuantizationSpecTraitBase< - ConcreteType, AccumulatorUniformScale::Impl> { - public: - // Whether the index-th operand is a bias. - static bool IsBias(int index) { return index == Bias; } - - // Returns the indexes of all the non-bias operands. - static std::vector GetAllNonBiasOperands() { + static std::vector GetStatefulOperands() { return std::vector({Operands...}); } }; }; -// This class provides the API for TFL ops that shouldn't be quantized. This is -// used as a trait like this: -// -// class LessOp : public Op { -// -template -class NoQuantizableResult - : public QuantizationSpecTraitBase { - public: - static bool IsQuantizable() { return false; } -}; - } // namespace TFL } // namespace OpTrait } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index ff27ad76136..52a8bd35d36 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "mlir/IR/Function.h" // TF:local_config_mlir @@ -79,9 +78,7 @@ static std::string TfLiteTensorString(const TfLiteTensor& tensor) { } int main(int argc, char** argv) { - llvm::PrettyStackTraceProgram x(argc, argv); llvm::InitLLVM y(argc, argv); - llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR TFLite runner\n"); auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(inputFileName.c_str()); diff --git a/tensorflow/compiler/mlir/lite/operator_writer_gen.cc b/tensorflow/compiler/mlir/lite/operator_converter_gen.cc similarity index 84% rename from tensorflow/compiler/mlir/lite/operator_writer_gen.cc rename to tensorflow/compiler/mlir/lite/operator_converter_gen.cc index fd8325577d9..5db1aa1a3c0 100644 --- a/tensorflow/compiler/mlir/lite/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/lite/operator_converter_gen.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/InitLLVM.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Main.h" @@ -247,6 +247,51 @@ static void EmitBuildOperator(const std::vector &defs, "}\n"; } +// Emit a function that converts a BuiltinOptionsUnion to a vector of attributes +// Signature: +// void mlir::BuiltinOptionsToAttributes( +// tflite::BuiltinOptionsUnion op_union, +// mlir::Builder builder, +// llvm::SmallVectorImpl &attributes); +static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, + const std::vector &defs, + raw_ostream *ostream) { + raw_ostream &os = *ostream; + + // Signature + os << "void mlir::BuiltinOptionsToAttributes(" + "tflite::BuiltinOptionsUnion op_union, " + "mlir::Builder builder, " + "llvm::SmallVectorImpl &attributes) {\n"; + + const auto attr_type = record_keeper.getClass("Attr"); + for (const auto *def : defs) { + if (!def->getValueAsBit("hasOptions")) continue; + auto option_name = GetOperatorOptionName(*def); + os << formatv(" if(const auto *op = op_union.As{0}()) {\n", option_name); + + // We only care about options that are in arguments + auto *arg_values = def->getValueAsDag("arguments"); + for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) { + auto arg = arg_values->getArg(i); + DefInit *arg_def = dyn_cast(arg); + if (!arg_def) continue; + if (arg_def->getDef()->isSubClassOf(attr_type)) { + StringRef arg_name = arg_values->getArgNameStr(i); + StringRef attr_type = mlir::tblgen::Attribute(arg_def).getAttrDefName(); + os << formatv( + " attributes.emplace_back(builder.getNamedAttr(\"{0}\"," + " Build{1}(op->{0}, builder)));\n", + arg_name, attr_type); + } + } + + os << " return;\n"; + os << " }\n"; + } + // Fallthrough case is no attributes + os << "}"; +} // The function below has a non-constant reference as that is required by LLVM's // TableGenMain. // NOLINTNEXTLINE @@ -278,15 +323,14 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) { EmitGetBuiltinOpCode(defs, &os); os << "\n\n"; EmitBuildOperator(defs, &os); + os << "\n\n"; + EmitBuiltinOptionsToAttributes(records, defs, &os); return false; } int main(int argc, char **argv) { - llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); - llvm::PrettyStackTraceProgram X(argc, argv); - - llvm::llvm_shutdown_obj Y; + llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); return TableGenMain(argv[0], &OperatorWritersMain); } diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 0eab2981a83..5094b015f68 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -15,7 +15,10 @@ cc_library( hdrs = [ "graphdef_to_tfl_flatbuffer.h", ], + copts = ["-std=c++14"], deps = [ + "//tensorflow/compiler/mlir/lite:common", + "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index 2a60715e13d..b2bca0b4f54 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -17,8 +17,10 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -129,16 +131,26 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags, bool emit_select_tf_ops = toco_flags.enable_select_tf_ops(); bool emit_custom_ops = toco_flags.allow_custom_ops(); specs.prune_unused_nodes = true; + specs.convert_legacy_fed_inputs = true; + specs.graph_as_function = false; WarningUnusedFlags(model_flags, toco_flags); - bool emit_quant_adaptor_ops = false; - bool lower_tensor_list_ops = true; TF_ASSIGN_OR_RETURN( auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context)); - return ConvertTFControlFlowToTFLOrFlatbuffer( + + mlir::PassManager pm; + bool run_quantize = tensorflow::ShouldRunQuantizePasses(module.get()); + mlir::TFL::PassConfig pass_config; + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.run_quantize = run_quantize; + pass_config.lower_tensor_list_ops = true; + + tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); + + return ConvertTFExecutorToTFLOrFlatbuffer( module.get(), /*export_to_mlir=*/false, emit_builtin_tflite_ops, - emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops, - lower_tensor_list_ops, result); + emit_select_tf_ops, emit_custom_ops, /*emit_quant_adaptor_ops=*/false, + /*lower_tensor_list_ops=*/true, result, &pm); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD new file mode 100644 index 00000000000..57b9a48e5de --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -0,0 +1,60 @@ +load("//tensorflow:tensorflow.bzl", "tf_native_cc_binary") + +package( + default_visibility = [ + ":friends", + ], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["@local_config_mlir//:subpackages"], + packages = ["//tensorflow/compiler/mlir/..."], +) + +exports_files(["quantization_traits.h"]) + +filegroup( + name = "quantization_td_files", + srcs = [ + "quantization.td", + "@local_config_mlir//:OpBaseTdFiles", + "@local_config_mlir//:QuantizationOpsTdFiles", + ], +) + +cc_library( + name = "quantization_lib", + srcs = [ + "quantization_driver.cc", + "quantization_utils.cc", + ], + hdrs = [ + "quantization_utils.h", + ], + copts = ["-std=c++14"], + deps = [ + "@com_google_absl//absl/memory", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:QuantOps", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + # TODO(fengliuai): remove this dependence. + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/core:lib_proto_parsing", + ], +) + +tf_native_cc_binary( + name = "op_quant_spec_getters_gen", + srcs = [ + "tools/op_quant_spec_getters_gen.cc", + ], + deps = [ + "@llvm//:support", + "@llvm//:tablegen", + "@local_config_mlir//:TableGen", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization.td b/tensorflow/compiler/mlir/lite/quantization/quantization.td new file mode 100644 index 00000000000..24b299ba39b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/quantization.td @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the quantization definition file for TensorFlow. + +#ifdef TF_Quantization +#else +#define TF_Quantization + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +include "mlir/Dialect/QuantOps/QuantPredicates.td" + + +//===----------------------------------------------------------------------===// +// Min-max range pair definitions. +//===----------------------------------------------------------------------===// + +// A pair of floating point values which defines the min and max of a value +// range for quantization. The attribute is allowed to be empty or +// have 2 elements. +def MinMaxAttr : Attr().size() == 0">, + CPred<"$_self.cast().size() == 2">]>, + "min-max range pair"> { + let storageType = [{ ArrayAttr }]; + let returnType = [{ ArrayRef }]; +} + +//===----------------------------------------------------------------------===// +// QuantizedType definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +class QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +// Uniform quantized types. Two integers "smantissa" and "sexp" are used to +// express the Mantissa and Exponent components of the floating-point scale so +// the scale of the quantized type is "smantissa * 10 ^ sexp". +class UInt8UniformQuantizedType + : QuantizedType<"Uniform", + [8, zero_pt, smantissa, sexp, 0, 255], 0>; +class Int8UniformQuantizedType + : QuantizedType<"Uniform", + [8, zero_pt, smantissa, sexp, -128, 127], 1>; + +// General uniform quantized types. The definitions can be used to specify +// operand's tensor types. +def QUI8 : QuantizedType<"Uniform", [8], 0>; +def QI8 : QuantizedType<"Uniform", [8], 1>; +def QUI16 : QuantizedType<"Uniform", [16], 0>; +def QI16 : QuantizedType<"Uniform", [16], 1>; +def QUI32 : QuantizedType<"Uniform", [32], 0>; +def QI32 : QuantizedType<"Uniform", [32], 1>; + +//===----------------------------------------------------------------------===// +// TFL native op traits (for quantization). +// +// Ops in this link should have those traits specified: +// https://www.tensorflow.org/lite/performance/quantization_spec +//===----------------------------------------------------------------------===// + +// Specify this trait if the op has a fixed output value range. +class FixedResultScale : NativeOpTrait::Impl")>; + +// Specify this trait if the op requires same inputs and outputs quantization +// scales. +def SameOperandsAndResultsScale : NativeOpTrait< + "quant::SameOperandsAndResultsScale">; + +// Specify this trait if the b-th input of the op is a bias input, which needs +// a scale based on the scales of op1 and op2. +class AccumulatorUniformScale : NativeOpTrait< + !strconcat("quant::AccumulatorUniformScale<", + StrJoinInt<[bias, op1, op2]>.result, + ">::Impl")>; + +// Specify this trait if the op doesn't have quantizable ouput. We shouldn't +// apply quantization on this op. +def NoQuantizableResult : NativeOpTrait<"quant::NoQuantizableResult">; + +#endif // TF_Quantization diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc similarity index 87% rename from tensorflow/compiler/mlir/lite/utils/quantization_driver.cc rename to tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 1ab00ec3634..63c055c1ac8 100644 --- a/tensorflow/compiler/mlir/lite/utils/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -17,13 +17,14 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir @@ -32,47 +33,15 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_traits.h" -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/core/platform/logging.h" namespace mlir { namespace TFL { namespace { - -using QuantParams = quant::QuantizedType; -using AccumulatorScaleFunc = - std::function &)>; -using SignedInteger = std::pair; // bitwidth and sign -using QuantParamsForResults = llvm::SmallVector; - -// Quantization specs of ops, driving the TF Lite quantization algorithm. -struct OpQuantSpec { - // Whether the op has quantizable result. This flag is set to false if the op - // has "TFL::NoQuantizableResult" trait. - bool is_quantizable = true; - - // Whether it requires same inputs and result scale. This flag is set to true - // if the op has "TFL::SameOperandsAndResultScale" trait. - bool requires_same_scale = false; - - // Maps the operand index of a bias input to its quantization specifications, - // including the non-bias operand indexes and the method retrieving - // quantization parameters from list of parameters of the non-bias operands. - // This map is empty if the op doesn't havea bias operand. - std::unordered_map, AccumulatorScaleFunc>> - biases_params; - - // Quantization parameters for value restricted outputs. This is the - // "hard-coded" parameters and should be used unconditionally for the - // quantized op. This vector is empty if the op doesn't have value resctricted - // outputs. - llvm::DenseMap restricted_output_params; -}; - static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); } // The state for each op result during the quantization parameters propagation. @@ -125,8 +94,12 @@ struct RequantizeState { // class QuantizationDriver { public: - explicit QuantizationDriver(FuncOp fn, bool is_signed) - : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed) {} + explicit QuantizationDriver(FuncOp fn, bool is_signed, + OpQuantSpecGetter op_quant_spec_getter) + : fn_(fn), + builder_(fn.getBody()), + is_signed_(is_signed), + op_quant_spec_getter_(op_quant_spec_getter) {} // The entry point of the quantization parameters propagation. void Run(); @@ -146,17 +119,19 @@ class QuantizationDriver { // result. void Finalize(); - // Whether the constant is used as a bias input of another op. Here we assume - // bias is used immediately by the user. This assumption is always correct - // after constant folding. - bool UsedAsBias(ConstantOp cst) { - Value *value = cst.getResult(); - for (auto &use : value->getUses()) { - auto biases = GetQuantSpec(use.getOwner())->biases_params; - if (biases.find(use.getOperandNumber()) != biases.end()) return true; - } - return false; - } + // The quantization parameters of bias operand are usually determined by + // other operands, so if a constant is used by different ops as bias, it needs + // to be duplicated, thus each op can assign its own quantization parameter + // for this bias. Also this methods add all the non-bias constants to a set + // for looking up later. + void PreprocessConstantOps(); + + // Setup all the data structures for quantization propagation. + void SetupAllStates(); + + // Whether the constant is a weight, which shouldn't be shared by different + // ops. + bool IsWeight(Operation *cst) { return llvm::is_contained(weights_, cst); } // Returns all the related quantization constraints of the op. std::unique_ptr GetQuantSpec(Operation *op); @@ -294,6 +269,11 @@ class QuantizationDriver { OpBuilder builder_; bool is_signed_; + // We should distinguish weights and bias constants. Biases are specified by + // the quantization spec or are the operands of ops with same scale spec. The + // rest are weights. + llvm::DenseSet weights_; + // All the ops needs to propagate the quantization parameters to. std::vector work_list_; std::unordered_set quantized_; @@ -316,14 +296,13 @@ class QuantizationDriver { // This vector is to preserve the arguments order, so the newly inserted // quantized ops for the arguments are deterministically ordered. llvm::SmallVector args_; -}; -#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" + OpQuantSpecGetter op_quant_spec_getter_; +}; } // namespace -// TODO(fengliuai): cache the quantization parameters. std::unique_ptr QuantizationDriver::GetQuantSpec(Operation *op) { - return GetOpQuantSpec(op); + return op_quant_spec_getter_(op); } bool QuantizationDriver::IsQuantized(Operation *op) { @@ -354,10 +333,10 @@ bool QuantizationDriver::SetConstantResultParams(Operation *op) { if (!matchPattern(res, m_Constant(&attr))) { return false; } - // TODO(fengliuai): the bit width should be determined by its user. + // TODO(fengliuai): make storage_type_width and narrow_range configurable. auto final_type = - GetUniformQuantizedTypeForElementsAttr( - attr, /*storage_type_width=*/8, is_signed_, /*narrow_range_=*/false) + GetUniformQuantizedTypeForElementsAttr(attr, /*storage_type_width=*/8, + is_signed_, /*narrow_range_=*/true) .dyn_cast_or_null(); if (!final_type) return false; return SetResultParams(op, 0, final_type); @@ -432,6 +411,9 @@ void QuantizationDriver::QuantizeValue(Value *value, QuantParams params, Location loc) { Type expressed_type = value->getType(); Type new_type = params.castFromExpressedType(expressed_type); + // This value isn't an expressed type (float), skip. + if (!new_type) return; + TypeAttr type_attr = builder_.getTypeAttr(new_type); auto quantize = builder_.create(loc, new_type, value, type_attr); @@ -482,10 +464,15 @@ void QuantizationDriver::RequantizeValue(Value *value, RequantizeState *state, } else { Type expressed_type = quant::QuantizedType::castToExpressedType(value->getType()); + if (!expressed_type) return; + // The value needs to be requantized. A Quantize op will be created to use // it as the operand and replace its uses. new_type = state->params.castFromExpressedType(expressed_type); } + // This value isn't an expressed type (float), skip. + if (!new_type) return; + TypeAttr type_attr = builder_.getTypeAttr(new_type); auto requantize_op = builder_.create(loc, new_type, value, type_attr); @@ -560,12 +547,39 @@ QuantParams QuantizationDriver::GetQuantParamsForSameScaleConstraint( return {}; } -// This method scans the operations in the function to setup the initial -// states for quantization parameter propagation. -// TODO(fengliuai): This algorithm assumes there are only one pair of -// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity -// check should be applied. -void QuantizationDriver::Initialize() { +void QuantizationDriver::PreprocessConstantOps() { + fn_.walk([&](ConstantOp cst) { + // Non-float tensors are neither weights or require quantization. + if (!cst.getType().cast().getElementType().isa()) { + return; + } + + Value *value = cst.getResult(); + SmallVector, 4> bias_users; + for (auto &use : value->getUses()) { + auto spec = GetQuantSpec(use.getOwner()); + auto biases = spec->biases_params; + Operation *user = use.getOwner(); + int operand_num = use.getOperandNumber(); + + // The user doesn't use this value as a bias operand nor require same + // scale. + if (biases.find(operand_num) == biases.end() && + !spec->requires_same_scale) { + weights_.insert(cst); + } else { + bias_users.push_back({user, operand_num}); + } + } + builder_.setInsertionPoint(cst); + for (int i = 1; i < bias_users.size(); ++i) { + auto copied = builder_.create(cst.getLoc(), cst.getValue()); + bias_users[i].first->setOperand(bias_users[i].second, copied.getResult()); + } + }); +} + +void QuantizationDriver::SetupAllStates() { llvm::DenseMap value_to_state; fn_.walk([&](Operation *op) { @@ -603,6 +617,21 @@ void QuantizationDriver::Initialize() { }); } +// This method scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// TODO(fengliuai): This algorithm assumes there are only one pair of +// tfl.quantize and tfl.dequantize ops between two quantizable ops. A sanity +// check should be applied. +void QuantizationDriver::Initialize() { + // Duplicate the bias constant, so the states can be setup correctly. + // TODO(fengliuai): Function definition should also be duplicated if there are + // multiple call sites. + PreprocessConstantOps(); + + // Setup all the internal states. + SetupAllStates(); +} + bool QuantizationDriver::PropagateParams() { // TODO(fengliuai): uses a typed indicator instead of a bool value. bool changed = false; @@ -610,8 +639,8 @@ bool QuantizationDriver::PropagateParams() { Operation *op = work_list_.back(); work_list_.pop_back(); - // This op has been quantized, so we should consider it again. - if (quantized_.find(op) != quantized_.end()) continue; + // This op has been quantized, so we should not consider it again. + if (llvm::is_contained(quantized_, op)) continue; quantized_.insert(op); auto spec = GetQuantSpec(op); @@ -621,9 +650,8 @@ bool QuantizationDriver::PropagateParams() { if (!spec->is_quantizable) continue; if (auto cst = llvm::dyn_cast(op)) { - // This constant is used as a bias in another op, then the quantization - // parameters are determined by that op. - if (UsedAsBias(cst) || IsQuantized(op)) continue; + // If it isn't a weight or has been quantized, skip. + if (!IsWeight(cst) || IsQuantized(op)) continue; // The quantization parameters are determined by the content of the // constant. @@ -648,11 +676,13 @@ bool QuantizationDriver::PropagateParams() { for (int res = 0, e = op->getNumResults(); res != e; ++res) changed |= SetResultParams(op, res, params); } + // TODO(fengliuai): make the bit width configurable. auto key = std::make_pair(8, is_signed_); auto &restricted_outputs = spec->restricted_output_params[key]; - for (int i = 0, e = restricted_outputs.size(); i != e; ++i) + for (int i = 0, e = restricted_outputs.size(); i != e; ++i) { changed |= SetResultParams(op, i, restricted_outputs[i]); + } for (auto &it : spec->biases_params) { auto params = @@ -712,8 +742,9 @@ void QuantizationDriver::Run() { } } -void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed) { - QuantizationDriver(func, is_signed).Run(); +void ApplyQuantizationParamsPropagation( + mlir::FuncOp func, bool is_signed, OpQuantSpecGetter op_quant_spec_getter) { + QuantizationDriver(func, is_signed, op_quant_spec_getter).Run(); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h new file mode 100644 index 00000000000..b64776ddee7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_traits.h @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow Lite dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ + +#include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir + +namespace mlir { +namespace OpTrait { +namespace quant { + +using QuantizedType = mlir::quant::QuantizedType; +using UniformQuantizedType = mlir::quant::UniformQuantizedType; + +// The base class that all the quantization related OpTrait implements. +template class TraitType> +struct QuantizationSpecTraitBase : public TraitBase { + static bool IsBias(int index) { return false; } + static bool IsQuantizable() { return true; } +}; + +// This class provides the API for TFL ops that requires same input and output +// scale as the quantization results. This is used as a trait like this: +// +// class TransposeOp +// : public Op { +// +template +class SameOperandsAndResultsScale + : public QuantizationSpecTraitBase {}; + +// This class provides the API for TFL ops that has a fixed output value range. +// This is used as a trait like this: +// +// class SoftmaxOp +// : public Op::Impl> { +// +// TODO(fengliuai): create a better way to epxress floating point scale in the +// template argument list. +template +class FixedResultUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, FixedResultUniformScale< + BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, + StorageTypeMin, StorageTypeMax, Sign>::Impl> { + public: + QuantizedType GetResultQuantizedType(int index) { + auto op = this->getOperation(); + auto result_type = + op->getResult(index)->getType().template cast(); + Builder builder(op->getContext()); + IntegerType storage_type = builder.getIntegerType(BitWidth); + const double scale = static_cast(ScaleMantissa) * + ::pow(10.0, static_cast(ScaleExp)); + return UniformQuantizedType::getChecked( + Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, + StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); + } + }; +}; + +// This class provides the API for TFL ops that has input as bias. This is used +// as a trait like this: +// +// class Conv2DOp +// : public Op::Impl> { +// +// TODO(fengliuai): supports a configurable accumulator bit width. +template +class AccumulatorUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, AccumulatorUniformScale::Impl> { + public: + // Whether the index-th operand is a bias. + static bool IsBias(int index) { return index == Bias; } + + // Returns the indexes of all the non-bias operands. + static std::vector GetAllNonBiasOperands() { + return std::vector({Operands...}); + } + }; +}; + +// This class provides the API for TFL ops that shouldn't be quantized. This is +// used as a trait like this: +// +// class LessOp : public Op { +// +template +class NoQuantizableResult + : public QuantizationSpecTraitBase { + public: + static bool IsQuantizable() { return false; } +}; + +} // namespace quant +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_TRAITS_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc similarity index 90% rename from tensorflow/compiler/mlir/lite/utils/quantization_utils.cc rename to tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index da797db4cd4..31a7a181124 100644 --- a/tensorflow/compiler/mlir/lite/utils/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" // TF:local_config_mlir #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir @@ -61,6 +61,17 @@ TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, narrow_range.getValue(), /*is_signed=*/false); } +TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, + TypeAttr source, Type target) { + if (!source || !source.getValue().isa()) return {}; + auto ele_type = source.getValue().cast().getElementType(); + if (auto quantized_type = ele_type.dyn_cast()) { + Type final_type = quantized_type.castFromExpressedType(target); + if (final_type) return builder.getTypeAttr(final_type); + } + return {}; +} + Type GetUniformQuantizedTypeForElementsAttr(ElementsAttr attr, unsigned storage_type_width, bool is_signed, bool narrow_range) { diff --git a/tensorflow/compiler/mlir/lite/utils/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h similarity index 52% rename from tensorflow/compiler/mlir/lite/utils/quantization_utils.h rename to tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 941ce636bc1..e101893b06d 100644 --- a/tensorflow/compiler/mlir/lite/utils/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -16,18 +16,54 @@ limitations under the License. // This header file defines common utils used by TFLite transformation // passes to work with op attributes. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ + +#include #include "mlir/Dialect/QuantOps/QuantTypes.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir namespace mlir { namespace TFL { +using QuantParams = quant::QuantizedType; +using SignedInteger = std::pair; // bitwidth and sign +using QuantParamsForResults = llvm::SmallVector; +using AccumulatorScaleFunc = + std::function&)>; + +// Quantization spec of an op, driving the quantization algorithm. +struct OpQuantSpec { + // Whether the op has quantizable result. This flag is set to false if the op + // has "TFL::NoQuantizableResult" trait. + bool is_quantizable = true; + + // Whether it requires same inputs and result scale. This flag is set to true + // if the op has "TFL::SameOperandsAndResultScale" trait. + bool requires_same_scale = false; + + // Maps the operand index of a bias input to its quantization specifications, + // including the non-bias operand indexes and the method retrieving + // quantization parameters from list of parameters of the non-bias operands. + // This map is empty if the op doesn't havea bias operand. + std::unordered_map, AccumulatorScaleFunc>> + biases_params; + + // Quantization parameters for value restricted outputs. This is the + // "hard-coded" parameters and should be used unconditionally for the + // quantized op. This vector is empty if the op doesn't have value resctricted + // outputs. + llvm::DenseMap restricted_output_params; +}; + +// A function signature for getting the particular OpQuantSpec for the provided +// op. +typedef std::unique_ptr (*OpQuantSpecGetter)(Operation* op); + // A generic rewrite pattern which matches any N-in-1-out operations with // quantization parameters propagated to all the operands and results values. // The quantization parameters are annotated by the Q/DQ op pairs. Each matched @@ -49,11 +85,11 @@ struct GenericFullQuantizationPattern : public RewritePattern { return matchFailure(); } auto quantize_op = cast(op); - auto quantized_op = quantize_op.input()->getDefiningOp(); + Operation* quantized_op = quantize_op.input()->getDefiningOp(); // If it is a block argument, requantize op, or has more than one result, we // shouldn't rewrite this op. if (!quantized_op || llvm::isa(quantized_op) || - llvm::isa(quantized_op) || quantized_op->getNumResults() != 1) { + llvm::isa(quantized_op)) { return matchFailure(); } @@ -61,21 +97,66 @@ struct GenericFullQuantizationPattern : public RewritePattern { // inputs. SmallVector inputs; inputs.reserve(quantized_op->getNumOperands()); - for (int i = 0, e = quantized_op->getNumOperands(); i != e; ++i) { - auto* operand = quantized_op->getOperand(i); + for (auto operand : quantized_op->getOperands()) { + Type operand_type = operand->getType(); + if (operand_type.isa()) { + inputs.push_back(operand); + continue; + } + auto operand_ele_type = + operand->getType().cast().getElementType(); if (auto op_inst = dyn_cast_or_null(operand->getDefiningOp())) { inputs.push_back(op_inst.input()); + } else if (operand_ele_type.isa()) { + // If the operand is an integer tensor, then it doesn't require the + // DQ op in the pattern. + inputs.push_back(operand); } else { return matchFailure(); } } + + // Collect all the quantized outputs and replace them by the results of the + // new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantized_op->getNumResults()); + for (auto enumerated_result : llvm::enumerate(quantized_op->getResults())) { + Value* result = enumerated_result.value(); + Type result_type = result->getType(); + // Add this to the test coverage once we create test ops with none type + // results. + if (result_type.isa()) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + if (!result->hasOneUse()) return matchFailure(); + Type result_ele_type = + result->getType().cast().getElementType(); + if (auto user = dyn_cast_or_null(*result->user_begin())) { + outputs_replaced.insert({user.output(), enumerated_result.index()}); + output_types.push_back(user.getType()); + } else if (result_ele_type.template isa()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_ele_type); + } else { + return matchFailure(); + } + } + // Use OpBuilder so we can use op name to create the new op. OpBuilder builder(quantized_op); - OperationState new_state( - quantized_op->getLoc(), quantized_op->getName().getStringRef(), inputs, - op->getResult(0)->getType(), quantized_op->getAttrs()); + OperationState new_state(quantized_op->getLoc(), + quantized_op->getName().getStringRef(), inputs, + output_types, quantized_op->getAttrs()); Operation* new_op = builder.createOperation(new_state); - rewriter.replaceOp(op, {new_op->getResult(0)}); + for (auto output : outputs_replaced) { + output.getFirst()->replaceAllUsesWith( + new_op->getResult(output.getSecond())); + } return matchSuccess(); } }; @@ -95,6 +176,16 @@ TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, Attribute max, IntegerAttr num_bits, BoolAttr narrow_range); +// Casts the `target` type to a quantized type by using the quantization +// parameters from the type in the `source` type attribute. +// Examples: +// f32 -> !quant.uniform +// tensor<4xf32> -> tensor<4x!quant.uniform> +// The result is wrapped by a type attribute. Returns nullptr if the cast isn't +// valid. +TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, + TypeAttr source, Type target); + // Quantizes the elements in the attribute `real_value` by the quantization // parameters in `tensor_type`. Returns empty Attribute if the // `tensor_type` is not a QuantizedType or the quantization fails. @@ -119,9 +210,10 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( // quantization parameters are stored as adjacent quantize and dequantize ops // and the propagation results are materialized by inserting pairs of quantize // and dequantize ops to this function. -void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed); +void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, + OpQuantSpecGetter op_quant_spec_getter); } // end namespace TFL } // end namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_QUANTIZATION_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc similarity index 90% rename from tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc rename to tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 9be4a0bf9d7..b381a5fa898 100644 --- a/tensorflow/compiler/mlir/lite/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -38,9 +38,10 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"}; llvm::Regex fixed_uniform_trait_regex{ "FixedResultUniformScale<([0-9]+).*(true|false)>"}; - emitSourceFileHeader("TensorFlow Lite Ops Quant Spec Getters", os); + emitSourceFileHeader("Generated Ops Quant Spec Getters", os); - // Retrieve all the definitions derived from TFL_Op and sort by record name. + // Retrieve all the definitions derived from Op defintion and sort by record + // name. std::vector defs = records.getAllDerivedDefinitions("Op"); llvm::sort(defs, LessRecord()); @@ -53,9 +54,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { for (const auto t : op.getTraits()) { if (auto opTrait = llvm::dyn_cast(&t)) { auto trait = opTrait->getTrait(); - // We only handle TFL specific native op traits. - if (!trait.startswith("TFL::")) continue; - trait.consume_front("TFL::"); + if (!trait.consume_front("OpTrait::quant::")) continue; OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName() << ">(op)) {\n"; @@ -74,7 +73,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n"; OUT(6) << "spec->restricted_output_params[std::make_pair(" << matches[1] << ", " << matches[2] - << ")].push_back(tfl.OpTrait::TFL::" << trait << "<" + << ")].push_back(tfl.OpTrait::quant::" << trait << "<" << op.getQualCppClassName() << ">::GetResultQuantizedType(i));\n"; matches.clear(); @@ -98,7 +97,6 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { } int main(int argc, char **argv) { - llvm::PrettyStackTraceProgram X(argc, argv); llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); return TableGenMain(argv[0], &OpQuantSpecWriter); diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index da779c14ea8..68a9fb7bc3e 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -test-constant-fold | FileCheck %s +// RUN: tf-opt %s -test-constant-fold | FileCheck %s --dump-input-on-failure // CHECK-LABEL: @add_float func @add_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) { @@ -109,6 +109,36 @@ func @mul_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) return %5, %6, %7, %8 : tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32> } +// CHECK-LABEL: @elementwise_unary_ops +func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) { + %0 = constant dense<-1.0> : tensor + %1 = constant dense<1.0> : tensor + %2 = constant dense<1.0> : tensor + %3 = constant dense<1.0> : tensor + %4 = constant dense<4.0> : tensor + %5 = constant dense<4.0> : tensor + %6 = constant dense<2.0> : tensor + + // CHECK-DAG: [[cst0:%.*]] = constant dense<1.000000e+00> : tensor + // CHECK-DAG: [[cst1:%.*]] = constant dense<0.841470957> : tensor + // CHECK-DAG: [[cst2:%.*]] = constant dense<0.540302277> : tensor + // CHECK-DAG: [[cst3:%.*]] = constant dense<0.000000e+00> : tensor + // CHECK-DAG: [[cst4:%.*]] = constant dense<2.000000e+00> : tensor + // CHECK-DAG: [[cst5:%.*]] = constant dense<5.000000e-01> : tensor + // CHECK-DAG: [[cst6:%.*]] = constant dense<4.000000e+00> : tensor + // CHECK: return [[cst0]], [[cst1]], [[cst2]], [[cst3]], [[cst4]], [[cst5]], [[cst6]] + + %7 = "tfl.abs"(%0) : (tensor) -> tensor + %8 = "tfl.sin"(%1) : (tensor) -> tensor + %9 = "tfl.cos"(%2) : (tensor) -> tensor + %10 = "tfl.log"(%3) : (tensor) -> tensor + %11 = "tfl.sqrt"(%4) : (tensor) -> tensor + %12 = "tfl.rsqrt"(%5) : (tensor) -> tensor + %13 = "tfl.square"(%6) : (tensor) -> tensor + + return %7, %8, %9, %10, %11, %12, %13 : tensor, tensor, tensor, tensor, tensor, tensor, tensor +} + // CHECK-LABEL: @mul_int func @mul_int() -> (tensor, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) { %0 = constant dense<8> : tensor @@ -273,3 +303,179 @@ func @add_dense_dense_float_mixfng_1_n() -> tensor<2x2xf32> { // CHECK: %0 = "tfl.add" // CHECK: return %0 } + +// CHECK-LABEL: @rank +func @rank() -> tensor<1xi32> { + %cst = constant dense<[[1], [2]]> : tensor<2x1xi32> + + // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return [[cst]] + %0 = "tfl.rank"(%cst) : (tensor<2x1xi32>) -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// CHECK-LABEL: @rank_input_known_rank +func @rank_input_known_rank(%arg0 : tensor<2x1xi32>) -> tensor<1xi32> { + // CHECK: [[cst:%.*]] = constant dense<2> : tensor<1xi32> + // CHECK: return [[cst]] + %0 = "tfl.rank"(%arg0) : (tensor<2x1xi32>) -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// CHECK-LABEL: @reshape +func @reshape() -> tensor<1x2xi32> { + %cst = constant dense<[1, 2]> : tensor<2xi32> + + // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}1, 2]]> : tensor<1x2xi32> + // CHECK: return [[cst]] + %0 = "tfl.reshape"(%cst) : (tensor<2xi32>) -> tensor<1x2xi32> + return %0 : tensor<1x2xi32> +} +// CHECK-LABEL: @pseudo_const +func @pseudo_const() -> tensor { + // CHECK: [[cst:%.*]] = constant dense<1> : tensor + // CHECK: return [[cst]] + %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + return %0 : tensor +} + + +// CHECK-LABEL: @range_int +func @range_int() -> tensor { + %cst = constant dense<0> : tensor + %cst_1 = constant dense<4> : tensor + %cst_2 = constant dense<1> : tensor + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor + // CHECK: return [[cst]] + %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @range_float +func @range_float() -> tensor { + %cst = constant dense<0.0> : tensor + %cst_1 = constant dense<4.0> : tensor + %cst_2 = constant dense<1.0> : tensor + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return [[cst]] + %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + + +// CHECK-LABEL: @range_float_neg_delta +func @range_float_neg_delta() -> tensor { + %cst = constant dense<0.0> : tensor + %cst_1 = constant dense<-4.0> : tensor + %cst_2 = constant dense<-1.0> : tensor + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[0.000000e+00, -1.000000e+00, -2.000000e+00, -3.000000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return [[cst]] + %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @range_float_nonzero_base +func @range_float_nonzero_base() -> tensor { + %cst = constant dense<2.0> : tensor + %cst_1 = constant dense<7.0> : tensor + %cst_2 = constant dense<1.5> : tensor + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[2.000000e+00, 3.500000e+00, 5.000000e+00, 6.500000e+00]> : tensor<4xf32>} : () -> tensor + // CHECK: return [[cst]] + %0 = "tfl.range"(%cst, %cst_1, %cst_2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @transpose_no_fold +func @transpose_no_fold(%arg0 : tensor<2xi32>) -> tensor<2x2xi32> { + %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + + // CHECK: tfl.transpose + %0 = "tfl.transpose"(%cst, %arg0) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// CHECK-LABEL: @transpose_1d +// Basic 1D identity +func @transpose_1d() -> tensor<3xi32> { + %cst = constant dense<[1, 2, 3]> : tensor<3xi32> + %cst_perm = constant dense<0> : tensor<1xi32> + + // CHECK: [[cst:%.*]] = constant dense<{{\[}}1, 2, 3]> : tensor<3xi32> + // CHECK: return [[cst]] + %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// CHECK-LABEL: @transpose_dynamic +func @transpose_dynamic() -> tensor { + %cst = constant dense<[1, 2, 3]> : tensor<3xi32> + %cst_perm = constant dense<0> : tensor<1xi32> + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<{{\[}}1, 2, 3]> : tensor<3xi32>} : () -> tensor + // CHECK: return [[cst]] + %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<3xi32>, tensor<1xi32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: @transpose_2d +func @transpose_2d() -> tensor<2x2xi32> { + %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + %cst_perm = constant dense<[1, 0]> : tensor<2xi32> + + // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 2], {{\[}}1, 3]]> : tensor<2x2xi32> + // CHECK: return [[cst]] + %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// CHECK-LABEL: @transpose_2d_identity +func @transpose_2d_identity() -> tensor<2x2xi32> { + %cst = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + %cst_perm = constant dense<[0, 1]> : tensor<2xi32> + + // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}0, 1], {{\[}}2, 3]]> : tensor<2x2xi32> + // CHECK: return [[cst]] + %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// CHECK-LABEL: @transpose_3d +// A test case adopted from TransposeTest.Test3DInputConstTensor in +// tensorflow/lite/kernels/transpose_test.cc +func @transpose_3d() -> tensor<4x2x3xi32> { + %cst = constant dense<[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]> : tensor<2x3x4xi32> + %cst_perm = constant dense<[2, 0, 1]> : tensor<3xi32> + + // CHECK: [[cst:%.*]] = constant dense<{{\[\[\[}}0, 4, 8], {{\[}}12, 16, 20]], {{\[\[}}1, 5, 9], {{\[}}13, 17, 21]], {{\[\[}}2, 6, 10], {{\[}}14, 18, 22]], {{\[\[}}3, 7, 11], {{\[}}15, 19, 23]]]> : tensor<4x2x3xi32> + // CHECK: return [[cst]] + %0 = "tfl.transpose"(%cst, %cst_perm) : (tensor<2x3x4xi32>, tensor<3xi32>) -> tensor<4x2x3xi32> + return %0 : tensor<4x2x3xi32> +} + +// CHECK-LABEL: @ConstantFoldBinaryOpDynamicOutput +func @ConstantFoldBinaryOpDynamicOutput() -> tensor { + %cst = constant dense<10> : tensor + %cst_0 = "tfl.pseudo_const"() {value = dense<[5, 10]> : tensor<2xi32>} : () -> tensor + %87 = "tfl.sub"(%cst_0, %cst) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + return %87 : tensor + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[-5, 0]> : tensor<2xi32>} : () -> tensor + // CHECK: return [[cst]] +} + +// CHECK-LABEL: @add_dense_dense_int_same_shape_dynamic +func @add_dense_dense_int_same_shape_dynamic() -> tensor { + %0 = constant dense<[15, 23, -44, -2]> : tensor<4xi32> + %1 = constant dense<[-10, -1, 42, 100]> : tensor<4xi32> + + %2 = "tfl.add"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor + + return %2 : tensor + + // CHECK: [[cst:%.*]] = "tfl.pseudo_const"() {value = dense<[5, 22, -2, 98]> : tensor<4xi32>} : () -> tensor + // CHECK: return [[cst]] +} diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt index 1bf0b075baf..c1bb797ebee 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.line.part.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 # CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format diff --git a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt index edad75c4fc2..d3dcbc65719 100644 --- a/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/debuginfo/v1_1.0_224_frozen.wrong_attr.stack.part.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +# RUN: tf_tfl_translate -mlir-pretty-debuginfo -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes=1,224,224,3 -tf-output-arrays=MobilenetV1/MobilenetV1/Conv2d_0/BatchNorm/FusedBatchNorm -tf-debug-info=%s.debug %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 # CHECK: fake/user/code/file_C.py:27:1: error: 'tf.Conv2D' op attribute 'data_format' failed to satisfy constraint: 'NHWC' or 'NCHW' convnet data format # CHECK: fake/user/code/file_D.py:28:1: note: called from diff --git a/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir new file mode 100644 index 00000000000..5cbcb1e1cb8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/extract-ophint.mlir @@ -0,0 +1,155 @@ +// RUN: tf-opt -tfl-extract-ophint %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: extractSimpleOphint +func @extractSimpleOphint() { +// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @d4b1eb00b81211e99426dc4a3e957995(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> +// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation", _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + return +} + +// CHECK-LABEL: extractPackedInputOphint +func @extractPackedInputOphint() { +// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> +// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @47393154b9af11e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_stack", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_stack", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "47393154b9af11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack-47393154b9af11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + return +} + +// CHECK-LABEL: extractFirstInputOphint +func @extractFirstInputOphint() { +// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b703f0f4b9ec11e99426dc4a3e957995(%0) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_first", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "first", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_first", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "first", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_first", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_first", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "b703f0f4b9ec11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_first-b703f0f4b9ec11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + return +} + +// CHECK-LABEL: extractLastInputOphint +func @extractLastInputOphint() { +// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @e31fcf90b9ed11e99426dc4a3e957995(%1) : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_last", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "last", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_last", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "last", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_last", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_last", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "e31fcf90b9ed11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_last-e31fcf90b9ed11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + return +} + +// CHECK-LABEL: extractPackOneInputOphint +func @extractPackOneInputOphint() { +// CHECK: %[[RESHAPE:[0-9]*]] = "tfl.reshape"(%0) : (tensor<1x16x1xf32>) -> tensor<1x1x16x1xf32> +// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @33fab028b9ef11e99426dc4a3e957995(%[[RESHAPE]]) : (tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[OUTPUT:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_pack_input_one", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "33fab028b9ef11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_pack_input_one-33fab028b9ef11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_pack_input_one", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "33fab028b9ef11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_pack_input_one-33fab028b9ef11e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %3 = "tf.Identity"(%2) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_pack_input_one", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "33fab028b9ef11e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_pack_input_one-33fab028b9ef11e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + return +} + +// CHECK-LABEL: extractStackInputOutputOphint +func @extractStackInputOutputOphint() { +// CHECK: %[[PACK:[0-9]*]] = "tfl.pack"(%0, %1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<2x1x16x1xf32> +// CHECK: %[[OP_HINT_CALL:[0-9]*]] = call @b92ed354b9f011e99426dc4a3e957995(%[[PACK]]) : (tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> +// CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[OP_HINT_CALL]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) +// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[UNPACK]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 0 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-0-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %8 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_2"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %9 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_3"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %10 = "tf.Add"(%8, %9) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %11 = "tf.Identity"(%10) {T = "tfdtype$DT_FLOAT", _tflite_function_aggregate = "stack", _tflite_function_name = "cool_activation_stack_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_sort_index = 1 : i64, _tflite_function_uuid = "b92ed354b9f011e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_stack_input_output-b92ed354b9f011e99426dc4a3e957995-0-1-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + return +} + +// CHECK-LABEL: extractMultipleInputsOutputsOphint +func @extractMultipleInputsOutputsOphint() { +// CHECK: %[[OP_HINT_CALL:[0-9]*]]:2 = call @a6ca45beb9f411e99426dc4a3e957995(%0, %1) : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) +// CHECK: %[[OUTPUT1:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#0) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: %[[OUTPUT2:[0-9]*]] = "tf.Identity"(%[[OP_HINT_CALL]]#1) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %1 = "tf.Identity"(%0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %3 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder_1", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 1 : i64, _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %5 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_1"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %6 = "tf.Mul"(%2, %5) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %7 = "tf.Identity"(%6) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-0-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %8 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_2"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %9 = "tf.Sigmoid"(%4) {T = "tfdtype$DT_FLOAT", name = "Sigmoid_3"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %10 = "tf.Add"(%8, %9) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + %11 = "tf.Identity"(%10) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation_multiple_input_output", _tflite_function_output_index = 1 : i64, _tflite_function_uuid = "a6ca45beb9f411e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation_multiple_input_output-a6ca45beb9f411e99426dc4a3e957995-1-None-None"} : (tensor<1x16x1xf32>) -> tensor<1x16x1xf32> + return +} + +// CHECK: func @d4b1eb00b81211e99426dc4a3e957995(tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> +// CHECK: attributes {_tflite_function_name = "cool_activation"} +// CHECK: func @47393154b9af11e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_name = "cool_activation_stack"} +// CHECK: func @b703f0f4b9ec11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_name = "cool_activation_first"} +// CHECK: func @e31fcf90b9ed11e99426dc4a3e957995(tensor<1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_name = "cool_activation_last"} +// CHECK: func @33fab028b9ef11e99426dc4a3e957995(tensor<1x1x16x1xf32>) -> tensor<1x16x1xf32> +// CHECK: attributes {_tflite_function_name = "cool_activation_pack_input_one"} +// CHECK: func @b92ed354b9f011e99426dc4a3e957995(tensor<2x1x16x1xf32>) -> tensor<2x1x16x1xf32> +// CHECK: attributes {_tflite_function_name = "cool_activation_stack_input_output"} +// CHECK: func @a6ca45beb9f411e99426dc4a3e957995(tensor<1x16x1xf32>, tensor<1x16x1xf32>) -> (tensor<1x16x1xf32>, tensor<1x16x1xf32>) +// CHECK: attributes {_tflite_function_name = "cool_activation_multiple_input_output"} + + +// ----- + +// expected-error@+1 {{Found malformed ophint regions: missing inputs or outputs.}} +module { +func @extractOphintFailure() { + %0 = "tf.Placeholder"() {dtype = "tfdtype$DT_FLOAT", name = "Placeholder", shape = "tfshape$dim { size: 1 } dim { size: 16 } dim { size: 16 } dim { size: 1 }"} : () -> tensor<1x16x16x1xf32> + %1 = call @AnotherFunc(%0) : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %2 = "tf.Sigmoid"(%1) {T = "tfdtype$DT_FLOAT", name = "Sigmoid"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %3 = "tf.Mul"(%2, %1) {T = "tfdtype$DT_FLOAT", name = "mul"} : (tensor<1x16x16x1xf32>, tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %4 = "tf.Identity"(%3) {T = "tfdtype$DT_FLOAT", _tflite_function_name = "cool_activation", _tflite_function_output_index = 0 : i64, _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "OutputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + return +} + +func @AnotherFunc(%arg0: tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> { + %0 = "tf.Identity"(%arg0) {T = "tfdtype$DT_FLOAT", _tflite_function_input_index = 0 : i64, _tflite_function_name = "cool_activation", _tflite_function_uuid = "d4b1eb00b81211e99426dc4a3e957995", _tflite_ophint_level = 1 : i64, name = "InputHint-cool_activation-d4b1eb00b81211e99426dc4a3e957995-0-None-None"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + return %0 : tensor<1x16x16x1xf32> +} +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir new file mode 100644 index 00000000000..b6231c050b5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir @@ -0,0 +1,89 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Ensure constants roundtrip exactly + +func @bool() -> tensor<4xi1> { + // CHECK-LABEL: @bool + // CHECK: value = dense<[false, true, true, false]> : tensor<4xi1> + %0 = "tfl.pseudo_const"() { value = dense<[false, true, true, false]> : tensor<4xi1> } : () -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +func @complex64() -> tensor<4x!tf.complex64> { + // CHECK-LABEL: @complex64 + // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3230303F5C3030305C3030305C303030405C3030305C3030305C303030405C3030305C30303040405C3030305C30303040405C3030305C3030305C323030405C3030305C3030305C3230304022"> : tensor<4x!tf.complex64> + %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C455836342074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3230303F5C3030305C3030305C3230303F5C3030305C3030305C303030405C3030305C3030305C303030405C3030305C30303040405C3030305C30303040405C3030305C3030305C323030405C3030305C3030305C3230304022"> : tensor<4x!tf.complex64> } : () -> tensor<4x!tf.complex64> + return %0 : tensor<4x!tf.complex64> +} + +// TODO(b/138847107) this should work but doesn't +// func @f16() -> tensor<4xf16> { +// %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16> +// return %0 : tensor<4xf16> +// } + +func @f32() -> tensor<4xf32> { + // CHECK-LABEL: @f32 + // CHECK: value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf32> + %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> } : () -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +func @i8() -> tensor<4xi8> { + // CHECK-LABEL: @i8 + // CHECK: value = dense<[1, 2, 3, 4]> : tensor<4xi8> + %0 = "tfl.pseudo_const" () { value = dense<[1, 2, 3, 4]> : tensor<4xi8> } : () -> tensor<4xi8> + return %0 : tensor<4xi8> +} + +func @i16() -> tensor<4xi16> { + // CHECK-LABEL: @i16 + // CHECK: value = dense<[1, 2, 3, 258]> : tensor<4xi16> + %0 = "tfl.pseudo_const" () { value = dense<[1, 2, 3, 258]> : tensor<4xi16> } : () -> tensor<4xi16> + return %0 : tensor<4xi16> +} + +func @i32() -> tensor<4xi32> { + // CHECK-LABEL: @i32 + // CHECK: value = dense<[1, 2, 3, 16909060]> : tensor<4xi32> + // Check bytes come back in the right order + %0 = "tfl.pseudo_const" () { value = dense<[1, 2, 3, 16909060]> : tensor<4xi32> } : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +func @i64() -> tensor<4xi64> { + // CHECK-LABEL: @i64 + // CHECK: value = dense<[1, 2, 3, 72623859790382856]> : tensor<4xi64> + %0 = "tfl.pseudo_const" () { value = dense<[1, 2, 3, 72623859790382856]> : tensor<4xi64> } : () -> tensor<4xi64> + return %0 : tensor<4xi64> +} + +// TODO(krzysd) Add a test for strings. This isn't too urgent, since they use +// the same sort of opaque round-trip we get for complex64, but it might be good +// to check + +func @uint8() -> tensor<4x!tf.uint8> { + // CHECK-LABEL: @uint8 + // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> + %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F55494E54382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3333365C3235355C3237365C33353722"> : tensor<4x!tf.uint8> } : () -> tensor<4x!tf.uint8> + return %0 : tensor<4x!tf.uint8> +} + +func @qi32_per_axis() -> tensor<3x3x!quant.uniform> { + // CHECK-LABEL: @qi32_per_axis + // CHECK: {qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> + %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x3x!quant.uniform>, value = dense<1> : tensor<3x3xi32>} : () -> tensor<3x3x!quant.uniform> + return %0 : tensor<3x3x!quant.uniform> +} + +func @qu8() -> tensor<3x!quant.uniform:f32, 1.0>> { + // CHECK-LABEL: @qu8 + // CHECK: {qtype = tensor<3x!quant.uniform:f32, 1.000000e+00>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform:f32, 1.000000e+00>> + %0 = "tfl.pseudo_qconst"() { qtype = tensor<3x!quant.uniform:f32, 1.0>>, value = dense<1> : tensor<3xi8>} : () -> tensor<3x!quant.uniform:f32, 1.0>> + return %0 : tensor<3x!quant.uniform:f32, 1.0>> +} + +// Identity function to make the exporter happy +func @main(%arg0: tensor<4xi8>) -> tensor<4xi8> { + %0 = "tfl.pseudo_input"(%arg0) : (tensor<4xi8>) -> tensor<4xi8> + return %0 : tensor<4xi8> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir new file mode 100644 index 00000000000..3f3cad12b61 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir @@ -0,0 +1,20 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Confirm function references in if ops are preserved +func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { +// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32> + %2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> + %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return %3 : tensor<1xf32> +} + +func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> + return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir new file mode 100644 index 00000000000..4cfa8e39969 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/many_attribute_op.mlir @@ -0,0 +1,10 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s + +// Confirm a wide array of attribute survives the round-trip +func @main(tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> { +^bb0(%arg0: tensor<1x6x6x16xf32>): + // CHECK: "tfl.average_pool_2d"(%{{.*}}) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> + %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> loc("Input") + %1 = "tfl.average_pool_2d"(%0) {filter_height = 3 : i32, filter_width = 6 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 3 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16xf32>) -> tensor<1x1x1x16xf32> loc("avgpool") + return %1 : tensor<1x1x1x16xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir index a92e985c668..c9528aed3e2 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/math.mlir @@ -1,13 +1,16 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Confirm float constants and operators survive a roundtrip func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): - // CHECK: func @main(%arg0: tensor<4xf32>) - // CHECK-NEXT: return - // CHECK-NEXT: } - + // CHECK: [[INPUT:%.*]] = "tfl.pseudo_input"(%arg0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: [[CONST:%.*]] = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<4xf32>} : () -> tensor<4xf32> + // CHECK-NEXT: [[SQDIFF:%.*]] = tfl.squared_difference [[INPUT]], [[CONST]] : tensor<4xf32> + // CHECK-NEXT: %{{.*}} = tfl.mul [[INPUT]], [[SQDIFF]] {fused_activation_function = "NONE"} : tensor<4xf32> %0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input") %1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + // Confirm that attributes that cannot be stored in the flatbuffer options + // for a given operator are dropped silently. %2 = "tfl.squared_difference"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") %3 = "tfl.mul"(%0, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") %4 = "tfl.div"(%3, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div") diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir new file mode 100644 index 00000000000..ce62aa381f1 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/optional.mlir @@ -0,0 +1,13 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Test to make sure optional parameters survive a roundtrip + +func @main(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { +// CHECK: [[NONE:%.*]] = constant unit +// CHECK: "tfl.fully_connected"(%{{.()}}, %{{.*}}, [[NONE]]) +// CHECK-SAME: (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>, tensor<40x40xf32>) + %cst = constant unit + %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input") + %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input") + %2:2 = "tfl.fully_connected"(%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>, tensor<40x40xf32>) + return %2 : tensor<40x40xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir new file mode 100644 index 00000000000..18e2888dfcd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/quantization.mlir @@ -0,0 +1,19 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s + +func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> { +// CHECK: %{{.*}} = "tfl.quantize"(%{{.*}}) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> +// The float values here doesn't match exactly because double -> float -> double is lossy +// CHECK-NEXT: %{{.*}} = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678{{[0-9]*}}:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678{{[0-9]*}}:151>> +// CHECK-NEXT: %{{.*}} = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> +// CHECK: %{{.*}} = "tfl.dequantize"(%{{.*}}) : (tensor<1x1001x!quant.uniform>) -> tensor<1x1001xf32> + + %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32> + %1 = "tfl.quantize"(%0) {qtype = tensor<1x224x224x3x!quant.uniform>} : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3x!quant.uniform> + %2 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>> + %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<0> : tensor<32xi32>} : () -> tensor<32x!quant.uniform> + %4 = "tfl.conv_2d"(%1, %2, %3) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3x!quant.uniform>, tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, tensor<32x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> + %5 = "tfl.reshape"(%4) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x1001x!quant.uniform> + %6 = "tfl.softmax"(%5) {beta = 1.000000e+00 : f32} : (tensor<1x1001x!quant.uniform>) -> tensor<1x1001x!quant.uniform> + %7 = "tfl.dequantize"(%6) : (tensor<1x1001x!quant.uniform>) -> tensor<1x1001xf32> + return %7 : tensor<1x1001xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir new file mode 100644 index 00000000000..85596169508 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/reshape.mlir @@ -0,0 +1,9 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Confirm we can extract type info from reshape + +func @main() -> tensor<2x2xf32> { + // CHECK: %{{.*}} = "tfl.reshape"(%{{.*}}) : (tensor<4xf32>) -> tensor<2x2xf32> + %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %1 = "tfl.reshape" (%0) : (tensor<4xf32>) -> tensor<2x2xf32> loc("reshape") + return %1 : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir index 600c7a02ed5..714027d67d1 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/simple.mlir @@ -1,10 +1,17 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Check a few basic properties of the import-export, +// including constants retaining their shape +// and the module including the TFLite version. func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { ^bb0(%arg0: tensor<3x2xi32>): - // CHECK: func @main(%arg0: tensor<3x2xi32>) { - // CHECK-NEXT: return - // CHECK-NEXT: } + // CHECK: module attributes {tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} + + // CHECK: %{{.*}} = "tfl.pseudo_const"() {value = dense<{{\[\[1, 2\], \[3, 4\], \[5, 6\]\]}}> : tensor<3x2xi32>} + // CHECK-NEXT: [[SUB:%.*]] = tfl.sub %{{.*}}, %{{.*}} {fused_activation_function = "RELU6"} : tensor<3x2xi32> + // CHECK-NEXT: [[SCALAR:%.*]] = "tfl.pseudo_const"() {value = dense<10> : tensor} : () -> tensor + // CHECK-NEXT: [[ADD:%.*]] = "tfl.add"([[SCALAR]], [[SUB]]) {fused_activation_function = "NONE"} : (tensor, tensor<3x2xi32>) -> tensor<3x2xi32> + // CHECK-NEXT: return [[ADD]] : tensor<3x2xi32> %0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input") %1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const") diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir new file mode 100644 index 00000000000..141423f9231 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/while_op.mlir @@ -0,0 +1,27 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s +// Check to see if function references in while loops are preserved +func @main(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor<1xf32> { +// TODO(b/138222071) Expect first output to be a scalar +// CHECK: %{{.*}}:2 = "tf.While"(%{{.*}}, %{{.*}}) {body = @body, cond = @cond, is_stateless = false} : (tensor, tensor<1xf32>) -> (tensor<*xi32>, tensor<1xf32>) + %0 = "tfl.pseudo_input"(%arg0) : (tensor) -> tensor + %1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32> + + // While %0 is greater than zero, element wise add %1 with itself. + %2:2 = "tf.While"(%0, %1) { + cond = @cond, body = @body, is_stateless = false + } : (tensor, tensor<1xf32>) -> (tensor, tensor<1xf32>) + return %2#1 : tensor<1xf32> +} + +func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { + %0 = "std.constant" () {value = dense<0> : tensor} : () -> tensor loc("Const") + %1 = "tfl.greater"(%arg0, %0) : (tensor<*xi32>, tensor) -> tensor + return %1 : tensor +} + +func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) { + %0 = "std.constant" () {value = dense<1> : tensor} : () -> tensor loc("Const") + %1 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> + return %1, %2 : tensor<*xi32>, tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-ophint-func-op.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-ophint-func-op.mlir new file mode 100644 index 00000000000..06f304c55ba --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/legalize-ophint-func-op.mlir @@ -0,0 +1,26 @@ +// RUN: tf-opt -tfl-legalize-ophint-func-op %s | FileCheck %s + +module { + // CHECK-LABEL: func @testConvertUnidirectionalSequenceRNN + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<1x3xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<1x3xf32>) + func @testConvertUnidirectionalSequenceRNN(%arg0: tensor<1x3xf32>, %arg1: tensor<1x3xf32>) -> tensor<1x4xf32> { + // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<1x4xf32> + // CHECK: %[[CST_0:.*]] = constant dense<0.000000e+00> : tensor<4xf32> + // CHECK: %[[CST_1:.*]] = constant dense<0.000000e+00> : tensor<4x3xf32> + // CHECK: %[[CST_2:.*]] = constant dense<0.000000e+00> : tensor<4x4xf32> + // CHECK: %[[PACKED_INPUT:[a-z0-9]*]] = "tfl.pack"(%[[ARG_0]], %[[ARG_1]]) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32> + // CHECK: %[[FUSED_OUTPUT:[a-z0-9]*]] = "tfl.unidirectional_sequence_rnn"(%[[PACKED_INPUT]], %[[CST_1]], %[[CST_2]], %[[CST_0]], %[[CST]]) {fused_activation_function = "TANH", time_major = true} : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32> + // CHECK: %[[UNPACK:[0-9]*]]:2 = "tfl.unpack"(%[[FUSED_OUTPUT]]) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>) + + %cst = constant dense<0.000000e+00> : tensor<1x4xf32> + %cst0 = constant dense<0.000000e+00> : tensor<4xf32> + %cst1 = constant dense<0.000000e+00> : tensor<4x3xf32> + %cst2 = constant dense<0.000000e+00> : tensor<4x4xf32> + %2 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<2x1x3xf32> + %3 = call @a9211722c23011e9875cdc4a3e957995(%2, %cst1, %cst2, %cst0, %cst) : (tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32> + %4:2 = "tfl.unpack"(%3) {axis = 0 : i32, num = 2 : i32} : (tensor<2x1x4xf32>) -> (tensor<1x4xf32>, tensor<1x4xf32>) + return %4#0 : tensor<1x4xf32> + } + func @a9211722c23011e9875cdc4a3e957995(tensor<2x1x3xf32>, tensor<4x3xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>) -> tensor<2x1x4xf32> + attributes {_tflite_function_name = "UnidirectionalSequenceRnn"} +} diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 616922ba8d3..9c029bfc1d1 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -142,7 +142,7 @@ func @const() -> tensor<2xi32> { return %0: tensor<2xi32> // CHECK-LABEL: @const -// CHECK: %0 = "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32> +// CHECK: "tfl.pseudo_const"() {value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> tensor<2xi32> } func @placeholder(%arg0: tensor) -> tensor { @@ -213,6 +213,20 @@ func @sigmoid(%arg0: tensor) -> tensor { // CHECK: %0 = "tfl.logistic"(%arg0) : (tensor) -> tensor } +func @sqrt(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "tf.Sqrt"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +// CHECK-LABEL: sqrt +// CHECK: %0 = "tfl.sqrt"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> +} + +func @square(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "tf.Square"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +// CHECK-LABEL: square +// CHECK: %0 = "tfl.square"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> +} + func @log_softmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.LogSoftmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -289,6 +303,14 @@ func @abs(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK: %0 = "tfl.abs"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> } +func @any(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + %0 = "tf.Any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL:any +// CHECK: %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor +} + func @ceil(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Ceil"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -442,12 +464,12 @@ func @less_equal(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x1 // CHECK: return %0 : tensor<8x16xi1> } -func @rank(%arg0: tensor<11x16xf32>) -> tensor<1xi32> { - %0 = "tf.Rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32> +func @rank(%arg0: tensor<*xf32>) -> tensor<1xi32> { + %0 = "tf.Rank"(%arg0) : (tensor<*xf32>) -> tensor<1xi32> return %0 : tensor<1xi32> // CHECK-LABEL:rank -// CHECK: %0 = "tfl.rank"(%arg0) : (tensor<11x16xf32>) -> tensor<1xi32> +// CHECK: %0 = "tfl.rank"(%arg0) : (tensor<*xf32>) -> tensor<1xi32> } func @floor(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { @@ -487,6 +509,15 @@ func @select(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> // CHECK: return %0 : tensor<8xf32> } +func @select_v2(%arg0: tensor<8xi1>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>) -> tensor<8xf32> { + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<8xi1>, tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> + return %0: tensor<8xf32> + +// CHECK-LABEL: select_v2 +// CHECK: %0 = "tfl.select"(%arg0, %arg1, %arg2) +// CHECK: return %0 : tensor<8xf32> +} + func @sin(%arg0: tensor) -> tensor { %0 = "tf.Sin"(%arg0) : (tensor) -> tensor return %0 : tensor @@ -629,6 +660,17 @@ func @pad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor { // CHECK: return %0 : tensor } +func @tile(tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> { +^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>): + %cst = constant dense<[1, 2]> : tensor<2xi32> + %0 = "tf.Tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> + return %0 : tensor<2x6xf32> + + // CHECK-LABEL: tile + // CHECK: %0 = "tfl.tile"(%arg0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<2x6xf32> + // CHECK: return %0 : tensor<2x6xf32> +} + func @padv2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor { ^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>): %cst = constant dense<2.0> : tensor @@ -782,12 +824,12 @@ func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: // CHECK: %0 = "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor } -func @split(%arg0: tensor<1xi32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> { - %0:3 = "tf.Split"(%arg0, %arg1) {num_split = 3 : i64} : (tensor<1xi32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>) +func @split(%arg0: tensor, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3xf32> { + %0:3 = "tf.Split"(%arg0, %arg1) {num_split = 3 : i64} : (tensor, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>) return %0#0 : tensor<1x4x3xf32> // CHECK-LABEL: split - // CHECK: %0:3 = "tfl.split"(%arg0, %arg1) {num_splits = 3 : i32} : (tensor<1xi32>, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>) + // CHECK: %0:3 = "tfl.split"(%arg0, %arg1) {num_splits = 3 : i32} : (tensor, tensor<1x4x3x3xf32>) -> (tensor<1x4x3xf32>, tensor<1x4x3xf32>, tensor<1x4x3xf32>) } func @splitv(%arg0: tensor<1x4x3x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<1xi32>) -> tensor<1x4x2x3xf32> { @@ -941,3 +983,104 @@ func @OneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3 // CHECK-LABEL: OneHot // CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> } + +func @argmax(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { + %0 = "tf.ArgMax"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL: argmax +// CHECK: %0 = "tfl.arg_max"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor +} + +func @argmax64(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { + %0 = "tf.ArgMax"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor + return %0 : tensor + +// CHECK-LABEL: argmax64 +// CHECK: %0 = "tfl.arg_max"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor +} + +func @space_to_depth(%arg0: tensor<1x2x2x1xf32>) -> tensor { + %0 = "tf.SpaceToDepth"(%arg0) {block_size = 2: i64, data_format = "NHWC"}: (tensor<1x2x2x1xf32>) -> tensor + return %0 : tensor + + // CHECK-LABEL: space_to_depth + // CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32> + // CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor +} + +func @round(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "tf.Round"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> + + // CHECK-LABEL: round + // CHECK: %[[ARG:.*]]: tensor<8x16xf32> + // CHECK: %[[RESULT:.*]] = "tfl.round"(%[[ARG]]) : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: return %[[RESULT]] : tensor<8x16xf32> +} + +func @resize_nearest_neighbor(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { + %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + return %0 : tensor + // CHECK-LABEL: resize_nearest_neighbor + // CHECK: "tfl.resize_nearest_neighbor"(%arg0, %arg1) {align_corners = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor +} + +// Note: half_pixel_centers isn't supported by TFLite, so it's not legalized. +func @resize_nearest_neighbor_with_half_pixel_centers(%arg0: tensor<1x100x100x3xf32>, %arg1: tensor<4xi32>) -> tensor { + %0 = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor + return %0 : tensor + // CHECK-LABEL: resize_nearest_neighbor_with_half_pixel_centers + // CHECK: "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = true, half_pixel_centers = true} +} + +func @sparse_to_dense_with_scalar_sparse_indices(%arg0: tensor, %arg1: tensor<3xi32>, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor, tensor<3xi32>, tensor, tensor) -> tensor + return %0 : tensor + // CHECK-LABEL: sparse_to_dense_with_scalar_sparse_indices + // CHECK: "tfl.sparse_to_dense"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<3xi32>, tensor, tensor) -> tensor +} + +func @sparse_to_dense_with_vector_sparse_indices(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>, %arg2: tensor<3xf32>, %arg3: tensor) -> tensor { + %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor<3xi32>, tensor<3xi32>, tensor<3xf32>, tensor) -> tensor + return %0 : tensor + // CHECK-LABEL: sparse_to_dense_with_vector_sparse_indices + // CHECK: "tfl.sparse_to_dense"(%arg0, %arg1, %arg2, %arg3) : (tensor<3xi32>, tensor<3xi32>, tensor<3xf32>, tensor) -> tensor +} + +func @sparse_to_dense_with_2d_sparse_indices(%arg0: tensor<3x2xi32>, %arg1: tensor<3xi32>, %arg2: tensor<2xf32>, %arg3: tensor) -> tensor { + %0 = "tf.SparseToDense"(%arg0, %arg1, %arg2, %arg3) {validate_indices = true}: (tensor<3x2xi32>, tensor<3xi32>, tensor<2xf32>, tensor) -> tensor + return %0 : tensor + // CHECK-LABEL: sparse_to_dense_with_2d_sparse_indices + // CHECK: "tfl.sparse_to_dense"(%arg0, %arg1, %arg2, %arg3) : (tensor<3x2xi32>, tensor<3xi32>, tensor<2xf32>, tensor) -> tensor +} + +func @where(%arg0: tensor<3x5xi1>) -> tensor { + %0 = "tf.Where"(%arg0) : (tensor<3x5xi1>) -> tensor + return %0 : tensor + // CHECK-LABEL: where + // CHECK: "tfl.where"(%arg0) : (tensor<3x5xi1>) -> tensor +} + +func @floor_mod(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> + // CHECK-LABEL: floor_mod + // CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> +} + +func @exp(%arg0: tensor<5xf32>) -> tensor<5xf32> { + %0 = "tf.Exp"(%arg0) : (tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> + // CHECK-LABEL: exp + // CHECK: "tfl.exp"(%arg0) : (tensor<5xf32>) -> tensor<5xf32> +} + +func @depth_to_space(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { + %0 = "tf.DepthToSpace"(%arg0) {block_size = 2: i64, data_format = "NHWC"}: (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> + return %0 : tensor<1x2x2x1xf32> + + // CHECK-LABEL: depth_to_space + // CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32> + // CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 1fe6757c0c7..817ced79ced 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -1,6 +1,6 @@ // RUN: tf-opt -tfl-lower-static-tensor-list %s | FileCheck %s --dump-input-on-failure -func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor) -> (tensor<10xf32>, tensor<3x10xf32>) { -^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor): + +func @tensorlistGetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor) -> (tensor<10xf32>, tensor<3x10xf32>) { %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor>> %1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor>>, tensor, tensor<1xi32>) -> tensor<10xf32> %2 = "tf.TensorListStack"(%0, %arg1) : (tensor>>, tensor<1xi32>) -> tensor<3x10xf32> @@ -11,8 +11,7 @@ func @tensorlistGetItem(tensor<3x10xf32>, tensor<1xi32>, tensor) -> (tensor // CHECK: return %0, %arg0 : tensor<10xf32>, tensor<3x10xf32> } -func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor) -> (tensor<*xf32>, tensor<*xf32>) { -^bb0(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor): +func @tensorlistGetItemWithUnknownRank(%arg0: tensor<*xf32>, %arg1: tensor<1xi32>, %arg2: tensor) -> (tensor<*xf32>, tensor<*xf32>) { %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<*xf32>, tensor<1xi32>) -> tensor>> %1 = "tf.TensorListGetItem"(%0, %arg2, %arg1) : (tensor>>, tensor, tensor<1xi32>) -> tensor<*xf32> %2 = "tf.TensorListStack"(%0, %arg1) : (tensor>>, tensor<1xi32>) -> tensor<*xf32> @@ -23,8 +22,7 @@ func @tensorlistGetItemWithUnknownRank(tensor<*xf32>, tensor<1xi32>, tensor // CHECK: return %0, %arg0 : tensor<*xf32>, tensor<*xf32> } -func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor, tensor<10xf32>) -> tensor<3x10xf32> { -^bb0(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<10xf32>): +func @tensorlistSetItem(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor, %arg3: tensor<10xf32>) -> tensor<3x10xf32> { %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor>> %1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor>>, tensor, tensor<10xf32>) -> tensor>> %2 = "tf.TensorListStack"(%1, %arg1) : (tensor>>, tensor<1xi32>) -> tensor<3x10xf32> @@ -56,8 +54,7 @@ func @tensorlistSetItem(tensor<3x10xf32>, tensor<1xi32>, tensor, tensor<10x // CHECK: return %15 : tensor<3x10xf32> } -func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor, tensor) -> tensor<5xf32> { -^bb0(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor, %arg3: tensor): +func @tensorlistSetItemWithScalarElements(%arg0: tensor<5xf32>, %arg1: tensor<0xi32>, %arg2: tensor, %arg3: tensor) -> tensor<5xf32> { %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<5xf32>, tensor<0xi32>) -> tensor>> %1 = "tf.TensorListSetItem"(%0, %arg2, %arg3) : (tensor>>, tensor, tensor) -> tensor>> %2 = "tf.TensorListStack"(%1, %arg1) : (tensor>>, tensor<0xi32>) -> tensor<5xf32> @@ -89,24 +86,23 @@ func @tensorlistSetItemWithScalarElements(tensor<5xf32>, tensor<0xi32>, tensor } -func @tensorlistReserve(tensor<3xi32>, tensor, tensor) -> tensor { -^bb0(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor): +func @tensorlistReserve(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor>> %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor<3xi32>) -> tensor return %1 : tensor // CHECK-LABEL: tensorlistReserve -// CHECK: %cst = constant dense<0> : tensor -// CHECK: %0 = "tf.ExpandDims"(%arg1, %cst) : (tensor, tensor) -> tensor<1xi32> -// CHECK: %1 = "tf.Concat"(%cst, %0, %arg0) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> -// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor -// CHECK: %2 = "tf.Fill"(%1, %cst_0) : (tensor<4xi32>, tensor) -> tensor -// CHECK: %3 = "tf.Gather"(%2, %arg2) {validate_indices = true} : (tensor, tensor) -> tensor -// CHECK: return %3 : tensor +// CHECK-DAG: [[ZERO1:%cst.*]] = constant dense<0> : tensor +// CHECK-DAG: [[ZERO2:%cst.*]] = constant dense<0> : tensor +// CHECK-DAG: [[DIM0:%.*]] = "tf.ExpandDims"(%arg1, [[ZERO1]]) : (tensor, tensor) -> tensor<1xi32> +// CHECK-DAG: [[SHAPE:%.*]] = "tf.Concat"([[ZERO2]], [[DIM0]], %arg0) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> +// CHECK-DAG: [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor +// CHECK: [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor) -> tensor +// CHECK: [[RESULT:%.*]] = "tf.Gather"([[LIST]], %arg2) {validate_indices = true} : (tensor, tensor) -> tensor +// CHECK: return [[RESULT]] : tensor } -func @tensorlistReserveUnrankedElements(tensor, tensor, tensor) -> tensor<*xf32> { -^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor): +func @tensorlistReserveUnrankedElements(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<*xf32> { %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor, tensor) -> tensor>> %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor) -> tensor<*xf32> return %1 : tensor<*xf32> @@ -117,13 +113,42 @@ func @tensorlistReserveUnrankedElements(tensor, tensor, tensor) // CHECK: return [[RESULT2]] : tensor<*xf32> } -func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> { -^bb0(%arg0: tensor<2x3xf32>): +func @EmptyTensorList(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tf.EmptyTensorList"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor>> + %1 = "tf.TensorListGetItem"(%0, %arg2, %arg0) : (tensor>>, tensor, tensor<3xi32>) -> tensor + return %1 : tensor + +// CHECK-LABEL: EmptyTensorList +// CHECK-SAME: ([[ELEM_SHAPE:%.*]]: tensor<3xi32>, [[MAX_ELEMS:%.*]]: tensor, [[IDX:%.*]]: tensor) +// CHECK-DAG: [[DIM0:%cst.*]] = constant dense<0> : tensor<1xi32> +// CHECK-DAG: [[ZERO:%cst.*]] = constant dense<0> : tensor +// CHECK-DAG: [[SHAPE:%.*]] = "tf.Concat"([[ZERO]], [[DIM0]], [[ELEM_SHAPE]]) {N = 2 : i64} : (tensor, tensor<1xi32>, tensor<3xi32>) -> tensor<4xi32> +// CHECK-DAG: [[VALUES:%.*]] = constant dense<0.000000e+00> : tensor +// CHECK: [[LIST:%.*]] = "tf.Fill"([[SHAPE]], [[VALUES]]) : (tensor<4xi32>, tensor) -> tensor +// CHECK: [[RESULT:%.*]] = "tf.Gather"([[LIST]], [[IDX]]) {validate_indices = true} : (tensor, tensor) -> tensor +// CHECK: return [[RESULT]] : tensor +} + +func @tensorlistPushBack(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<10xf32>) -> tensor { + %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor>> + %1 = "tf.TensorListPushBack"(%0, %arg2) : (tensor>>, tensor<10xf32>) -> tensor>> + %2 = "tf.TensorListStack"(%1, %arg1) : (tensor>>, tensor<1xi32>) -> tensor + return %2 : tensor + +// CHECK-LABEL: tensorlistPushBack +// CHECK-SAME: ([[INPUT:%.*]]: tensor<3x10xf32>, [[ELEM_SHAPE:%.*]]: tensor<1xi32>, [[ITEM:%.*]]: tensor<10xf32>) +// CHECK: [[ZERO:%.*]] = constant dense<0> : tensor +// CHECK: [[EXP_ITEM:%.*]] = "tf.ExpandDims"([[ITEM]], [[ZERO]]) {{.*}} -> tensor<1x10xf32> +// CHECK: [[RESULT:%.*]] = "tf.Concat"(%cst, [[INPUT]], [[EXP_ITEM]]) {N = 2 : i64} : {{.*}} -> tensor +// CHECK: return [[RESULT]] : tensor +} + +func @tensorlistWhileLoop(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { %cst = constant dense<3> : tensor<1xi32> %cst_0 = constant dense<0> : tensor %cst_1 = constant dense<-1> : tensor %0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor>> - %1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond} : (tensor, tensor>>) -> (tensor, tensor>>) + %1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond, is_stateless = false} : (tensor, tensor>>) -> (tensor, tensor>>) %2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor>>, tensor) -> tensor<*xf32> return %2 : tensor<*xf32> @@ -136,8 +161,7 @@ func @tensorlistWhileLoop(tensor<2x3xf32>) -> tensor<*xf32> { // CHECK: return %0#1 : tensor<*xf32> } -func @tensorlistWhileBody(tensor<*xi32>, tensor) -> (tensor<*xi32>, tensor) { -^bb0(%arg0: tensor<*xi32>, %arg1: tensor): +func @tensorlistWhileBody(%arg0: tensor<*xi32>, %arg1: tensor) -> (tensor<*xi32>, tensor) { %cst = constant dense<1> : tensor %0 = "tf.Add"(%arg0, %cst) : (tensor<*xi32>, tensor) -> tensor<*xi32> %1 = "tf.Identity"(%arg1) : (tensor) -> tensor @@ -151,8 +175,7 @@ func @tensorlistWhileBody(tensor<*xi32>, tensor) -> (tensor<*xi32>, // CHECK: return %0, %1 : tensor<*xi32>, tensor<*xf32> } -func @tensorlistWhileCond(tensor<*xi32>, tensor) -> tensor<*xi1> { -^bb0(%arg0: tensor<*xi32>, %arg1: tensor): +func @tensorlistWhileCond(%arg0: tensor<*xi32>, %arg1: tensor) -> tensor<*xi1> { %cst = constant dense<2> : tensor %0 = "tf.Less"(%arg0, %cst) : (tensor<*xi32>, tensor) -> tensor<*xi1> return %0 : tensor<*xi1> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir index 6f0882f7260..408fb516dac 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_builtin.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1 +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-builtin-tflite-ops=false -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 # CHECK: loc("disable_builtin.mlir":2:1): is a TFLite builtin op but builtin emission is not enabled # CHECK-NEXT: Verification failed. diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir index be62118804a..c4dd8b5bacf 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/disable_flex.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[1]} -eq 1 +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 # CHECK: loc("disable_flex.mlir":96:8): error: 'tf.div' op is a Flex op but Flex ops are not enabled for emission # CHECK-NEXT: Verification failed. diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/dynamic_shape_constant.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/dynamic_shape_constant.mlir new file mode 100644 index 00000000000..1eae96217a5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/dynamic_shape_constant.mlir @@ -0,0 +1,25 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - + +func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %cst = "tfl.pseudo_const"() {value = dense<[1, 2]> : tensor<2xi32>} : () -> tensor + %0 = "tfl.pseudo_input" (%arg0) : (tensor<2xi32>) -> tensor<2xi32> + %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<2xi32>, tensor) -> tensor<2xi32> + return %1 : tensor<2xi32> +} + + +// CHECK: tensors: [ { +// CHECK-NEXT: shape: [ 2 ], +// CHECK-NEXT: type: INT32, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "tfl.pseudo_const", +// CHECK-NEXT: quantization: { +// CHECK-NEXT: +// CHECK-NEXT: } + +// CHECK: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 1, 0, 0, 0, 2, 0, 0, 0 ] +// CHECK-NEXT: }, { + diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir index 7702045547e..726441876cd 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/if_op.mlir @@ -1,12 +1,12 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s --dump-input-on-failure + // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { // CHECK-NEXT: builtin_code: LESS // CHECK-NEXT: }, { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "Experimental_If" +// CHECK-NEXT: builtin_code: IF // CHECK-NEXT: }, { // CHECK-EMPTY: // CHECK-NEXT: }, { @@ -52,8 +52,12 @@ // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 2, 0, 1 ], // CHECK-NEXT: outputs: [ 3 ], -// CHECK-NEXT: custom_options: [ 116, 104, 101, 110, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 101, 108, 115, 101, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ] -// CHECK-NEXT: } ] +// CHECK-NEXT: builtin_options_type: IfOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: then_subgraph_index: 1, +// CHECK-NEXT: else_subgraph_index: 2 +// CHECK-NEXT: } +// CHECK-NEXT: } ], // CHECK-NEXT: name: "main" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -88,7 +92,7 @@ // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: // CHECK-NEXT: } -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "cond_true" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -123,7 +127,7 @@ // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: // CHECK-NEXT: } -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "cond_false" // CHECK-NEXT: } ], // CHECK-NEXT: description: "MLIR Converted.", @@ -156,7 +160,7 @@ func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { %0 = "tfl.pseudo_input"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> %1 = "tfl.pseudo_input"(%arg1) : (tensor<1xf32>) -> tensor<1xf32> %2 = "tfl.less"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> - %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> return %3 : tensor<1xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir new file mode 100644 index 00000000000..ddb122f6e37 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/lstm.mlir @@ -0,0 +1,283 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: LSTM +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "tfl.pseudo_input", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "tfl.pseudo_input1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "tfl.pseudo_input2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "tfl.pseudo_input3", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 5, +// CHECK-NEXT: name: "tfl.pseudo_input4", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "tfl.pseudo_input5", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 7, +// CHECK-NEXT: name: "tfl.pseudo_input6", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 8, +// CHECK-NEXT: name: "tfl.pseudo_input7", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 9, +// CHECK-NEXT: name: "tfl.pseudo_input8", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 10, +// CHECK-NEXT: name: "tfl.pseudo_input9", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 11, +// CHECK-NEXT: name: "tfl.pseudo_input10", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 12, +// CHECK-NEXT: name: "tfl.pseudo_input11", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 13, +// CHECK-NEXT: name: "tfl.pseudo_input12", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 14, +// CHECK-NEXT: name: "tfl.pseudo_input13", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 15, +// CHECK-NEXT: name: "tfl.pseudo_input14", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 16, +// CHECK-NEXT: name: "tfl.pseudo_input15", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 17, +// CHECK-NEXT: name: "tfl.pseudo_input16", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 18, +// CHECK-NEXT: name: "tfl.pseudo_input17", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: name: "Const", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: name: "Const1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 21, +// CHECK-NEXT: name: "tfl.pseudo_input18", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 22, +// CHECK-NEXT: name: "tfl.pseudo_input19", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 23, +// CHECK-NEXT: name: "tfl.pseudo_input20", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 24, +// CHECK-NEXT: name: "tfl.pseudo_input21", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 25, +// CHECK-NEXT: name: "tfl.lstm", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23 ], +// CHECK-NEXT: outputs: [ 24 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 ], +// CHECK-NEXT: outputs: [ 24 ], +// CHECK-NEXT: builtin_options_type: LSTMOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: } ] +// CHECK-NEXT: } +// CHECK-EMPTY: + + +^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>, %arg22: tensor<4 x f32>, %arg23: tensor<4 x f32>): + %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32> + %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32> + %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32> + %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32> + %4 = "tfl.pseudo_input" (%arg4) : (tensor<4 x f32>) -> tensor<4 x f32> + %5 = "tfl.pseudo_input" (%arg5) : (tensor<4 x f32>) -> tensor<4 x f32> + %6 = "tfl.pseudo_input" (%arg6) : (tensor<4 x f32>) -> tensor<4 x f32> + %7 = "tfl.pseudo_input" (%arg7) : (tensor<4 x f32>) -> tensor<4 x f32> + %8 = "tfl.pseudo_input" (%arg8) : (tensor<4 x f32>) -> tensor<4 x f32> + %9 = "tfl.pseudo_input" (%arg9) : (tensor<4 x f32>) -> tensor<4 x f32> + %10 = "tfl.pseudo_input" (%arg10) : (tensor<4 x f32>) -> tensor<4 x f32> + %11 = "tfl.pseudo_input" (%arg11) : (tensor<4 x f32>) -> tensor<4 x f32> + %12 = "tfl.pseudo_input" (%arg12) : (tensor<4 x f32>) -> tensor<4 x f32> + %13 = "tfl.pseudo_input" (%arg13) : (tensor<4 x f32>) -> tensor<4 x f32> + %14 = "tfl.pseudo_input" (%arg14) : (tensor<4 x f32>) -> tensor<4 x f32> + %15 = "tfl.pseudo_input" (%arg15) : (tensor<4 x f32>) -> tensor<4 x f32> + %16 = "tfl.pseudo_input" (%arg16) : (tensor<4 x f32>) -> tensor<4 x f32> + %17 = "tfl.pseudo_input" (%arg17) : (tensor<4 x f32>) -> tensor<4 x f32> + %18 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %19 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %20 = "tfl.pseudo_input" (%arg20) : (tensor<4 x f32>) -> tensor<4 x f32> + %21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32> + %22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32> + %23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32> + %24 = "tfl.lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %24 : tensor<4xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir index eb9119d1c46..43ee98934e0 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/simple.mlir @@ -1,4 +1,5 @@ // RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - -strip-debug-info | flatbuffer_to_string - | FileCheck --dump-input-on-failure %s --check-prefix=STRIP func @main(tensor<3x2xi32>) -> tensor<3x2xi32> attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} { @@ -16,6 +17,8 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: type: INT32, // CHECK-NEXT: buffer: 1, // CHECK-NEXT: name: "input", +// STRIP: buffer: 1, +// STRIP-NEXT: name: "input", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } @@ -24,6 +27,8 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: type: INT32, // CHECK-NEXT: buffer: 2, // CHECK-NEXT: name: "Const", +// STRIP: buffer: 2, +// STRIP-NEXT: name: "0", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } @@ -32,6 +37,8 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: type: INT32, // CHECK-NEXT: buffer: 3, // CHECK-NEXT: name: "sub", +// STRIP: buffer: 3, +// STRIP-NEXT: name: "1", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } @@ -40,6 +47,8 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: type: INT32, // CHECK-NEXT: buffer: 4, // CHECK-NEXT: name: "SameNameAsOutput1", +// STRIP: buffer: 4, +// STRIP-NEXT: name: "2", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } @@ -48,6 +57,8 @@ func @main(tensor<3x2xi32>) -> tensor<3x2xi32> // CHECK-NEXT: type: INT32, // CHECK-NEXT: buffer: 5, // CHECK-NEXT: name: "SameNameAsOutput", +// STRIP: buffer: 5, +// STRIP-NEXT: name: "SameNameAsOutput", // CHECK-NEXT: quantization: { // CHECK-EMPTY: // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir new file mode 100644 index 00000000000..3ab36f554ae --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/svdf.mlir @@ -0,0 +1,93 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: SVDF +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "tfl.pseudo_input", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "tfl.pseudo_input1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "tfl.pseudo_input2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "tfl.pseudo_input3", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: name: "Const", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "tfl.svdf", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1, 2, 3 ], +// CHECK-NEXT: outputs: [ 5 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4 ], +// CHECK-NEXT: outputs: [ 5 ], +// CHECK-NEXT: builtin_options_type: SVDFOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: rank: 2, +// CHECK-NEXT: fused_activation_function: RELU +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: } ] +// CHECK-NEXT: } +// CHECK-EMPTY: + +^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>): + %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32> + %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32> + %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32> + %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32> + %4 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %5 = "tfl.svdf"(%0, %1, %2, %3, %4) {fused_activation_function = "RELU", rank = 2 : i32} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %5 : tensor<4xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir new file mode 100644 index 00000000000..e2ffb24a6b3 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_lstm.mlir @@ -0,0 +1,282 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_LSTM +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "tfl.pseudo_input", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "tfl.pseudo_input1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "tfl.pseudo_input2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "tfl.pseudo_input3", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 5, +// CHECK-NEXT: name: "tfl.pseudo_input4", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "tfl.pseudo_input5", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 7, +// CHECK-NEXT: name: "tfl.pseudo_input6", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 8, +// CHECK-NEXT: name: "tfl.pseudo_input7", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 9, +// CHECK-NEXT: name: "tfl.pseudo_input8", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 10, +// CHECK-NEXT: name: "tfl.pseudo_input9", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 11, +// CHECK-NEXT: name: "tfl.pseudo_input10", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 12, +// CHECK-NEXT: name: "tfl.pseudo_input11", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 13, +// CHECK-NEXT: name: "tfl.pseudo_input12", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 14, +// CHECK-NEXT: name: "tfl.pseudo_input13", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 15, +// CHECK-NEXT: name: "tfl.pseudo_input14", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 16, +// CHECK-NEXT: name: "tfl.pseudo_input15", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 17, +// CHECK-NEXT: name: "tfl.pseudo_input16", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 18, +// CHECK-NEXT: name: "tfl.pseudo_input17", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: name: "Const", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: name: "Const1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 21, +// CHECK-NEXT: name: "tfl.pseudo_input18", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 22, +// CHECK-NEXT: name: "tfl.pseudo_input19", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 23, +// CHECK-NEXT: name: "tfl.pseudo_input20", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 24, +// CHECK-NEXT: name: "tfl.pseudo_input21", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 25, +// CHECK-NEXT: name: "tfl.unidirectional_sequence_lstm", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 20, 21, 22, 23 ], +// CHECK-NEXT: outputs: [ 24 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 ], +// CHECK-NEXT: outputs: [ 24 ], +// CHECK-NEXT: builtin_options_type: UnidirectionalSequenceLSTMOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: time_major: true +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: } ] +// CHECK-NEXT: } +// CHECK-EMPTY: + +^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>, %arg4: tensor<4 x f32>, %arg5: tensor<4 x f32>, %arg6: tensor<4 x f32>, %arg7: tensor<4 x f32>, %arg8: tensor<4 x f32>, %arg9: tensor<4 x f32>, %arg10: tensor<4 x f32>, %arg11: tensor<4 x f32>, %arg12: tensor<4 x f32>, %arg13: tensor<4 x f32>, %arg14: tensor<4 x f32>, %arg15: tensor<4 x f32>, %arg16: tensor<4 x f32>, %arg17: tensor<4 x f32>, %arg20: tensor<4 x f32>, %arg21: tensor<4 x f32>, %arg22: tensor<4 x f32>, %arg23: tensor<4 x f32>): + %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32> + %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32> + %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32> + %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32> + %4 = "tfl.pseudo_input" (%arg4) : (tensor<4 x f32>) -> tensor<4 x f32> + %5 = "tfl.pseudo_input" (%arg5) : (tensor<4 x f32>) -> tensor<4 x f32> + %6 = "tfl.pseudo_input" (%arg6) : (tensor<4 x f32>) -> tensor<4 x f32> + %7 = "tfl.pseudo_input" (%arg7) : (tensor<4 x f32>) -> tensor<4 x f32> + %8 = "tfl.pseudo_input" (%arg8) : (tensor<4 x f32>) -> tensor<4 x f32> + %9 = "tfl.pseudo_input" (%arg9) : (tensor<4 x f32>) -> tensor<4 x f32> + %10 = "tfl.pseudo_input" (%arg10) : (tensor<4 x f32>) -> tensor<4 x f32> + %11 = "tfl.pseudo_input" (%arg11) : (tensor<4 x f32>) -> tensor<4 x f32> + %12 = "tfl.pseudo_input" (%arg12) : (tensor<4 x f32>) -> tensor<4 x f32> + %13 = "tfl.pseudo_input" (%arg13) : (tensor<4 x f32>) -> tensor<4 x f32> + %14 = "tfl.pseudo_input" (%arg14) : (tensor<4 x f32>) -> tensor<4 x f32> + %15 = "tfl.pseudo_input" (%arg15) : (tensor<4 x f32>) -> tensor<4 x f32> + %16 = "tfl.pseudo_input" (%arg16) : (tensor<4 x f32>) -> tensor<4 x f32> + %17 = "tfl.pseudo_input" (%arg17) : (tensor<4 x f32>) -> tensor<4 x f32> + %18 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %19 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %20 = "tfl.pseudo_input" (%arg20) : (tensor<4 x f32>) -> tensor<4 x f32> + %21 = "tfl.pseudo_input" (%arg21) : (tensor<4 x f32>) -> tensor<4 x f32> + %22 = "tfl.pseudo_input" (%arg22) : (tensor<4 x f32>) -> tensor<4 x f32> + %23 = "tfl.pseudo_input" (%arg23) : (tensor<4 x f32>) -> tensor<4 x f32> + %24 = "tfl.unidirectional_sequence_lstm"(%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23) {fused_activation_function = "NONE", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %24 : tensor<4xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir new file mode 100644 index 00000000000..3d91f66501d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unidirectional_sequence_rnn.mlir @@ -0,0 +1,93 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_to_string - | FileCheck %s + +func @main(tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>, tensor<4 x f32>) -> tensor<4 x f32> { +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: UNIDIRECTIONAL_SEQUENCE_RNN +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "tfl.pseudo_input", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "tfl.pseudo_input1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "tfl.pseudo_input2", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 4, +// CHECK-NEXT: name: "tfl.pseudo_input3", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: name: "Const", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: }, +// CHECK-NEXT: is_variable: true +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: buffer: 6, +// CHECK-NEXT: name: "tfl.unidirectional_sequence_rnn", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1, 2, 3 ], +// CHECK-NEXT: outputs: [ 5 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1, 2, 3, 4 ], +// CHECK-NEXT: outputs: [ 5 ], +// CHECK-NEXT: builtin_options_type: SequenceRNNOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: time_major: true, +// CHECK-NEXT: fused_activation_function: TANH +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: } ] +// CHECK-NEXT: } +// CHECK-EMPTY: + +^bb0(%arg0: tensor<4 x f32>, %arg1: tensor<4 x f32>, %arg2: tensor<4 x f32>, %arg3: tensor<4 x f32>): + %0 = "tfl.pseudo_input" (%arg0) : (tensor<4 x f32>) -> tensor<4 x f32> + %1 = "tfl.pseudo_input" (%arg1) : (tensor<4 x f32>) -> tensor<4 x f32> + %2 = "tfl.pseudo_input" (%arg2) : (tensor<4 x f32>) -> tensor<4 x f32> + %3 = "tfl.pseudo_input" (%arg3) : (tensor<4 x f32>) -> tensor<4 x f32> + %4 = "tfl.pseudo_const" () {value = dense<0.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %5 = "tfl.unidirectional_sequence_rnn"(%0, %1, %2, %3, %4) {fused_activation_function = "TANH", time_major = true} : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %5 : tensor<4xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir index 14f8174e9bf..eb20f3759dd 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir @@ -1,4 +1,4 @@ -// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { ^bb0(%arg0: tensor<3x2xi32>): diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir index fd403aa72c5..bf76f4feae6 100644 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/while_op.mlir @@ -3,8 +3,7 @@ // CHECK: { // CHECK-NEXT: version: 3, // CHECK-NEXT: operator_codes: [ { -// CHECK-NEXT: builtin_code: CUSTOM, -// CHECK-NEXT: custom_code: "Experimental_While" +// CHECK-NEXT: builtin_code: WHILE // CHECK-NEXT: }, { // CHECK-NEXT: builtin_code: GREATER // CHECK-NEXT: }, { @@ -49,8 +48,12 @@ // CHECK-NEXT: operators: [ { // CHECK-NEXT: inputs: [ 0, 1 ], // CHECK-NEXT: outputs: [ 2, 3 ], -// CHECK-NEXT: custom_options: [ 99, 111, 110, 100, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 98, 111, 100, 121, 95, 115, 117, 98, 103, 114, 97, 112, 104, 95, 105, 110, 100, 101, 120, 0, 2, 21, 42, 2, 1, 2, 2, 1, 4, 4, 4, 36, 1 ] -// CHECK-NEXT: } ] +// CHECK-NEXT: builtin_options_type: WhileOptions, +// CHECK-NEXT: builtin_options: { +// CHECK-NEXT: cond_subgraph_index: 1, +// CHECK-NEXT: body_subgraph_index: 2 +// CHECK-NEXT: } +// CHECK-NEXT: } ], // CHECK-NEXT: name: "main" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -91,7 +94,7 @@ // CHECK-NEXT: opcode_index: 1, // CHECK-NEXT: inputs: [ 0, 2 ], // CHECK-NEXT: outputs: [ 3 ] -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "cond" // CHECK-NEXT: }, { // CHECK-NEXT: tensors: [ { @@ -151,7 +154,7 @@ // CHECK-NEXT: builtin_options: { // CHECK-EMPTY: // CHECK-NEXT: } -// CHECK-NEXT: } ] +// CHECK-NEXT: } ], // CHECK-NEXT: name: "body" // CHECK-NEXT: } ], // CHECK-NEXT: description: "MLIR Converted.", @@ -192,7 +195,7 @@ func @main(%arg0: tensor, %arg1: tensor<1xf32>) -> tensor<1xf32> { // While %0 is greater than zero, element wise add %1 with itself. %2:2 = "tf.While"(%0, %1) { - cond = @cond, body = @body + cond = @cond, body = @body, is_stateless = false } : (tensor, tensor<1xf32>) -> (tensor, tensor<1xf32>) return %2#1 : tensor<1xf32> } diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index aaa560c0fd6..fe6dc486822 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -80,7 +80,6 @@ func @testGatherUnsupportedRank(%arg0 : tensor, %arg1 : tensor<1xi32>) -> t return %0 : tensor } - // ----- // CHECK-LABEL: testAbs @@ -155,6 +154,26 @@ func @testSinWithWrongInputType(tensor) -> tensor { // ----- +// test invalid Sqrt input +func @testSqrtWithWrongInputType(tensor) -> tensor { +^bb0(%arg0: tensor): + // expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of floating-point values}} + %0 = "tfl.sqrt"(%arg0): (tensor) -> tensor + return %0#0 : tensor +} + +// ----- + +// test invalid Square input +func @testSquareWithWrongInputType(tensor) -> tensor { +^bb0(%arg0: tensor): + // expected-error @+1 {{tfl.square' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}} + %0 = "tfl.square"(%arg0): (tensor) -> tensor + return %0#0 : tensor +} + +// ----- + // CHECK-LABEL: testSqrt func @testSqrt(tensor) -> tensor { ^bb0(%arg0: tensor): @@ -171,6 +190,18 @@ func @testSquare(tensor) -> tensor { return %0 : tensor } +func @testQuantizedSquare(tensor>) -> tensor> { +^bb0(%arg0: tensor>): + %0 = "tfl.square"(%arg0): (tensor>) -> tensor> + return %0 : tensor> +} + +func @testQuantizedResizeNearestNeighbor(tensor>, tensor) -> tensor> { +^bb0(%arg0: tensor>, %arg1: tensor): + %0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false } : (tensor>, tensor) -> tensor> + return %0 : tensor> +} + // CHECK-LABEL: testTanh func @testTanh(tensor) -> tensor { ^bb0(%arg0: tensor): @@ -179,6 +210,18 @@ func @testTanh(tensor) -> tensor { return %0 : tensor } +// CHECK-LABEL: testTanhWithQI8 +func @testTanhWithQI8(%arg0: tensor>) -> tensor> { + %0 = "tfl.tanh"(%arg0): (tensor>) -> tensor> + return %0 : tensor> +} + +// CHECK-LABEL: testTanhWithQUI8 +func @testTanhWithQUI8(%arg0: tensor>) -> tensor> { + %0 = "tfl.tanh"(%arg0): (tensor>) -> tensor> + return %0 : tensor> +} + // CHECK-LABEL: testZerosLike func @testZerosLike(tensor) -> tensor { ^bb0(%arg0: tensor): @@ -287,11 +330,9 @@ func @testFloorDivF32(%arg0: tensor<2 x f32>, %arg1: tensor<2 x i32>) -> tensor< // ----- // CHECK-LABEL: testFloorMod -func @testFloorMod(tensor, tensor) -> tensor { -^bb0(%arg0: tensor, %arg1: tensor): - // CHECK: tfl.floor_mod %arg0, %arg1 - %0 = tfl.floor_mod %arg0, %arg1 : tensor - return %0#0 : tensor +func @testFloorMod(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor } // CHECK-LABEL: testPow @@ -310,6 +351,13 @@ func @testConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) return %0 : tensor<256x30x30x16xf32> } + +func @testConv2DNoBias(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { + // CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2) + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU6"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, none) -> tensor<256x30x30x16xf32> + return %0 : tensor<256x30x30x16xf32> +} + // CHECK-LABEL: testFakeQuant func @testFakeQuant(tensor, f32, f32) -> tensor { ^bb0(%arg0: tensor, %arg1: f32, %arg2: f32): @@ -489,13 +537,22 @@ func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> { // test invalid Logistic input func @testLogisticWithWrongInputType(tensor) -> tensor { ^bb0(%arg0: tensor): - // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point values}} + // expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}} %0 = "tfl.logistic"(%arg0): (tensor) -> tensor return %0#0 : tensor } // ----- +// CHECK-LABEL: testUnidirectionalSequenceRnn +func @testUnidirectionalSequenceRnn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { + // CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: testUnidirectionalSequenceLstm func @testUnidirectionalSequenceLstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor, %arg20: tensor, %arg21: tensor, %arg22: tensor, %arg23: tensor) -> tensor { // CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor @@ -768,6 +825,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- +func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32> + return %0 : tensor<1x4x2xi32> +} + +// ----- + +func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> + return %0 : tensor<2x1x4xi32> +} + +// ----- + func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // expected-error @+1 {{input count should match 'values_count' attribute}} %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> @@ -776,6 +849,22 @@ func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- +func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{operands should be of the same type}} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}} + %0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) @@ -785,6 +874,14 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- +func @unpackQuantized(%arg0: tensor<2x3x!quant.uniform>) -> tensor<2x!quant.uniform> { + %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3x!quant.uniform>) -> (tensor<2x!quant.uniform>, tensor<2x!quant.uniform>, tensor<2x!quant.uniform>) + return %0#0 : tensor<2x!quant.uniform> + +} + +// ----- + func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // expected-error @+1 {{output count should match 'num' attribute}} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) @@ -879,7 +976,7 @@ func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) // ----- func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor { - // expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float values}} + // expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float or QI8 type or QUI8 type values}} %0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor return %0 : tensor } @@ -893,6 +990,18 @@ func @testStridedSlice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: return %0 : tensor<1x2x2x5xf32> } +// CHECK-LABEL: testStridedSliceWithQI8 +func @testStridedSliceWithQI8(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> { + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> + return %0 : tensor<1x2x2x5x!quant.uniform> +} + +// CHECK-LABEL: testStridedSliceWithQUI8 +func @testStridedSliceWithQUI8(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> { + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform> + return %0 : tensor<1x2x2x5x!quant.uniform> +} + // ----- func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> { @@ -917,3 +1026,401 @@ func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor, %0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xi8> return %0 : tensor<*xi8> } + +// ----- + +func @testArgMax(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { + // CHECK: "tfl.arg_max"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor) -> tensor + %0 = "tfl.arg_max"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testArgMin(%arg0: tensor<3xi32>, %arg1: tensor) -> tensor { + // CHECK: "tfl.arg_min"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor) -> tensor + %0 = "tfl.arg_min"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: testSpaceToDepth +func @testSpaceToDepthF32(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> { + // CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32> + // CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> + %0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> + return %0 : tensor<1x1x1x4xf32> +} + +// ----- + +func @testSpaceToDepthInvalidOutputType(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xi32> { + // expected-error @+1 {{'tfl.space_to_depth' op failed to verify that input and output must have same element type}} + %0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xi32> + return %0 : tensor<1x1x1x4xi32> +} + +// ----- + +func @testRange(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + %0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testRangeNonScalarTensorInput(%arg0 : tensor<1xi32>, %arg1 : tensor, %arg2 : tensor) -> tensor { + // expected-error @+1 {{op failed to verify that operand 0 is 0-D}} + %0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<1xi32>, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testRangeOutputTypeMismatch(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // expected-error @+1 {{op failed to verify that operands and output must have same element type}} + %0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @transpose(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> { + %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) -> tensor<2x2xi32> { + // expected-error @+1 {{op operand #1 must be tensor of 32-bit integer values}} + %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xf32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_element_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{input and output must have same element type}} + %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + + +// ----- + +func @transpose_1d_perm(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2x2xi32>) -> tensor<2x2xi32> { + // expected-error @+1 {{op failed to verify that operand 1 is 1-D}} + %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +func @anyWithI64Axis(%arg0: tensor<2x2xi1>, %arg1: tensor) -> tensor { + // expected-error @+1 {{tfl.reduce_any' op operand #1 must be tensor of 32-bit integer values}} + %0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testRoundInvalidInputType(%arg: tensor) -> tensor { + // expected-error @+1 {{'tfl.round' op operand #0 must be tensor of 32-bit float values}} + %0 = "tfl.round"(%arg) : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testSplitWithQuantizedTypes(%arg0 : tensor, %arg1 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.split"(%arg0, %arg1) {num_splits = 1 : i32} : (tensor, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %0 : tensor<10x!quant.uniform> +} + +// ----- + +func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform>, %arg1 : tensor, %arg2 : tensor) -> tensor<10x!quant.uniform> { + %0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform>, tensor, tensor) -> tensor<10x!quant.uniform> + return %0 : tensor<10x!quant.uniform> +} + +// ----- + +func @whereWithI32Input(%arg0: tensor<3x5xi32>) -> tensor { + // expected-error @+1 {{'tfl.where' op operand #0 must be tensor of 1-bit integer values}} + %0 = "tfl.where"(%arg0) : (tensor<3x5xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testMinimumWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform>, %arg1 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.minimum"(%arg0, %arg1) : (tensor<10x!quant.uniform>, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %0 : tensor<10x!quant.uniform> +} + +// ----- + +func @testMaximumWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform>, %arg1 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.maximum"(%arg0, %arg1) : (tensor<10x!quant.uniform>, tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %0 : tensor<10x!quant.uniform> +} + +// ----- + +func @testReluWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.relu"(%arg0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %0 : tensor<10x!quant.uniform> +} + +// ----- + +func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> { + %0 = "tfl.relu6"(%arg0) : (tensor<10x!quant.uniform>) -> tensor<10x!quant.uniform> + return %0 : tensor<10x!quant.uniform> +} + +// ----- + +func @testEmbeddingLookup(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testEmbeddingLookupInvalidResultType(%arg0 : tensor, %arg1 : tensor) -> tensor { + // expected-error @+1 {{'tfl.embedding_lookup' op result #0 must be tensor of 32-bit float or 8-bit integer or TFLite uint8 type values}} + %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor, %arg1 : tensor) -> tensor { + // expected-error @+1 {{'tfl.embedding_lookup' op failed to verify that value and output must have same element type}} + %0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor,tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> { + %0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform>) -> tensor<1x56x56x192x!quant.uniform> + return %0 : tensor<1x56x56x192x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: testSvdf +func @testSvdf(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { + // CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor + %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @testSvdfUnsupportedType(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor { + // expected-error @+1 {{'tfl.svdf' op operand #0 must be tensor of 32-bit float or 8-bit integer values}} + %0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor, tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} +// ----- + +// CHECK-LABEL: testDepthToSpace +func @testDepthToSpaceF32(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> { + // CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32> + // CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> + %0 = "tfl.depth_to_space"(%arg0) {block_size = 2: i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> + return %0 : tensor<1x2x2x1xf32> +} + +// ----- + +func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xi32> { + // expected-error @+1 {{'tfl.depth_to_space' op failed to verify that input and output must have same element type}} + %0 = "tfl.depth_to_space"(%arg0) {block_size = 2: i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xi32> + return %0 : tensor<1x2x2x1xi32> +} + +// ----- + +func @testSlice(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor { + %0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSliceBadBeginDimension(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2xi32>, %arg2: tensor<3xi32>) -> tensor { + // expected-error @+1 {{begin tensor elements size is not equal to input tensor rank}} + %0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<2xi32>, tensor<3xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSliceBadSizeDimension(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<2xi32>) -> tensor { + // expected-error @+1 {{size tensor elements size is not equal to input tensor rank}} + %0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<2xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSliceBadBegin(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor { + %cst = constant dense<[2, -1, 5]> : tensor<3xi32> + // expected-error @+1 {{begin[1] cannot be negative}} + %0 = "tfl.slice"(%arg0, %cst, %arg1) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSliceNegativeSize(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor { + %cst = constant dense<[-2, -1, 5]> : tensor<3xi32> + // expected-error @+1 {{size[0] cannot be negative other than -1}} + %0 = "tfl.slice"(%arg0, %arg1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSliceSizeOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor { + %cst = constant dense<[2, 1, 5]> : tensor<3xi32> + %cst_1 = constant dense<[0, 1, 1]> : tensor<3xi32> + // expected-error @+1 {{begin[2] + size[2] cannot exceed dimension length: 5}} + %0 = "tfl.slice"(%arg0, %cst_1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSliceBeginOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor { + %cst = constant dense<[1, 1, 1]> : tensor<3xi32> + %cst_1 = constant dense<[2, 1, 3]> : tensor<3xi32> + // expected-error @+1 {{begin[0] cannot exceed dimension length: 2}} + %0 = "tfl.slice"(%arg0, %cst_1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor + return %0 : tensor +} + +// ----- + +func @testSplitOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () { + %split_dim = constant dense<0> : tensor + // expected-error @+1 {{'tfl.split' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}} + "tfl.split"(%split_dim, %arg0) {num_splits = 0 : i32} : (tensor, tensor<16xf32>) -> () + return +} + +// ----- + +func @testSplitOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) { + %split_dim = constant dense<0> : tensor + // expected-error @+1 {{'tfl.split' op output count should match 'num_splits' attribute}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 4 : i32} : (tensor, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) + return %0, %1 : tensor<8xf32>, tensor<8xf32> +} + +// ----- + +func @testSplitOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> { + %split_dim = constant dense<0> : tensor<2x2xi32> + // expected-error @+1 {{'tfl.split' op operand #0 must be tensor}} + %0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<2x2xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32> + return %0 : tensor<16x4x4xf32> +} + +// ----- + +func @testSplitOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor) -> tensor<16x4x4xf32> { + // expected-error @+1 {{'tfl.split' op operand #0 must be tensor}} + %0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor, tensor<16x4x4xf32>) -> tensor<16x4x4xf32> + return %0 : tensor<16x4x4xf32> +} + +// ----- + +func @testSplitOpWithOutOfRangeSplitDim(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) { + %split_dim = constant dense<1> : tensor + // expected-error @+1 {{'tfl.split' op 'split_dim' should be in [-rank, rank)}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) + return %0, %1 : tensor<8xf32>, tensor<8xf32> +} + +// ----- + +func @testSplitOpWithOutOfRangeSplitDimTFLConst(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) { + %split_dim = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor + // expected-error @+1 {{'tfl.split' op 'split_dim' should be in [-rank, rank)}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) + return %0, %1 : tensor<8xf32>, tensor<8xf32> +} + +// ----- + +func @testSplitOpWithOutOfRangeSplitDimNegative(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) { + %split_dim = constant dense<-2> : tensor + // expected-error @+1 {{'tfl.split' op 'split_dim' should be in [-rank, rank)}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) + return %0, %1 : tensor<8xf32>, tensor<8xf32> +} + +// ----- + +func @testSplitOpWithUnevenDivision(%arg0 : tensor<16xf32>) -> (tensor<6xf32>, tensor<5xf32>, tensor<5xf32>) { + %split_dim = constant dense<0> : tensor + // expected-error @+1 {{'tfl.split' op 'num_splits' should evenly divide 'split_dim' axis}} + %0, %1, %2 = "tfl.split"(%split_dim, %arg0) {num_splits = 3 : i32} : (tensor, tensor<16xf32>) -> (tensor<6xf32>, tensor<5xf32>, tensor<5xf32>) + return %0, %1, %2 : tensor<6xf32>, tensor<5xf32>, tensor<5xf32> +} + +// ----- + +func @testSplitOpWithMismatchTensorTypeSplitDimOut0(%arg0 : tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %split_dim = constant dense<0> : tensor + // expected-error @+1 {{'tfl.split' op output #0 should be 'tensor<8xf32>'}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>) + return %0, %1 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +func @testSplitOpWithMismatchTensorTypeSplitDimOut1(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) { + %split_dim = constant dense<0> : tensor + // expected-error @+1 {{'tfl.split' op output #1 should be 'tensor<8xf32>'}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) + return %0, %1 : tensor<8xf32>, tensor<4xf32> +} + +// ----- + +func @testSplitOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>) { + %split_dim = constant dense<0> : tensor + // expected-error @+1 {{'tfl.split' op output #0 should be 'tensor<8x4xf32>'}} + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>) + return %0, %1 : tensor<8x2xf32>, tensor<8x2xf32> +} + +// ----- + +func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>) { + %split_dim_0 = constant dense<0> : tensor + %0, %1 = "tfl.split"(%split_dim_0, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>) + %split_dim_1 = constant dense<1> : tensor + %2, %3 = "tfl.split"(%split_dim_1, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>) + return %0, %1, %2, %3 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32> +} + +// ----- + +func @testSplitOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>) { + %split_dim = constant dense<0> : tensor + %0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor, tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>) + return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index e7ebace3a54..15c4898341f 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -96,6 +96,54 @@ func @intermOpUsedTwice(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf } +// CHECK-LABEL: @fuseMulIntoFullyConnected +func @fuseMulIntoFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { + %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst1 = constant dense<2.0> : tensor<2xf32> + %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32> + + %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + + return %1 : tensor<4x2xf32> + +// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> +// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: return %0 : tensor<4x2xf32> +} + +// CHECK-LABEL: @fuseMulIntoFullyConnectedBroadcast +func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { + %cst0 = constant dense<[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32> + %cst1 = constant dense<2.0> : tensor<2xf32> + %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32> + %0 = "tfl.fully_connected"(%arg0, %cst0, %cst1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<1x2xf32> + // %cst2 isn't broadcast-compatible to %cst0, but tf.Mul is able to fold them. + %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32> + return %1 : tensor<1x2xf32> + +// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [2.000000e+00, 4.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> +// CHECK: %cst_0 = constant dense<[2.000000e+00, 4.000000e+00]> : tensor<2xf32> +// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %cst_0) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: return %0 : tensor<1x2xf32> +} + +// CHECK-LABEL: @fuseMulIntoFullyConnectedNoBias +func @fuseMulIntoFullyConnectedNoBias(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<4x2xf32> { + %cst0 = constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst2 = constant dense<[1.0, 2.0]> : tensor<2xf32> + + %0 = "tfl.fully_connected"(%arg0, %cst0, %arg1) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> + %1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + + return %1 : tensor<4x2xf32> + +// CHECK: %cst = constant dense<{{\[\[}}1.000000e+00, 4.000000e+00], [3.000000e+00, 8.000000e+00]]> : tensor<2x2xf32> +// CHECK: %0 = "tfl.fully_connected"(%arg0, %cst, %arg1) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> +// CHECK: return %0 : tensor<4x2xf32> +} + // CHECK-LABEL: @fuseMulIntoDepthwiseConv2d func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> { %cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32> @@ -130,11 +178,11 @@ func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x // CHECK: return %1 } -// CHECK-LABEL: @FuseFullyConnectedAdd -func @FuseFullyConnectedAdd(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { +// CHECK-LABEL: @FuseFullyConnectedAddUnit +func @FuseFullyConnectedAddUnit(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { %cst = constant unit - %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input") - %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input") + %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> + %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> %cst2 = constant dense<2.0> : tensor<40x40xf32> %2 = "tfl.fully_connected" (%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>) @@ -146,6 +194,37 @@ func @FuseFullyConnectedAdd(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) // CHECK: %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> // CHECK: %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> // CHECK: %2 = "tfl.fully_connected"(%0, %1, %cst) + // CHECK: return %2 +} + +// CHECK-LABEL: @FuseFullyConnectedAddConst +func @FuseFullyConnectedAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %cst = constant dense<3.0> : tensor<40x40xf32> + %0 = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input") + %1 = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> loc("Input") + %cst2 = constant dense<2.0> : tensor<40x40xf32> + + %2 = "tfl.fully_connected" (%0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>) + %3 = "tfl.add"(%2, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32> + + return %3 : tensor<40x40xf32> + + // CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32> + // CHECK: %[[cst_0:.*]] = "tfl.pseudo_input"(%arg0) : (tensor<40x37xf32>) -> tensor<40x37xf32> + // CHECK: %[[cst_1:.*1]] = "tfl.pseudo_input"(%arg1) : (tensor<40x37xf32>) -> tensor<40x37xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%[[cst_0]], %[[cst_1]], %[[cst]]) + // CHECK: return %[[fc]] +} + +// CHECK-LABEL: @FuseFullyConnectedRelu +func @FuseFullyConnectedRelu(%arg0: tensor<1x256xf32>, %arg1: tensor<128x256xf32>, %arg2: tensor<128xf32>) -> tensor<1x128xf32> { + %0 = "tfl.fully_connected" (%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x256xf32>, tensor<128x256xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %1 = "tfl.relu"(%0) : (tensor<1x128xf32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + + // CHECK: %[[RES:[0-9].*]] = "tfl.fully_connected" + // CHECK-SAME: fused_activation_function = "RELU" + // CHECK: return %[[RES]] } // CHECK-LABEL: @NoPadStridedSliceNonNewAxisMask @@ -176,3 +255,53 @@ func @PadStridedSliceNewAxisMask(%arg0: tensor<2x3xf32>) -> tensor<1x2x3x1xf32> // CHECK: %1 = "tfl.reshape"(%0) : (tensor<2x3xf32>) -> tensor<1x2x3x1xf32> // CHECK: %2 = "tfl.strided_slice"(%1, %cst, %cst, %cst_0) {begin_mask = 15 : i32, ellipsis_mask = 0 : i32, end_mask = 15 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<1x2x3x1xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x2x3x1xf32> } + +// CHECK-LABEL: @L2NormalizePattern +func @L2NormalizePattern(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %cst = constant dense<[0]> : tensor<1xi32> + %0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %2 = "tfl.rsqrt"(%1) : (tensor) -> tensor + %3 = "tfl.mul"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> + return %3: tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: return %[[RES]] +} + +// CHECK-LABEL: @L2NormalizePattern1 +func @L2NormalizePattern1(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %cst = constant dense<[0]> : tensor<1xi32> + %0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %2 = "tfl.sqrt"(%1) : (tensor) -> tensor + %3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> + return %3: tensor<2xf32> + // CHECK: %[[RES:[0-9].*]] = "tfl.l2_normalization"([[INPUT:%.*]]) {fused_activation_function = "NONE"} : (tensor<2xf32>) -> tensor<2xf32> + // CHECK: return %[[RES]] +} + +// CHECK-LABEL: @InvalidL2NormalizePattern +// Div and square ops must take the same argument to be eligible. +func @InvalidL2NormalizePattern(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf32> { + %cst = constant dense<[0]> : tensor<1xi32> + %0 = "tfl.square"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor + %2 = "tfl.sqrt"(%1) : (tensor) -> tensor + %3 = "tfl.div"(%arg1, %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> + return %3: tensor<2xf32> + // CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2xf32>, tensor) -> tensor<2xf32> + // CHECK: return %3 +} + +// CHECK-LABEL: @InvalidL2NormalizePatternMorethan1Dimension +// Input has higher rank, it should be limited to 1D only. +func @InvalidL2NormalizePatternMorethan1Dimension(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<[0]> : tensor<1xi32> + %0 = "tfl.square"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %1 = "tfl.sum"(%0, %cst) {keep_dims = false} : (tensor<2x2xf32>, tensor<1xi32>) -> tensor + %2 = "tfl.sqrt"(%1) : (tensor) -> tensor + %3 = "tfl.div"(%arg0, %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %3: tensor<2x2xf32> + // CHECK: %3 = "tfl.div"([[INPUT:%.*]], %2) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + // CHECK: return %3 +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir new file mode 100644 index 00000000000..cabbc4d9da5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -0,0 +1,12 @@ +// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s | FileCheck %s --dump-input-on-failure + +func @foo(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} { + %0 = "tf.Fill" (%arg1, %arg0) : (tensor, tensor) -> tensor + %1 = "tf.MatMul" (%0, %arg0) : (tensor, tensor) -> tensor + return %1 : tensor +} + +// CHECK: func @foo([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor) -> tensor +// CHECK: attributes {tf._implements = "fused_tfl_embedding_lookup", tf._reference = "mlir"} +// CHECK: [[VAL_2:%.*]] = "tfl.embedding_lookup"([[VAL_1]], [[VAL_0]]) : (tensor, tensor) -> tensor +// CHECK: return [[VAL_2]] : tensor \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index a3e7c01ca91..bf695e130d0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -35,6 +35,27 @@ func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform // CHECK: return %6 } +// CHECK-LABEL: QuantizeFullyConnected +func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { +^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): + %cst = constant dense<-1.23697901> : tensor<32xf32> + %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> + %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>> + %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32> + %5 = "tfl.fully_connected"(%2, %4, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> + return %6 : tensor<1x112x112x32x!quant.uniform> + +// CHECK: %cst = constant dense<-1.23697901> : tensor<32xf32> +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<32x!quant.uniform>} +// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) +// CHECK: %2 = "tfl.dequantize"(%arg0) +// CHECK: %3 = "tfl.pseudo_qconst"() +// CHECK: %4 = "tfl.dequantize"(%3) +// CHECK: %5 = "tfl.fully_connected"(%2, %4, %1) +// CHECK: %6 = "tfl.quantize"(%5) +// CHECK: return %6 +} // CHECK-LABEL: QuantizeDepthwiseConv2D func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { @@ -74,6 +95,80 @@ func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform } +// CHECK-LABEL: QuantizeMaximum +func @QuantizeMaximum(tensor<1x6x6x16x!quant.uniform>, tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>, %arg1: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.dequantize"(%arg1) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %2 = "tfl.maximum"(%0, %1) : (tensor<1x6x6x16xf32>, tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + return %2 : tensor<1x6x6x16xf32> + +// CHECK: %0 = "tfl.dequantize"(%arg0) +// CHECK: %1 = "tfl.dequantize"(%arg1) +// CHECK: %2 = "tfl.maximum"(%0, %1) +// CHECK: %3 = "tfl.quantize"(%2) +// CHECK: %4 = "tfl.dequantize"(%3) +// CHECK: return %4 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: QuantizeMinimum +func @QuantizeMinimum(tensor<1x6x6x16x!quant.uniform>, tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>, %arg1: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.dequantize"(%arg1) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %2 = "tfl.minimum"(%0, %1) : (tensor<1x6x6x16xf32>, tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + return %2 : tensor<1x6x6x16xf32> + +// CHECK: %0 = "tfl.dequantize"(%arg0) +// CHECK: %1 = "tfl.dequantize"(%arg1) +// CHECK: %2 = "tfl.minimum"(%0, %1) +// CHECK: %3 = "tfl.quantize"(%2) +// CHECK: %4 = "tfl.dequantize"(%3) +// CHECK: return %4 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: QuantizeSlice +func @QuantizeSlice(tensor<2x3x5x!quant.uniform>, tensor<3xi32>, tensor<3xi32>) -> tensor { +^bb0(%arg0: tensor<2x3x5x!quant.uniform>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>): + %0 = "tfl.dequantize"(%arg0) : (tensor<2x3x5x!quant.uniform>) -> tensor<2x3x5xf32> + %1 = "tfl.slice"(%0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor + return %1 : tensor + +// CHECK: %0 = "tfl.dequantize"(%arg0) +// CHECK: %1 = "tfl.slice"(%0, %arg1, %arg2) +// CHECK: %2 = "tfl.quantize"(%1) +// CHECK: %3 = "tfl.dequantize"(%2) +// CHECK: return %3 : tensor +} + +// CHECK-LABEL: QuantizeStridedSlice +func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> { +^bb0(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>): + %0 = "tfl.dequantize"(%arg0) : (tensor<12x2x2x5x!quant.uniform>) -> tensor<12x2x2x5xf32> + %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> + return %1 : tensor<1x2x2x5xf32> + +// CHECK: %0 = "tfl.dequantize"(%arg0) +// CHECK: %1 = "tfl.strided_slice"(%0, %arg1, %arg2, %arg3) +// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x2x2x5x!quant.uniform>} +// CHECK: %3 = "tfl.dequantize"(%2) +// CHECK: return %3 : tensor<1x2x2x5xf32> +} + +// CHECK-LABEL: QuantizePad +func @QuantizePad(tensor<2x1x3x!quant.uniform>, tensor<3x2xi32>) -> tensor { +^bb0(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<3x2xi32>): + %0 = "tfl.dequantize"(%arg0) : (tensor<2x1x3x!quant.uniform>) -> tensor<2x1x3xf32> + %1 = "tfl.pad"(%0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor + return %1 : tensor + +// CHECK: %0 = "tfl.dequantize"(%arg0) +// CHECK: %1 = "tfl.pad"(%0, %arg1) +// CHECK: %2 = "tfl.quantize"(%1) +// CHECK: %3 = "tfl.dequantize"(%2) +// CHECK: return %3 : tensor +} + // CHECK-LABEL: QuantizeReshape2D func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -102,6 +197,31 @@ func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) // CHECK: return %3 : tensor<1x6x6x16xf32> } +// CHECK-LABEL: QuantizeLogistic +func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + return %1 : tensor<1x6x6x16xf32> + +// CHECK: %0 = "tfl.dequantize"(%arg0) +// CHECK: %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> +// CHECK: %2 = "tfl.quantize"(%1) {qtype = tensor<1x6x6x16x!quant.uniform>} +// CHECK: %3 = "tfl.dequantize"(%2) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> +// CHECK: return %3 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: NotQuantizeConcatConstantOperand +func @NotQuantizeConcatConstantOperand(%arg0: tensor<2xf32>) -> tensor<2x2xf32> { + %0 = constant dense<1.0> : tensor<2xf32> + %1 = "tfl.concatenation"(%arg0, %0) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> + +// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[cc:.*]] = "tfl.concatenation"(%arg0, %[[cst]]) +// CHECK-NEXT: return %[[cc]] +} + // CHECK-LABEL: QuantizeConcatOperand0ToAll func @QuantizeConcatOperand0ToAll(tensor<2x!quant.uniform>, tensor<2xf32>) -> tensor<2x2xf32> { ^bb0(%arg0: tensor<2x!quant.uniform>, %arg1: tensor<2xf32>): @@ -248,7 +368,35 @@ func @QuantizeConstant() -> tensor<2x3xf32> { return %cst : tensor<2x3xf32> // CHECK: %cst = constant dense{{.*}}tensor<2x3xf32> -// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform>} +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<2x3x!quant.uniform:f32, 0.023622047244094488:128>>} // CHECK: %1 = "tfl.dequantize"(%0) // CHECK: return %1 : tensor<2x3xf32> -} \ No newline at end of file +} + +// CHECK-LABEL: QuantizeSharedBiases +func @QuantizeSharedBiases( + %arg0: tensor<1x224x224x3x!quant.uniform>, + %arg1: tensor<32x3x3x3x!quant.uniform:f32, 1.0>>, + %arg2: tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> (tensor<1x56x56x32x!quant.uniform>) { + %cst = constant dense<1.0> : tensor<32xf32> + %1 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> + %2 = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform:f32, 1.0>>) -> tensor<32x3x3x3xf32> + %conv1 = "tfl.conv_2d"(%1, %2, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %3 = "tfl.quantize"(%conv1) {qtype = tensor<1x112x112x32xf32>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> + + %4 = "tfl.dequantize"(%3) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> + %5 = "tfl.dequantize"(%arg2) : (tensor<32x3x3x3x!quant.uniform:f32, 2.0>>) -> tensor<32x3x3x3xf32> + %conv2 = "tfl.conv_2d"(%4, %5, %cst) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x32xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x56x56x32xf32> + %6 = "tfl.quantize"(%conv2) {qtype = tensor<1x56x56x32x!quant.uniform>} : (tensor<1x56x56x32xf32>) -> tensor<1x56x56x32x!quant.uniform> + + return %6 : tensor<1x56x56x32x!quant.uniform> + +// CHECK: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32> +// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cst]]) +// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) +// CHECK: %[[cst_0:.*]] = constant dense<1.000000e+00> : tensor<32xf32> +// CHECK: %[[q_0:.*]] = "tfl.quantize"(%[[cst_0]]) +// CHECK: %[[dq_0:.*]] = "tfl.dequantize"(%[[q_0]]) +// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) +// CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]]) +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 0edb4f40cdc..ad11764851c 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -63,102 +63,239 @@ func @fusedBatchNorm(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8 return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> // CHECK-LABEL: fusedBatchNorm -// CHECK:%cst = constant dense<1.000000e-03> : tensor +// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> // variance + epsilon -// CHECK: %0 = "tf.Add"(%arg4, %cst) : (tensor<8xf32>, tensor) -> tensor<8xf32> +// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) // rsqrt(variance + epsilon) -// CHECK: %1 = "tf.Rsqrt"(%0) : (tensor<8xf32>) -> tensor<8xf32> +// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]]) // scale * rsqrt(variance + epsilon) -// CHECK: %2 = "tf.Mul"(%arg1, %1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> +// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]]) // x * scale * rsqrt(variance + epsilon) -// CHECK: %3 = "tf.Mul"(%arg0, %2) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]]) // mean * scale * rsqrt(variance + epsilon) -// CHECK: %4 = "tf.Mul"(%arg3, %2) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> +// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]]) // offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %5 = "tf.Sub"(%arg2, %4) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> +// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]]) // x * scale * rsqrt(variance + epsilon) + // offset - mean * scale * rsqrt(variance + epsilon) -// CHECK: %6 = "tf.Add"(%3, %5) : (tensor<8x8x8x8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) -// CHECK: %7:5 = "tf.FusedBatchNorm"(%6, %arg1, %arg2, %arg3, %arg4) -// CHECK: %8:5 = "tf.FusedBatchNorm"(%7#0, %arg1, %arg2, %arg3, %arg4) +// CHECK: %[[BATCHNORM1:.*]]:5 = "tf.FusedBatchNorm"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) +// CHECK: {{.*}} = "tf.FusedBatchNorm"(%[[BATCHNORM1]]#0, %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) } -func @fakeQuantNotFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) { -^bb0(%arg0: tensor<8x8x8x8xf32>): - %arg1 = constant dense<-0.1> : tensor - %arg2 = constant dense<0.2> : tensor - %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor, tensor) -> tensor<8x8x8x8xf32> - return %0 : tensor<8x8x8x8xf32> +func @fusedBatchNormV3(tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>) { +^bb0(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>): + // OK + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // Unsupported training + %1:6 = "tf.FusedBatchNormV3"( %0#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // Use other output + %2:6 = "tf.FusedBatchNormV3"( %1#0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", U = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + + return %2, %2#1 : tensor<8x8x8x8xf32>, tensor<8xf32> + +// CHECK-LABEL: fusedBatchNormV3 +// CHECK: %[[CONSTANT:.*]] = constant dense<1.000000e-03> +// variance + epsilon +// CHECK: %[[ADD1:.*]] = "tf.Add"(%[[ARG4:.*]], %[[CONSTANT]]) +// rsqrt(variance + epsilon) +// CHECK: %[[RSQRT:.*]] = "tf.Rsqrt"(%[[ADD1]]) +// scale * rsqrt(variance + epsilon) +// CHECK: %[[MUL1:.*]] = "tf.Mul"(%[[ARG1:.*]], %[[RSQRT]]) +// x * scale * rsqrt(variance + epsilon) +// CHECK: %[[MUL2:.*]] = "tf.Mul"(%[[ARG0:.*]], %[[MUL1]]) +// mean * scale * rsqrt(variance + epsilon) +// CHECK: %[[MUL3:.*]] = "tf.Mul"(%[[ARG3:.*]], %[[MUL1]]) +// offset - mean * scale * rsqrt(variance + epsilon) +// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[ARG2:.*]], %[[MUL3]]) +// x * scale * rsqrt(variance + epsilon) + +// offset - mean * scale * rsqrt(variance + epsilon) +// CHECK: %[[ADD2:.*]] = "tf.Add"(%[[MUL2]], %[[SUB]]) + +// CHECK: %[[BATCHNORM1:.*]]:6 = "tf.FusedBatchNormV3"(%[[ADD2]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) +// CHECK: %[[BATCHNORM2:.*]]:6 = "tf.FusedBatchNormV3"(%[[BATCHNORM1]]#0, %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) +} + +// CHECK-LABEL: fakeQuantForActivation +func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) { +^bb0(%arg0: tensor<8xf32>): + %arg1 = constant dense<0.0> : tensor + %arg2 = constant dense<255.0> : tensor + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor, tensor) -> tensor<8xf32> + return %0 : tensor<8xf32> + +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) +// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} +// CHECK: %2 = "tfl.dequantize"(%1) +// CHECK: return %2 +} + +// CHECK-LABEL: fakeQuantForActivationNoDuplication +func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform>) { +^bb0(%arg0: tensor<8xf32>): + %arg1 = constant dense<0.0> : tensor + %arg2 = constant dense<255.0> : tensor + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor, tensor) -> tensor<8xf32> + %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} : (tensor<8xf32>) -> tensor<8x!quant.uniform> + return %1 : tensor<8x!quant.uniform> -// CHECK-LABEL: fakeQuantNotFollowedByQuant // CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64} -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform>} -// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform>) -// CHECK: return %2 : tensor<8x8x8x8xf32> +// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x!quant.uniform>} +// CHECK: return %1 } -func @fakeQuantFollowedByQuant(tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>) { -^bb0(%arg0: tensor<8x8x8x8xf32>): - %arg1 = constant dense<-0.1> : tensor - %arg2 = constant dense<0.2> : tensor - %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor, tensor) -> tensor<8x8x8x8xf32> - %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform>} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform> - %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform>) -> tensor<8x8x8x8xf32> - return %2 : tensor<8x8x8x8xf32> +// CHECK-LABEL: fakeQuantFolded +func @fakeQuantFolded() -> (tensor<8xf32>) { + %in = constant dense<0.0> : tensor<8xf32> + %min = constant dense<0.0> : tensor + %max = constant dense<255.0> : tensor + %mini = "tf.Identity"(%min) : (tensor) -> tensor + %maxi = "tf.Identity"(%max) : (tensor) -> tensor + %rst = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor, tensor) -> tensor<8xf32> + return %rst : tensor<8xf32> -// CHECK-LABEL: fakeQuantFollowedByQuant -// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) {narrow_range = false, num_bits = 3 : i64} -// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<8x8x8x8x!quant.uniform>} -// CHECK: %2 = "tfl.dequantize"(%1) : (tensor<8x8x8x8x!quant.uniform>) -// CHECK: return %2 : tensor<8x8x8x8xf32> +// CHECK: %cst = constant dense<0.000000e+00> : tensor<8xf32> +// CHECK: %0 = "tfl.quantize"(%cst) {qtype = tensor<8x!quant.uniform>} +// CHECK: %1 = "tfl.dequantize"(%0) +// CHECK: return %1 : tensor<8xf32> } -func @fakeQuantVarsNotConst(tensor<8x8x8x8xf32>, tensor, tensor) -> (tensor<8x8x8x8xf32>) { -^bb0(%arg0: tensor<8x8x8x8xf32>, %arg3: tensor, %arg4: tensor): - %1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8x8x8x8xf32>, tensor, tensor) -> tensor<8x8x8x8xf32> - return %1 : tensor<8x8x8x8xf32> +// CHECK-LABEL: fakeQuantNotFolded +func @fakeQuantNotFolded(tensor<8xf32>, tensor, tensor) -> (tensor<8xf32>) { +^bb0(%arg0: tensor<8xf32>, %arg3: tensor, %arg4: tensor): + %1 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg3, %arg4) {num_bits = 3, narrow_range = false} : (tensor<8xf32>, tensor, tensor) -> tensor<8xf32> + return %1 : tensor<8xf32> -// CHECK-LABEL: fakeQuantVarsNotConst -// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64} -// CHECK: return %0 : tensor<8x8x8x8xf32> +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) +// CHECK: return %0 : tensor<8xf32> } -func @fakeQuantFollowedByTranspose(tensor<3x3x3x16xf32>, tensor, tensor) -> (tensor<16x3x3x3xf32>) { -^bb0(%arg0: tensor<3x3x3x16xf32>, %arg1: tensor, %arg2: tensor): - %cst_0 = constant dense<[3, 0, 1, 2]> : tensor<4xi32> - %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor, tensor) -> tensor<3x3x3x16xf32> - %1 = "tf.Transpose"(%0, %cst_0): (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32> - return %1 : tensor<16x3x3x3xf32> - // CHECK-LABEL: fakeQuantFollowedByTranspose -// CHECK: %cst = constant dense<[3, 0, 1, 2]> : tensor<4xi32> -// CHECK: %0 = "tf.Transpose"(%arg0, %cst) : (tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor<16x3x3x3xf32> -// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64} -// CHECK: return %1 : tensor<16x3x3x3xf32> -} +func @fakeQuantFollowedByTranspose(tensor<1x2xf32>, tensor, tensor) -> (tensor<2x1xf32>) { +^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor, %arg2: tensor): + %cst_0 = constant dense<[1, 0]> : tensor<2xi32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<1x2xf32>, tensor, tensor) -> tensor<1x2xf32> + %1 = "tf.Transpose"(%0, %cst_0): (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32> + return %1 : tensor<2x1xf32> -func @fakeQuantFollowedByReshape(tensor<3x3x3x4xf32>, tensor, tensor) -> (tensor<1x3x3x12xf32>) { -^bb0(%arg0: tensor<3x3x3x4xf32>, %arg1: tensor, %arg2: tensor): - %cst_0 = constant dense<[1, 3, 3, 12]> : tensor<4xi64> - %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x4xf32>, tensor, tensor) -> tensor<3x3x3x4xf32> - %1 = "tf.Reshape"(%0, %cst_0) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32> - return %1 : tensor<1x3x3x12xf32> +// CHECK: %cst = constant +// CHECK: %0 = "tf.Transpose"(%arg0, %cst) +// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) +// CHECK: return %1 +} // CHECK-LABEL: fakeQuantFollowedByReshape -// CHECK: %cst = constant dense<[1, 3, 3, 12]> : tensor<4xi64> -// CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<3x3x3x4xf32>, tensor<4xi64>) -> tensor<1x3x3x12xf32> -// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) {narrow_range = false, num_bits = 3 : i64} -// CHECK: return %1 : tensor<1x3x3x12xf32> +func @fakeQuantFollowedByReshape(tensor<1x2xf32>, tensor, tensor) -> (tensor<2x1xf32>) { +^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor, %arg2: tensor): + %cst_0 = constant dense<[2, -1]> : tensor<2xi64> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %arg1, %arg2) {num_bits = 3, narrow_range = false} : (tensor<1x2xf32>, tensor, tensor) -> tensor<1x2xf32> + %1 = "tf.Reshape"(%0, %cst_0) : (tensor<1x2xf32>, tensor<2xi64>) -> tensor<2x1xf32> + return %1 : tensor<2x1xf32> + +// CHECK: %cst = constant +// CHECK: %0 = "tf.Reshape"(%arg0, %cst) +// CHECK-SAME: tensor<2x1xf32> +// CHECK: %1 = "tf.FakeQuantWithMinMaxVars"(%0, %arg1, %arg2) +// CHECK: return %1 } -func @identity(tensor<10xi32>) -> tensor<10xi32> { -^bb0(%arg0: tensor<10xi32>): +// CHECK-LABEL: QDQsFollowedByTranspose +func @QDQsFollowedByTranspose(tensor<1x2xf32>) -> (tensor<2x1xf32>) { +^bb0(%arg0: tensor<1x2xf32>): + %cst_0 = constant dense<[1, 0]> : tensor<2xi32> + %0 = "tfl.quantize"(%arg0){qtype = tensor<1x2x!quant.uniform>}: (tensor<1x2xf32>) -> (tensor<1x2x!quant.uniform>) + %1 = "tfl.dequantize"(%0): (tensor<1x2x!quant.uniform>) -> (tensor<1x2xf32>) + %2 = "tf.Transpose"(%1, %cst_0): (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32> + return %2 : tensor<2x1xf32> + +// CHECK: %cst = constant +// CHECK: %0 = "tf.Transpose" +// CHECK-SAME: -> tensor<2x1xf32> +// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<2x1x!quant.uniform>} +// CHECK-SAME: -> tensor<2x1x!quant.uniform> +// CHECK: %2 = "tfl.dequantize"(%1) +// CHECK-SAME: -> tensor<2x1xf32> +// CHECK: return %2 +} + +// CHECK-LABEL: QDQFollowedByReshape +func @QDQFollowedByReshape(tensor<1x2xf32>) -> (tensor<2x1xf32>) { +^bb0(%arg0: tensor<1x2xf32>): + %cst_0 = constant dense<[2, 1]> : tensor<2xi32> + %0 = "tfl.quantize"(%arg0){qtype = tensor<1x2x!quant.uniform>}: (tensor<1x2xf32>) -> (tensor<1x2x!quant.uniform>) + %1 = "tfl.dequantize"(%0): (tensor<1x2x!quant.uniform>) -> (tensor<1x2xf32>) + %2 = "tf.Reshape"(%1, %cst_0): (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x1xf32> + return %2 : tensor<2x1xf32> + +// CHECK: %cst = constant +// CHECK: %0 = "tf.Reshape" +// CHECK-SAME: -> tensor<2x1xf32> +// CHECK: %1 = "tfl.quantize"(%0) {qtype = tensor<2x1x!quant.uniform>} +// CHECK-SAME: -> tensor<2x1x!quant.uniform> +// CHECK: %2 = "tfl.dequantize"(%1) +// CHECK-SAME: -> tensor<2x1xf32> +// CHECK: return %2 +} + +// CHECK-LABEL: QDQFollowedByRank +func @QDQFollowedByRank(%arg0: tensor<1x2xf32>) -> (tensor) { + %0 = "tfl.quantize"(%arg0){qtype = tensor<1x2x!quant.uniform>}: (tensor<1x2xf32>) -> (tensor<1x2x!quant.uniform>) + %1 = "tfl.dequantize"(%0): (tensor<1x2x!quant.uniform>) -> (tensor<1x2xf32>) + %2 = "tf.Rank"(%1): (tensor<1x2xf32>) -> tensor + return %2 : tensor + +// CHECK-NEXT: %[[R:.*]] = "tf.Rank"(%arg0) +// CHECK-NEXT: return %[[R]] : tensor +} + +// CHECK-LABEL: fakeQuantWithConv2D +func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) { +^bb0(%arg: tensor<256x32x32x3xf32>) : + %in = constant dense<0.0> : tensor<3x3x3x16xf32> + %min = constant dense<0.0> : tensor + %max = constant dense<255.0> : tensor + %mini = "tf.Identity"(%min) : (tensor) -> tensor + %maxi = "tf.Identity"(%max) : (tensor) -> tensor + %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor, tensor) -> tensor<3x3x3x16xf32> + %rst = "tf.Conv2D"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> + return %rst : tensor<256x30x30x16xf32> + +// CHECK: %cst = constant dense<0.000000e+00> : tensor<16xf32> +// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<16x3x3x3xf32> +// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<16x3x3x3x!quant.uniform>} +// CHECK: %1 = "tfl.dequantize"(%0) +// CHECK: %2 = "tfl.conv_2d"(%arg0, %1, %cst) +// CHECK: return %2 +} + +// CHECK-LABEL: fakeQuantWithDepthwiseConv2D +func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) { +^bb0(%arg: tensor<256x32x32x3xf32>) : + %in = constant dense<0.0> : tensor<3x3x3x16xf32> + %min = constant dense<0.0> : tensor + %max = constant dense<255.0> : tensor + %mini = "tf.Identity"(%min) : (tensor) -> tensor + %maxi = "tf.Identity"(%max) : (tensor) -> tensor + %fq = "tf.FakeQuantWithMinMaxVars"(%in, %mini, %maxi) {num_bits = 3, narrow_range = false} : (tensor<3x3x3x16xf32>, tensor, tensor) -> tensor<3x3x3x16xf32> + %rst = "tf.DepthwiseConv2dNative"(%arg, %fq) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> + return %rst : tensor<256x30x30x16xf32> + +// CHECK: %cst = constant dense<0.000000e+00> : tensor<48xf32> +// CHECK: %cst_0 = constant dense<0.000000e+00> : tensor<1x3x3x48xf32> +// CHECK: %0 = "tfl.quantize"(%cst_0) {qtype = tensor<1x3x3x48x!quant.uniform>} +// CHECK: %1 = "tfl.dequantize"(%0) +// CHECK: %2 = "tfl.depthwise_conv_2d"(%arg0, %1, %cst) +// CHECK: return %2 +} + +func @identity(%arg0: tensor<10xi32>, %arg1: tensor<20xi32>, %arg2: tensor<30xi32>) -> (tensor<10xi32>, tensor<20xi32>, tensor<30xi32>) { %0 = "tf.Identity"(%arg0) : (tensor<10xi32>) -> tensor<10xi32> - return %0: tensor<10xi32> + %1:2 = "tf.IdentityN"(%arg1,%arg2) : (tensor<20xi32>, tensor<30xi32>) -> (tensor<20xi32>, tensor<30xi32>) + return %0, %1#0, %1#1: tensor<10xi32>, tensor<20xi32>, tensor<30xi32> // CHECK-LABEL: identity -// CHECK: return %arg0 +// CHECK: return %arg0, %arg1, %arg2 } @@ -195,3 +332,19 @@ func @matmulNoTransposeB(%arg0: tensor<1x1280xf32>, %arg1: tensor<1280x1000xf32> // CHECK: %7 = "tf.Transpose"(%arg1, %6) : (tensor<1280x1000xf32>, tensor) -> tensor<*xf32> // CHECK: %8 = "tf.MatMul"(%3, %7) {transpose_a = false, transpose_b = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<1x1000xf32> } + +func @snapshot(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %0 = "tf.Snapshot"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> + // Should be converted to Identity and then from Identity to value + // CHECK-LABEL: snapshot + // CHECK: return %arg0 : tensor<3xi32> +} + +func @stop_gradient(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %0 = "tf.StopGradient"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> + // Should be converted to Identity and then from Identity to value + // CHECK-LABEL: stop_gradient + // CHECK: return %arg0 : tensor<3xi32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index b3b439b2b8a..dc24b1004d7 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -82,6 +82,35 @@ func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { +^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): + %cst = constant dense<-1.23697901> : tensor<32xf32> + %2 = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform>) -> tensor<1x224x224x3xf32> + %3 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} : () -> tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>> + %4 = "tfl.dequantize"(%3) : (tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>) -> tensor<32x3x3x3xf32> + %5 = "tfl.fully_connected"(%2, %4, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> + %6 = "tfl.quantize"(%5) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> + return %6 : tensor<1x112x112x32x!quant.uniform> + +// CHECK: %0 = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<-7254> : tensor<32xi32>} +// CHECK: %1 = "tfl.pseudo_qconst"() {qtype = tensor<32x3x3x3x!quant.uniform:f32, 0.021826678373682216:151>>, value = dense<-76> : tensor<32x3x3x3xi8>} +// CHECK: %2 = "tfl.fully_connected"(%arg0, %1, %0) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} +// CHECK: return %2 +} + +// CHECK-LABEL: QuantizeNoBiasFullyConnected +func @QuantizeNoBiasFullyConnected(%arg0: tensor<3x!quant.uniform>, %arg1: tensor<3x!quant.uniform:f32, 1.0>>, %arg2: none) -> tensor<3x!quant.uniform> { + %0 = "tfl.dequantize"(%arg0) : (tensor<3x!quant.uniform>) -> tensor<3xf32> + %1 = "tfl.dequantize"(%arg1) : (tensor<3x!quant.uniform:f32, 1.0>>) -> tensor<3xf32> + %2 = "tfl.fully_connected"(%0, %1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<3xf32>, tensor<3xf32>, none) -> tensor<3xf32> + %3 = "tfl.quantize"(%2) {qtype = tensor<3x!quant.uniform>} : (tensor<3xf32>) -> tensor<3x!quant.uniform> + return %3 : tensor<3x!quant.uniform> + +// CHECK-NEXT: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %arg2) +// CHECK-NEXT: return %[[fc]] +} + // CHECK-LABEL: QuantizeAveragePool2D func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -118,6 +147,18 @@ func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) // CHECK: return %1 : tensor<1x6x6x16xf32> } +// CHECK-LABEL: QuantizeLogistic +func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { +^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): + %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> + %1 = "tfl.logistic"(%0) : (tensor<1x6x6x16xf32>) -> tensor<1x6x6x16xf32> + return %1 : tensor<1x6x6x16xf32> + +// CHECK: %0 = "tfl.logistic"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) +// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x6x6x16x!quant.uniform>) +// CHECK: return %1 +} + // CHECK-LABEL: QuantizeAdd func @QuantizeAdd(tensor<1x56x56x24x!quant.uniform>, tensor<1x56x56x24x!quant.uniform>) -> tensor<1x56x56x24x!quant.uniform> { ^bb0(%arg0: tensor<1x56x56x24x!quant.uniform>, %arg1: tensor<1x56x56x24x!quant.uniform>): @@ -167,4 +208,16 @@ func @QuantizeMaxPool2D(tensor<1x6x6x16x!quant.uniform // CHECK: %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x1x1x16x!quant.uniform> // CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x1x1x16x!quant.uniform>) -> tensor<1x1x1x16xf32> // CHECK: return %1 : tensor<1x1x1x16xf32> +} + +// CHECK-LABEL: QuantizeSplit +func @QuantizeSplit(%arg: tensor<4x!quant.uniform>, %cst: tensor) -> (tensor<2x!quant.uniform>,tensor<2x!quant.uniform>) { + %0 = "tfl.dequantize"(%arg) : (tensor<4x!quant.uniform>) -> tensor<4xf32> + %1:2 = "tfl.split"(%cst, %0) {num_splits = 2 : i32} : (tensor, tensor<4xf32>) -> (tensor<2xf32>, tensor<2xf32>) + %2 = "tfl.quantize"(%1#0) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> + %3 = "tfl.quantize"(%1#1) {qtype = tensor<2x!quant.uniform>} : (tensor<2xf32>) -> tensor<2x!quant.uniform> + return %2, %3 : tensor<2x!quant.uniform>, tensor<2x!quant.uniform> + +// CHECK: %0:2 = "tfl.split"(%arg1, %arg0) {num_splits = 2 : i32} : (tensor, tensor<4x!quant.uniform>) +// CHECK: return %0#0, %0#1 } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir new file mode 100644 index 00000000000..95844ccad1c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir @@ -0,0 +1,21 @@ +// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s --dump-input-on-failure + +func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { + %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +func @bar(%arg0: tensor<2x4xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x4xf32> { + %0 = "tfl.pseudo_input"(%arg0) : (tensor<2x4xf32>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> +} + +func @foobar(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { + %0 = "tfl.pseudo_input"(%arg0) : (tensor<1x4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// CHECK-DAG: func @main +// CHECK-DAG: func @foobar +// CHECK-NOT: func @foo +// CHECK-NOT: func @bar \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc new file mode 100644 index 00000000000..25d15614ef6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" + +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Transforms/Passes.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" + +namespace mlir { +/// Create a pass to convert from the TFExecutor to the TF control dialect. +std::unique_ptr CreateTFExecutorToControlDialectConversion(); +} // namespace mlir + +namespace tensorflow { + +bool ShouldRunQuantizePasses(mlir::ModuleOp m) { + if (mlir::FuncOp main_fn = m.lookupSymbol("main")) { + return main_fn.getAttrOfType("tf.quantize") != + mlir::Attribute(); + } + return false; +} + +void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, + mlir::PassManager* pass_manager) { + pass_manager->addPass(mlir::CreateTFExecutorToControlDialectConversion()); + pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); + // Ophint extraction will happen after island extraction pass. + pass_manager->addPass(mlir::TFL::CreateExtractOphintPass()); + // Convert composite op pass will happen after ophint extraction pass. + pass_manager->addPass(mlir::TFL::CreateLegalizeOphintFuncOpPass()); + + if (pass_config.lower_tensor_list_ops) { + // Execute this pass before `CanonicalizerPass` in case some TensorList + // ops are constant folded into variant types. + // TODO(b/137125056): Move this pass after `CanonicalizerPass` after we + // handle constant ops that produce `TensorList`. + // TODO(haoliang): Add this pass by default. + pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); + } + + // TODO(jpienaar): Revise post dialect constants. + pass_manager->addPass(mlir::TF::CreateDecodeConstantPass()); + // Canonicalization includes const folding, which is utilized here to optimize + // away ops that can't get constant folded after PrepareTF pass. For example, + // tf.Conv2D is split into tf.Transpose and tfl.Conv2D. + pass_manager->addPass(mlir::createCanonicalizerPass()); + + // The below passes only make sense if Builtin TFLite ops are enabled + // for emission. + if (pass_config.emit_builtin_tflite_ops) { + // Prepare for TFLite dialect, rerun canonicalization, and then legalize to + // the TFLite dialect. + pass_manager->addPass(mlir::TFL::CreatePrepareTFPass()); + pass_manager->addPass(mlir::createCanonicalizerPass()); + pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); + pass_manager->addPass(mlir::TFL::CreateOptimizePass()); + if (pass_config.run_quantize) { + pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass( + /*quantize_sign=*/false)); + pass_manager->addPass(mlir::TFL::CreateQuantizePass()); + pass_manager->addPass(mlir::TFL::CreatePostQuantizePass( + pass_config.emit_quant_adaptor_ops)); + } + pass_manager->addPass(mlir::createCanonicalizerPass()); + pass_manager->addPass(mlir::createCSEPass()); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.h b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h new file mode 100644 index 00000000000..653e4ec5245 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ + +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" + +namespace tensorflow { + +// Quantization passess will run only when the user specifies a quantized type +// in the `-tf-inference-type` flag, which is converted to the function +// attribute "tf.quantize" by the importer module. +// TODO(fengliuai): switch to the cmd flag once the flags are moved to this +// file with main method. +bool ShouldRunQuantizePasses(mlir::ModuleOp m); + +// Add the TF to TFLite passes, specified in the pass_config, into a +// pass_manager. +void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, + mlir::PassManager* pass_manager); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TFL_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 9656abb1611..33044a63271 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -16,7 +16,6 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" -#include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir @@ -24,7 +23,10 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_translate.h" +#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" @@ -37,8 +39,8 @@ using mlir::FuncOp; using mlir::MLIRContext; using mlir::ModuleOp; using stream_executor::port::StatusOr; -using tensorflow::Status; +// Debugging flag to print function mapping in the flatbuffer. // NOLINTNEXTLINE static llvm::cl::opt print_function_result_mapping( "print-function-result-mapping", @@ -99,9 +101,8 @@ static int PrintFunctionResultMapping(const std::string &result, } int main(int argc, char **argv) { - llvm::PrettyStackTraceProgram x(argc, argv); // TODO(jpienaar): Revise the command line option parsing here. - llvm::InitLLVM y(argc, argv); + tensorflow::InitMlir y(&argc, &argv); // TODO(antiagainst): We are pulling in multiple transformations as follows. // Each transformation has its own set of command-line options; options of one @@ -112,14 +113,9 @@ int main(int argc, char **argv) { // We need to disable duplicated ones to provide a cleaner command-line option // interface. That also means we need to relay the value set in one option to // all its aliases. - llvm::cl::ParseCommandLineOptions( argc, argv, "TF GraphDef to TFLite FlatBuffer converter\n"); - // TODO(ashwinm): Enable command line parsing for both sides. - int fake_argc = 1; - tensorflow::port::InitMain(argv[0], &fake_argc, &argv); - MLIRContext context; llvm::SourceMgr source_mgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); @@ -135,11 +131,22 @@ int main(int argc, char **argv) { // message. So we can just return here. if (!module.ok()) return kTrFailure; + mlir::PassManager pm; + bool run_quantize = + tensorflow::ShouldRunQuantizePasses(module.ValueOrDie().get()); + mlir::TFL::PassConfig pass_config; + pass_config.emit_builtin_tflite_ops = emit_builtin_tflite_ops; + pass_config.emit_quant_adaptor_ops = emit_quant_adaptor_ops; + pass_config.lower_tensor_list_ops = lower_tensor_list_ops; + pass_config.run_quantize = run_quantize; + + tensorflow::AddTFToTFLConversionPasses(pass_config, &pm); + std::string result; - auto status = tensorflow::ConvertTFControlFlowToTFLOrFlatbuffer( + auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.ValueOrDie().get(), output_mlir, emit_builtin_tflite_ops, emit_select_tf_ops, emit_custom_ops, emit_quant_adaptor_ops, - lower_tensor_list_ops, &result); + lower_tensor_list_ops, &result, &pm); if (!status.ok()) return kTrFailure; auto output = mlir::openOutputFile(output_file_name); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index bc2f36beb4d..fed9f1739ad 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -31,6 +31,11 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/errors.h" +namespace mlir { +/// Create a pass to convert from the TFExecutor to the TF control dialect. +std::unique_ptr CreateTFExecutorToControlDialectConversion(); +} // namespace mlir + namespace tensorflow { using mlir::MLIRContext; @@ -79,79 +84,23 @@ StatusOr LoadFromGraphdefOrMlirSource( return tensorflow::GraphdefToSplattedMlirTranslateFunction( input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, inference_type, min_values, max_values, - prune_unused_nodes, context); + prune_unused_nodes, /*convert_legacy_fed_inputs=*/true, + /*graph_as_function=*/false, context); } return tensorflow::GraphdefToMlirTranslateFunction( input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, inference_type, min_values, max_values, prune_unused_nodes, - context); + /*convert_legacy_fed_inputs=*/true, /*graph_as_function=*/false, context); } -bool ShouldRunQuantizePasses(mlir::ModuleOp m) { - if (mlir::FuncOp main_fn = m.lookupSymbol("main")) { - return main_fn.getAttrOfType("tf.quantize") != - mlir::Attribute(); - } - return false; -} - -void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize, - bool emit_quant_adaptor_ops, - bool lower_tensor_list_ops, - mlir::PassManager *pass_manager) { - pass_manager->addPass(mlir::TFControlFlow::CreateRaiseTFControlFlowPass()); - - if (lower_tensor_list_ops) { - // Execute this pass before `CanonicalizerPass` in case some TensorList - // ops are constant folded into variant types. - // TODO(b/137125056): Move this pass after `CanonicalizerPass` after we - // handle constant ops that produce `TensorList`. - // TODO(haoliang): Add this pass by default. - pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass()); - } - - // TODO(jpienaar): Revise post dialect constants. - pass_manager->addPass(mlir::TF::CreateDecodeConstantPass()); - // Canonicalization includes const folding, which is utilized here to optimize - // away ops that can't get constant folded after PrepareTF pass. For example, - // tf.Conv2D is split into tf.Transpose and tfl.Conv2D. - pass_manager->addPass(mlir::createCanonicalizerPass()); - - // The below passes only make sense if Builtin TFLite ops are enabled - // for emission. - if (emit_builtin_tflite_ops) { - // Prepare for TFLite dialect, rerun canonicalization, and then legalize to - // the TFLite dialect. - pass_manager->addPass(mlir::TFL::CreatePrepareTFPass()); - pass_manager->addPass(mlir::createCanonicalizerPass()); - pass_manager->addPass(mlir::TFL::CreateLegalizeTFPass()); - pass_manager->addPass(mlir::TFL::CreateOptimizePass()); - if (run_quantize) { - pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass( - /*quantize_sign=*/false)); - pass_manager->addPass(mlir::TFL::CreateQuantizePass()); - pass_manager->addPass( - mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); - } - pass_manager->addPass(mlir::createCanonicalizerPass()); - pass_manager->addPass(mlir::createCSEPass()); - } -} - -Status ConvertTFControlFlowToTFLOrFlatbuffer( +Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops, - bool lower_tensor_list_ops, std::string *result) { + bool lower_tensor_list_ops, std::string *result, + mlir::PassManager *pass_manager) { mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), /*propagate=*/true); - mlir::PassManager pm; - bool run_quantize = ShouldRunQuantizePasses(module); - - AddTFToTFLConversionPasses(emit_builtin_tflite_ops, run_quantize, - emit_quant_adaptor_ops, lower_tensor_list_ops, - &pm); - - if (failed(pm.run(module))) { + if (failed(pass_manager->run(module))) { return statusHandler.ConsumeStatus(); } diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 68ab674872f..2979e4617b0 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -41,37 +41,16 @@ LoadFromGraphdefOrMlirSource( bool prune_unused_nodes, llvm::SourceMgr* source_mgr, mlir::MLIRContext* context); -// Quantization passess will run only when the user specifies a quantized type -// in the `-tf-inference-type` flag, which is converted to the function -// attribute "tf.quantize" by the importer module. -// TODO(fengliuai): switch to the cmd flag once the flags are moved to this -// file with main method. -bool ShouldRunQuantizePasses(mlir::ModuleOp m); - -// Add the MLIR passes that convert TF control flow dialect to TF Lite dialect -// to a MLIR `pass_manager`. These passes first raise the control flow in the TF -// control flow dialect, decode the constant tensors, and then legalize the -// module to TF Lite dialect with some optimizations afterwards. -// If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be -// added, which produces TF Lite ops. If `run_quantize` is true, quantization -// passes will be added. If `emit_quant_adaptor_ops` is true, Quantize and -// Dequantize ops are added to the inputs and outputs of the quantized model. -// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic -// TF ops before legalization to TF Lite dialect. -void AddTFToTFLConversionPasses(bool emit_builtin_tflite_ops, bool run_quantize, - bool emit_quant_adaptor_ops, - bool lower_tensor_list_ops, - mlir::PassManager* pass_manager); - -// Taking a MLIR module in TF control flow dialect and a set of parameters, +// Taking a MLIR module in TF executor dialect and a set of parameters, // applies a set of passes to convert the module to TF Lite dialect and // serializes the result to a string. Depending on an attribute in the module // main function, Quantization is applied. If `export_to_mlir` is true, the // result is exported in MLIR text format, otherwise exported in flat buffer. -Status ConvertTFControlFlowToTFLOrFlatbuffer( +Status ConvertTFExecutorToTFLOrFlatbuffer( mlir::ModuleOp module, bool export_to_mlir, bool emit_builtin_tflite_ops, bool emit_select_tf_ops, bool emit_custom_ops, bool emit_quant_adaptor_ops, - bool lower_tensor_list_ops, std::string* result); + bool lower_tensor_list_ops, std::string* result, + mlir::PassManager* pass_manager); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_LITE_TF_TO_TFL_FLATBUFFER_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc new file mode 100644 index 00000000000..b6a898e6cda --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/extract_ophint.cc @@ -0,0 +1,595 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" +#include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" + +namespace mlir { +namespace TFL { +namespace { + +constexpr char kTfLiteFunctionName[] = "_tflite_function_name"; +constexpr char kTfLiteFunctionUUID[] = "_tflite_function_uuid"; +constexpr char kTfLiteFunctionInputIndex[] = "_tflite_function_input_index"; +constexpr char kTfLiteFunctionOutputIndex[] = "_tflite_function_output_index"; +constexpr char kTfLiteFunctionSortIndex[] = "_tflite_function_sort_index"; +constexpr char kTfLiteFunctionAggregate[] = "_tflite_function_aggregate"; + +constexpr char kStrategyNone[] = "None"; +constexpr char kStrategyStack[] = "stack"; +constexpr char kStrategyFirst[] = "first"; +constexpr char kStrategyLast[] = "last"; + +// A Ophinted op typically looks like below" +// +// InputOp1 InputOp2 InputOp3 +// / \ | | +// val1 val2 val3 val4 +// | | | | +// identOp1 identOp2 identOp3 identOp4 +// \ | | / +// \ | | / +// .... a bunch of operations (needs to be fused) ... +// / \ +// / \ +// identOp1 (output) identOp2 (output) +// | | +// Other ops Other ops +// +// +// In this pass, we are trying to convert them into the following format: +// +// || +// || +// \ / +// +// InputOp1 InputOp2 InputOp3 +// / \ | / +// val1 val2 val3 val4 +// \ | | / +// PackOp | / +// \ | | / +// \ | | / +// Call funcOp (fusedOp - name like 'UnidirectionalSequenceRNN') +// (The funcOp will be inserted at the bottom of the module, also +// . note every funcOp will be unique.) +// | +// UnpackOp +// / \ +// / \ +// Other ops Other ops +struct OphintCompositeOp { + // OphintCompositeOp is a conceptually "composite op" which will be converted + // to a "fused op" later. + // + // As a "composite op", it has "inputs" and "outputs", and all the inputs + // and outputs are annotated by special-annotated identity ops. + // + // All inputs and outputs need to be processed based on different strategies, + // See all the different strategies under + // tensorflow/lite/python/op_hint.py + // + // For example, "stack" strategy means we need to pack the inputs together + // or unpack the outputs. + public: + OphintCompositeOp(StringRef uuid, StringRef function_name) + : uuid(uuid), function_name(function_name) {} + + void AddInput(int index, Operation* op, StringRef aggregation, + int sort_index) { + auto it = inputs.find(index); + if (it == inputs.end()) { + AggregatedOperand operand; + operand.aggregation = aggregation; + it = inputs.insert({index, operand}).first; + } + // TODO(renjieliu): check aggregation strategy stays the same. + // Also needs to make sure if aggregation strategy is "None" we should not + // have more than one op. + it->second.ops[sort_index] = op; + } + + void AddOutput(int index, Operation* op, llvm::StringRef aggregation, + int sort_index) { + auto it = outputs.find(index); + if (it == outputs.end()) { + AggregatedOperand operand; + operand.aggregation = aggregation; + it = outputs.insert({index, operand}).first; + } + // TODO(renjieliu): check aggregation strategy stays the same. + // Also needs to make sure if aggregation strategy is "None" we should not + // have more than one op. + it->second.ops[sort_index] = op; + } + + std::vector GetAllInputOps() { + std::vector all_input_ops; + for (const auto& kv : inputs) { + if (kv.second.aggregation == kStrategyFirst) { + all_input_ops.push_back(kv.second.ops.at(0)); + continue; + } + for (const auto& operandKv : kv.second.ops) { + all_input_ops.push_back(operandKv.second); + } + } + return all_input_ops; + } + + std::vector GetAllOutputOps() { + std::vector all_output_ops; + for (const auto& kv : outputs) { + for (const auto& operand_kv : kv.second.ops) { + all_output_ops.push_back(operand_kv.second); + } + } + return all_output_ops; + } + + // This function will process the aggregated inputs based on different + // strategies like "first", "last", "stack". + std::map GetAggregatedInputs(OpBuilder* builder) { + std::map aggregated_inputs; + for (const auto& kv : inputs) { + Value* op_input = nullptr; + const AggregatedOperand& operand = kv.second; + // Dealiong with "stack" strategy: + // This breaks into two parts: + // 1) If the ops only has one element, we only add a reshape op to expand + // the dim. + // 2) If the ops contain more than one element, we need to append a + // pack_op after the input ops. + if (operand.aggregation == kStrategyStack) { + if (operand.ops.size() == 1) { + // If ops size is 1, it will be simply expanding dimensions at dim 0. + Operation* current_identity_op = operand.ops.begin()->second; + Value* input = current_identity_op->getOperand(0); + RankedTensorType input_type = + input->getType().cast(); + // The Reshape will be {1, (original_shape)} + SmallVector reshape_op_shape; + reshape_op_shape.push_back(1); + for (const auto& dim : input_type.getShape()) { + reshape_op_shape.push_back(dim); + } + auto reshape_output_type = builder->getTensorType( + reshape_op_shape, input_type.getElementType()); + Operation* first_use = current_identity_op->getNextNode(); + builder->setInsertionPoint(first_use); + Operation* reshape = builder->create( + first_use->getLoc(), reshape_output_type, input); + op_input = reshape->getResult(0); + + } else { + // Insert a pack op to pack all the inputs together. + std::vector pack_input_operands; + std::vector packed_input_consumers; + for (int i = 0, e = operand.ops.size(); i < e; ++i) { + pack_input_operands.push_back(operand.ops.at(i)->getOperand(0)); + packed_input_consumers.push_back(operand.ops.at(i)->getResult(0)); + } + // Find the first op that consumes the last value of the aggregated + // inputs. + Operation* first_use = *(packed_input_consumers.back()->user_begin()); + // The pack reshape will be {N, (original_shape)} + SmallVector pack_shape; + pack_shape.push_back(pack_input_operands.size()); + RankedTensorType type = operand.ops.at(0) + ->getResult(0) + ->getType() + .cast(); + for (const auto& dim : type.getShape()) { + pack_shape.push_back(dim); + } + auto pack_input_type = + builder->getTensorType(pack_shape, type.getElementType()); + builder->setInsertionPoint(first_use); + Operation* pack_op = builder->create( + first_use->getLoc(), pack_input_type, pack_input_operands, + builder->getI32IntegerAttr(pack_input_operands.size()), + builder->getI32IntegerAttr(0)); + op_input = pack_op->getResult(0); + } + } else if (operand.aggregation == kStrategyLast) { + // This handle the strategy "last", if simply takes the last input. + op_input = operand.ops.at(operand.ops.size() - 1)->getOperand(0); + } else { + // This handle the strategy "first" and default, if simply takes the + // first input. + op_input = operand.ops.at(0)->getOperand(0); + } + aggregated_inputs[kv.first] = op_input; + } + return aggregated_inputs; + } + + // For now, we just return the first output's location which the fused op will + // be inserted in. + Operation* GetFirstOutputOp() { return outputs.begin()->second.ops.at(0); } + + // Since we have differnt aggregation strategies, e.g., "first", "last", + // "stack". We don't somehow aggregated to get the outputs for the funcOp. + // This function is simply compute the RankedTensorType (shape & element type) + std::map GetAggregatedOuputTypes(OpBuilder* builder) { + std::map aggregated_output_types; + for (const auto& kv : outputs) { + const AggregatedOperand& operand = kv.second; + if (operand.aggregation == kStrategyStack) { + const int output_numer = operand.ops.size(); + Value* first_output = operand.ops.at(0)->getOperand(0); + RankedTensorType first_output_type = + first_output->getType().cast(); + // The aggregated output shape will be {N, original_shape}. + SmallVector shape; + shape.push_back(output_numer); + for (const auto& dim : first_output_type.getShape()) { + shape.push_back(dim); + } + aggregated_output_types[kv.first] = + builder->getTensorType(shape, first_output_type.getElementType()); + } else if (operand.aggregation == kStrategyLast) { + Value* last_output = + operand.ops.at(operand.ops.size() - 1)->getOperand(0); + aggregated_output_types[kv.first] = last_output->getType(); + } else { + Value* first_output = operand.ops.at(0)->getOperand(0); + aggregated_output_types[kv.first] = first_output->getType(); + } + } + return aggregated_output_types; + } + + void AggregateAndRewireOutputs(OpBuilder* builder, Operation* fused_op) { + // TODO(renjieliu): Consider get rid of the ophinted identity nodes here + // as well or just rely on the general path to get rid of the identity + // nodes. + int output_index = 0; + for (const auto& kv : outputs) { + const AggregatedOperand& operand = kv.second; + // This handles the "stack" stratefy. It push a unpack_op before all the + // outputs and make all the outputs point to the unpack_op. + if (operand.aggregation == kStrategyStack) { + // TODO(renjieliu): Revisit here if we need to handle + // operand.ops().size() == 1 case. Insert a unpack op to unpack the + // outputs. + const int output_number = operand.ops.size(); + // Find the first output. + Operation* first_output = operand.ops.at(0); + Location insert_loc = first_output->getLoc(); + SmallVector unpack_output_types( + output_number, first_output->getOperand(0)->getType()); + + builder->setInsertionPoint(first_output); + Operation* unpack_op = builder->create( + insert_loc, unpack_output_types, fused_op->getResult(output_index), + builder->getI32IntegerAttr(output_number), + builder->getI32IntegerAttr(0)); + // For every unpack output, make sure they point to the right ones. + for (int i = 0; i < output_number; ++i) { + Operation* to_be_replaced_op = operand.ops.at(i); + to_be_replaced_op->replaceUsesOfWith(to_be_replaced_op->getOperand(0), + unpack_op->getResult(i)); + } + } else if (operand.aggregation == kStrategyLast) { + // This handles the strategy "last", it simply takes the last output. + Operation* op = operand.ops.at(operand.ops.size() - 1); + op->replaceUsesOfWith(op->getOperand(0), fused_op->getResult(kv.first)); + } else { + // This handles the strategy "first" and default, it simply takes the + // first output. + Operation* op = operand.ops.at(0); + op->replaceUsesOfWith(op->getOperand(0), fused_op->getResult(kv.first)); + } + + output_index++; + } + } + + LogicalResult VerifyOphint() const { + if (inputs.empty() || outputs.empty()) return failure(); + return success(); + } + + StringRef uuid; + StringRef function_name; + + private: + // The AggregatedOperand is used to hold one "aggregated operand". + // For example, this can be + // { + // aggregation = "stack", + // {0: ident_op1, 1: ident_op2, 2: ident_op3} + // } + struct AggregatedOperand { + StringRef aggregation; + std::map ops; + }; + + std::map inputs; + std::map outputs; +}; + +Operation* BuildFusedFuncOp(StringRef func_name, StringRef fused_func_type, + Operation* insert_before_op, + const std::map& inputs, + const std::map& output_types, + OpBuilder* builder, ModuleOp* module_op) { + SmallVector input_types; + SmallVector input_values; + for (const auto& kv : inputs) { + Value* input = kv.second; + input_types.push_back(input->getType()); + input_values.push_back(input); + } + + SmallVector func_output_types; + for (const auto& kv : output_types) { + func_output_types.push_back(kv.second); + } + + FunctionType function_type = + builder->getFunctionType(/*inputs=*/input_types, + /*results=*/func_output_types); + + SmallVector attrs; + attrs.push_back(builder->getNamedAttr( + kTfLiteFunctionName, builder->getStringAttr(fused_func_type))); + FuncOp func_op = FuncOp::create(insert_before_op->getLoc(), func_name, + function_type, llvm::makeArrayRef(attrs)); + module_op->push_back(func_op); + builder->setInsertionPoint(insert_before_op); + return builder->create(insert_before_op->getLoc(), func_op, + input_values); +} + +llvm::StringMap FindAllOphintNodes(Block* bb) { + llvm::StringMap ophint_composite_ops; + for (auto& op : *bb) { + auto nameAttr = op.getAttrOfType(kTfLiteFunctionName); + if (!nameAttr) continue; + StringRef function_name = nameAttr.getValue(); + auto uuidAttr = op.getAttrOfType(kTfLiteFunctionUUID); + if (!uuidAttr) continue; + StringRef uuid = uuidAttr.getValue(); + auto it = ophint_composite_ops.find(uuid); + if (it == ophint_composite_ops.end()) { + OphintCompositeOp ophint_composite_op(uuid, function_name); + it = ophint_composite_ops.insert({uuid, ophint_composite_op}).first; + } + + // The default aggregation strategy is "NONE". + StringRef aggregation = kStrategyNone; + auto aggregationAttr = + op.getAttrOfType(kTfLiteFunctionAggregate); + if (aggregationAttr != nullptr) aggregation = aggregationAttr.getValue(); + + // The default sort index is 0. + int sortIndex = 0; + auto sortIndexAttr = + op.getAttrOfType(kTfLiteFunctionSortIndex); + if (sortIndexAttr != nullptr) sortIndex = sortIndexAttr.getInt(); + + auto inputIndexAttr = + op.getAttrOfType(kTfLiteFunctionInputIndex); + if (inputIndexAttr != nullptr) { + it->second.AddInput(inputIndexAttr.getInt(), &op, aggregation, sortIndex); + } else { + auto outputIndexAttr = + op.getAttrOfType(kTfLiteFunctionOutputIndex); + it->second.AddOutput(outputIndexAttr.getInt(), &op, aggregation, + sortIndex); + } + } + + return ophint_composite_ops; +} + +llvm::DenseSet BfsForReachableOps(ArrayRef input_ops) { + llvm::DenseSet reachable_ops; + std::queue ops_queue; + for (auto& input_op : input_ops) { + for (Value* value : input_op->getOperands()) { + Operation* op = value->getDefiningOp(); + if (op != nullptr) ops_queue.push(op); + } + } + + while (!ops_queue.empty()) { + Operation* current_op = ops_queue.front(); + ops_queue.pop(); + reachable_ops.insert(current_op); + for (Value* value : current_op->getOperands()) { + Operation* upstream_op = value->getDefiningOp(); + // Not visited, put it into the queue. + if (upstream_op != nullptr && + !llvm::is_contained(reachable_ops, upstream_op)) { + ops_queue.emplace(upstream_op); + } + } + } + + return reachable_ops; +} + +// Convert ophint to stub will remove all ops within the ophint region and +// place a new fused op right before the first op. +LogicalResult ConvertOphintToStub(StringRef stub_name, + OphintCompositeOp ophint_composite_op, + OpBuilder* builder, ModuleOp* module_op) { + // Step 1, find all ops reachable by inputs. + const llvm::DenseSet& reachable_by_inputs = + BfsForReachableOps(ophint_composite_op.GetAllInputOps()); + + // Step 2, find all ops reachable by outputs. + const llvm::DenseSet& reachable_by_outputs = + BfsForReachableOps(ophint_composite_op.GetAllOutputOps()); + + // Step 3, deal with inputs aggregation strategies. + const std::map& aggregated_inputs = + ophint_composite_op.GetAggregatedInputs(builder); + + // Step 4, get aggregated output types. + const std::map& aggregated_output_types = + ophint_composite_op.GetAggregatedOuputTypes(builder); + + // Step 5, create & place the fused op and rewire the inputs. + // Here we use a funcOp to represent the fused op. This "funcOp" will be + // coonverted to other ops (like UnidirectionalSequenceRNNOp) in the + // legalization phase. + Operation* inserted_before_op = ophint_composite_op.GetFirstOutputOp(); + Operation* fused_op = BuildFusedFuncOp( + stub_name, ophint_composite_op.function_name, inserted_before_op, + aggregated_inputs, aggregated_output_types, builder, module_op); + + for (const auto& kv : aggregated_inputs) { + Operation* op = kv.second->getDefiningOp(); + if (op == nullptr) return failure(); + op->moveBefore(fused_op); + } + + // Step 6, deal with outputs aggregation strategies and rewire the outputs. + ophint_composite_op.AggregateAndRewireOutputs(builder, fused_op); + + // Step 7, remove all the removable ops where + // (reachable_by_outputs - reachable_by_inputs) as removable and the rest + // ops are not removable. + auto removeRemovableOps = [&](Operation* op) { + if (!llvm::is_contained(reachable_by_inputs, op) && + llvm::is_contained(reachable_by_outputs, op)) { + op->dropAllDefinedValueUses(); + op->dropAllReferences(); + op->erase(); + } + }; + + builder->getBlock()->walk(removeRemovableOps); + return success(); +} + +struct ExtractOphintPass : public ModulePass { + void runOnModule() override; + void Verify(); + + private: + int ophint_composite_ops_count = 0; +}; + +// TODO(renjieliu): Current ophint extraction does not support inputs/outputs +// cross functions, we need to do that. +void ExtractOphintPass::runOnModule() { + ModuleOp module = getModule(); + for (auto function : module.getOps()) { + // Process block by block. + for (auto& bb : function.getBody()) { + // Find ophints. + const llvm::StringMap& ophint_composite_ops = + FindAllOphintNodes(&bb); + if (ophint_composite_ops.empty()) continue; + + // Verify: Make sure all ophint_composite_ops are valid. + for (const auto& kv : ophint_composite_ops) { + if (failed(kv.getValue().VerifyOphint())) { + module.emitError() + << "Found malformed ophint regions: missing inputs or outputs."; + return signalPassFailure(); + } + } + + ophint_composite_ops_count = ophint_composite_ops.size(); + + // Convert. + OpBuilder builder(&bb); + for (const auto& kv : ophint_composite_ops) { + if (failed(ConvertOphintToStub(kv.getKey(), kv.getValue(), &builder, + &module))) { + module.emitError() + << "Convert ophint failed, malformed inputs or outputs."; + return signalPassFailure(); + } + } + } + } +} + +void ExtractOphintPass::Verify() { + ModuleOp module = getModule(); + int ophint_func_op_count = 0; + for (FuncOp func : getModule().getOps()) { + for (const NamedAttribute attr : func.getAttrs()) { + if (attr.first == kTfLiteFunctionName) { + ophint_func_op_count++; + if (func.getNumArguments() == 0) { + module.emitError() << "Ophint function has no inputs."; + return signalPassFailure(); + } + if (func.getType().getNumResults() == 0) { + module.emitError() << "Ophint function has no outputs."; + return signalPassFailure(); + } + } + } + } + if (ophint_func_op_count != ophint_composite_ops_count) { + module.emitError() + << "Ophint converted functions do not match ophint regions founded."; + return signalPassFailure(); + } +} + +} // namespace + +/// Creates an instance of the TensorFlow Lite dialect ExtractOphintPass +/// pass. +std::unique_ptr CreateExtractOphintPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tfl-extract-ophint", "Extract Ophint for TfLite dialect."); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc new file mode 100644 index 00000000000..2ea5dba3e17 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_ophint_func_op.cc @@ -0,0 +1,209 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/StringMap.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { +namespace { + +constexpr char kTfLiteFunctionName[] = "_tflite_function_name"; +constexpr char kUnidirectionalSequenceRnn[] = "UnidirectionalSequenceRnn"; + +// This pass is used for converting to TFLite composite op like +// UnidirectionalSequenceRNN, UnidirectionalSequenceLSTM or SVDF Op. Currently, +// this pass is only for ophint converted function op only. See below diagram: +// +// InputOp1 InputOp2 ... +// \ / +// \ / +// call funcOp (say UnidirectionalSequenceRNN) +// | +// | +// OutputOp1 +// +// funcOp() { '_tflite_function_name' = 'UnidirectionalSequenceRNN'} +// +// || +// || +// \ / +// +// InputOp1 InputOp2 ... +// \ / +// \ / +// tfl.UnidirectionalSequenceRNN +// | +// | +// OutputOp1 +struct LegalizeOphintFuncOpPass : public ModulePass { + void runOnModule() override; +}; + +llvm::StringMap FindCompositeFuncOps(ModuleOp module) { + llvm::StringMap composite_func_ops; + for (FuncOp func : module.getOps()) { + if (func.getAttr(kTfLiteFunctionName)) + composite_func_ops[func.getName()] = func; + } + return composite_func_ops; +} + +LogicalResult BuildUnidirectionalSequenceRnnOp(FuncOp composite_func_op, + CallOp* call_op, + OpBuilder* builder, + Operation** fused_op) { + // UnidirectionalSequenceRnn takes exactly 5 inputs. + if (composite_func_op.getNumArguments() != 5) return failure(); + if (call_op->getNumOperands() != 5) return failure(); + // UnidirectionalSequenceRnn has exactly 1 input. + if (call_op->getNumResults() != 1) return failure(); + + // Inputs is indexed at 0. + Value* input = call_op->getOperand(0); + // Input_weight is indexed at 1. + Value* weight = call_op->getOperand(1); + // Recurrent_weight is indexed at 2. + Value* recurrent_weight = call_op->getOperand(2); + // Bias is indexed at 3. + Value* bias = call_op->getOperand(3); + // Hidden_state is indexed at 4. + Value* hidden_state = call_op->getOperand(4); + + // Build Output. + auto output_type = call_op->getResult(0)->getType(); + + // Currently, ophinted RNN only supports time_major = True. + const bool time_major = true; + // Activation will always be TanH. + StringAttr fused_activation_function = builder->getStringAttr("TANH"); + + builder->setInsertionPoint(call_op->getOperation()); + *fused_op = builder->create( + call_op->getLoc(), output_type, input, weight, recurrent_weight, bias, + hidden_state, builder->getBoolAttr(time_major), + fused_activation_function); + return success(); +} + +LogicalResult ConvertTfLiteFusedOpIfAvaiable(StringRef func_name, + FuncOp composite_func_op, + CallOp* call_op, + OpBuilder* builder) { + Operation* fused_op = nullptr; + if (func_name == kUnidirectionalSequenceRnn) { + // TODO(renjieliu): Validate the func op inputs. + LogicalResult build_fused_op_result = BuildUnidirectionalSequenceRnnOp( + composite_func_op, call_op, builder, &fused_op); + if (failed(build_fused_op_result)) return build_fused_op_result; + } else { // If we support more fused op, we should add the conversion here. + return failure(); + } + + call_op->replaceAllUsesWith(fused_op); + + // Delete call op. + Operation* call = call_op->getOperation(); + call->dropAllDefinedValueUses(); + call->dropAllReferences(); + call->erase(); + return success(); +} + +LogicalResult ConvertCallOps(llvm::StringMap* composite_func_ops, + ModuleOp* module) { + for (auto func : module->getOps()) { + // Ideally it will be much simpler if we can just use walk, but we also + // want to early return if encounter errors. :( + OpBuilder builder(func.getBody()); + // The call_op replacement within this loop works like an in-place + // replacement, so it should be safe to do so. + for (auto call_op : + llvm::make_early_inc_range(builder.getBlock()->getOps())) { + auto it = composite_func_ops->find(call_op.getCallee()); + if (it == composite_func_ops->end()) return failure(); + + // Replace the call op with TfLite fused op. + // Currently it's only handled case by case, but ideally it would be + // much better if we can do this automatically. + FuncOp composite_func_op = it->second; + StringRef func_name = composite_func_op.getAttr(kTfLiteFunctionName) + .cast() + .getValue(); + if (failed(ConvertTfLiteFusedOpIfAvaiable(func_name, composite_func_op, + &call_op, &builder))) + return failure(); + + composite_func_ops->erase(it); + // Delete func op. + Operation* func = composite_func_op.getOperation(); + func->erase(); + } + } + return success(); +} + +void LegalizeOphintFuncOpPass::runOnModule() { + ModuleOp module = getModule(); + // Find all composite funcs, then for every call op inside every func op + // within the module, we go ahead and replace the callop with the tflite + // corresponding op and destroy the func op. This two-phase processing is + // intended: + // + // Every func op is meant to be used exactly once. + // Instead of finding the composite func then loop through the graph and + // convert the call op immediately, we break finding & converting into two + // phases. This changes the complexity from O(op_in_module * + // function_in_module * attr_in_func) to O(op_in_module) * O(map_look_up) + + // O(function_in_module * attr_in_func). O(op_in_module) is the dominant + // factor here and map look up should be very cheap. + llvm::StringMap composite_func_ops = FindCompositeFuncOps(module); + if (composite_func_ops.empty()) return; + if (failed(ConvertCallOps(&composite_func_ops, &module))) { + module.emitError() << "Legalize ophint: ConvertCallOps failed."; + return signalPassFailure(); + } +} + +} // namespace + +/// Creates an instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass +/// pass. +std::unique_ptr CreateLegalizeOphintFuncOpPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tfl-legalize-ophint-func-op", "Convert composite op for TfLite dialect."); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 90ff6713874..94efc7d2719 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // TFLite legalization patterns include "mlir/IR/OpBase.td" -include "mlir/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" @@ -29,7 +29,6 @@ class ExtractI32At : NativeCodeCall< "$_builder.getI32IntegerAttr($_self.cast().getValue()[" # i # "].cast().getInt())">; - // Merge the two Attributes to a ArrayAttr; def Merge2AttrsToArray : NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; @@ -80,6 +79,7 @@ def : Pat<(TF_AvgPoolOp $value, /*stride_w=*/ExtractI32At<2>:$strides, /*fused_activation_function=*/TFL_AF_None)>; +def : Pat<(TF_ArgMaxOp $input, $dim), (TFL_ArgMaxOp $input, $dim)>; def : Pat<(TF_ArgMinOp $input, $dim), (TFL_ArgMinOp $input, $dim)>; def : Pat<(TF_CeilOp $arg), (TFL_CeilOp $arg)>; @@ -134,10 +134,14 @@ def : Pat<(TF_ReverseSequenceOp $input, $seq_lengths, $seq_dim, $batch_dim), (TFL_ReverseSequenceOp $input, $seq_lengths, (convertIntAttrTo32Bit $seq_dim), (convertIntAttrTo32Bit $batch_dim))>; +def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>; def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>; +def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>; +def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; // TODO(jpienaar): this is not true for all selects, TF's select supports rank 0 // condition def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; +def : Pat<(TF_SelectV2Op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; @@ -146,6 +150,7 @@ def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>; def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>; def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>; +def : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; // The following two rules can both match an tf.Placeholder.input node with @@ -250,6 +255,8 @@ def : Pat<(TF_EqualOp $arg0, $arg1), (TFL_EqualOp $arg0, $arg1)>; def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>; +def : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>; + def : Pat<(TF_PadV2Op $arg0, $arg1, $cst), (TFL_PadV2Op $arg0, $arg1, $cst)>; def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>; @@ -265,16 +272,29 @@ def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1 def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; +def : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims), + (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>; + def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>; def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>; +def : Pat<(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format), + (TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>; + +def : Pat<(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format), + (TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>; + def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners)>; +def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, ConstBoolAttrFalse:$half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners)>; def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>; +def : Pat<(TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value, $validate_indices), + (TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value)>; + def : Pat< (TF_StridedSliceOp $input, $begin, $end, $strides, $begin_mask, $end_mask, $ellipsis_mask, $new_axis_mask, $shrink_axis_mask), (TFL_StridedSliceOp $input, $begin, $end, $strides, @@ -283,4 +303,7 @@ def : Pat< def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>; +def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>; +def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>; + def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index faf80f3acb8..b20af2b4215 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -33,9 +33,9 @@ limitations under the License. #include "mlir/Support/Functional.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -205,16 +205,18 @@ void LegalizeTF::runOnFunction() { // Add the generated patterns to the list. populateWithGenerated(ctx, &patterns); - RewriteListBuilder::build(patterns, ctx); - applyPatternsGreedily(func, std::move(patterns)); + patterns.insert(ctx); + applyPatternsGreedily(func, patterns); } } // namespace // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -FunctionPassBase* CreateLegalizeTFPass() { return new LegalizeTF(); } +std::unique_ptr CreateLegalizeTFPass() { + return std::make_unique(); +} static PassRegistration pass( "tfl-legalize-tf", "Legalize from TensorFlow to TensorFlow Lite dialect"); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 44ff796b7cc..716c8216433 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Block.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir @@ -35,15 +36,14 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/Functional.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir -#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" @@ -82,7 +82,7 @@ struct LowerStaticTensorListPass // Changes the function type of `cond_func` and `body_func`, and the result // type of the `WhileOp`. - LogicalResult UpdateWhileFunctionType(TF::WhileOp *while_op); + LogicalResult UpdateWhileFunctionType(TF::WhileOp op); }; Value *CreateI32SplatConst(Operation *op, PatternRewriter *rewriter, @@ -100,10 +100,10 @@ Value *CreateI32SplatTensor(Operation *op, PatternRewriter *rewriter, shape_tensor, scalar_val); } -struct ConvertTFTensorListSetItem : public RewritePattern { +struct ConvertTFTensorListSetItem + : public OpRewritePattern { explicit ConvertTFTensorListSetItem(MLIRContext *context) - : RewritePattern(TF::TensorListSetItemOp::getOperationName(), 1, - context) {} + : OpRewritePattern(context, 1) {} // This function rewrites the original op into a series of slice and concat op // to produce the same result. It first slices the first `$index` rows. Then // expands the dimension of the `$item`, followed by another slice of the @@ -116,23 +116,21 @@ struct ConvertTFTensorListSetItem : public RewritePattern { // (Slice $input, [0, 0, ...], (Concat (ExpandDims $index, expand_dim = // 0), [-1, -1, ...])), (ExpandDims $item, expand_dim = 0), (Slice // $input, [$index + 1, 0, 0, ...], [-1, -1, ...]))>; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(TF::TensorListSetItemOp op, PatternRewriter &rewriter) const override { - TF::TensorListSetItemOp tf_op = cast(op); - - auto input = tf_op.input_handle(); + auto input = op.input_handle(); auto shape_dtype = rewriter.getIntegerType(32); auto input_rank = rewriter.create( - op->getLoc(), rewriter.getTensorType({}, shape_dtype), input); - auto item = tf_op.item(); + op.getLoc(), rewriter.getTensorType({}, shape_dtype), input); + auto item = op.item(); auto item_rank = rewriter.create( - op->getLoc(), rewriter.getTensorType({}, shape_dtype), item); + op.getLoc(), rewriter.getTensorType({}, shape_dtype), item); // Prepare the start position for the first slice op, which is [0, 0, .., // 0]. auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); auto position_shape = rewriter.create( - op->getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank, + op.getLoc(), rewriter.getTensorType({1}, shape_dtype), input_rank, scalar_zero); // Fill all 0s into the first position tensor. auto first_start_position = @@ -141,33 +139,33 @@ struct ConvertTFTensorListSetItem : public RewritePattern { // Prepare the start position for the second slice op, which is // [index + 1, 0, 0 .. 0]. // Calculate the first dimension, which is index + 1. - auto index = tf_op.index(); + auto index = op.index(); auto vector_type = rewriter.getTensorType({1}, shape_dtype); auto begin = rewriter.create( - op->getLoc(), rewriter.getTensorType(shape_dtype), index, + op.getLoc(), rewriter.getTensorType(shape_dtype), index, CreateI32SplatConst(op, &rewriter, {1}, 1)); // Followed by the first dimension `begin`, are `item_rank` of 0s. auto item_position_shape = rewriter.create( - op->getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank, + op.getLoc(), rewriter.getTensorType({1}, shape_dtype), item_rank, scalar_zero); auto partial_second_start_position = CreateI32SplatTensor(op, &rewriter, item_position_shape, 0); auto position_type = first_start_position->getType(); // Concatenate `begin` with the remaining 0s. auto second_start_position = rewriter.create( - op->getLoc(), position_type, scalar_zero, + op.getLoc(), position_type, scalar_zero, ArrayRef({begin, partial_second_start_position}), rewriter.getI64IntegerAttr(2)); // Create the size parameter for the first slice op, which is [index, -1, // -1, .., -1]. auto size1_leading_dim = rewriter.create( - op->getLoc(), vector_type, index, scalar_zero); + op.getLoc(), vector_type, index, scalar_zero); auto partial_size1 = CreateI32SplatTensor(op, &rewriter, item_position_shape, -1); auto size1 = rewriter.create( - op->getLoc(), position_type, scalar_zero, + op.getLoc(), position_type, scalar_zero, ArrayRef({size1_leading_dim, partial_size1}), rewriter.getI64IntegerAttr(2)); @@ -179,14 +177,14 @@ struct ConvertTFTensorListSetItem : public RewritePattern { auto element_type = input->getType().cast().getElementType(); auto unranked_tensor = rewriter.getTensorType(element_type); auto slice1 = rewriter.create( - op->getLoc(), unranked_tensor, input, first_start_position, size1); + op.getLoc(), unranked_tensor, input, first_start_position, size1); auto slice2 = rewriter.create( - op->getLoc(), unranked_tensor, input, second_start_position, size2); + op.getLoc(), unranked_tensor, input, second_start_position, size2); // Expand the dimension of item so that it will have the same rank with // input. auto expanded_item = rewriter.create( - op->getLoc(), unranked_tensor, item, scalar_zero); + op.getLoc(), unranked_tensor, item, scalar_zero); // Concatenate three parts together to generate the final result. rewriter.replaceOpWithNewOp( @@ -198,52 +196,54 @@ struct ConvertTFTensorListSetItem : public RewritePattern { } }; -struct ConvertTFTensorListReserve : public RewritePattern { - explicit ConvertTFTensorListReserve(MLIRContext *context) - : RewritePattern(TF::TensorListReserveOp::getOperationName(), 1, - context) {} +// Rewrites op of the template type initializing a TensorList with a list of ops +// to generate an equivalent raw tensor. Derived classes are required to +// override GetNumElements method. +template +struct ConvertTFTensorListInitOp : public OpRewritePattern { + explicit ConvertTFTensorListInitOp(MLIRContext *context) + : OpRewritePattern(context, 1) {} + + // Create and return a 1-d tensor with exactly one element equal to the number + // of list elements to initialize the output tensor list with. + virtual Value *GetNumElements(OpT op, PatternRewriter *rewriter) const = 0; // Rewrites the original op into `tf.fill`. The result tensor shape is // [num_element, element_shape]. All the values in the result tensor will be // initialized to 0. - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { - TF::TensorListReserveOp tf_op = cast(op); - - auto element_shape = tf_op.element_shape(); + auto element_shape = op.element_shape(); auto shape_dtype = getElementTypeOrSelf(element_shape->getType()); - auto num_elements = tf_op.num_elements(); - Type element_dtype = tf_op.element_dtype(); + Type element_dtype = op.element_dtype(); int64_t result_rank = -1; // -1 means unknown result rank. Type result_type = rewriter.getTensorType(element_dtype); - if (auto element_type = tf_op.element_type().dyn_cast()) { + if (auto element_type = + op.element_type().template dyn_cast()) { result_rank = element_type.getRank() + 1; // If element type is ranked, then result type will have unknown leading // dimension and element shape for the following dimensions. // - // Note: leading dim is not inferred here even if num_elements input is a - // constant. + // Note: leading dim is not inferred here even when it is a constant. SmallVector result_shape = {-1}; ArrayRef shape = element_type.getShape(); result_shape.append(shape.begin(), shape.end()); result_type = rewriter.getTensorType(result_shape, element_dtype); } - // The output shape of the result tensor should be [num_elements + - // element_shape]. - auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); - auto leading_dim = rewriter.create( - op->getLoc(), rewriter.getTensorType({1}, shape_dtype), num_elements, - scalar_zero); - // Create a 1-D RankedTensorType for result's shape. Number of elements in // it is equal to the rank of the result, if known. Otherwise, the number of // elements are unknown and represented with -1. In both cases, we can // specify dimension using rank of the result. Type shape_type = rewriter.getTensorType({result_rank}, shape_dtype); + + // Add number of elements as the prefix to the element shape to get shape of + // the output tensor. + auto leading_dim = GetNumElements(op, &rewriter); + auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); auto list_shape = rewriter.create( - op->getLoc(), shape_type, scalar_zero, + op.getLoc(), shape_type, scalar_zero, ArrayRef({leading_dim, element_shape}), rewriter.getI64IntegerAttr(2)); @@ -251,9 +251,92 @@ struct ConvertTFTensorListReserve : public RewritePattern { // as specified by element_dtype. auto zero_type = rewriter.getTensorType({}, element_dtype); auto zero_attr = rewriter.getZeroAttr(zero_type); - auto zero = rewriter.create(op->getLoc(), zero_type, zero_attr); + auto zero = rewriter.create(op.getLoc(), zero_type, zero_attr); rewriter.replaceOpWithNewOp(op, result_type, list_shape, zero); + return Pattern::matchSuccess(); + } +}; + +struct ConvertTFTensorListReserve + : public ConvertTFTensorListInitOp { + explicit ConvertTFTensorListReserve(MLIRContext *context) + : ConvertTFTensorListInitOp(context) {} + + Value *GetNumElements(TF::TensorListReserveOp op, + PatternRewriter *rewriter) const override { + auto scalar_zero = CreateI32SplatConst(op, rewriter, {}, 0); + auto shape_dtype = getElementTypeOrSelf(op.element_shape()->getType()); + return rewriter->create( + op.getLoc(), rewriter->getTensorType({1}, shape_dtype), + op.num_elements(), scalar_zero); + } +}; + +// TODO(hinsu): Replace with declarative patterns once the RewriterGen infra +// supports patterns involving variadic operand ops. +// +// Note that we ignore the second operand `max_num_elements` as we don't have +// any restrictions on the number of elements we can support. So this may +// have a different behavior compared to TensorFlow in case of errors. +struct ConvertTFEmptyTensorList + : public ConvertTFTensorListInitOp { + explicit ConvertTFEmptyTensorList(MLIRContext *context) + : ConvertTFTensorListInitOp(context) {} + + Value *GetNumElements(TF::EmptyTensorListOp op, + PatternRewriter *rewriter) const override { + return CreateI32SplatConst(op, rewriter, {1}, 0); + } +}; + +struct ConvertTFTensorListPushBack : public RewritePattern { + explicit ConvertTFTensorListPushBack(MLIRContext *context) + : RewritePattern(TF::TensorListPushBackOp::getOperationName(), 1, + context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + TF::TensorListPushBackOp push_back_op = cast(op); + Value *item = push_back_op.tensor(); + Type dtype = getElementTypeOrSelf(*item); + + // Returns a new type by prepending the specified dimension to the shape of + // the given type if it is a ranked type. + auto with_leading_dim = [&](int64_t dim, Type type) -> Type { + if (RankedTensorType ty = type.dyn_cast()) { + llvm::SmallVector shape = {dim}; + shape.append(ty.getShape().begin(), ty.getShape().end()); + return rewriter.getTensorType(shape, dtype); + } + + return rewriter.getTensorType(dtype); + }; + + // Expand the shape of the item so that it will have rank same as the input + // tensor and it is compatible for the Concat Op. + Type expanded_item_type = with_leading_dim(1, item->getType()); + auto scalar_zero = CreateI32SplatConst(op, &rewriter, {}, 0); + auto expanded_item = rewriter.create( + op->getLoc(), expanded_item_type, item, scalar_zero); + + // If the variant type in the output handle has item shape available, use it + // to derive the output shape by setting unknown leading dimension. + // Otherwise, result type will be of unranked type. + Type handle_type = push_back_op.output_handle()->getType(); + TF::VariantType handle_dtype = + getElementTypeOrSelf(handle_type).cast(); + Type result_type = rewriter.getTensorType(dtype); + if (!handle_dtype.getSubtypes().empty()) { + result_type = with_leading_dim(-1, handle_dtype.getSubtypes()[0]); + } + + // Concatenate tensor stored in the input handle with the expanded item to + // get a tensor equivalent to the TensorList generated by this op. + rewriter.replaceOpWithNewOp( + op, result_type, scalar_zero, + ArrayRef({push_back_op.input_handle(), expanded_item}), + rewriter.getI64IntegerAttr(2)); return matchSuccess(); } }; @@ -267,17 +350,17 @@ namespace { } // namespace TFL LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType( - TF::WhileOp *while_op) { + TF::WhileOp op) { SmallVector unranked_argument_types; - for (const auto &operand : while_op->getOperands()) { + for (const auto &operand : op.getOperands()) { unranked_argument_types.push_back( UnrankedTensorType::get(getElementTypeOrSelf(operand->getType()))); } auto *context = &getContext(); auto module = getModule(); - FuncOp cond_func = module.lookupSymbol(while_op->getCond()); - FuncOp body_func = module.lookupSymbol(while_op->getBody()); + FuncOp cond_func = module.lookupSymbol(op.cond()); + FuncOp body_func = module.lookupSymbol(op.body()); if (cond_func) { // Change `cond_func`'s argument types to `unranked_argument_types`. @@ -313,9 +396,9 @@ LogicalResult LowerStaticTensorListPass::UpdateWhileFunctionType( } } - for (int i = 0; i < while_op->getNumOperands(); ++i) { - auto operand = while_op->getOperand(i); - auto result = while_op->getResult(i); + for (int i = 0; i < op.getNumOperands(); ++i) { + auto operand = op.getOperand(i); + auto result = op.getResult(i); if (getElementTypeOrSelf(result->getType()).isa()) { // If we notice the result type is a DT_VARIANT, we change the // corresponding result type to unranked tensor type. @@ -357,7 +440,11 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( } auto c = ConvertTFTensorListReserve(context); rewriter->setInsertionPoint(op); - c.matchAndRewrite(op, *rewriter); + c.matchAndRewrite(tf_op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + auto c = ConvertTFEmptyTensorList(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(tf_op, *rewriter); } else if (auto tf_op = llvm::dyn_cast(op)) { auto c = TFL::ConvertTFTensorListGetItem(context); rewriter->setInsertionPoint(op); @@ -365,14 +452,18 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( } else if (auto tf_op = llvm::dyn_cast(op)) { auto c = ConvertTFTensorListSetItem(context); rewriter->setInsertionPoint(op); - c.matchAndRewrite(op, *rewriter); + c.matchAndRewrite(tf_op, *rewriter); } else if (auto tf_op = llvm::dyn_cast(op)) { auto c = TFL::ConvertTFTensorListStack(context); rewriter->setInsertionPoint(op); c.matchAndRewrite(op, *rewriter); + } else if (auto tf_op = llvm::dyn_cast(op)) { + auto c = ConvertTFTensorListPushBack(context); + rewriter->setInsertionPoint(op); + c.matchAndRewrite(op, *rewriter); } else if (auto tf_op = llvm::dyn_cast(op)) { if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context)); - UpdateWhileFunctionType(&tf_op); + UpdateWhileFunctionType(tf_op); } else if (auto tf_op = llvm::dyn_cast(op)) { if (op->getAttr("T")) op->removeAttr(Identifier::get("T", context)); tf_op.getResult()->setType(tf_op.getOperand()->getType()); @@ -408,8 +499,8 @@ void LowerStaticTensorListPass::runOnModule() { /// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList /// pass. -ModulePassBase *TFL::CreateLowerStaticTensorListPass() { - return new LowerStaticTensorListPass(); +std::unique_ptr TFL::CreateLowerStaticTensorListPass() { + return std::make_unique(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 8e3d9690486..33d85b633d5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -21,14 +21,22 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Support/Functional.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { namespace TFL { @@ -37,85 +45,166 @@ namespace TFL { // The actual Optimize Pass. namespace { +using ::llvm::cast; + // Optimize TFLite operations in functions. struct Optimize : public FunctionPass { void runOnFunction() override; }; +// Returns whether the given type `a` is broadcast-compatible with `b`. +bool IsBroadcastableElementsAttrAndType(Type a, Type b) { + return OpTrait::util::getBroadcastedType(a, b) != Type(); +} + // Returns whether the given `a` and `b` ElementsAttr have broadcast-compatible // types. bool IsBroadcastableElementsAttrs(Attribute a, Attribute b) { - return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); + return IsBroadcastableElementsAttrAndType(a.getType(), b.getType()); } #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" -// Fuse Add with FullyConnected. -// Note that this assumes that the bias in the fullyConnected -// is always None. + +// Fuse Add with proceeding FullyConnected. // TODO(b/136285429): Move to tablegen when variadic is supported -// and add support for bias with noneType type. -struct FuseFullyConnectedAndAdd : public RewritePattern { - explicit FuseFullyConnectedAndAdd(MLIRContext *context) - : RewritePattern(TFL::AddOp::getOperationName(), - {"tfl.fully_connected", "tfl.add", "std.constant"}, 4, - context) {} +struct FuseFullyConnectedAndAdd : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *add_op, + PatternMatchResult matchAndRewrite(TFL::AddOp add_op, PatternRewriter &rewriter) const override { + // Add. + DenseElementsAttr added_value; + Value *constant_val = add_op.rhs(); + if (!matchPattern(constant_val, m_Constant(&added_value))) + return matchFailure(); + // Fully Connected. - Operation *fully_connected = add_op->getOperand(0)->getDefiningOp(); - if (!fully_connected || !isa(fully_connected)) + auto fc_op = + dyn_cast_or_null(add_op.lhs()->getDefiningOp()); + if (!fc_op) return matchFailure(); + + Value *filter = fc_op.filter(); + Value *bias = fc_op.bias(); + ElementsAttr bias_value; + const bool is_none_bias = bias->getType().isa(); + if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value))) return matchFailure(); - TFL::FullyConnectedOp fully_connected_op = - llvm::cast(fully_connected); - Value *input = fully_connected_op.input(); - Value *filter = fully_connected_op.filter(); - - // Make sure the bias is None. - // TODO(karimnosseir): Support non None case. - Operation *bias_op = fully_connected_op.bias()->getDefiningOp(); - if (!bias_op || !isa(bias_op)) return matchFailure(); - if (!fully_connected_op.bias()->getType().isa()) - return matchFailure(); - - auto activation_func = fully_connected_op.getAttrOfType( - "fused_activation_function"); - if (!activation_func) return matchFailure(); - if (activation_func.cast().getValue() != "NONE") - return matchFailure(); - - auto weight_format = - fully_connected_op.getAttrOfType("weights_format"); - if (!weight_format) return matchFailure(); - - auto keep_num_dims = - fully_connected_op.getAttrOfType("keep_num_dims"); - if (!keep_num_dims) return matchFailure(); - - auto constant_op = add_op->getOperand(1)->getDefiningOp(); - if (!constant_op) return matchFailure(); - if (!isa(constant_op)) return matchFailure(); - - auto add_value = constant_op->getAttrOfType("value"); - if (!add_value) return matchFailure(); - if (!((add_value.cast().getType().getElementType().isF32()))) - return matchFailure(); - - auto fused_activation_func = - add_op->getAttrOfType("fused_activation_function"); - if (!fused_activation_func) return matchFailure(); + if (fc_op.fused_activation_function() != "NONE") return matchFailure(); // Rewrite - // TODO(karimnosseir): Check what constraints needed to apply. - // TODO(b/136171362): Check for single output consumer. + Location loc = fc_op.getLoc(); + // If bias isn't None, it needs to be added as well. + if (is_none_bias) { + bias = constant_val; + } else { + auto none_af = rewriter.getStringAttr("NONE"); + bias = rewriter.create(loc, bias, constant_val, none_af).output(); + } rewriter.replaceOpWithNewOp( - add_op, add_op->getResult(0)->getType(), - /*input=*/input, + add_op, add_op.getType(), + /*input=*/fc_op.input(), /*filter=*/filter, - /*bias=*/add_op->getOperand(1), - /*fused_activation_function=*/fused_activation_func, - /*weights_format=*/weight_format, - /*keep_num_dims=*/keep_num_dims); + /*bias=*/bias, + /*fused_activation_function=*/ + rewriter.getStringAttr(add_op.fused_activation_function()), + /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()), + /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims())); + + return matchSuccess(); + } +}; + +// TODO(b/136285429): Move to tablegen when variadic is supported. +struct FuseFullyConnectedAndRelu : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TFL::ReluOp relu_op, + PatternRewriter &rewriter) const override { + Operation *input = relu_op.getOperand()->getDefiningOp(); + if (!isa_and_nonnull(input)) return matchFailure(); + auto fully_connected_op = cast(input); + if (fully_connected_op.fused_activation_function() != "NONE") + return matchFailure(); + + auto new_activation_func = rewriter.getStringAttr("RELU"); + auto new_weights_format = + rewriter.getStringAttr(fully_connected_op.weights_format()); + auto new_keep_num_dims = + rewriter.getBoolAttr(fully_connected_op.keep_num_dims()); + rewriter.replaceOpWithNewOp( + relu_op, relu_op.getType(), fully_connected_op.input(), + fully_connected_op.filter(), fully_connected_op.bias(), + new_activation_func, new_weights_format, new_keep_num_dims); + + return matchSuccess(); + } +}; + +// Fuse Mul with proceeding FullyConnected. +// TODO(b/136285429): Move to tablegen when variadic is supported +struct FuseFullyConnectedAndMul : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TFL::MulOp mul_op, + PatternRewriter &rewriter) const override { + // Mul. + DenseElementsAttr cst; + Value *constant_val = mul_op.rhs(); + if (!matchPattern(constant_val, m_Constant(&cst))) return matchFailure(); + + // Fully Connected. + auto fc_op = + dyn_cast_or_null(mul_op.lhs()->getDefiningOp()); + if (!fc_op) return matchFailure(); + Value *filter = fc_op.filter(); + Value *bias = fc_op.bias(); + ElementsAttr cst_tmp; + if (!matchPattern(filter, m_Constant(&cst_tmp))) return matchFailure(); + if (!bias->getType().isa() && + !matchPattern(bias, m_Constant(&cst_tmp))) + return matchFailure(); + if (fc_op.fused_activation_function().equals("None")) return matchFailure(); + + // Broadcast the constant operand of Mul if it isn't compatible to the + // filter input. We only support broadcasting the operand along the depth + // dimension, when the operand's depth is 1. + Value *new_const_val = constant_val; + if (!IsBroadcastableElementsAttrAndType(cst.getType(), filter->getType())) { + auto original_shape = cst.getType().getShape(); + llvm::SmallVector normalized_shape(original_shape.begin(), + original_shape.end()); + normalized_shape.push_back(1); + auto new_cst = cst.reshape(rewriter.getTensorType( + normalized_shape, cst.getType().getElementType())); + Type new_type = new_cst.getType(); + if (!IsBroadcastableElementsAttrAndType(new_type, filter->getType())) { + return matchFailure(); + } + auto new_op = + rewriter.create(mul_op.getLoc(), new_type, new_cst); + new_const_val = new_op.getResult(); + } + + // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands, + // TF::MulOp is used to fold the constant. + // TODO(b/139192933): switch to the TFL constant folding + Location loc = fc_op.getLoc(); + auto new_filter = + rewriter.create(loc, filter, new_const_val).z(); + // If bias isn't None, it needs to be multiplied as well. + if (!bias->getType().isa()) { + bias = rewriter.create(loc, bias, constant_val).z(); + } + + rewriter.replaceOpWithNewOp( + mul_op, mul_op.getType(), + /*input=*/fc_op.input(), + /*filter=*/new_filter, + /*bias=*/bias, + /*fused_activation_function=*/ + rewriter.getStringAttr(mul_op.fused_activation_function()), + /*weights_format=*/rewriter.getStringAttr(fc_op.weights_format()), + /*keep_num_dims=*/rewriter.getBoolAttr(fc_op.keep_num_dims())); return matchSuccess(); } @@ -154,12 +243,12 @@ struct PadStridedSliceDims : public RewritePattern { // Insert a new reshape op. Value *original_input = strided_slice.input(); - const RankedTensorType &original_input_type = - original_input->getType().template cast(); + RankedTensorType original_input_type = + original_input->getType().cast(); const ArrayRef &original_input_shape = original_input_type.getShape(); - const RankedTensorType &begin_type = - strided_slice.begin()->getType().template cast(); + RankedTensorType begin_type = + strided_slice.begin()->getType().cast(); const int dim_size = begin_type.getShape()[0]; SmallVector new_shape; int mask = 1; @@ -204,19 +293,22 @@ struct PadStridedSliceDims : public RewritePattern { void Optimize::runOnFunction() { OwningRewritePatternList patterns; + auto *ctx = &getContext(); auto func = getFunction(); + // Add the generated patterns to the list. - TFL::populateWithGenerated(&getContext(), &patterns); - patterns.push_back( - llvm::make_unique(&getContext())); - patterns.push_back(llvm::make_unique(&getContext())); - applyPatternsGreedily(func, std::move(patterns)); + TFL::populateWithGenerated(ctx, &patterns); + patterns.insert(ctx); + applyPatternsGreedily(func, patterns); } } // namespace // Creates an instance of the TensorFlow Lite dialect Optimize pass. -FunctionPassBase *CreateOptimizePass() { return new Optimize(); } +std::unique_ptr CreateOptimizePass() { + return std::make_unique(); +} static PassRegistration pass( "tfl-optimize", "Optimize within the TensorFlow Lite dialect"); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 6d7e3aa24db..51610832db6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the optimization pattern definition file for TensorFlow Lite. include "mlir/IR/OpBase.td" -include "mlir/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" def F32ElementsAttr : ElementsAttrBase< @@ -110,3 +110,40 @@ def : Pat<(TFL_MulOp (TFL_DepthwiseConv2DOp $input, // with the same scale. We want to remove the redundancy. // TODO(fengliuai): move this to the sanity check of pre-quantize pass. def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>; + +// Constraint that makes sure both operands are the same operands. +def EqualOperands : Constraint>; + +// Checks if the operand has rank == n +class OperandHasRank : Constraint< + CPred<"$0->getType().cast().getRank() == " # n>>; + +// This pattern constructs L2NormalizationOp from +// Mul->Rsqrt->Sum->Square +// Currently L2Normalization doesn't support activation function +// in TFLite. +def : Pat<(TFL_MulOp $operand1, + (TFL_RsqrtOp + (TFL_SumOp + (TFL_SquareOp $square_operand), + (ConstantOp I32ElementsAttr:$constant), + $keep_dims)), + TFL_AF_None), + (TFL_L2NormalizationOp $operand1, TFL_AF_None), + [(EqualOperands $operand1, $square_operand), + (OperandHasRank<1> $operand1)]>; + +// This pattern constructs L2NormalizationOp from +// Div->sqrt->Sum->Square +// Currently L2Normalization doesn't support activation function +// in TFLite. +def : Pat<(TFL_DivOp $operand1, + (TFL_SqrtOp + (TFL_SumOp + (TFL_SquareOp $square_operand), + (ConstantOp I32ElementsAttr:$constant), + $keep_dims)), + TFL_AF_None), + (TFL_L2NormalizationOp $operand1, TFL_AF_None), + [(EqualOperands $operand1, $square_operand), + (OperandHasRank<1> $operand1)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 561c0de815f..fb01ba0e9c8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -16,6 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_PASSES_H_ +#include + +#include "llvm/ADT/ArrayRef.h" + namespace mlir { class FunctionPassBase; class ModulePassBase; @@ -23,29 +27,47 @@ class ModulePassBase; namespace TFL { // Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. -FunctionPassBase *CreateLegalizeTFPass(); +std::unique_ptr CreateLegalizeTFPass(); // Creates an instance of the TensorFlow Lite dialect Optimize pass. -FunctionPassBase *CreateOptimizePass(); +std::unique_ptr CreateOptimizePass(); // Creates an instance of the TensorFlow Lite dialect PrepareTF pass. -FunctionPassBase *CreatePrepareTFPass(); +std::unique_ptr CreatePrepareTFPass(); // Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList // pass. -ModulePassBase *CreateLowerStaticTensorListPass(); +std::unique_ptr CreateLowerStaticTensorListPass(); // Creates an instance of the TensorFlow Lite dialect Quantize pass. -FunctionPassBase *CreateQuantizePass(); +std::unique_ptr CreateQuantizePass(); // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. // When `quantize_sign` is true, constant tensors will use int8 quantization // scheme. // TODO(fengliuai): make the bit width configurable. -FunctionPassBase *CreatePrepareQuantizePass(bool quantize_sign); +std::unique_ptr CreatePrepareQuantizePass(bool quantize_sign); // Creates a instance of the TensorFlow Lite dialect PostQuantize pass. -FunctionPassBase *CreatePostQuantizePass(bool emit_quant_adaptor_ops); +std::unique_ptr CreatePostQuantizePass( + bool emit_quant_adaptor_ops); + +// Creates an instance of the TensorFlow Lite dialect TrimFunctions +// pass. +std::unique_ptr CreateTrimFunctionsPass( + llvm::ArrayRef trim_funcs_whitelist); + +// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions +// pass. +std::unique_ptr CreatePrepareCompositeFunctionsPass(); + +// Creates a instance of the TensorFlow Lite dialect ExtractOphint pass. +std::unique_ptr CreateExtractOphintPass(); + +// Creates a instance of the TensorFlow Lite dialect LegalizeOphintFuncOpPass +// pass. The composite op is created from the ophint extraction pass. +std::unique_ptr CreateLegalizeOphintFuncOpPass(); + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 94c19d27adc..17e715960d9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -18,8 +18,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" //===----------------------------------------------------------------------===// // The post-quantize Pass. @@ -125,8 +125,9 @@ void PostQuantizePass::runOnFunction() { } // namespace // Creates an instance of the TensorFlow Lite dialect PostQuantize pass. -FunctionPassBase* CreatePostQuantizePass(bool emit_quant_adaptor_ops) { - return new PostQuantizePass(emit_quant_adaptor_ops); +std::unique_ptr CreatePostQuantizePass( + bool emit_quant_adaptor_ops) { + return std::make_unique(emit_quant_adaptor_ops); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc new file mode 100644 index 00000000000..58e58c05c4d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -0,0 +1,124 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Identifier.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" + +namespace mlir { +namespace TFL { +namespace { + +// Abstracts the conversion of the embedded lookup composite function. +class ConvertEmbeddedLookupFunc { + public: + explicit ConvertEmbeddedLookupFunc(FuncOp func) : func_(func) {} + + void RewriteFunc() { + func_.eraseBody(); + func_.addEntryBlock(); + func_.setAttr( + "tf._implements", + StringAttr::get("fused_tfl_embedding_lookup", func_.getContext())); + Value* lookup = func_.getArgument(1); + Value* value = func_.getArgument(0); + auto output_type = func_.getType().getResult(0); + + OpBuilder builder(func_.getBody()); + auto op = builder.create( + func_.getLoc(), output_type, lookup, value); + + builder.create(func_.getLoc(), op.getResult()); + } + + LogicalResult VerifySignature() { + if (func_.getNumArguments() != 2) { + return func_.emitError() + << "Invalid number of arguments in the embedding " + "matmal composite function"; + } + if (func_.getType().getNumResults() != 1) { + return func_.emitError() << "Invalid number of results in the embedding " + "matmal composite function"; + } + return success(); + } + + private: + FuncOp func_; +}; + +// This pass uses mechanisms listed in RFC: +// https://github.com/tensorflow/community/pull/113 +// It prepares composite functions that are attributed to indicate +// a specific interface (LSTM, SVDF, Embedding lookup etc.) by replacing the +// body with the corresponding fused TFLite op. The replacement need not always +// be a fused op, though that is the primary use case. +class PrepareCompositeFunctionsPass + : public FunctionPass { + public: + explicit PrepareCompositeFunctionsPass() {} + + private: + void runOnFunction() override; +}; + +void PrepareCompositeFunctionsPass::runOnFunction() { + // TODO(ashwinm): Explore if we can generalize this pass by simply taking + // a map and doing the transform. This should be + // revisited after we add LSTM composite op to this pass. + auto func = getFunction(); + auto attr = func.getAttrOfType("tf._implements"); + if (!attr || attr.getValue() != "embedding_matmul") return; + // Convert the composite embedding_matmul function body to a + // TFLite fused embedding_lookup op. + ConvertEmbeddedLookupFunc convert_embedded_lookup(func); + if (failed(convert_embedded_lookup.VerifySignature())) { + return signalPassFailure(); + } + convert_embedded_lookup.RewriteFunc(); +} +} // namespace + +std::unique_ptr CreatePrepareCompositeFunctionsPass() { + return std::unique_ptr(); +} + +static PassRegistration pass( + "tfl-prepare-composite-funcs-tf", + "Prepares composite functions in Tensorflow dialect of MLIR "); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 62c3de86e72..e3dabb7a48d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -18,7 +18,7 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" def FalseBoolAttr : AttrConstraint>; -// Converts tf.FusedBatchNorm into a sequence of more primitive arithmetic +// Converts tf.FusedBatchNorm & tf.FusedBatchNormV3 into a sequence of more primitive arithmetic // operations. Specifically, performs the following calculation: // // (x - mean) * scale / sqrt(variance + epsilon) + offset @@ -29,9 +29,9 @@ def FalseBoolAttr : AttrConstraint>; // is then to compute // (x * multiplier) + (offset - mean * multiplier). def : Pattern< - (TF_FusedBatchNormOp $x, $scale, $offset, $mean, $variance, - F32Attr:$epsilon, $data_format, - FalseBoolAttr:$is_training), + (TF_FusedBatchNormOp:$root + $x, $scale, $offset, $mean, $variance, + F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training), [(TF_AddOp (TF_MulOp $x, @@ -41,21 +41,40 @@ def : Pattern< (TF_AddOp $variance, (TF_ConstOp $epsilon))))), (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), - /*batch_mean=*/(verifyUnusedValue), - /*batch_variance=*/(verifyUnusedValue), - /*reserve_space_1=*/(verifyUnusedValue), - /*reserve_space_2=*/(verifyUnusedValue) - ]>; + // We already guaranteed that the last four results has no use so it does + // not matter what value we provide here for replacement. + /*batch_mean=*/(replaceWithValue $x), + /*batch_variance=*/(replaceWithValue $x), + /*reserve_space_1=*/(replaceWithValue $x), + /*reserve_space_2=*/(replaceWithValue $x)], + [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), + (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>; -// TODO(jpienaar): Move to opbase something more general. -def TFi32ElementsAttr : Attr">, - "scalar int attribute"> { - let storageType = [{ DenseIntElementAttr }]; - let constBuilderCall = "$_builder.getDenseElementsAttr(" - "$_builder.getTensorType({}, $_builder.getIntegerType(32)), " - "{$_builder.getI32IntegerAttr($0)})"; -} -class TFi32 : ConstantAttr(v)>; +def : Pattern< + (TF_FusedBatchNormV3Op:$root + $x, $scale, $offset, $mean, $variance, + F32Attr:$epsilon, $data_format, FalseBoolAttr:$is_training), + [(TF_AddOp + (TF_MulOp + $x, + (TF_MulOp:$multiplier + $scale, + (TF_RsqrtOp + (TF_AddOp $variance, + (TF_ConstOp $epsilon))))), + (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), + // We already guaranteed that the last five results have no use so it does + // not matter what value we provide here for replacement. + /*batch_mean=*/(replaceWithValue $x), + /*batch_variance=*/(replaceWithValue $x), + /*reserve_space_1=*/(replaceWithValue $x), + /*reserve_space_2=*/(replaceWithValue $x), + /*reserve_space_3=*/(replaceWithValue $x)], + [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), + (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), + (HasNoUseOf:$root__5)]>; + +class TFi32 : ConstantAttr(v)>; // Matmul without transpose on b to matmul with explicit transpose op and // transposed b. @@ -75,10 +94,14 @@ def : Pat<(TF_MatMulOp $a, $b, ConstBoolAttrTrue, $bt), /*delta=*/(ConstantOp TFi32<-1>)), (ConstantOp TFi32<1>))), $b, ConstBoolAttrFalse, $bt)>; +def : Pat<(TF_SnapshotOp $arg), (TF_IdentityOp $arg)>; +def : Pat<(TF_StopGradientOp $arg), (TF_IdentityOp $arg)>; + //===----------------------------------------------------------------------===// // Op removal patterns. //===----------------------------------------------------------------------===// def : Pat<(TF_IdentityOp $arg), (replaceWithValue $arg)>; +def : Pat<(TF_IdentityNOp $arg), (replaceWithValue $arg)>; //===----------------------------------------------------------------------===// // Op quantization pass-through patterns. @@ -98,3 +121,27 @@ def : Pat<(TF_ReshapeOp $shape), (TF_FakeQuantWithMinMaxVarsOp (TF_ReshapeOp $input, $shape), $min, $max, $num_bits, $narrow_range)>; + +// Casts result type of $1 to a quantized type by using the quantization +// parameters from the type in $0. +def UpdateShape : NativeCodeCall< + "CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1->getType())">; + +// When the op is passing-through, the output types of the quantized ops need +// to be updated as well. Since the quantize op manages its own type by the +// "qtype" attribute, we should update the type shape in this attribute. +def : Pat<(TF_TransposeOp:$old_value + (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $perm), + (TFL_DequantizeOp (TFL_QuantizeOp (TF_TransposeOp $input, $perm), + (UpdateShape $qtype, $old_value)))>; + +def : Pat<(TF_ReshapeOp:$old_value + (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype)), $shape), + (TFL_DequantizeOp + (TFL_QuantizeOp (TF_ReshapeOp $input, $shape), + (UpdateShape $qtype, $old_value)))>; + +// The Rank op produces result which is independent with the quantization +// parameters of the input, so we can remove the quantization ops. +def : Pat<(TF_RankOp (TFL_DequantizeOp (TFL_QuantizeOp $input, $qtype))), + (TF_RankOp $input)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index c91cdb3df45..9ad26e4d782 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -15,10 +15,12 @@ limitations under the License. // This transformation pass applies quantization propagation on TFLite dialect. +#include "absl/memory/memory.h" #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" //===----------------------------------------------------------------------===// // The prepare-quantize Pass. @@ -27,6 +29,7 @@ namespace mlir { namespace TFL { namespace { + // Applies prepare quantization on the model in TFL dialect. This pass runs // before the quantization pass and propagate the quantization parameters // across ops. This step is necessary for post-training quantization and also @@ -47,15 +50,19 @@ class PrepareQuantizePass : public FunctionPass { bool quantize_sign_; }; +#include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" + void PrepareQuantizePass::runOnFunction() { - ApplyQuantizationParamsPropagation(getFunction(), quantize_sign_); + ApplyQuantizationParamsPropagation(getFunction(), quantize_sign_, + GetOpQuantSpec); } } // namespace // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. -FunctionPassBase *CreatePrepareQuantizePass(bool quantize_sign) { - return new PrepareQuantizePass(quantize_sign); +std::unique_ptr CreatePrepareQuantizePass( + bool quantize_sign) { + return std::make_unique(quantize_sign); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 6f2e9e6ea1e..7c7983ae254 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -48,9 +48,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -71,54 +71,79 @@ struct PrepareTFPass : public FunctionPass { }; // TODO(fengliuai): move this rule to PreparePatterns.td -// Inserts a "tfl.quantize" and "tfl.dequantize" op pair after the +// Inserts a "tfl.quantize" and "tfl.dequantize" op pair (QDQs) after the // "tf.FakeQuantWithMinMaxVarsOp" to be constant folded. Since the constant // folding logic will use a "std.constant" op to replace the // "tf.FakeQuantWithMinMaxVarsOp", the "tfl.quantize" op is used to preserve // the quantization parameters as a TypeAttr and "tfl.dequantize" op used to -// convert the output type to the next op. +// convert the output type to the next op. Here are the transformations: +// +// input min cst max cst input min cst max cst +// \ | | \ | | +// \ (tf.Identity) (tf.Identity) => \ (tf.Identity) (tf.Identity) +// \ | | \ | | +// tf.FakeQuantWithMinMaxVars tf.FakeQuantWithMinMaxVars +// | | +// tf.quantize +// | +// tf.dequantize +// | +// If the input is a constant, the result pattern will eventually converted to + +// quant-emulated input +// | +// tf.quantize +// | +// tf.dequantize +// | struct InsertTFLQuantOpsAfterTFFakeQuantOp : public RewritePattern { InsertTFLQuantOpsAfterTFFakeQuantOp(MLIRContext *context) - : RewritePattern(TF::FakeQuantWithMinMaxVarsOp::getOperationName(), 1, + : RewritePattern(TF::FakeQuantWithMinMaxVarsOp::getOperationName(), 3, context) {} - struct MatchedState : public PatternState { - FloatAttr min; - FloatAttr max; - APInt num_bits; - bool narrow_range; - }; - - PatternMatchResult match(Operation *op) const override { + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { auto tf_op = cast(op); + // We don't want to insert quantize/dequantize if the quantize op exists. auto res = tf_op.outputs(); if (!res->hasOneUse() || isa(*res->user_begin())) return matchFailure(); - auto state = absl::make_unique(); - ElementsAttr min_value, max_value; - if (!matchPattern(tf_op.min(), m_Constant(&min_value))) - return matchFailure(); - if (!matchPattern(tf_op.max(), m_Constant(&max_value))) - return matchFailure(); - state->min = ExtractSingleElementAsFloat(min_value); - state->max = ExtractSingleElementAsFloat(max_value); - if (!state->min || !state->max) return matchFailure(); - state->num_bits = tf_op.num_bits(); - state->narrow_range = tf_op.narrow_range(); - return matchSuccess(std::move(state)); - } - void rewrite(Operation *op, std::unique_ptr state, - PatternRewriter &rewriter) const override { - auto &s = *static_cast(state.get()); - Location loc = op->getLoc(); - Value *copied = OpBuilder(op).clone(*op)->getResult(0); - Type res_type = copied->getType(); - Type storage_type = rewriter.getIntegerType(s.num_bits.getSExtValue()); - TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, s.min, s.max, - storage_type, s.narrow_range); - Value *quantize_op = - rewriter.create(loc, qtype.getValue(), copied, qtype); - rewriter.replaceOpWithNewOp(op, res_type, quantize_op); + // Extract the min/max constant values from the operands. We also consider + // a special case that there are tf.Identity ops between the min/max + // constants and the tf.FakeQuantWithMinMaxVarsOp. + Value *min = tf_op.min(), *max = tf_op.max(); + ElementsAttr min_value, max_value; + if (auto id1 = dyn_cast_or_null(min->getDefiningOp())) + min = id1.input(); + if (auto id2 = dyn_cast_or_null(max->getDefiningOp())) + max = id2.input(); + if (!matchPattern(min, m_Constant(&min_value))) return matchFailure(); + if (!matchPattern(max, m_Constant(&max_value))) return matchFailure(); + FloatAttr min_attr = ExtractSingleElementAsFloat(min_value); + FloatAttr max_attr = ExtractSingleElementAsFloat(max_value); + if (!min_attr || !max_attr) return matchFailure(); + + // Use the min/max from the operands and the num_bits and narrow_range + // attribute to create the quantization parameter for the new quantize op. + rewriter.setInsertionPoint(op->getBlock(), ++Block::iterator(op)); + Type num_bits = rewriter.getIntegerType(tf_op.num_bits().getSExtValue()); + bool narrow_range = tf_op.narrow_range(); + Type res_type = tf_op.getType(); + TypeAttr qtype = GetQuantizedTypeAttr(rewriter, res_type, min_attr, + max_attr, num_bits, narrow_range); + + // Finally, use the quantization parameter to create the quantize and + // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp + // and its users. + Value *value = tf_op.outputs(); + auto quantize = rewriter.create( + op->getLoc(), qtype.getValue(), value, qtype); + auto dequantize = rewriter.create(op->getLoc(), res_type, + quantize.output()); + value->replaceAllUsesWith(dequantize); + quantize.getOperation()->replaceUsesOfWith(dequantize, value); + + return matchSuccess(); } }; @@ -170,7 +195,7 @@ struct ConvertTFConvOp : public RewritePattern { IntegerAttr height, width; if (!TFIntListIs1XY1(op, "strides", &height, &width)) return matchFailure(); - auto state = llvm::make_unique(); + auto state = std::make_unique(); state->stride_height = height; state->stride_width = width; @@ -352,25 +377,34 @@ class ConvertTFDepthwiseConv2dNative void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); + // This pattern was intented to uses TFL QDQs to preserve the quantization + // parameters from the TF Quant ops, thus this pattern should run with the + // first `applyPatternsGreedily` method, which would otherwise removes the + // TF FakeQuant ops by the constant folding. + patterns.insert(&getContext()); TFL::populateWithGenerated(&getContext(), &patterns); // TODO(karimnosseir): Split to separate pass probably after // deciding on long term plan for this optimization. // This will allow optimizing any TF_Mul->TF_Conv in the graph // and any expanded from FusedBatchNorm. We need to do this // before converting TF_Conv to TFL_Conv - applyPatternsGreedily(func, std::move(patterns)); - patterns.push_back(llvm::make_unique(&getContext())); - patterns.push_back( - llvm::make_unique(&getContext())); - patterns.push_back( - llvm::make_unique(&getContext())); - applyPatternsGreedily(func, std::move(patterns)); + applyPatternsGreedily(func, patterns); + + // Load the generated pattern again, so new quantization pass-through + // will be applied. + patterns.clear(); + TFL::populateWithGenerated(&getContext(), &patterns); + patterns.insert( + &getContext()); + applyPatternsGreedily(func, patterns); } } // namespace // Creates an instance of the TensorFlow Lite dialect PrepareTF pass. -FunctionPassBase *CreatePrepareTFPass() { return new PrepareTFPass(); } +std::unique_ptr CreatePrepareTFPass() { + return std::make_unique(); +} static PassRegistration pass( "tfl-prepare-tf", "Prepare TF for legalization to TensorFlow Lite dialect"); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index 91bb26a976b..e4029d7f13f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -31,8 +31,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Support/Functional.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/utils/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" namespace mlir { @@ -55,14 +55,16 @@ void QuantizePass::runOnFunction() { auto func = getFunction(); auto* ctx = func.getContext(); TFL::populateWithGenerated(ctx, &patterns); - mlir::RewriteListBuilder>::build(patterns, ctx); - applyPatternsGreedily(func, std::move(patterns)); + patterns.insert>(ctx); + applyPatternsGreedily(func, patterns); } } // namespace // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. -FunctionPassBase* CreateQuantizePass() { return new QuantizePass(); } +std::unique_ptr CreateQuantizePass() { + return std::make_unique(); +} static PassRegistration pass( "tfl-quantize", "Apply quantization on models in TensorFlow Lite dialect"); diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index 7fcf926d89f..369b5300540 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -16,7 +16,7 @@ limitations under the License. // This is the quantization pattern definition file for TensorFlow Lite. include "mlir/IR/OpBase.td" -include "mlir/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc new file mode 100644 index 00000000000..1cd4f42810e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -0,0 +1,133 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Identifier.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/SymbolTable.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" + +// The cmd line flag to specify the whitelist of functions. Rest are trimmed +// after this pass is run. +// NOLINTNEXTLINE +static llvm::cl::list trim_funcs_whitelist( + "tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"), + llvm::cl::desc("comma seprarated list of whitelisted functions. The first " + "function specified will be used as main."), + llvm::cl::CommaSeparated); + +namespace mlir { +namespace TFL { +namespace { + +// The pass to trim functions before we legalize to TFL +// dialect using the specified whitelist. +class TrimFunctionsPass : public mlir::ModulePass { + public: + explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {} + explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_whitelist) + : trim_funcs_whitelist_(trim_funcs_whitelist) {} + + private: + void runOnModule() override; + bool TrimModule(); + void Verify(); + + llvm::ArrayRef trim_funcs_whitelist_; +}; + +void TrimFunctionsPass::runOnModule() { + // trim the functions in the module using the trim_funcs_whitelist_ + // by removing functions not in the whitelist. + if (TrimModule()) { + // verify the updated module is still valid, if not signal the + // pass as failed. + Verify(); + } +} + +bool TrimFunctionsPass::TrimModule() { + // if no trim_funcs_whitelist_ is specified, this pass is a no-op. + if (trim_funcs_whitelist_.empty()) return false; + + llvm::SmallVector funcs_to_trim; + for (auto func : getModule().getOps()) { + if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) { + // If no main is specified in the whitelist, use the 1st func + // in trim_funcs_whitelist as the main. + // TODO(ashwinm): Currently tflite flatbuffer export assumes there is + // always a main. This is strictly not required for TFlite. We need to + // remove that restriction once we have support to attribute the main + // tensorflow function in MLIR TF import using an entry_point attr. + if (!llvm::is_contained(trim_funcs_whitelist_, "main") && + func.getName() == trim_funcs_whitelist_[0]) { + func.setName("main"); + } + } else { + funcs_to_trim.push_back(func); + } + } + + // remove all unexported functions from the module. + for (auto func : funcs_to_trim) { + func.erase(); + } + return true; +} + +// validate that all reachable functions from the remaining functions are +// also in the whitelist. +void TrimFunctionsPass::Verify() { + // TODO(ashwinm): Instead, we should make sure that references to all + // SymbolRefAttrs of all ops are present. + SymbolTable symbol_table = SymbolTable(getModule()); + llvm::SetVector reachable_funcs; + for (auto func : getModule().getOps()) { + func.walk([&](CallOp op) { + if (!symbol_table.lookup(op.getCallee())) { + getModule().emitError() + << func.getName() << " is not in the funcs whitelist"; + return signalPassFailure(); + } + }); + } +} + +} // namespace + +// Creates an instance of the TensorFlow Lite dialect TrimFunctions +/// pass. +std::unique_ptr CreateTrimFunctionsPass( + llvm::ArrayRef trim_funcs_whitelist) { + return std::make_unique(trim_funcs_whitelist); +} + +static PassRegistration pass( + "tfl-trim-funcs-tf", + "Trim functions to restrict them to a specified whitelist prior to " + "legalization to TensorFlow lite dialect"); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index a1a427a0381..33da9929711 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -25,7 +25,7 @@ FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr) { return {}; } SmallVector index(attr.getType().getRank(), 0); - return attr.getValue(index).cast(); + return attr.getValue(index); } FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) { @@ -42,7 +42,7 @@ IntegerAttr ExtractSingleElementAsInteger(ElementsAttr attr) { return {}; } SmallVector index(attr.getType().getRank(), 0); - return attr.getValue(index).cast(); + return attr.getValue(index); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h index efa782ce4e8..263a0a8dc93 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.h @@ -19,7 +19,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_ATTRIBUTE_UTILS_H_ -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc new file mode 100644 index 00000000000..5dcd40aab6b --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -0,0 +1,77 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/convert_type.h" + +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace tflite { + +mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { + switch (type) { + case tflite::TensorType_FLOAT32: + return builder.getF32Type(); + case tflite::TensorType_FLOAT16: + return builder.getF16Type(); + case tflite::TensorType_INT32: + return builder.getIntegerType(32); + case tflite::TensorType_UINT8: + return mlir::TF::Uint8Type::get(builder.getContext()); + case tflite::TensorType_INT64: + return builder.getIntegerType(64); + case tflite::TensorType_STRING: + return mlir::TF::StringType::get(builder.getContext()); + case tflite::TensorType_BOOL: + return builder.getI1Type(); + case tflite::TensorType_INT16: + return builder.getIntegerType(16); + case tflite::TensorType_COMPLEX64: + return mlir::TF::Complex64Type::get(builder.getContext()); + case tflite::TensorType_INT8: + return builder.getIntegerType(8); + } +} + +tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { + switch (type) { + case tflite::TensorType_BOOL: + return tensorflow::DT_BOOL; + case tflite::TensorType_COMPLEX64: + return tensorflow::DT_COMPLEX64; + case tflite::TensorType_FLOAT16: + return tensorflow::DT_HALF; + case tflite::TensorType_FLOAT32: + return tensorflow::DT_FLOAT; + case tflite::TensorType_INT8: + return tensorflow::DT_INT8; + case tflite::TensorType_INT16: + return tensorflow::DT_INT16; + case tflite::TensorType_INT32: + return tensorflow::DT_INT32; + case tflite::TensorType_INT64: + return tensorflow::DT_INT64; + case tflite::TensorType_STRING: + return tensorflow::DT_STRING; + case tflite::TensorType_UINT8: + return tensorflow::DT_UINT8; + } +} + +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.h b/tensorflow/compiler/mlir/lite/utils/convert_type.h new file mode 100644 index 00000000000..ff4ccb325a8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.h @@ -0,0 +1,36 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ + +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "tensorflow/core/framework/types.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace mlir { +class Builder; +} + +namespace tflite { +// Convert the scalar type of a TFlite tensor to the corresponding MLIR type. +mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder); + +// Convert the scalar type of a TFLite tensor to the corresponding +// Tensorflow type +tensorflow::DataType TflTypeToTfType(tflite::TensorType type); + +} // namespace tflite +#endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONVERT_TYPE_H_ diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 8cd375a61f7..c68cd0e8605 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -19,8 +19,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/op_name_mapper.cc b/tensorflow/compiler/mlir/op_name_mapper.cc new file mode 100644 index 00000000000..cd0bc0d3e02 --- /dev/null +++ b/tensorflow/compiler/mlir/op_name_mapper.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/op_name_mapper.h" + +#include "llvm/ADT/APInt.h" + +namespace tensorflow { + +using llvm::StringRef; +using mlir::Operation; + +OpNameMapper::~OpNameMapper() {} + +std::string OpNameMapper::GetUniqueName(llvm::StringRef prefix) { + std::string name = prefix; + auto& val = name_to_count_[name]; + if (!val) { + ++val; + return name; + } + + llvm::SmallString<64> probe_name(prefix); + while (true) { + probe_name.resize(prefix.size()); + // TODO(jpienaar): Subtract one so that the initial suffix is 0 instead + // of 1. + // TODO(jpienaar): Switch to radix 36 and update tests. + llvm::APInt(32, val++).toString(probe_name, /*Radix=*/10, + /*Signed=*/false); + if (!name_to_count_.count(probe_name)) { + name = llvm::StringRef(probe_name); + break; + } + } + return name; +} + +const std::string& OpNameMapper::GetUniqueName(Operation* op) { + auto& name = op_to_name_[op]; + if (!name.empty()) return name; + // Update the value in the map with unique name. + name = GetUniqueName(GetName(op)); + return name; +} + +int OpNameMapper::InitOpName(mlir::Operation* op, llvm::StringRef name) { + op_to_name_[op] = name; + return name_to_count_[name]++; +} + +std::string OpLocNameMapper::GetName(Operation* op) { + if (auto name_loc = op->getLoc().dyn_cast()) + return name_loc.getName().str(); + + if (auto call_loc = op->getLoc().dyn_cast()) { + // Return name if CallSiteLoc's callee has a NameLoc (as should be the case + // if imported with DebugInfo), else use the fallback naming scheme below. + if (auto name_loc = call_loc.getCallee().dyn_cast()) + return name_loc.getName().str(); + } + + // If the location is none of the expected types, then simply use name + // generated using the op type. + return op->getName().getStringRef(); +} + +std::string OpStripNameMapper::GetName(Operation* op) { + return llvm::APInt(32, count_++) + .toString(/*Radix=*/36, + /*Signed=*/false); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/op_name_mapper.h b/tensorflow/compiler/mlir/op_name_mapper.h new file mode 100644 index 00000000000..2232ce2a80f --- /dev/null +++ b/tensorflow/compiler/mlir/op_name_mapper.h @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_OP_NAME_MAPPER_H_ +#define TENSORFLOW_COMPILER_MLIR_OP_NAME_MAPPER_H_ + +#include "absl/container/flat_hash_map.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Operation.h" // TF:local_config_mlir + +namespace tensorflow { + +// Mapper from operation to name. +class OpNameMapper { + public: + // Returns unique name for the operation. + const std::string& GetUniqueName(mlir::Operation* op); + + // Returns unique name for the given prefix. + std::string GetUniqueName(llvm::StringRef prefix); + + // Initializes operation to map to name. Returns number of operations already + // named 'name' which should be 0 else GetUniqueName could return the same + // names for different ops. + // Note: its up to the caller to decide the behavior when assigning two ops + // to the same name. + int InitOpName(mlir::Operation* op, llvm::StringRef name); + + virtual ~OpNameMapper(); + + private: + // Returns name from the location of the operation. + virtual std::string GetName(mlir::Operation* op) = 0; + + // Maps from op to name. + llvm::StringMap name_to_count_; + absl::flat_hash_map op_to_name_; +}; + +// OpNameMapper that returns, for ops not initialized to a specific name, a name +// based on the location of the operation. +class OpLocNameMapper : public OpNameMapper { + private: + std::string GetName(mlir::Operation* op) override; +}; + +// OpNameMapper that returns, for ops not initialized to a specific name, a +// short name. +class OpStripNameMapper : public OpNameMapper { + private: + std::string GetName(mlir::Operation* op) override; + + // Number of ops mapped. + int count_ = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_OP_NAME_MAPPER_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index abe8df63b20..f696eab4d44 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1,5 +1,5 @@ load("@local_config_mlir//:tblgen.bzl", "gentbl") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_native_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") package( default_visibility = [":friends"], @@ -10,8 +10,6 @@ package_group( name = "friends", includes = ["@local_config_mlir//:subpackages"], packages = [ - "//learning/brain/experimental/mlir/...", - "//learning/brain/google/xla/...", "//tensorflow/compiler/mlir/...", "//tensorflow/python/...", ], @@ -70,7 +68,31 @@ gentbl( td_file = "ir/tf_executor_ops.td", td_srcs = [ "@local_config_mlir//:include/mlir/IR/OpBase.td", - "@local_config_mlir//:include/mlir/StandardOps/Ops.td", + "@local_config_mlir//:include/mlir/Dialect/StandardOps/Ops.td", + ], +) + +gentbl( + name = "tensorflow_device_ops_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls", + "ir/tf_device.h.inc", + ), + ( + "-gen-op-defs", + "ir/tf_device.cc.inc", + ), + ( + "-gen-op-doc", + "g3doc/tf_device.md", + ), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "ir/tf_device_ops.td", + td_srcs = [ + "@local_config_mlir//:include/mlir/IR/OpBase.td", + "@local_config_mlir//:include/mlir/Dialect/StandardOps/Ops.td", ], ) @@ -93,30 +115,41 @@ cc_library( name = "tensorflow", srcs = [ "ir/control_flow_ops.cc", + "ir/tf_device.cc", "ir/tf_executor.cc", "ir/tf_executor.cc.inc", "ir/tf_executor.h.inc", "ir/tf_ops.cc", "ir/tf_ops.cc.inc", "ir/tf_ops.h.inc", + "transforms/cluster_formation.cc", + "transforms/cluster_outlining.cc", + "transforms/executor_island_coarsening.cc", "transforms/functional_control_flow_to_cfg.cc", "transforms/generated_canonicalize.inc", "transforms/generated_optimize.inc", + "transforms/graph_pruning.cc", "transforms/optimize.cc", "transforms/raise_control_flow.cc", + "transforms/tpu_rewrite_pass.cc", "translate/control_to_executor_dialect.cc", + "translate/executor_to_control_dialect.cc", ], hdrs = [ "ir/control_flow_ops.h", + "ir/tf_device.h", "ir/tf_executor.h", "ir/tf_ops.h", + "ir/tf_traits.h", "ir/tf_types.def", "ir/tf_types.h", "transforms/passes.h", ], + copts = ["-std=c++14"], includes = ["include"], deps = [ ":tensorflow_canonicalize_inc_gen", + ":tensorflow_device_ops_inc_gen", ":tensorflow_executor_inc_gen", ":tensorflow_ops_inc_gen", ":tensorflow_optimize_inc_gen", @@ -131,7 +164,6 @@ cc_library( "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:TypeUtilities", ], # TODO(jpienaar): Merge in the dialect registration. alwayslink = 1, @@ -141,6 +173,7 @@ cc_library( cc_library( name = "tensorflow_dialect_registration", srcs = ["ir/dialect_registration.cc"], + copts = ["-std=c++14"], deps = [ ":tensorflow", "@local_config_mlir//:IR", @@ -152,12 +185,13 @@ cc_library( name = "convert_graphdef", srcs = [ "translate/export_graphdef.cc", - "translate/import_graphdef.cc", + "translate/import_model.cc", ], hdrs = [ "translate/export_graphdef.h", - "translate/import_graphdef.h", + "translate/import_model.h", ], + copts = ["-std=c++14"], deps = [ ":convert_tensor", ":convert_type", @@ -166,6 +200,7 @@ cc_library( ":mangling_util", ":mlir_roundtrip_flags", ":tensorflow", + "//tensorflow/cc/saved_model:loader_lite", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", @@ -173,6 +208,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_proto_cc", + "//tensorflow/core/platform:types", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -182,6 +218,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@llvm//:support", "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", "@local_config_mlir//:StandardDialectRegistration", "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", @@ -196,6 +233,7 @@ cc_library( hdrs = [ "utils/import_utils.h", ], + copts = ["-std=c++14"], deps = [ ":error_util", "//tensorflow/core:lib", @@ -213,6 +251,7 @@ cc_library( hdrs = [ "utils/export_utils.h", ], + copts = ["-std=c++14"], deps = [ ":convert_tensor", ":convert_type", @@ -244,6 +283,7 @@ cc_library( hdrs = [ "translate/export_tf_dialect_op.h", ], + copts = ["-std=c++14"], deps = [ ":convert_type", ":export_utils", @@ -252,6 +292,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_proto_cc", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", "@llvm//:support", "@local_config_mlir//:IR", ], @@ -260,6 +302,7 @@ cc_library( cc_library( name = "translate_tf_dialect_op", srcs = ["translate/translate_tf_dialect_op.cc"], + copts = ["-std=c++14"], deps = [ ":export_tf_dialect_op", "@llvm//:support", @@ -274,6 +317,7 @@ cc_library( name = "mlir_roundtrip_pass", srcs = ["translate/mlir_roundtrip_pass.cc"], hdrs = ["translate/mlir_roundtrip_pass.h"], + copts = ["-std=c++14"], deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", @@ -281,15 +325,18 @@ cc_library( "//tensorflow/core:core_cpu_lib", "//tensorflow/core:lib", "//tensorflow/core:protos_all_proto_cc", + "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", "@local_config_mlir//:StandardOps", ], + alwayslink = 1, ) cc_library( name = "mlir_roundtrip_flags", srcs = ["translate/mlir_roundtrip_flags.cc"], hdrs = ["translate/mlir_roundtrip_flags.h"], + copts = ["-std=c++14"], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:framework", @@ -307,6 +354,7 @@ cc_library( name = "convert_type", srcs = ["utils/convert_type.cc"], hdrs = ["utils/convert_type.h"], + copts = ["-std=c++14"], deps = [ ":tensorflow", ":tensorflow_dialect_registration", @@ -325,6 +373,7 @@ cc_library( name = "convert_tensor", srcs = ["utils/convert_tensor.cc"], hdrs = ["utils/convert_tensor.h"], + copts = ["-std=c++14"], deps = [ ":convert_type", ":mangling_util", @@ -344,6 +393,7 @@ cc_library( name = "mangling_util", srcs = ["utils/mangling_util.cc"], hdrs = ["utils/mangling_util.h"], + copts = ["-std=c++14"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -356,6 +406,7 @@ cc_library( name = "error_util", srcs = ["utils/error_util.cc"], hdrs = ["utils/error_util.h"], + copts = ["-std=c++14"], deps = [ "//tensorflow/core:lib", "//tensorflow/stream_executor/lib", @@ -375,11 +426,11 @@ cc_library( "transforms/constant_fold.h", "transforms/decode_constant.h", ], + copts = ["-std=c++14"], deps = [ ":convert_tensor", ":eval_util", ":tensorflow", - ":tf_graph_optimization_pass", "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", "//tensorflow/core:framework", @@ -396,6 +447,7 @@ cc_library( cc_library( name = "tf_dialect_lib", + copts = ["-std=c++14"], deps = [ ":tensorflow_dialect_registration", ":tf_dialect_passes", @@ -406,9 +458,11 @@ cc_library( cc_library( name = "tf_graph_optimization_pass", srcs = ["transforms/tf_graph_optimization_pass.cc"], + copts = ["-std=c++14"], deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", + ":mlir_roundtrip_pass", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -428,6 +482,7 @@ cc_library( name = "eval_util", srcs = ["utils/eval_util.cc"], hdrs = ["utils/eval_util.h"], + copts = ["-std=c++14"], deps = [ ":convert_tensor", ":convert_type", @@ -460,6 +515,7 @@ cc_library( hdrs = [ "translate/tf_mlir_translate.h", ], + copts = ["-std=c++14"], deps = [ ":convert_graphdef", ":error_util", @@ -486,6 +542,7 @@ cc_library( hdrs = [ "translate/tf_mlir_translate_cl.h", ], + copts = ["-std=c++14"], deps = [ "@llvm//:support", ], @@ -497,6 +554,7 @@ cc_library( srcs = [ "translate/tf_mlir_translate_registration.cc", ], + copts = ["-std=c++14"], deps = [ ":convert_graphdef", ":mlir_roundtrip_flags", diff --git a/tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md b/tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md deleted file mode 100755 index cedeba5dae1..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/g3doc/tf_ops.md +++ /dev/null @@ -1,2761 +0,0 @@ - -# Operation definition -## tf.Abs (TF::AbsOp) -Computes the absolute value of a tensor. - -### Description: - -Given a tensor `x`, this operation returns a tensor containing the absolute -value of each element in `x`. For example, if x is an input element and y is -an output element, this operation computes \\(y = |x|\\). - -### Operands: -1. `x`: tensor of floating-point or 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 32/64-bit integer values - -## tf.AddN (TF::AddNOp) -Add all input tensors element wise. - -### Description: - - -### Operands: -1. `inputs`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow variant type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `sum`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow variant type values - -## tf.Add (TF::AddOp) -Returns x + y element-wise. - -### Description: - -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number or TensorFlow string type values -1. `y`: tensor of number or TensorFlow string type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number or TensorFlow string type values - -## tf.AddV2 (TF::AddV2Op) -Returns x + y element-wise. - -### Description: - -*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.AvgPool (TF::AvgPoolOp) -Performs average pooling on the input. - -### Description: - -Each entry in `output` is the mean of the corresponding size `ksize` -window in `value`. - -### Operands: -1. `value`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `ksize` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | -| `strides` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | -| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute | -| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.BatchToSpaceND (TF::BatchToSpaceNDOp) -BatchToSpace for N-D tensors of type T. - -### Description: - -This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of shape -`block_shape + [batch]`, interleaves these blocks back into the grid defined by -the spatial dimensions `[1, ..., M]`, to obtain a result with the same rank as -the input. The spatial dimensions of this intermediate result are then -optionally cropped according to `crops` to produce the output. This is the -reverse of SpaceToBatch. See below for a precise description. - -### Operands: -1. `input`: tensor of tf.dtype values -1. `block_shape`: tensor of 32/64-bit integer values -1. `crops`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tcrops` | `Attribute` | derived attribute attribute | -| `Tblock_shape` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.BiasAdd (TF::BiasAddOp) -Adds `bias` to `value`. - -### Description: - -This is a special case of `tf.add` where `bias` is restricted to be 1-D. -Broadcasting is supported, so `value` may have any number of dimensions. - -### Operands: -1. `value`: tensor of number values -1. `bias`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of number values - -## tf.Bitcast (TF::BitcastOp) - -Bitcasts a tensor from one type to another without copying data. - - -### Description: - -Given a tensor `input`, this operation returns a tensor that has the same buffer -data as `input` with datatype `type`. - -If the input datatype `T` is larger than the output datatype `type` then the -shape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)]. - -If `T` is smaller than `type`, the operator requires that the rightmost -dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from -[..., sizeof(`type`)/sizeof(`T`)] to [...]. - -tf.bitcast() and tf.cast() work differently when real dtype is casted as a complex dtype -(e.g. tf.complex64 or tf.complex128) as tf.cast() make imaginary part 0 while tf.bitcast() -gives module error. -For example, - -Example 1: -```python ->>> a = [1., 2., 3.] ->>> equality_bitcast = tf.bitcast(a,tf.complex128) -tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot bitcast from float to complex128: shape [3] [Op:Bitcast] ->>> equality_cast = tf.cast(a,tf.complex128) ->>> print(equality_cast) -tf.Tensor([1.+0.j 2.+0.j 3.+0.j], shape=(3,), dtype=complex128) -``` -Example 2: -```python ->>> tf.bitcast(tf.constant(0xffffffff, dtype=tf.uint32), tf.uint8) - -``` -Example 3: -```python ->>> x = [1., 2., 3.] ->>> y = [0., 2., 3.] ->>> equality= tf.equal(x,y) ->>> equality_cast = tf.cast(equality,tf.float32) ->>> equality_bitcast = tf.bitcast(equality_cast,tf.uint8) ->>> print(equality) -tf.Tensor([False True True], shape=(3,), dtype=bool) ->>> print(equality_cast) -tf.Tensor([0. 1. 1.], shape=(3,), dtype=float32) ->>> print(equality_bitcast) -tf.Tensor( -[[ 0 0 0 0] - [ 0 0 128 63] - [ 0 0 128 63]], shape=(3, 4), dtype=uint8) -``` - -*NOTE*: Bitcast is implemented as a low-level cast, so machines with different -endian orderings will give different results. - -### Operands: -1. `input`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `type` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of number values - -## tf.BroadcastTo (TF::BroadcastToOp) -Broadcast an array for a compatible shape. - -### Description: - -Broadcasting is the process of making arrays to have compatible shapes -for arithmetic operations. Two shapes are compatible if for each -dimension pair they are either equal or one of them is one. When trying -to broadcast a Tensor to a shape, it starts with the trailing dimensions, -and works its way forward. - -For example, - -```python ->>> x = tf.constant([1, 2, 3]) ->>> y = tf.broadcast_to(x, [3, 3]) ->>> sess.run(y) -array([[1, 2, 3], - [1, 2, 3], - [1, 2, 3]], dtype=int32) -``` - -In the above example, the input Tensor with the shape of `[1, 3]` -is broadcasted to output Tensor with shape of `[3, 3]`. - -### Operands: -1. `input`: tensor of tf.dtype values -1. `shape`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Cast (TF::CastOp) -Cast x of type SrcT to y of DstT. - -### Description: - - -### Operands: -1. `x`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `Truncate` | `BoolAttr` | bool attribute attribute | -| `SrcT` | `Attribute` | derived attribute attribute | -| `DstT` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of tf.dtype values - -## tf.Ceil (TF::CeilOp) -Returns element-wise smallest integer not less than x. - -### Description: - - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point values - -## tf.Concat (TF::ConcatOp) -Concatenates tensors along one dimension. - -### Description: - - -### Operands: -1. `concat_dim`: tensor of 32-bit integer values -1. `values`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 2 attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.ConcatV2 (TF::ConcatV2Op) -Concatenates tensors along one dimension. - -### Description: - - -### Operands: -1. `values`: tensor of tf.dtype values -1. `axis`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 2 attribute | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Conj (TF::ConjOp) -Returns the complex conjugate of a complex number. - -### Description: - -Given a tensor `input` of complex numbers, this operation returns a tensor of -complex numbers that are the complex conjugate of each element in `input`. The -complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the -real part and *b* is the imaginary part. - -The complex conjugate returned by this operation is of the form \\(a - bj\\). - -For example: - -``` -# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] -tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] -``` - -### Operands: -1. `input`: tensor of complex128 type or complex64 type or TensorFlow variant type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of complex128 type or complex64 type or TensorFlow variant type values - -## tf.Const (TF::ConstOp) -Constant tensor op - -### Description: - - -### Operands: - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `value` | `ElementsAttr` | constant vector/tensor attribute attribute | -| `dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Conv2D (TF::Conv2DOp) - -Computes a 2-D convolution given 4-D `input` and `filter` tensors. - - -### Description: - -Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -and a filter / kernel tensor of shape -`[filter_height, filter_width, in_channels, out_channels]`, this op -performs the following: - -1. Flattens the filter to a 2-D matrix with shape - `[filter_height * filter_width * in_channels, output_channels]`. -2. Extracts image patches from the input tensor to form a *virtual* - tensor of shape `[batch, out_height, out_width, - filter_height * filter_width * in_channels]`. -3. For each patch, right-multiplies the filter matrix and the image patch - vector. - -In detail, with the default NHWC format, - - output[b, i, j, k] = - sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] * - filter[di, dj, q, k] - -Must have `strides[0] = strides[3] = 1`. For the most common case of the same -horizontal and vertices strides, `strides = [1, stride, stride, 1]`. - -### Operands: -1. `input`: tensor of floating-point values -1. `filter`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `use_cudnn_on_gpu` | `BoolAttr` | bool attribute attribute | -| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID, or EXPLICIT attribute | -| `explicit_paddings` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | -| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.Cos (TF::CosOp) -Computes cos of x element-wise. - -### Description: - - -### Operands: -1. `x`: tensor of floating-point or 64/128-bit complex type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 64/128-bit complex type values - -## tf.DepthwiseConv2dNative (TF::DepthwiseConv2dNativeOp) - -Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. - - -### Description: - -Given an input tensor of shape `[batch, in_height, in_width, in_channels]` -and a filter / kernel tensor of shape -`[filter_height, filter_width, in_channels, channel_multiplier]`, containing -`in_channels` convolutional filters of depth 1, `depthwise_conv2d` applies -a different filter to each input channel (expanding from 1 channel to -`channel_multiplier` channels for each), then concatenates the results -together. Thus, the output has `in_channels * channel_multiplier` channels. - -``` -for k in 0..in_channels-1 - for q in 0..channel_multiplier-1 - output[b, i, j, k * channel_multiplier + q] = - sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] * - filter[di, dj, k, q] -``` - -Must have `strides[0] = strides[3] = 1`. For the most common case of the same -horizontal and vertices strides, `strides = [1, stride, stride, 1]`. - -### Operands: -1. `input`: tensor of floating-point values -1. `filter`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `strides` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute | -| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | -| `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.Div (TF::DivOp) -Returns x / y element-wise. - -### Description: - -*NOTE*: `Div` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.Elu (TF::EluOp) - -Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise. - - -### Description: - -See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) -](http://arxiv.org/abs/1511.07289) - -### Operands: -1. `features`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `activations`: tensor of floating-point values - -## tf.Equal (TF::EqualOp) -Returns the truth value of (x == y) element-wise. - -### Description: - -*NOTE*: `Equal` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -```python -x = tf.constant([2, 4]) -y = tf.constant(2) -tf.math.equal(x, y) ==> array([True, False]) - -x = tf.constant([2, 4]) -y = tf.constant([2, 4]) -tf.math.equal(x, y) ==> array([True, True]) -``` - -### Operands: -1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values -1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.ExpandDims (TF::ExpandDimsOp) -Inserts a dimension of 1 into a tensor's shape. - -### Description: - -Given a tensor `input`, this operation inserts a dimension of 1 at the -dimension index `axis` of `input`'s shape. The dimension index `axis` starts at -zero; if you specify a negative number for `axis` it is counted backward from -the end. - -This operation is useful if you want to add a batch dimension to a single -element. For example, if you have a single image of shape `[height, width, -channels]`, you can make it a batch of 1 image with `expand_dims(image, 0)`, -which will make the shape `[1, height, width, channels]`. - -Other examples: - -``` -# 't' is a tensor of shape [2] -shape(expand_dims(t, 0)) ==> [1, 2] -shape(expand_dims(t, 1)) ==> [2, 1] -shape(expand_dims(t, -1)) ==> [2, 1] - -# 't2' is a tensor of shape [2, 3, 5] -shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5] -shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5] -shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1] -``` - -This operation requires that: - -`-1-input.dims() <= dim <= input.dims()` - -This operation is related to `squeeze()`, which removes dimensions of -size 1. - -### Operands: -1. `input`: tensor of tf.dtype values -1. `dim`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tdim` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.FakeQuantWithMinMaxArgs (TF::FakeQuantWithMinMaxArgsOp) - -Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. - - -### Description: - -Attributes `[min; max]` define the clamping range for the `inputs` data. -`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -then de-quantized and output as floats in `[min; max]` interval. -`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. - -Before quantization, `min` and `max` values are adjusted with the following -logic. -It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, -the behavior can be unexpected: -If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. -If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. -If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, -`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. - -Quantization is called fake since the output is still in floating point. - -### Operands: -1. `inputs`: tensor of 32-bit float values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `min` | `FloatAttr` | 32-bit float attribute attribute | -| `max` | `FloatAttr` | 32-bit float attribute attribute | -| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | -| `narrow_range` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `outputs`: tensor of 32-bit float values - -## tf.FakeQuantWithMinMaxVars (TF::FakeQuantWithMinMaxVarsOp) - -Fake-quantize the 'inputs' tensor of type float via global float scalars `min` - - -### Description: - -and `max` to 'outputs' tensor of same shape as `inputs`. - -`[min; max]` define the clamping range for the `inputs` data. -`inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -then de-quantized and output as floats in `[min; max]` interval. -`num_bits` is the bitwidth of the quantization; between 2 and 16, inclusive. - -Before quantization, `min` and `max` values are adjusted with the following -logic. -It is suggested to have `min <= 0 <= max`. If `0` is not in the range of values, -the behavior can be unexpected: -If `0 < min < max`: `min_adj = 0` and `max_adj = max - min`. -If `min < max < 0`: `min_adj = min - max` and `max_adj = 0`. -If `min <= 0 <= max`: `scale = (max - min) / (2^num_bits - 1) `, -`min_adj = scale * round(min / scale)` and `max_adj = max + min_adj - min`. - -This operation has a gradient and thus allows for training `min` and `max` -values. - -### Operands: -1. `inputs`: tensor of 32-bit float values -1. `min`: tensor of 32-bit float values -1. `max`: tensor of 32-bit float values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | -| `narrow_range` | `BoolAttr` | bool attribute attribute | - -### Results: -1. `outputs`: tensor of 32-bit float values - -## tf.Fill (TF::FillOp) -Creates a tensor filled with a scalar value. - -### Description: - -This operation creates a tensor of shape `dims` and fills it with `value`. - -For example: - -``` -# Output tensor has shape [2, 3]. -fill([2, 3], 9) ==> [[9, 9, 9] - [9, 9, 9]] -``` - -`tf.fill` differs from `tf.constant` in a few ways: - -* `tf.fill` only supports scalar contents, whereas `tf.constant` supports - Tensor values. -* `tf.fill` creates an Op in the computation graph that constructs the actual - Tensor value at runtime. This is in contrast to `tf.constant` which embeds - the entire Tensor into the graph with a `Const` node. -* Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes - based on other runtime Tensors, unlike `tf.constant`. - -### Operands: -1. `dims`: tensor of 32/64-bit integer values -1. `value`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `index_type` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.FloorDiv (TF::FloorDivOp) -Returns x // y element-wise. - -### Description: - -*NOTE*: `FloorDiv` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.Floor (TF::FloorOp) -Returns element-wise largest integer not greater than x. - -### Description: - - -### Operands: -1. `x`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point values - -## tf.FusedBatchNorm (TF::FusedBatchNormOp) -Batch normalization. - -### Description: - -Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -The size of 1D Tensors matches the dimension C of the 4D Tensors. - -### Operands: -1. `x`: tensor of 32-bit float values -1. `scale`: tensor of 32-bit float values -1. `offset`: tensor of 32-bit float values -1. `mean`: tensor of 32-bit float values -1. `variance`: tensor of 32-bit float values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `epsilon` | `FloatAttr` | 32-bit float attribute attribute | -| `data_format` | `StringAttr` | 'NHWC' or 'NCHW' convnet data format attribute | -| `is_training` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of 32-bit float values -1. `batch_mean`: tensor of 32-bit float values -1. `batch_variance`: tensor of 32-bit float values -1. `reserve_space_1`: tensor of 32-bit float values -1. `reserve_space_2`: tensor of 32-bit float values - -## tf.Gather (TF::GatherOp) -Gather slices from `params` according to `indices`. - -### Description: - -`indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -Produces an output tensor with shape `indices.shape + params.shape[1:]` where: - -```python - # Scalar indices - output[:, ..., :] = params[indices, :, ... :] - - # Vector indices - output[i, :, ..., :] = params[indices[i], :, ... :] - - # Higher rank indices - output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :] -``` - -If `indices` is a permutation and `len(indices) == params.shape[0]` then -this operation will permute `params` accordingly. - -`validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in -`indices` are always validated to be within range. If assigned to GPU, -out-of-bound indices result in safe but unspecified behavior, which may include -raising an error. - -
- -
- -### Operands: -1. `params`: tensor of tf.dtype values -1. `indices`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `validate_indices` | `BoolAttr` | bool attribute attribute | -| `Tindices` | `Attribute` | derived attribute attribute | -| `Tparams` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.GatherV2 (TF::GatherV2Op) - -Gather slices from `params` axis `axis` according to `indices`. - - -### Description: - -`indices` must be an integer tensor of any dimension (usually 0-D or 1-D). -Produces an output tensor with shape `params.shape[:axis] + indices.shape + -params.shape[axis + 1:]` where: - -```python - # Scalar indices (output is rank(params) - 1). - output[a_0, ..., a_n, b_0, ..., b_n] = - params[a_0, ..., a_n, indices, b_0, ..., b_n] - - # Vector indices (output is rank(params)). - output[a_0, ..., a_n, i, b_0, ..., b_n] = - params[a_0, ..., a_n, indices[i], b_0, ..., b_n] - - # Higher rank indices (output is rank(params) + rank(indices) - 1). - output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] = - params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n] -``` - -
- -
- -Note that on CPU, if an out of bound index is found, an error is returned. -On GPU, if an out of bound index is found, a 0 is stored in the -corresponding output value. - -See also `tf.batch_gather` and `tf.gather_nd`. - -### Operands: -1. `params`: tensor of tf.dtype values -1. `indices`: tensor of 32/64-bit integer values -1. `axis`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `batch_dims` | `IntegerAttr` | 64-bit integer attribute attribute | -| `Tindices` | `Attribute` | derived attribute attribute | -| `Tparams` | `Attribute` | derived attribute attribute | -| `Taxis` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.GreaterEqual (TF::GreaterEqualOp) -Returns the truth value of (x >= y) element-wise. - -### Description: - -*NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 8/16/32/64-bit integer or floating-point values -1. `y`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.Greater (TF::GreaterOp) -Returns the truth value of (x > y) element-wise. - -### Description: - -*NOTE*: `Greater` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 8/16/32/64-bit integer or floating-point values -1. `y`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.IdentityN (TF::IdentityNOp) - -Returns a list of tensors with the same shapes and contents as the input - - -### Description: - -tensors. - -This op can be used to override the gradient for complicated functions. For -example, suppose y = f(x) and we wish to apply a custom function g for backprop -such that dx = g(dy). In Python, - -```python -with tf.get_default_graph().gradient_override_map( - {'IdentityN': 'OverrideGradientWithG'}): - y, _ = identity_n([f(x), x]) - -@tf.RegisterGradient('OverrideGradientWithG') -def ApplyG(op, dy, _): - return [None, g(dy)] # Do not backprop to f(x). -``` - -### Operands: -1. `input`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Identity (TF::IdentityOp) -Identity op - -### Description: - -Returns a tensor with the same shape and contents as input. - -### Operands: -1. `input`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Invert (TF::InvertOp) - -Invert (flip) each bit of supported types; for example, type `uint8` value 01010101 becomes 10101010. - - -### Description: - -Flip each bit of supported types. For example, type `int8` (decimal 2) binary 00000010 becomes (decimal -3) binary 11111101. -This operation is performed on each element of the tensor argument `x`. - -### Operands: -1. `x`: tensor of 8/16/32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of 8/16/32/64-bit integer values - -## tf.LeakyRelu (TF::LeakyReluOp) -Computes rectified linear: `max(features, features * alpha)`. - -### Description: - - -### Operands: -1. `features`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `alpha` | `FloatAttr` | 32-bit float attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `activations`: tensor of floating-point values - -## tf.LessEqual (TF::LessEqualOp) -Returns the truth value of (x <= y) element-wise. - -### Description: - -*NOTE*: `LessEqual` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 8/16/32/64-bit integer or floating-point values -1. `y`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.Less (TF::LessOp) -Returns the truth value of (x < y) element-wise. - -### Description: - -*NOTE*: `Less` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 8/16/32/64-bit integer or floating-point values -1. `y`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.Log (TF::LogOp) -Computes natural logarithm of x element-wise. - -### Description: - -I.e., \\(y = \log_e x\\). - -### Operands: -1. `x`: tensor of floating-point or 64/128-bit complex type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 64/128-bit complex type values - -## tf.LogSoftmax (TF::LogSoftmaxOp) -Computes log softmax activations. - -### Description: - -For each batch `i` and class `j` we have - - logsoftmax[i, j] = logits[i, j] - log(sum(exp(logits[i]))) - -### Operands: -1. `logits`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `logsoftmax`: tensor of floating-point values - -## tf.LogicalAnd (TF::LogicalAndOp) -Returns the truth value of x AND y element-wise. - -### Description: - -*NOTE*: `LogicalAnd` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 1-bit integer values -1. `y`: tensor of 1-bit integer values - -### Attributes: - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.LogicalNot (TF::LogicalNotOp) -Returns the truth value of NOT x element-wise. - -### Description: - - -### Operands: -1. `x`: tensor of 1-bit integer values - -### Attributes: - -### Results: -1. `y`: tensor of 1-bit integer values - -## tf.LogicalOr (TF::LogicalOrOp) -Returns the truth value of x OR y element-wise. - -### Description: - -*NOTE*: `LogicalOr` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 1-bit integer values -1. `y`: tensor of 1-bit integer values - -### Attributes: - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.MatMul (TF::MatMulOp) - -Multiply the matrix "a" by the matrix "b". - - -### Description: - -The inputs must be two-dimensional matrices and the inner dimension of -"a" (after being transposed if transpose_a is true) must match the -outer dimension of "b" (after being transposed if transposed_b is -true). - -*Note*: The default kernel implementation for MatMul on GPUs uses -cublas. - -### Operands: -1. `a`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values -1. `b`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `transpose_a` | `BoolAttr` | bool attribute attribute | -| `transpose_b` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `product`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -## tf.Max (TF::MaxOp) - -Computes the maximum of elements across dimensions of a tensor. - - -### Description: - -Reduces `input` along the dimensions given in `axis`. Unless -`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -`axis`. If `keep_dims` is true, the reduced dimensions are -retained with length 1. - -### Operands: -1. `input`: tensor of number values -1. `reduction_indices`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of number values - -## tf.MaxPool (TF::MaxPoolOp) -Performs max pooling on the input. - -### Description: - - -### Operands: -1. `input`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `ksize` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | -| `strides` | `ArrayAttr` | 64-bit integer array attribute with at least 4 elements attribute | -| `padding` | `StringAttr` | string attribute whose value is SAME, or VALID attribute | -| `data_format` | `StringAttr` | string attribute whose value is NHWC, or NCHW, or NCHW_VECT_C attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of 8/16/32/64-bit integer or floating-point values - -## tf.Maximum (TF::MaximumOp) -Returns the max of x and y (i.e. x > y ? x : y) element-wise. - -### Description: - -*NOTE*: `Maximum` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of floating-point or 32/64-bit integer values -1. `y`: tensor of floating-point or 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of floating-point or 32/64-bit integer values - -## tf.Mean (TF::MeanOp) -Computes the mean of elements across dimensions of a tensor. - -### Description: - -Reduces `input` along the dimensions given in `axis`. Unless -`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -`axis`. If `keep_dims` is true, the reduced dimensions are -retained with length 1. - -### Operands: -1. `input`: tensor of number values -1. `reduction_indices`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of number values - -## tf.Min (TF::MinOp) - -Computes the minimum of elements across dimensions of a tensor. - - -### Description: - -Reduces `input` along the dimensions given in `axis`. Unless -`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -`axis`. If `keep_dims` is true, the reduced dimensions are -retained with length 1. - -### Operands: -1. `input`: tensor of number values -1. `reduction_indices`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of number values - -## tf.Minimum (TF::MinimumOp) -Returns the min of x and y (i.e. x < y ? x : y) element-wise. - -### Description: - -*NOTE*: `Minimum` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of floating-point or 32/64-bit integer values -1. `y`: tensor of floating-point or 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of floating-point or 32/64-bit integer values - -## tf.MulNoNan (TF::MulNoNanOp) - -Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or NaN. - - -### Description: - -*NOTE*: `MulNoNan` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values -1. `y`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values - -## tf.Mul (TF::MulOp) -Returns x * y element-wise. - -### Description: - -*NOTE*: `Multiply` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.Neg (TF::NegOp) -Computes numerical negative value element-wise. - -### Description: - -I.e., \\(y = -x\\). - -### Operands: -1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -## tf.NoOp (TF::NoOp) -Does nothing. Only useful as a placeholder for control edges. - -### Description: - - -### Operands: - -### Attributes: - -### Results: - -## tf.NotEqual (TF::NotEqualOp) -Returns the truth value of (x != y) element-wise. - -### Description: - -*NOTE*: `NotEqual` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values -1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 1-bit integer values - -## tf.Pack (TF::PackOp) - -Packs a list of `N` rank-`R` tensors into one rank-`(R+1)` tensor. - - -### Description: - -Packs the `N` tensors in `values` into a tensor with rank one higher than each -tensor in `values`, by packing them along the `axis` dimension. -Given a list of tensors of shape `(A, B, C)`; - -if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`. -if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`. -Etc. - -For example: - -``` -# 'x' is [1, 4] -# 'y' is [2, 5] -# 'z' is [3, 6] -pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim. -pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]] -``` - -This is the opposite of `unpack`. - -### Operands: -1. `values`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `N` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | -| `axis` | `IntegerAttr` | 64-bit integer attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Pad (TF::PadOp) -Pads a tensor with zeros. - -### Description: - -This operation pads a `input` with zeros according to the `paddings` you -specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the -rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -how many zeros to add before the contents of `input` in that dimension, and -`paddings[D, 1]` indicates how many zeros to add after the contents of `input` -in that dimension. - -The padded size of each dimension D of the output is: - -`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` - -For example: - -``` -# 't' is [[1, 1], [2, 2]] -# 'paddings' is [[1, 1], [2, 2]] -# rank of 't' is 2 -pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] - [0, 0, 1, 1, 0, 0] - [0, 0, 2, 2, 0, 0] - [0, 0, 0, 0, 0, 0]] -``` - -### Operands: -1. `input`: tensor of tf.dtype values -1. `paddings`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tpaddings` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.PadV2 (TF::PadV2Op) -Pads a tensor. - -### Description: - -This operation pads `input` according to the `paddings` and `constant_values` -you specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is -the rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates -how many padding values to add before the contents of `input` in that dimension, -and `paddings[D, 1]` indicates how many padding values to add after the contents -of `input` in that dimension. `constant_values` is a scalar tensor of the same -type as `input` that indicates the value to use for padding `input`. - -The padded size of each dimension D of the output is: - -`paddings(D, 0) + input.dim_size(D) + paddings(D, 1)` - -For example: - -``` -# 't' is [[1, 1], [2, 2]] -# 'paddings' is [[1, 1], [2, 2]] -# 'constant_values' is 0 -# rank of 't' is 2 -pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] - [0, 0, 1, 1, 0, 0] - [0, 0, 2, 2, 0, 0] - [0, 0, 0, 0, 0, 0]] -``` - -### Operands: -1. `input`: tensor of tf.dtype values -1. `paddings`: tensor of 32/64-bit integer values -1. `constant_values`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tpaddings` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Placeholder.input (TF::PlaceholderInputOp) -PlaceholderInput op - -### Description: - -Inserts a placeholder for a tensor that will be always fed. - -### Operands: -1. `arg`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `min` | `FloatAttr` | 32-bit float attribute attribute | -| `max` | `FloatAttr` | 32-bit float attribute attribute | -| `type` | `TypeAttr` | integer type attribute | -| `dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Placeholder (TF::PlaceholderOp) -Placeholder op - -### Description: - -Inserts a placeholder for a tensor that will be always fed. - -### Operands: - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.QuantizeAndDequantize (TF::QuantizeAndDequantizeOp) -Use QuantizeAndDequantizeV2 instead. - -### Description: - - -### Operands: -1. `input`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `signed_input` | `BoolAttr` | bool attribute attribute | -| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | -| `range_given` | `BoolAttr` | bool attribute attribute | -| `input_min` | `FloatAttr` | 32-bit float attribute attribute | -| `input_max` | `FloatAttr` | 32-bit float attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.QuantizeAndDequantizeV2 (TF::QuantizeAndDequantizeV2Op) -Quantizes then dequantizes a tensor. - -### Description: - -This op simulates the precision loss from the quantized forward pass by: - -1. Quantizing the tensor to fixed point numbers, which should match the target - quantization method when it is used in inference. -2. Dequantizing it back to floating point numbers for the following ops, most - likely matmul. - -There are different ways to quantize. This version uses only scaling, so 0.0 -maps to 0. - -From the specified 'num_bits' in the quantized output type, it determines -minimum and maximum representable quantized values. - -e.g. - -* [-128, 127] for signed, num_bits = 8, or -* [0, 255] for unsigned, num_bits = 8. - -If range_given == False, the initial input_min, input_max will be determined -automatically as the minimum and maximum values in the input tensor, otherwise -the specified values of input_min, input_max are used. - -Note: If the input_min, input_max are specified, they do not need to equal the -actual minimum and maximum values in the tensor. e.g. in some cases it may be -beneficial to specify these values such that the low probability extremes of the -input distribution are clipped. - -This op determines the maximum scale_factor that would map the initial -[input_min, input_max] range to a range that lies within the representable -quantized range. - -It determines the scale from one of input_min and input_max, then updates the -other one to maximize the respresentable range. - -e.g. - -* if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, - 5.0]: it would use a scale_factor of -128 / -10.0 = 12.8 In this case, it - would update input_max to be 127 / 12.8 = 9.921875 -* if the output is signed, num_bits = 8, [input_min, input_max] = [-10.0, - 10.0]: it would use a scale_factor of 127 / 10.0 = 12.7 In this case, it - would update input_min to be 128.0 / 12.7 = -10.07874 -* if the output is unsigned, input_min is forced to be 0, and only the - specified input_max is used. - -After determining the scale_factor and updating the input range, it applies the -following to each value in the 'input' tensor. - -output = round(clamp(value, input_min, input_max) * scale_factor) / scale_factor. - -The above round function rounds the value based on the given round_mode. - -### Operands: -1. `input`: tensor of floating-point values -1. `input_min`: tensor of floating-point values -1. `input_max`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `signed_input` | `BoolAttr` | bool attribute attribute | -| `num_bits` | `IntegerAttr` | 64-bit integer attribute attribute | -| `range_given` | `BoolAttr` | bool attribute attribute | -| `round_mode` | `StringAttr` | string attribute whose value is HALF_TO_EVEN, or HALF_UP attribute | -| `narrow_range` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.QuantizeAndDequantizeV3 (TF::QuantizeAndDequantizeV3Op) -Quantizes then dequantizes a tensor. - -### Description: - -This is almost identical to QuantizeAndDequantizeV2, except that num_bits is a -tensor, so its value can change during training. - -### Operands: -1. `input`: tensor of floating-point values -1. `input_min`: tensor of floating-point values -1. `input_max`: tensor of floating-point values -1. `num_bits`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `signed_input` | `BoolAttr` | bool attribute attribute | -| `range_given` | `BoolAttr` | bool attribute attribute | -| `narrow_range` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.RandomUniform (TF::RandomUniformOp) -Outputs random values from a uniform distribution. - -### Description: - -The generated values follow a uniform distribution in the range `[0, 1)`. The -lower bound 0 is included in the range, while the upper bound 1 is excluded. - -### Operands: -1. `shape`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `seed` | `IntegerAttr` | 64-bit integer attribute attribute | -| `seed2` | `IntegerAttr` | 64-bit integer attribute attribute | -| `T` | `Attribute` | derived attribute attribute | -| `dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of floating-point values - -## tf.Range (TF::RangeOp) -Creates a sequence of numbers. - -### Description: - -This operation creates a sequence of numbers that begins at `start` and -extends by increments of `delta` up to but not including `limit`. - -For example: - -``` -# 'start' is 3 -# 'limit' is 18 -# 'delta' is 3 -tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] -``` - -### Operands: -1. `start`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values -1. `limit`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values -1. `delta`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of bfloat16 type or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer values - -## tf.Rank (TF::RankOp) -Returns the rank of a tensor. - -### Description: - -This operation returns an integer representing the rank of `input`. - -For example: - -``` -# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -# shape of tensor 't' is [2, 2, 3] -rank(t) ==> 3 -``` - -**Note**: The rank of a tensor is not the same as the rank of a matrix. The rank -of a tensor is the number of indices required to uniquely select each element -of the tensor. Rank is also known as "order", "degree", or "ndims." - -### Operands: -1. `input`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of 32-bit integer values - -## tf.RealDiv (TF::RealDivOp) -Returns x / y element-wise for real types. - -### Description: - -If `x` and `y` are reals, this will return the floating-point division. - -*NOTE*: `Div` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.Reciprocal (TF::ReciprocalOp) -Computes the reciprocal of x element-wise. - -### Description: - -I.e., \\(y = 1 / x\\). - -### Operands: -1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -## tf.Relu6 (TF::Relu6Op) -Computes rectified linear 6: `min(max(features, 0), 6)`. - -### Description: - - -### Operands: -1. `features`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `activations`: tensor of 8/16/32/64-bit integer or floating-point values - -## tf.Relu (TF::ReluOp) -Computes rectified linear: `max(features, 0)`. - -### Description: - - -### Operands: -1. `features`: tensor of 8/16/32/64-bit integer or floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `activations`: tensor of 8/16/32/64-bit integer or floating-point values - -## tf.Reshape (TF::ReshapeOp) -Reshapes a tensor. - -### Description: - -Given `tensor`, this operation returns a tensor that has the same values -as `tensor` with shape `shape`. - -If one component of `shape` is the special value -1, the size of that dimension -is computed so that the total size remains constant. In particular, a `shape` -of `[-1]` flattens into 1-D. At most one component of `shape` can be -1. - -If `shape` is 1-D or higher, then the operation returns a tensor with shape -`shape` filled with the values of `tensor`. In this case, the number of elements -implied by `shape` must be the same as the number of elements in `tensor`. - -For example: - -``` -# tensor 't' is [1, 2, 3, 4, 5, 6, 7, 8, 9] -# tensor 't' has shape [9] -reshape(t, [3, 3]) ==> [[1, 2, 3], - [4, 5, 6], - [7, 8, 9]] - -# tensor 't' is [[[1, 1], [2, 2]], -# [[3, 3], [4, 4]]] -# tensor 't' has shape [2, 2, 2] -reshape(t, [2, 4]) ==> [[1, 1, 2, 2], - [3, 3, 4, 4]] - -# tensor 't' is [[[1, 1, 1], -# [2, 2, 2]], -# [[3, 3, 3], -# [4, 4, 4]], -# [[5, 5, 5], -# [6, 6, 6]]] -# tensor 't' has shape [3, 2, 3] -# pass '[-1]' to flatten 't' -reshape(t, [-1]) ==> [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6] - -# -1 can also be used to infer the shape - -# -1 is inferred to be 9: -reshape(t, [2, -1]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], - [4, 4, 4, 5, 5, 5, 6, 6, 6]] -# -1 is inferred to be 2: -reshape(t, [-1, 9]) ==> [[1, 1, 1, 2, 2, 2, 3, 3, 3], - [4, 4, 4, 5, 5, 5, 6, 6, 6]] -# -1 is inferred to be 3: -reshape(t, [ 2, -1, 3]) ==> [[[1, 1, 1], - [2, 2, 2], - [3, 3, 3]], - [[4, 4, 4], - [5, 5, 5], - [6, 6, 6]]] - -# tensor 't' is [7] -# shape `[]` reshapes to a scalar -reshape(t, []) ==> 7 -``` - -### Operands: -1. `tensor`: tensor of tf.dtype values -1. `shape`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tshape` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.ResizeBilinear (TF::ResizeBilinearOp) -Resize `images` to `size` using bilinear interpolation. - -### Description: - -Input images can be of different types but output images are always float. - -### Operands: -1. `images`: tensor of 8/16/32/64-bit integer or floating-point values -1. `size`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `align_corners` | `BoolAttr` | bool attribute attribute | -| `half_pixel_centers` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `resized_images`: tensor of 32-bit float values - -## tf.ReverseV2 (TF::ReverseV2Op) -Reverses specific dimensions of a tensor. - -### Description: - -NOTE `tf.reverse` has now changed behavior in preparation for 1.0. -`tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0. - -Given a `tensor`, and a `int32` tensor `axis` representing the set of -dimensions of `tensor` to reverse. This operation reverses each dimension -`i` for which there exists `j` s.t. `axis[j] == i`. - -`tensor` can have up to 8 dimensions. The number of dimensions specified -in `axis` may be 0 or more entries. If an index is specified more than -once, a InvalidArgument error is raised. - -For example: - -``` -# tensor 't' is [[[[ 0, 1, 2, 3], -# [ 4, 5, 6, 7], -# [ 8, 9, 10, 11]], -# [[12, 13, 14, 15], -# [16, 17, 18, 19], -# [20, 21, 22, 23]]]] -# tensor 't' shape is [1, 2, 3, 4] - -# 'dims' is [3] or 'dims' is [-1] -reverse(t, dims) ==> [[[[ 3, 2, 1, 0], - [ 7, 6, 5, 4], - [ 11, 10, 9, 8]], - [[15, 14, 13, 12], - [19, 18, 17, 16], - [23, 22, 21, 20]]]] - -# 'dims' is '[1]' (or 'dims' is '[-3]') -reverse(t, dims) ==> [[[[12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23] - [[ 0, 1, 2, 3], - [ 4, 5, 6, 7], - [ 8, 9, 10, 11]]]] - -# 'dims' is '[2]' (or 'dims' is '[-2]') -reverse(t, dims) ==> [[[[8, 9, 10, 11], - [4, 5, 6, 7], - [0, 1, 2, 3]] - [[20, 21, 22, 23], - [16, 17, 18, 19], - [12, 13, 14, 15]]]] -``` - -### Operands: -1. `tensor`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values -1. `axis`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 1-bit integer or 16-bit integer or 32-bit integer or 64-bit integer or 8-bit integer or complex128 type or complex64 type or TensorFlow string type values - -## tf.Rsqrt (TF::RsqrtOp) -Computes reciprocal of square root of x element-wise. - -### Description: - -I.e., \\(y = 1 / \sqrt{x}\\). - -### Operands: -1. `x`: tensor of floating-point or 64/128-bit complex type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 64/128-bit complex type values - -## tf.Select (TF::SelectOp) -Selects elements from `x` or `y`, depending on `condition`. - -### Description: - -The `x`, and `y` tensors must all have the same shape, and the -output will also have that shape. - -The `condition` tensor must be a scalar if `x` and `y` are scalars. -If `x` and `y` are vectors or higher rank, then `condition` must be either a -scalar, a vector with size matching the first dimension of `x`, or must have -the same shape as `x`. - -The `condition` tensor acts as a mask that chooses, based on the value at each -element, whether the corresponding element / row in the output should be -taken from `x` (if true) or `y` (if false). - -If `condition` is a vector and `x` and `y` are higher rank matrices, then -it chooses which row (outer dimension) to copy from `x` and `y`. -If `condition` has the same shape as `x` and `y`, then it chooses which -element to copy from `x` and `y`. - -For example: - -```python -# 'condition' tensor is [[True, False] -# [False, True]] -# 't' is [[1, 2], -# [3, 4]] -# 'e' is [[5, 6], -# [7, 8]] -select(condition, t, e) # => [[1, 6], [7, 4]] - - -# 'condition' tensor is [True, False] -# 't' is [[1, 2], -# [3, 4]] -# 'e' is [[5, 6], -# [7, 8]] -select(condition, t, e) ==> [[1, 2], - [7, 8]] - -``` - -### Operands: -1. `condition`: tensor of 1-bit integer values -1. `t`: tensor of tf.dtype values -1. `e`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Shape (TF::ShapeOp) -Returns the shape of a tensor. - -### Description: - -This operation returns a 1-D integer tensor representing the shape of `input`. - -For example: - -``` -# 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] -shape(t) ==> [2, 2, 3] -``` - -### Operands: -1. `input`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `out_type` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of 32/64-bit integer values - -## tf.Sigmoid (TF::SigmoidOp) -Computes sigmoid of `x` element-wise. - -### Description: - -Specifically, `y = 1 / (1 + exp(-x))`. - -### Operands: -1. `x`: tensor of floating-point or 64/128-bit complex type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 64/128-bit complex type values - -## tf.Sin (TF::SinOp) -Computes sin of x element-wise. - -### Description: - - -### Operands: -1. `x`: tensor of floating-point or 64/128-bit complex type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 64/128-bit complex type values - -## tf.Slice (TF::SliceOp) -Return a slice from 'input'. - -### Description: - -The output tensor is a tensor with dimensions described by 'size' -whose values are extracted from 'input' starting at the offsets in -'begin'. - -*Requirements*: - 0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n) - -### Operands: -1. `input`: tensor of tf.dtype values -1. `begin`: tensor of 32/64-bit integer values -1. `size`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Index` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Softmax (TF::SoftmaxOp) -Computes softmax activations. - -### Description: - -For each batch `i` and class `j` we have - - $$softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))$$ - -### Operands: -1. `logits`: tensor of floating-point values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `softmax`: tensor of floating-point values - -## tf.SpaceToBatchND (TF::SpaceToBatchNDOp) -SpaceToBatch for N-D tensors of type T. - -### Description: - -This operation divides "spatial" dimensions `[1, ..., M]` of the input into a -grid of blocks of shape `block_shape`, and interleaves these blocks with the -"batch" dimension (0) such that in the output, the spatial dimensions -`[1, ..., M]` correspond to the position within the grid, and the batch -dimension combines both the position within a spatial block and the original -batch position. Prior to division into blocks, the spatial dimensions of the -input are optionally zero padded according to `paddings`. See below for a -precise description. - -### Operands: -1. `input`: tensor of tf.dtype values -1. `block_shape`: tensor of 32/64-bit integer values -1. `paddings`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tpaddings` | `Attribute` | derived attribute attribute | -| `Tblock_shape` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Split (TF::SplitOp) -Splits a tensor into `num_split` tensors along one dimension. - -### Description: - - -### Operands: -1. `split_dim`: tensor of 32-bit integer values -1. `value`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num_split` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.SplitV (TF::SplitVOp) -Splits a tensor into `num_split` tensors along one dimension. - -### Description: - - -### Operands: -1. `value`: tensor of tf.dtype values -1. `size_splits`: tensor of 32/64-bit integer values -1. `split_dim`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num_split` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 1 attribute | -| `Tlen` | `Attribute` | derived attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Sqrt (TF::SqrtOp) -Computes square root of x element-wise. - -### Description: - -I.e., \\(y = \sqrt{x} = x^{1/2}\\). - -### Operands: -1. `x`: tensor of floating-point or 64/128-bit complex type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of floating-point or 64/128-bit complex type values - -## tf.Square (TF::SquareOp) -Computes square of x element-wise. - -### Description: - -I.e., \\(y = x * x = x^2\\). - -### Operands: -1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -## tf.SquaredDifference (TF::SquaredDifferenceOp) -Returns (x - y)(x - y) element-wise. - -### Description: - -*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values -1. `y`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of bfloat16 type or 16-bit float or 32-bit float or 64-bit float or 32-bit integer or 64-bit integer or complex128 type or complex64 type values - -## tf.Squeeze (TF::SqueezeOp) -Removes dimensions of size 1 from the shape of a tensor. - -### Description: - -Given a tensor `input`, this operation returns a tensor of the same type with -all dimensions of size 1 removed. If you don't want to remove all size 1 -dimensions, you can remove specific size 1 dimensions by specifying -`axis`. - -For example: - -``` -# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -shape(squeeze(t)) ==> [2, 3] -``` - -Or, to remove specific size 1 dimensions: - -``` -# 't' is a tensor of shape [1, 2, 1, 3, 1, 1] -shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] -``` - -### Operands: -1. `input`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `squeeze_dims` | `ArrayAttr` | 64-bit integer array attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.StridedSlice (TF::StridedSliceOp) -Return a strided slice from `input`. - -### Description: - -Note, most python users will want to use the Python `Tensor.__getitem__` -or `Variable.__getitem__` rather than this op directly. - -The goal of this op is to produce a new tensor with a subset of -the elements from the `n` dimensional `input` tensor. The subset is chosen using -a sequence of `m` sparse range specifications encoded into the arguments -of this function. Note, in some cases -`m` could be equal to `n`, but this need not be the case. Each -range specification entry can be one of the following: - -- An ellipsis (...). Ellipses are used to imply zero or more - dimensions of full-dimension selection and are produced using - `ellipsis_mask`. For example, `foo[...]` is the identity slice. - -- A new axis. This is used to insert a new shape=1 dimension and is - produced using `new_axis_mask`. For example, `foo[:, ...]` where - `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. - - -- A range `begin:end:stride`. This is used to specify how much to choose from - a given dimension. `stride` can be any integer but 0. `begin` is an integer - which represents the index of the first value to select while `end` represents - the index of the last value to select. The number of values selected in each - dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. - `begin` and `end` can be negative where `-1` is the last element, `-2` is - the second to last. `begin_mask` controls whether to replace the explicitly - given `begin` with an implicit effective value of `0` if `stride > 0` and - `-1` if `stride < 0`. `end_mask` is analogous but produces the number - required to create the largest open interval. For example, given a shape - `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do - not assume this is equivalent to `foo[0:-1]` which has an effective `begin` - and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the - first dimension of a tensor while dropping the last two (in the original - order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. - -- A single index. This is used to keep only elements that have a given - index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a - shape `(6,)` tensor. This is encoded in `begin` and `end` and - `shrink_axis_mask`. - -Each conceptual range specification is encoded in the op's argument. This -encoding is best understand by considering a non-trivial example. In -particular, -`foo[1, 2:4, None, ..., :-3:-1, :]` will be encoded as - -``` -begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0) -end = [2, 4, x, x, -3, x] -strides = [1, 1, x, x, -1, 1] -begin_mask = 1<<4 | 1 << 5 = 48 -end_mask = 1<<5 = 32 -ellipsis_mask = 1<<3 = 8 -new_axis_mask = 1<<2 4 -shrink_axis_mask = 1<<0 -``` - -In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of -the slice becomes (2, 1, 5, 5, 2, 5). -Let us walk step by step through each argument specification. - -1. The first argument in the example slice is turned into `begin = 1` and -`end = begin + 1 = 2`. To disambiguate from the original spec `2:4` we -also set the appropriate bit in `shrink_axis_mask`. - -2. `2:4` is contributes 2, 4, 1 to begin, end, and stride. All masks have -zero bits contributed. - -3. None is a synonym for `tf.newaxis`. This means insert a dimension of size 1 -dimension in the final shape. Dummy values are contributed to begin, -end and stride, while the new_axis_mask bit is set. - -4. `...` grab the full ranges from as many dimensions as needed to -fully specify a slice for every dimension of the input shape. - -5. `:-3:-1` shows the use of negative indices. A negative index `i` associated -with a dimension that has shape `s` is converted to a positive index -`s + i`. So `-1` becomes `s-1` (i.e. the last element). This conversion -is done internally so begin, end and strides receive x, -3, and -1. -The appropriate begin_mask bit is set to indicate the start range is the -full range (ignoring the x). - -6. `:` indicates that the entire contents of the corresponding dimension -is selected. This is equivalent to `::` or `0::1`. begin, end, and strides -receive 0, 0, and 1, respectively. The appropriate bits in `begin_mask` and -`end_mask` are also set. - -*Requirements*: - `0 != strides[i] for i in [0, m)` - `ellipsis_mask must be a power of two (only one ellipsis)` - -### Operands: -1. `input`: tensor of tf.dtype values -1. `begin`: tensor of 32/64-bit integer values -1. `end`: tensor of 32/64-bit integer values -1. `strides`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `begin_mask` | `IntegerAttr` | 64-bit integer attribute attribute | -| `end_mask` | `IntegerAttr` | 64-bit integer attribute attribute | -| `ellipsis_mask` | `IntegerAttr` | 64-bit integer attribute attribute | -| `new_axis_mask` | `IntegerAttr` | 64-bit integer attribute attribute | -| `shrink_axis_mask` | `IntegerAttr` | 64-bit integer attribute attribute | -| `T` | `Attribute` | derived attribute attribute | -| `Index` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Sub (TF::SubOp) -Returns x - y element-wise. - -### Description: - -*NOTE*: `Subtract` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.Sum (TF::SumOp) -Computes the sum of elements across dimensions of a tensor. - -### Description: - -Reduces `input` along the dimensions given in `axis`. Unless -`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -`axis`. If `keep_dims` is true, the reduced dimensions are -retained with length 1. - -### Operands: -1. `input`: tensor of number values -1. `reduction_indices`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `keep_dims` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | -| `Tidx` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of number values - -## tf.TensorListFromTensor (TF::TensorListFromTensorOp) - -Creates a TensorList which, when stacked, has the value of `tensor`. - - -### Description: - -Each tensor in the result list corresponds to one row of the input tensor. - -tensor: The input tensor. -output_handle: The list. - -### Operands: -1. `tensor`: tensor of tf.dtype values -1. `element_shape`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `shape_type` | `Attribute` | derived attribute attribute | -| `element_dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `output_handle`: tensor of TensorFlow variant type values - -## tf.TensorListGetItem (TF::TensorListGetItemOp) - - -### Description: - - -### Operands: -1. `input_handle`: tensor of TensorFlow variant type values -1. `index`: tensor of 32-bit integer values -1. `element_shape`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `element_dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `item`: tensor of tf.dtype values - -## tf.TensorListReserve (TF::TensorListReserveOp) -List of the given size with empty elements. - -### Description: - -element_shape: the shape of the future elements of the list -num_elements: the number of elements to reserve -handle: the output list -element_dtype: the desired type of elements in the list. - -### Operands: -1. `element_shape`: tensor of 32/64-bit integer values -1. `num_elements`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `element_dtype` | `TypeAttr` | any type attribute attribute | -| `shape_type` | `Attribute` | derived attribute attribute | - -### Results: -1. `handle`: tensor of TensorFlow variant type values - -## tf.TensorListSetItem (TF::TensorListSetItemOp) - - -### Description: - - -### Operands: -1. `input_handle`: tensor of TensorFlow variant type values -1. `index`: tensor of 32-bit integer values -1. `item`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `element_dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `output_handle`: tensor of TensorFlow variant type values - -## tf.TensorListStack (TF::TensorListStackOp) -Stacks all tensors in the list. - -### Description: - -Requires that all tensors have the same shape. - -input_handle: the input list -tensor: the gathered result -num_elements: optional. If not -1, the number of elements in the list. - -### Operands: -1. `input_handle`: tensor of TensorFlow variant type values -1. `element_shape`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num_elements` | `IntegerAttr` | 64-bit integer attribute attribute | -| `element_dtype` | `Attribute` | derived attribute attribute | - -### Results: -1. `tensor`: tensor of tf.dtype values - -## tf.TopKV2 (TF::TopKV2Op) - -Finds values and indices of the `k` largest elements for the last dimension. - - -### Description: - -If the input is a vector (rank-1), finds the `k` largest entries in the vector -and outputs their values and indices as vectors. Thus `values[j]` is the -`j`-th largest entry in `input`, and its index is `indices[j]`. - -For matrices (resp. higher rank input), computes the top `k` entries in each -row (resp. vector along the last dimension). Thus, - - values.shape = indices.shape = input.shape[:-1] + [k] - -If two elements are equal, the lower-index element appears first. - -### Operands: -1. `input`: tensor of 8/16/32/64-bit integer or floating-point values -1. `k`: tensor of 32-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `sorted` | `BoolAttr` | bool attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `values`: tensor of 8/16/32/64-bit integer or floating-point values -1. `indices`: tensor of 32-bit integer values - -## tf.Transpose (TF::TransposeOp) -Shuffle dimensions of x according to a permutation. - -### Description: - -The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy: - `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]` - -### Operands: -1. `x`: tensor of tf.dtype values -1. `perm`: tensor of 32/64-bit integer values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | -| `Tperm` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of tf.dtype values - -## tf.TruncateDiv (TF::TruncateDivOp) -Returns x / y element-wise for integer types. - -### Description: - -Truncation designates that negative numbers will round fractional quantities -toward zero. I.e. -7 / 5 = -1. This matches C semantics but it is different -than Python semantics. See `FloorDiv` for a division function that matches -Python Semantics. - -*NOTE*: `TruncateDiv` supports broadcasting. More about broadcasting -[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) - -### Operands: -1. `x`: tensor of number values -1. `y`: tensor of number values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of number values - -## tf.Unpack (TF::UnpackOp) - -Unpacks a given dimension of a rank-`R` tensor into `num` rank-`(R-1)` tensors. - - -### Description: - -Unpacks `num` tensors from `value` by chipping it along the `axis` dimension. -For example, given a tensor of shape `(A, B, C, D)`; - -If `axis == 0` then the i'th tensor in `output` is the slice `value[i, :, :, :]` - and each tensor in `output` will have shape `(B, C, D)`. (Note that the - dimension unpacked along is gone, unlike `split`). - -If `axis == 1` then the i'th tensor in `output` is the slice `value[:, i, :, :]` - and each tensor in `output` will have shape `(A, C, D)`. -Etc. - -This is the opposite of `pack`. - -### Operands: -1. `value`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `num` | `IntegerAttr` | 64-bit integer attribute whose minimal value is 0 attribute | -| `axis` | `IntegerAttr` | 64-bit integer attribute attribute | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `output`: tensor of tf.dtype values - -## tf.Xdivy (TF::XdivyOp) -Returns 0 if x == 0, and x / y otherwise, elementwise. - -### Description: - - -### Operands: -1. `x`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values -1. `y`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `z`: tensor of 16-bit float or 32-bit float or 64-bit float or complex128 type or complex64 type values - -## tf.ZerosLike (TF::ZerosLikeOp) -Returns a tensor of zeros with the same shape and type as x. - -### Description: - - -### Operands: -1. `x`: tensor of tf.dtype values - -### Attributes: -| Attribute | MLIR Type | Description | -| :-------: | :-------: | ----------- | -| `T` | `Attribute` | derived attribute attribute | - -### Results: -1. `y`: tensor of tf.dtype values - diff --git a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h index 2756b4c0885..4bf7029421e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h @@ -65,7 +65,7 @@ class TFControlType : public Type::TypeBase { // tensor needs its own _tf.Enter to be made available inside the while loop. // // More details can be found in Tensorflow Controlflow white paper: -// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf +// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf // // This is defined in Tensorflow as: // @@ -100,7 +100,7 @@ class EnterOp // of the operand type along with the index of the first match encountered. // // More details can be found in Tensorflow Controlflow white paper: -// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf +// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf // // This is defined in TensorFlow as: // @@ -130,7 +130,7 @@ class MergeOp : public Op::Impl, // outside of loop. Each returned tensor needs its own _tf.Exit. // // More details can be found in Tensorflow Controlflow white paper: -// http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf +// https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf // // This is defined in Tensorflow as: // diff --git a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc index 333711f52f6..235980e05c3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/dialect_registration.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -25,5 +26,7 @@ static DialectRegistration static DialectRegistration tf_ops; static DialectRegistration tf_excutor_dialect; +static DialectRegistration + tf_device_dialect; } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc new file mode 100644 index 00000000000..cac27164ef7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" + +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir + +namespace mlir { +namespace tf_device { + +TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext *context) + : Dialect(/*name=*/"tf_device", context) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" + +} // namespace tf_device +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h new file mode 100644 index 00000000000..91370bc6501 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the tf_device dialect: it contains operations that model +// TensorFlow's actions to launch computations on accelerator devices. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ + +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Dialect.h" // TF:local_config_mlir + +namespace mlir { +namespace tf_device { + +// The TensorFlow Device dialect. +// +// This dialect contains operations to describe/launch computations on devices. +// These operations do not map 1-1 to TensorFlow ops and requires a lowering +// pass later to transform them into Compile/Run op pairs, like XlaCompile and +// XlaRun. +class TensorFlowDeviceDialect : public Dialect { + public: + // Constructing TensorFlowDevice dialect under an non-null MLIRContext. + explicit TensorFlowDeviceDialect(MLIRContext *context); +}; + +// Declares the operations for this dialect using the generated header. +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc" + +} // namespace tf_device +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td new file mode 100644 index 00000000000..3220f0f98dc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -0,0 +1,129 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the definition file for the TensorFlow Device Dialect. + +#ifdef TF_DEVICE_DIALECT +#else +#define TF_DEVICE_DIALECT + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +//===----------------------------------------------------------------------===// +// TensorFlow Device Dialect definitions +//===----------------------------------------------------------------------===// + +def TfDevice_Dialect : Dialect { + let name = "tf_device"; + + let description = [{ + The TensorFlow Device dialect. + + This dialect contains operations to describe/launch computations on devices. + These operations do not map 1-1 to TensorFlow ops and requires a lowering + pass later to transform them into Compile/Run op pairs, like XlaCompile and + XlaRun. +}]; + + let cppNamespace = "tf_device"; +} + +//===----------------------------------------------------------------------===// +// TensorFlow Device Dialect Ops definitions +//===----------------------------------------------------------------------===// + +// Base class for the operation in this dialect. +class TfDevice_Op traits = []> : + Op { } + +def TfDevice_LaunchOp : TfDevice_Op<"launch", + [SingleBlockImplicitTerminator<"ReturnOp">]> +{ + let summary = [{The `tf_device.launch` op captures all needed live-in values + and launches containing operations on target device.}]; + + let arguments = (ins + StrAttr:$device + ); + + let results = (outs + Variadic:$results + ); + + let regions = (region SizedRegion<1>:$body); + + let extraClassDeclaration = [{ + Block &GetBody() { return getOperation()->getRegion(0).front(); } + StringRef getDevice() { return device(); } + }]; + + let builders = [ + OpBuilder<[{Builder *builder, OperationState *result, + StringAttr device, ArrayRef result_types}], + [{ + result->addAttribute("device", device); + result->addTypes(result_types); + result->addRegion(); + }] + > + ]; +} + +def TfDevice_ReturnOp : TfDevice_Op<"return", + [Terminator, HasParent<"LaunchOp">]> { + let summary = [{ + The `tf_device.return` operation terminates and returns values from + `tf_device.launch` operation; + }]; + + let arguments = (ins + Variadic:$results + ); + + let builders = [OpBuilder< + "Builder *builder, OperationState *result", + [{ + build(builder, result, {}); + }]> + ]; + + let verifier = ?; +} + +def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> { + let summary = [{ + The `tf_device.launch_func` launches a function on target device. + }]; + + let arguments = (ins + StrAttr:$device, + SymbolRefAttr:$func, + Variadic:$operands); + + let results = (outs + Variadic:$results + ); + + let extraClassDeclaration = [{ + StringRef getFunc() { return func(); } + StringRef getDevice() { return device(); } + FunctionType getFuncType(); + }]; +} + +#endif // TF_DEVICE_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 29d73a71ad9..77d412f02c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -16,27 +16,52 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include +#include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Dialect/Traits.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Matchers.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { namespace tf_executor { +namespace { + +// If the given tensor has elements of type variant, then returns a new type +// after dropping subtypes info. Otherwise, returns the original type as is. +Type DropVariantSubTypes(Type ty) { + ShapedType shaped_ty = ty.cast(); + Type element_ty = shaped_ty.getElementType(); + if (!element_ty.isa()) return ty; + + Type variant_ty = TF::VariantType::get(ty.getContext()); + if (shaped_ty.hasRank()) { + return RankedTensorType::get(shaped_ty.getShape(), variant_ty); + } + + return UnrankedTensorType::get(variant_ty); +} + +} // namespace //===----------------------------------------------------------------------===// // TF Executor Dialect @@ -77,21 +102,6 @@ void TensorFlowExecutorDialect::printType(Type type, raw_ostream &os) const { namespace { -// Inserts `tf_executor.Terminator` at the end of the region's only block if it -// does not have a terminator already. If the region is empty, insert a new -// block first. -template -void EnsureExecutorTerminator(Region *region, Builder *builder, Location loc) { - if (region->empty()) region->push_back(new Block); - - Block &block = region->back(); - if (!block.empty() && block.back().isKnownTerminator()) return; - - OperationState terminator_state(loc, Terminator::getOperationName()); - Terminator::build(builder, &terminator_state, {}); - block.push_back(Operation::create(terminator_state)); -} - // Verifies that every control operands are at the end of the list. // Used by the constraint `ControlOperandsAfterAllData` in ODS. LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { @@ -108,10 +118,16 @@ LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { return success(); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.graph //===----------------------------------------------------------------------===// +FetchOp GraphOp::GetFetch() { return llvm::cast(GetBody().back()); } + +namespace { + LogicalResult Verify(GraphOp graph) { auto *executorDialect = graph.getDialect(); @@ -123,6 +139,9 @@ LogicalResult Verify(GraphOp graph) { for (Operation &op : graph.GetBody()) { if (op.getDialect() != executorDialect) return op.emitOpError() << "unallowed inside a tf_executor.graph region"; + if (isa(op)) + return op.emitOpError() + << "unallowed directly inside another tf_executor.graph"; } Operation &fetch = graph.GetBody().back(); @@ -174,8 +193,7 @@ ParseResult ParseGraphOp(OpAsmParser *parser, OperationState *result) { // Ensure that the region is well formed: it contains at least a block with // a FetchOp terminator. - EnsureExecutorTerminator(&body, &parser->getBuilder(), - result->location); + GraphOp::ensureTerminator(body, parser->getBuilder(), result->location); // Get the results type from the terminator type inside the graph. Operation &fetch = body.back().back(); @@ -196,10 +214,14 @@ ParseResult ParseGraphOp(OpAsmParser *parser, OperationState *result) { return success(); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.fetch //===----------------------------------------------------------------------===// +namespace { + void Print(FetchOp fetch, OpAsmPrinter *p) { *p << fetch.getOperationName(); if (fetch.getNumOperands() > 0) { @@ -224,10 +246,16 @@ ParseResult ParseFetchOp(OpAsmParser *parser, OperationState *result) { ); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.island //===----------------------------------------------------------------------===// +YieldOp IslandOp::GetYield() { return llvm::cast(GetBody().back()); } + +namespace { + LogicalResult Verify(IslandOp island) { if (island.GetBody().empty()) return island.emitOpError() << "expects a non-empty body"; @@ -281,8 +309,7 @@ ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) { if (parser->parseOperandList(op_infos, OpAsmParser::Delimiter::OptionalParen)) return failure(); if (!op_infos.empty()) { - SmallVector types; - types.push_back(control_type); + SmallVector types(op_infos.size(), control_type); parser->resolveOperands(op_infos, types, loc, result->operands); } @@ -301,8 +328,7 @@ ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) { if (parser->parseRegion(body, llvm::None, llvm::None)) return failure(); - EnsureExecutorTerminator(&body, &parser->getBuilder(), - result->location); + IslandOp::ensureTerminator(body, parser->getBuilder(), result->location); // Get the results type for the island from the terminator operands. Operation &yield = body.back().back(); @@ -315,10 +341,14 @@ ParseResult ParseIslandOp(OpAsmParser *parser, OperationState *result) { return success(); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.yield //===----------------------------------------------------------------------===// +namespace { + void Print(YieldOp yield, OpAsmPrinter *p) { *p << yield.getOperationName(); if (yield.getNumOperands() > 0) { @@ -341,10 +371,14 @@ ParseResult ParseYieldOp(OpAsmParser *parser, OperationState *result) { parser->parseOptionalAttributeDict(result->attributes)); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.Switch //===----------------------------------------------------------------------===// +namespace { + ParseResult ParseSwitchOp(OpAsmParser *parser, OperationState *result) { SmallVector op_infos; SmallVector types; @@ -398,10 +432,14 @@ void Print(SwitchOp switch_op, OpAsmPrinter *p) { p->printOptionalAttrDict(switch_op.getAttrs()); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.SwitchN //===----------------------------------------------------------------------===// +namespace { + LogicalResult Verify(SwitchNOp switchn) { IntegerAttr num_outs = switchn.getAttrOfType("num_outs"); if (!num_outs) @@ -467,8 +505,9 @@ ParseResult ParseSwitchNOp(OpAsmParser *parser, OperationState *result) { // `types` already contains the type for the data, add an i32 for the // output_index, and then the optional control inputs. - types.push_back(parser->getBuilder().getIntegerType(32)); - Type control_type = ControlType::get(parser->getBuilder().getContext()); + auto builder = parser->getBuilder(); + types.push_back(builder.getTensorType({}, builder.getIntegerType(32))); + Type control_type = ControlType::get(builder.getContext()); types.append(op_infos.size() - 2, control_type); if (parser->resolveOperands(op_infos, types, loc, result->operands)) @@ -481,10 +520,14 @@ ParseResult ParseSwitchNOp(OpAsmParser *parser, OperationState *result) { return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.Merge //===----------------------------------------------------------------------===// +namespace { + LogicalResult Verify(MergeOp merge) { if (!merge.getNumOperands()) return merge.emitOpError() << "expects at least one operand"; @@ -498,8 +541,17 @@ LogicalResult Verify(MergeOp merge) { Type broadcasted_type = merge.output()->getType(); for (Type operand_type : merge.getOperandTypes()) { if (operand_type.isa()) break; + + // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this + // constraint. + if (!operand_type.isa()) + return merge.emitOpError("expects data operands to have tensor type"); + + // Variant types may have opaque subtypes information that need not match + // between the two types so drop them before computing the broadcasted type. Type new_broadcasted_type = - OpTrait::util::getBroadcastedType(broadcasted_type, operand_type); + OpTrait::util::getBroadcastedType(DropVariantSubTypes(broadcasted_type), + DropVariantSubTypes(operand_type)); if (!new_broadcasted_type) return merge.emitOpError() << "expects all operands to be broadcastable" @@ -508,10 +560,8 @@ LogicalResult Verify(MergeOp merge) { // This is because for example starting with a result of tensor<4xf32>, if // the first operand is unranked, the broadcasted type will be unranked. // Then any tensor operand will be broadcastable to this unranked type. - if ((broadcasted_type.isa() && - !broadcasted_type.cast().hasRank()) || - (new_broadcasted_type.isa() && - new_broadcasted_type.cast().hasRank())) + if (!broadcasted_type.cast().hasRank() || + new_broadcasted_type.cast().hasRank()) broadcasted_type = new_broadcasted_type; } @@ -519,11 +569,33 @@ LogicalResult Verify(MergeOp merge) { } void Print(MergeOp merge, OpAsmPrinter *p) { + // Use short form only when there are exactly two data operands and their + // type matches the output type. Otherwise, use the generic printer. + bool use_short_form = true; + int num_data_operands = 0; + + Type output_type = merge.output()->getType(); + for (Type operand_type : merge.getOperandTypes()) { + if (operand_type.isa()) break; + num_data_operands++; + + if (operand_type != output_type) { + use_short_form = false; + break; + } + } + *p << merge.getOperationName() << ' '; p->printOperands(merge.getOperands()); // Print the type signature of the operation. - *p << " : " << merge.getType(0); + *p << " : "; + if (!use_short_form || num_data_operands != 2) { + p->printFunctionalType(merge.getOperation()); + } else { + *p << output_type; + } + p->printOptionalAttrDict(merge.getAttrs()); } @@ -537,25 +609,38 @@ ParseResult ParseMergeOp(OpAsmParser *parser, OperationState *result) { return parser->emitError(parser->getNameLoc()) << " expects only a single data type"; - // Expect the type once, but use it for both operands. - types.push_back(types.front()); - // Extra operands are expected to be control inputs. - Type control_type = ControlType::get(parser->getBuilder().getContext()); - types.append(op_infos.size() - 2, control_type); + // Support parsing either a functional type (in which case all the types are + // fully qualified) or a short form with a single type (in which case the data + // inputs and the output are all using this type). + if (FunctionType type = types.front().dyn_cast()) { + result->types.assign(type.getResults().begin(), type.getResults().end()); + types.assign(type.getInputs().begin(), type.getInputs().end()); + } else { + // In case of the short form, use the parsed type for both the operands and + // the remaining operands are expected to be control inputs. + types.push_back(types.front()); + Type control_type = ControlType::get(parser->getBuilder().getContext()); + types.append(op_infos.size() - 2, control_type); + + RankedTensorType i32_tensor = + RankedTensorType::get({}, parser->getBuilder().getIntegerType(32)); + result->types = {types.front(), i32_tensor, control_type}; + } if (parser->resolveOperands(op_infos, types, loc, result->operands)) return failure(); - RankedTensorType i32_tensor = - RankedTensorType::get({}, parser->getBuilder().getIntegerType(32)); - result->types = {types.front(), i32_tensor, control_type}; return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.Enter //===----------------------------------------------------------------------===// +namespace { + // Default number for the parallel_iterations attributes on Enter nodes. constexpr int kDefaultParallelIterations = 10; @@ -638,10 +723,14 @@ ParseResult ParseEnterOp(OpAsmParser *parser, OperationState *result) { return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.NextIteration.Source //===----------------------------------------------------------------------===// +namespace { + LogicalResult Verify(NextIterationSourceOp source) { Value *token = source.token(); if (!token->hasOneUse()) @@ -668,10 +757,14 @@ ParseResult ParseNextIterationSourceOp(OpAsmParser *parser, return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.NextIteration.Sink //===----------------------------------------------------------------------===// +namespace { + LogicalResult Verify(NextIterationSinkOp sink) { Value *token = sink.token(); Operation *definingOp = token->getDefiningOp(); @@ -720,10 +813,14 @@ ParseResult ParseNextIterationSinkOp(OpAsmParser *parser, return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.Exit //===----------------------------------------------------------------------===// +namespace { + void Print(ExitOp exit, OpAsmPrinter *p) { *p << exit.getOperationName() << ' '; p->printOperands(exit.getOperands()); @@ -748,10 +845,14 @@ ParseResult ParseExitOp(OpAsmParser *parser, OperationState *result) { return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.ControlTrigger //===----------------------------------------------------------------------===// +namespace { + void Print(ControlTriggerOp trigger, OpAsmPrinter *p) { *p << trigger.getOperationName() << ' '; p->printOperands(trigger.getOperands()); @@ -774,10 +875,14 @@ ParseResult ParseControlTriggerOp(OpAsmParser *parser, OperationState *result) { return parser->parseOptionalAttributeDict(result->attributes); } +} // anonymous namespace + //===----------------------------------------------------------------------===// // tf_executor.LoopCond //===----------------------------------------------------------------------===// +namespace { + void Print(LoopCondOp loop_cond, OpAsmPrinter *p) { *p << loop_cond.getOperationName() << ' '; p->printOperands(loop_cond.getOperands()); @@ -832,6 +937,179 @@ ParseResult ParseLoopCondOp(OpAsmParser *parser, OperationState *result) { } // namespace +//===----------------------------------------------------------------------===// +// Canonicalization patterns +//===----------------------------------------------------------------------===// + +// TODO(lyandy): Add canonicalization for dedupping control inputs. + +//===----------------------------------------------------------------------===// +// tf_executor.graph +//===----------------------------------------------------------------------===// + +namespace { +// Finds in a block if the op of type `InnerOpT` is the first operation and +// optionally followed by a terminator. +template +bool HasSingleOpInBlock(Block *block) { + if (block->empty()) return false; + if (!llvm::isa(block->front())) return false; + // Either InnerOpT is the only instruction in the block, or there is a + // possible terminator. + return std::next(block->begin()) == block->end() || + std::next(block->begin(), 2) == block->end(); +} + +// This pattern matches GraphOps with only one FetchOp (empty) and remaps the +// results of the GraphOp to the operands of the FetchOp. +struct DropEmptyGraph : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(GraphOp op, + PatternRewriter &rewriter) const override { + Block &block = op.GetBody(); + // Check if graph only has one fetch. + if (&block.front() != &block.back()) return matchFailure(); + + // Map graph results to fetch operands. + llvm::SmallVector new_rets(op.GetFetch().fetches()); + rewriter.replaceOp(op, new_rets); + + return matchSuccess(); + } +}; + +// This pattern matches GraphOps with only one island, pulls out all inner ops +// of the island to the block containing the GraphOp, and then removes the +// GraphOp. +struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(GraphOp op, + PatternRewriter &rewriter) const override { + Block &block = op.GetBody(); + // Check if graph only has one island. + if (!HasSingleOpInBlock(&block)) return matchFailure(); + + FetchOp fetch_op = op.GetFetch(); + auto island_op = llvm::cast(block.front()); + YieldOp yield_op = island_op.GetYield(); + + // Map graph results to inner ops results of single island. + llvm::SmallVector new_rets; + for (Value *operand : fetch_op.fetches()) { + // Control results should not be propagated out. + if (operand->getType().isa()) break; + + if (operand->getDefiningOp() != island_op) { + // Operand is not from island, simply propagate it out. + new_rets.push_back(operand); + } else { + // Lookup yield operand in island for inner op result. + auto result = llvm::cast(operand); + new_rets.push_back(yield_op.getOperand(result->getResultNumber())); + } + } + + // Move inner ops from island to block containing graph. + auto &island_body = island_op.GetBody().getOperations(); + Operation *operation = op.getOperation(); + operation->getBlock()->getOperations().splice( + operation->getIterator(), island_body, island_body.begin(), + std::prev(island_body.end())); + rewriter.replaceOp(op, new_rets); + + return matchSuccess(); + } +}; +} // anonymous namespace + +void GraphOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// tf_executor.island +//===----------------------------------------------------------------------===// + +namespace { +// This pattern matches and removes IslandOps with no inner ops, no control +// operands and no data results. Control result users will have their relevant +// operands removed. +struct DropEmptyIslandNoOperandNoDataResult + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IslandOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() != 0 || op.getNumResults() != 1 || + !HasSingleOpInBlock(&op.GetBody())) + return matchFailure(); + + for (auto &use : llvm::make_early_inc_range(op.control()->getUses())) + use.getOwner()->eraseOperand(use.getOperandNumber()); + + rewriter.replaceOp(op, {nullptr}); + + return matchSuccess(); + } +}; + +// This pattern matches and removes IslandOps with no inner ops, no control +// operands, one data result and no control result user. The single data result +// (from YieldOps first operand) is forwarded to the IslandOp single data result +// users. +struct DropEmptyIslandNoOperandOneDataResult + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(IslandOp op, + PatternRewriter &rewriter) const override { + if (op.getNumOperands() != 0 || op.getNumResults() != 2 || + !op.control()->use_empty() || + !HasSingleOpInBlock(&op.GetBody())) + return matchFailure(); + + rewriter.replaceOp(op, {op.GetYield().getOperand(0), nullptr}); + + return matchSuccess(); + } +}; + +// TODO(lyandy): Add canonicalization for empty IslandOps with more than one +// control operand and no data results. + +} // anonymous namespace + +void IslandOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// Folders +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// tf_executor.island +//===----------------------------------------------------------------------===// + +LogicalResult IslandOp::fold(llvm::ArrayRef operands, + llvm::SmallVectorImpl &results) { + // This folds IslandOps with no inner ops, one control operand and no data + // results. The single control operand is forwarded to the IslandOp control + // result users. + if (getNumOperands() != 1 || getNumResults() != 1 || + !HasSingleOpInBlock(&GetBody())) + return failure(); + + results.emplace_back(getOperand(0)); + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index 125ef1bfda6..50412544460 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -55,12 +55,14 @@ def TfeControlType : Type()">, "control">; // Token type. def TfeTokenType : Type()">, "token">; +// TODO(hinsu): Define and use TensorType instead of AnyType for data operands +// and results. For example, MergeOp output type. + //===----------------------------------------------------------------------===// // TensorFlow Executor Type Constraint //===----------------------------------------------------------------------===// -// Predicate to verify that the opId'th operand can be broadcasted to the type -// of the resId'th result. +// Predicate to verify all control inputs appear after any non-control inputs. def ControlOperandsAfterAllData : PredOpTrait<"all control inputs must appear after any non-control input", CPred<"succeeded(VerifyControlOperandsAfterAllData(&$_op))">>; @@ -79,7 +81,8 @@ class TfExecutor_Op traits = []> : let parser = [{ return Parse$cppClass(parser, result); }]; } -def TfExecutor_GraphOp : TfExecutor_Op<"graph", []> { +def TfExecutor_GraphOp : TfExecutor_Op<"graph", + [SingleBlockImplicitTerminator<"FetchOp">]> { let summary = [{The `tf_executor.graph` operation contains a region with a single block that lists the operations in a TensorFlow graph.}]; @@ -120,10 +123,14 @@ def TfExecutor_GraphOp : TfExecutor_Op<"graph", []> { let extraClassDeclaration = [{ Block &GetBody() { return getOperation()->getRegion(0).front(); } + FetchOp GetFetch(); }]; + + let hasCanonicalizer = 1; } -def TfExecutor_FetchOp : TfExecutor_Op<"fetch", [Terminator, ControlOperandsAfterAllData]> { +def TfExecutor_FetchOp : TfExecutor_Op<"fetch", + [Terminator, ControlOperandsAfterAllData, HasParent<"GraphOp">]> { let summary = [{ The `tf_executor.fetch` operation terminates the graph and returns values"; }]; @@ -137,10 +144,18 @@ def TfExecutor_FetchOp : TfExecutor_Op<"fetch", [Terminator, ControlOperandsAfte Variadic:$fetches ); + let builders = [OpBuilder< + "Builder *builder, OperationState *result", + [{ + build(builder, result, {}); + }]> + ]; + let verifier = ?; } -def TfExecutor_IslandOp : TfExecutor_Op<"island", []> { +def TfExecutor_IslandOp : TfExecutor_Op<"island", + [HasParent<"GraphOp">, SingleBlockImplicitTerminator<"YieldOp">]> { let summary = [{ The `tf_executor.island` operation is a wrapper for operations in other dialects to be nested in a `tf_executor.graph`. @@ -190,11 +205,16 @@ def TfExecutor_IslandOp : TfExecutor_Op<"island", []> { let extraClassDeclaration = [{ Block &GetBody() { return getOperation()->getRegion(0).front(); } + YieldOp GetYield(); }]; + + let hasCanonicalizer = 1; + + let hasFolder = 1; } -def TfExecutor_YieldOp : - TfExecutor_Op<"yield", [Terminator, ControlOperandsAfterAllData]> { +def TfExecutor_YieldOp : TfExecutor_Op<"yield", + [Terminator, ControlOperandsAfterAllData, HasParent<"IslandOp">]> { let summary = [{ The `tf_executor.yield` operation terminates and returns values for the `tf_executor.island` operation. @@ -204,11 +224,18 @@ def TfExecutor_YieldOp : Variadic:$fetches ); + let builders = [OpBuilder< + "Builder *builder, OperationState *result", + [{ + build(builder, result, {}); + }]> + ]; + let verifier = ?; } def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", - [NoSideEffect, ControlOperandsAfterAllData, + [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">, PredOpTrait<"data operand must be broadcastable to true result", TCOpIsBroadcastableToRes<0, 0>>, PredOpTrait<"data operand must be broadcastable to false result", @@ -221,7 +248,7 @@ def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", let description = [{ More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf This is defined in TensorFlow as: @@ -253,8 +280,8 @@ def TfExecutor_SwitchOp : TfExecutor_Op<"Switch", let verifier = ?; } -def TfExecutor_SwitchNOp : - TfExecutor_Op<"SwitchN", [NoSideEffect, ControlOperandsAfterAllData]> { +def TfExecutor_SwitchNOp : TfExecutor_Op<"SwitchN", + [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">]> { let summary = [{ The "tf_executor.SwitchN" operation takes two inputs, `data` and `index` and an integer attribute `num_outs` indicating the number of outputs. The `data` @@ -282,7 +309,7 @@ def TfExecutor_SwitchNOp : let arguments = (ins AnyType:$data, - I32:$index, + TensorOf<[I32]>:$index, // Optional extra control inputs. Variadic:$controlInputs, I64Attr:$num_outs @@ -294,7 +321,8 @@ def TfExecutor_SwitchNOp : ); } -def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAfterAllData]> { +def TfExecutor_MergeOp : TfExecutor_Op<"Merge", + [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">]> { let summary = [{ The "tf_executor.Merge" operation takes a list of input operands and returns a value of the operand type along with the index of the first match encountered. @@ -302,7 +330,7 @@ def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAf let description = [{ More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf This is defined in TensorFlow as: @@ -322,14 +350,14 @@ def TfExecutor_MergeOp : TfExecutor_Op<"Merge", [NoSideEffect, ControlOperandsAf ); let results = (outs - AnyType:$output, + AnyTensor:$output, TensorOf<[I32]>:$valueIndex, TfeControlType:$control ); } def TfExecutor_EnterOp : TfExecutor_Op<"Enter", - [NoSideEffect, ControlOperandsAfterAllData, + [NoSideEffect, ControlOperandsAfterAllData, HasParent<"GraphOp">, PredOpTrait<"data operand must be broadcastable to result", TCOpIsBroadcastableToRes<0, 0>>]>{ let summary = [{ @@ -339,7 +367,7 @@ def TfExecutor_EnterOp : TfExecutor_Op<"Enter", let description = [{ More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf Each tensor needs its own tf_executor.Enter to be made available inside a while loop. @@ -378,7 +406,8 @@ def TfExecutor_EnterOp : TfExecutor_Op<"Enter", let verifier = ?; } -def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [NoSideEffect]> { +def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", + [NoSideEffect, HasParent<"GraphOp">]> { let summary = [{ The "tf_executor.NextIteration.Source" is paired with a "tf_executor.NextIteration.sink" to represent NextIteration op in @@ -390,7 +419,7 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [No of a while loop. Each loop variable needs its own NextIteration op. More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf In the TF executor dialect, the NextIteration op is broken into tf_executor.NextIteration.sink and tf_executor.NextIteration.source because @@ -415,10 +444,6 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [No Note: Additional result corresponds to the control output. }]; - let arguments = (ins - Variadic:$controlInputs - ); - let results = (outs AnyType:$output, // The NextIteration.Source operation returns an extra token consumed by the sink. @@ -428,19 +453,26 @@ def TfExecutor_NextIterationSourceOp : TfExecutor_Op<"NextIteration.Source", [No let builders = [OpBuilder< "Builder *builder, OperationState *result, Type result_type, " - "ArrayRef control_inputs = {}, ArrayRef attributes = {}", + "ArrayRef attributes = {}", [{ Type token_type = TokenType::get(builder->getContext()); Type control_type = ControlType::get(builder->getContext()); result->types = { result_type, token_type, control_type }; - result->operands.append(control_inputs.begin(), control_inputs.end()); result->attributes.append(attributes.begin(), attributes.end()); }]> ]; + + let extraClassDeclaration = [{ + NextIterationSinkOp GetSink() { + return cast(*token()->user_begin()); + } + }]; + } -def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink"> { +def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink", + [HasParent<"GraphOp">]> { let summary = [{ The "tf_executor.NextIteration.Sink" is paired with a "tf_executor.NextIteration.source" to represent NextIteration op in @@ -452,7 +484,7 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink"> { of a while loop. Each loop variable needs its own NextIteration op. More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf In the TF executor dialect, the NextIteration op is broken into tf_executor.NextIteration.sink and tf_executor.NextIteration.source because @@ -500,7 +532,7 @@ def TfExecutor_NextIterationSinkOp : TfExecutor_Op<"NextIteration.Sink"> { } def TfExecutor_ExitOp : TfExecutor_Op<"Exit", - [NoSideEffect, + [NoSideEffect, HasParent<"GraphOp">, PredOpTrait<"data operand must be broadcastable to result", TCOpIsBroadcastableToRes<0, 0>>]>{ @@ -512,7 +544,7 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit", let description = [{ More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf This is defined in Tensorflow as: @@ -540,7 +572,8 @@ def TfExecutor_ExitOp : TfExecutor_Op<"Exit", let verifier = ?; } -def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", [NoSideEffect]> { +def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", + [NoSideEffect, HasParent<"GraphOp">]> { let summary = [{ The `tf_executor.ControlTrigger` operation is similar to a no-op except that it always produces a valid output even when inputs are dead. @@ -576,7 +609,8 @@ def TfExecutor_ControlTriggerOp : TfExecutor_Op<"ControlTrigger", [NoSideEffect] ]; } -def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> { +def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", + [NoSideEffect, HasParent<"GraphOp">]> { let summary = [{ The "tf_executor.LoopCond" operation forwards a boolean value as loop condition of Tensorflow while loops. @@ -584,7 +618,7 @@ def TfExecutor_LoopCondOp : TfExecutor_Op<"LoopCond", [NoSideEffect]> { let description = [{ More details can be found in Tensorflow Control Flow white paper: - http://download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf + https://storage.googleapis.com/download.tensorflow.org/paper/white_paper_tf_control_flow_implementation_2017_11_1.pdf This is defined in Tensorflow as: diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 9c256034c2b..f7311f61985 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -123,6 +123,65 @@ def TF_AddV2Op : TF_Op<"AddV2", [Broadcastable, Commutative, NoSideEffect]>, let hasCanonicalizer = 1; } +def TF_AnyOp : TF_Op<"Any", [NoSideEffect]> { + let summary = [{ +Computes the "logical or" of elements across dimensions of a tensor. + }]; + + let description = [{ +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + }]; + + let arguments = (ins + I1Tensor:$input, + TF_I32OrI64Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims + ); + + let results = (outs + I1Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; +} + +def TF_ArgMaxOp : TF_Op<"ArgMax", [NoSideEffect]> { + let summary = [{ +Returns the index with the largest value across dimensions of a tensor. + }]; + + let description = [{ +Note that in case of ties the identity of the return value is not guaranteed. + +Usage: + ```python + import tensorflow as tf + a = [1, 10, 26.9, 2.8, 166.32, 62.3] + b = tf.math.argmax(input = a) + c = tf.keras.backend.eval(b) + # c = 4 + # here a[4] = 166.32 which is the largest element of a across axis 0 + ``` + }]; + + let arguments = (ins + TF_NumberTensor:$input, + TF_I32OrI64Tensor:$dimension + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr output_type = TF_DerivedResultTypeAttr<0>; +} + def TF_ArgMinOp : TF_Op<"ArgMin", [NoSideEffect]> { let summary = [{ Returns the index with the smallest value across dimensions of a tensor. @@ -136,7 +195,7 @@ Usage: import tensorflow as tf a = [1, 10, 26.9, 2.8, 166.32, 62.3] b = tf.math.argmin(input = a) - c = tf.keras.backend.eval(b) + c = tf.keras.backend.eval(b) # c = 0 # here a[0] = 1 which is the smallest element of a across axis 0 ``` @@ -156,6 +215,28 @@ Usage: TF_DerivedResultTypeAttr output_type = TF_DerivedResultTypeAttr<0>; } +def TF_AssertOp : TF_Op<"Assert", []> { + let summary = "Asserts that the given condition is true."; + + let description = [{ +If `condition` evaluates to false, print the list of tensors in `data`. +`summarize` determines how many entries of the tensors to print. + }]; + + let arguments = (ins + I1Tensor:$condition, + Variadic:$data, + + DefaultValuedAttr:$summarize + ); + + let results = (outs); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<1>; + + let hasCanonicalizer = 1; +} + def TF_AvgPoolOp : TF_Op<"AvgPool", [NoSideEffect]> { let summary = "Performs average pooling on the input."; @@ -528,6 +609,115 @@ Given an input tensor, this function computes cosine of every TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DepthToSpaceOp : TF_Op<"DepthToSpace", [NoSideEffect]> { + let summary = "DepthToSpace for tensors of type T."; + + let description = [{ +Rearranges data from depth into blocks of spatial data. +This is the reverse transformation of SpaceToDepth. More specifically, +this op outputs a copy of the input tensor where values from the `depth` +dimension are moved in spatial blocks to the `height` and `width` dimensions. +The attr `block_size` indicates the input block size and how the data is moved. + + * Chunks of data of size `block_size * block_size` from depth are rearranged + into non-overlapping blocks of size `block_size x block_size` + * The width the output tensor is `input_depth * block_size`, whereas the + height is `input_height * block_size`. + * The Y, X coordinates within each block of the output image are determined + by the high order component of the input channel index. + * The depth of the input tensor must be divisible by + `block_size * block_size`. + +The `data_format` attr specifies the layout of the input and output tensors +with the following options: + "NHWC": `[ batch, height, width, channels ]` + "NCHW": `[ batch, channels, height, width ]` + "NCHW_VECT_C": + `qint8 [ batch, channels / 4, height, width, 4 ]` + +It is useful to consider the operation as transforming a 6-D Tensor. +e.g. for data_format = NHWC, + Each element in the input tensor can be specified via 6 coordinates, + ordered by decreasing memory layout significance as: + n,iY,iX,bY,bX,oC (where n=batch index, iX, iY means X or Y coordinates + within the input image, bX, bY means coordinates + within the output block, oC means output channels). + The output would be the input transposed to the following layout: + n,iY,bY,iX,bX,oC + +This operation is useful for resizing the activations between convolutions +(but keeping all data), e.g. instead of pooling. It is also useful for training +purely convolutional models. + +For example, given an input of shape `[1, 1, 1, 4]`, data_format = "NHWC" and +block_size = 2: + +``` +x = [[[[1, 2, 3, 4]]]] + +``` + +This operation will output a tensor of shape `[1, 2, 2, 1]`: + +``` + [[[[1], [2]], + [[3], [4]]]] +``` + +Here, the input has a batch of 1 and each batch element has shape `[1, 1, 4]`, +the corresponding output will have 2x2 elements and will have a depth of +1 channel (1 = `4 / (block_size * block_size)`). +The output element shape is `[2, 2, 1]`. + +For an input tensor with larger depth, here of shape `[1, 1, 1, 12]`, e.g. + +``` +x = [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] +``` + +This operation, for block size of 2, will return the following tensor of shape +`[1, 2, 2, 3]` + +``` + [[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]] + +``` + +Similarly, for the following input of shape `[1 2 2 4]`, and a block size of 2: + +``` +x = [[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]] +``` + +the operator will return the following tensor of shape `[1 4 4 1]`: + +``` +x = [[[ [1], [2], [5], [6]], + [ [3], [4], [7], [8]], + [ [9], [10], [13], [14]], + [ [11], [12], [15], [16]]]] + +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + + Confined]>:$block_size, + DefaultValuedAttr, "NHWC">:$data_format + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DepthwiseConv2dNativeOp : TF_Op<"DepthwiseConv2dNative", [NoSideEffect]> { let summary = [{ Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors. @@ -646,6 +836,51 @@ tf.math.equal(x, y) ==> array([True, True]) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ExpOp : TF_Op<"Exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Computes exponential of x element-wise. \\(y = e^x\\). + }]; + + let description = [{ +This function computes the exponential of every element in the input tensor. + i.e. `exp(x)` or `e^(x)`, where `x` is the input tensor. + `e` denotes Euler's number and is approximately equal to 2.718281. + Output is positive for any real input. + + ```python + x = tf.constant(2.0) + tf.math.exp(x) ==> 7.389056 + + x = tf.constant([2.0, 8.0]) + tf.math.exp(x) ==> array([7.389056, 2980.958], dtype=float32) + ``` + + For complex numbers, the exponential value is calculated as follows: + + ``` + e^(x+iy) = e^x * e^iy = e^x * (cos y + i sin y) + ``` + + Let's consider complex number 1+1j as an example. + e^1 * (cos 1 + i sin 1) = 2.7182818284590 * (0.54030230586+0.8414709848j) + + ```python + x = tf.constant(1 + 1j) + tf.math.exp(x) ==> 1.4686939399158851+2.2873552871788423j + ``` + }]; + + let arguments = (ins + TF_FpOrComplexTensor:$x + ); + + let results = (outs + TF_FpOrComplexTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ExpandDimsOp : TF_Op<"ExpandDims", [NoSideEffect]> { let summary = "Inserts a dimension of 1 into a tensor's shape."; @@ -858,6 +1093,32 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [Broadcastable, NoSideEffect]>, TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_FloorModOp : TF_Op<"FloorMod", [Broadcastable, NoSideEffect]>, + WithBroadcastableBinOpBuilder { + let summary = [{ +Returns element-wise remainder of division. When `x < 0` xor `y < 0` is + }]; + + let description = [{ +true, this follows Python semantics in that the result here is consistent +with a flooring divide. E.g. `floor(x / y) * y + mod(x, y) = x`. + +*NOTE*: `FloorMod` supports broadcasting. More about broadcasting +[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) + }]; + + let arguments = (ins + TF_FpOrI32OrI64Tensor:$x, + TF_FpOrI32OrI64Tensor:$y + ); + + let results = (outs + TF_FpOrI32OrI64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_FusedBatchNormOp : TF_Op<"FusedBatchNorm", [NoSideEffect]> { let summary = "Batch normalization."; @@ -893,6 +1154,39 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; } +def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect]> { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; +} + def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { let summary = "Gather slices from `params` according to `indices`."; @@ -945,13 +1239,13 @@ Gather slices from `params` into a Tensor with shape specified by `indices`. }]; let description = [{ -`indices` is an K-dimensional integer tensor, best thought of as a +`indices` is a K-dimensional integer tensor, best thought of as a (K-1)-dimensional tensor of indices into `params`, where each element defines a slice of `params`: output[\\(i_0, ..., i_{K-2}\\)] = params[indices[\\(i_0, ..., i_{K-2}\\)]] -Whereas in `tf.gather` `indices` defines slices into the first +Whereas in `tf.gather` `indices` defines slices into the `axis` dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the first `N` dimensions of `params`, where `N = indices.shape[-1]`. @@ -1224,10 +1518,10 @@ for dtype in dtype_list: input_tensor, bitwise_ops.invert(input_tensor)), bitwise_ops.invert( tf.constant(0, dtype=dtype))] - + expected = tf.constant([0, 0, 0, 0], dtype=tf.float32) tf.assert_equal(tf.cast(not_a_and_a, tf.float32), expected) - + expected = tf.cast([not_0] * 4, tf.float32) tf.assert_equal(tf.cast(not_a_or_a, tf.float32), expected) @@ -2402,6 +2696,29 @@ Input images can be of different types but output images are always float. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> { + let summary = [{ +Resize `images` to `size` using nearest neighbor interpolation. + }]; + + let description = [{ + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$images, + I32Tensor:$size, + + DefaultValuedAttr:$align_corners, + DefaultValuedAttr:$half_pixel_centers + ); + + let results = (outs + TensorOf<[F16, F32, F64, I16, I32, I64, I8]>:$resized_images + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ReverseSequenceOp : TF_Op<"ReverseSequence", [NoSideEffect]> { let summary = "Reverses variable length slices."; @@ -2543,6 +2860,27 @@ reverse(t, dims) ==> [[[[8, 9, 10, 11], TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } +def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +Rounds the values of a tensor to the nearest integer, element-wise. + }]; + + let description = [{ +Rounds half to even. Also known as bankers rounding. If you want to round +according to the current system rounding mode use std::cint. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RsqrtOp : TF_Op<"Rsqrt", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes reciprocal of square root of x element-wise."; @@ -2618,6 +2956,25 @@ select(condition, t, e) ==> [[1, 2], TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } +def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect]> { + let summary = ""; + + let description = [{ + }]; + + let arguments = (ins + I1Tensor:$condition, + TF_Tensor:$t, + TF_Tensor:$e + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_ShapeOp : TF_Op<"Shape", [NoSideEffect]> { let summary = "Returns the shape of a tensor."; @@ -2650,6 +3007,31 @@ shape(t) ==> [2, 2, 3] let hasFolder = 1; } +def TF_ShapeNOp : TF_Op<"ShapeN", [NoSideEffect]> { + let summary = "Returns shape of tensors."; + + let description = [{ +This operation returns N 1-D integer tensors representing shape of `input[i]s`. + }]; + + let arguments = (ins + Variadic:$input, + + Confined]>:$N + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_SigmoidOp : TF_Op<"Sigmoid", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes sigmoid of `x` element-wise."; @@ -2719,6 +3101,23 @@ whose values are extracted from 'input' starting at the offsets in TF_DerivedOperandTypeAttr Index = TF_DerivedOperandTypeAttr<1>; } +def TF_SnapshotOp : TF_Op<"Snapshot", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns a copy of the input tensor."; + + let description = [{ + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_SoftmaxOp : TF_Op<"Softmax", [NoSideEffect]> { let summary = "Computes softmax activations."; @@ -2772,6 +3171,151 @@ precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } +def TF_SpaceToDepthOp : TF_Op<"SpaceToDepth", [NoSideEffect]> { + let summary = "SpaceToDepth for tensors of type T."; + + let description = [{ +Rearranges blocks of spatial data, into depth. More specifically, +this op outputs a copy of the input tensor where values from the `height` +and `width` dimensions are moved to the `depth` dimension. +The attr `block_size` indicates the input block size. + + * Non-overlapping blocks of size `block_size x block size` are rearranged + into depth at each location. + * The depth of the output tensor is `block_size * block_size * input_depth`. + * The Y, X coordinates within each block of the input become the high order + component of the output channel index. + * The input tensor's height and width must be divisible by block_size. + +The `data_format` attr specifies the layout of the input and output tensors +with the following options: + "NHWC": `[ batch, height, width, channels ]` + "NCHW": `[ batch, channels, height, width ]` + "NCHW_VECT_C": + `qint8 [ batch, channels / 4, height, width, 4 ]` + +It is useful to consider the operation as transforming a 6-D Tensor. +e.g. for data_format = NHWC, + Each element in the input tensor can be specified via 6 coordinates, + ordered by decreasing memory layout significance as: + n,oY,bY,oX,bX,iC (where n=batch index, oX, oY means X or Y coordinates + within the output image, bX, bY means coordinates + within the input block, iC means input channels). + The output would be a transpose to the following layout: + n,oY,oX,bY,bX,iC + +This operation is useful for resizing the activations between convolutions +(but keeping all data), e.g. instead of pooling. It is also useful for training +purely convolutional models. + +For example, given an input of shape `[1, 2, 2, 1]`, data_format = "NHWC" and +block_size = 2: + +``` +x = [[[[1], [2]], + [[3], [4]]]] +``` + +This operation will output a tensor of shape `[1, 1, 1, 4]`: + +``` +[[[[1, 2, 3, 4]]]] +``` + +Here, the input has a batch of 1 and each batch element has shape `[2, 2, 1]`, +the corresponding output will have a single element (i.e. width and height are +both 1) and will have a depth of 4 channels (1 * block_size * block_size). +The output element shape is `[1, 1, 4]`. + +For an input tensor with larger depth, here of shape `[1, 2, 2, 3]`, e.g. + +``` +x = [[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]] +``` + +This operation, for block_size of 2, will return the following tensor of shape +`[1, 1, 1, 12]` + +``` +[[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]] +``` + +Similarly, for the following input of shape `[1 4 4 1]`, and a block size of 2: + +``` +x = [[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]] +``` + +the operator will return the following tensor of shape `[1 2 2 4]`: + +``` +x = [[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]] +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + + Confined]>:$block_size, + DefaultValuedAttr, "NHWC">:$data_format + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_SparseToDenseOp : TF_Op<"SparseToDense", [NoSideEffect]> { + let summary = "Converts a sparse representation into a dense tensor."; + + let description = [{ +Builds an array `dense` with shape `output_shape` such that + +``` +# If sparse_indices is scalar +dense[i] = (i == sparse_indices ? sparse_values : default_value) + +# If sparse_indices is a vector, then for each i +dense[sparse_indices[i]] = sparse_values[i] + +# If sparse_indices is an n by d matrix, then for each i in [0, n) +dense[sparse_indices[i][0], ..., sparse_indices[i][d-1]] = sparse_values[i] +``` + +All other values in `dense` are set to `default_value`. If `sparse_values` is a +scalar, all sparse indices are set to this single value. + +Indices should be sorted in lexicographic order, and indices must not +contain any repeats. If `validate_indices` is true, these properties +are checked during execution. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$sparse_indices, + TF_I32OrI64Tensor:$output_shape, + TF_Tensor:$sparse_values, + TF_Tensor:$default_value, + + DefaultValuedAttr:$validate_indices + ); + + let results = (outs + TF_Tensor:$dense + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<2>; +} + def TF_SplitOp : TF_Op<"Split", [NoSideEffect]> { let summary = "Splits a tensor into `num_split` tensors along one dimension."; @@ -2910,6 +3454,42 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Stops gradient computation."; + + let description = [{ +When executed in a graph, this op outputs its input tensor as-is. + +When building ops to compute gradients, this op prevents the contribution of +its inputs to be taken into account. Normally, the gradient generator adds ops +to a graph to compute the derivatives of a specified 'loss' by recursively +finding out inputs that contributed to its computation. If you insert this op +in the graph it inputs are masked from the gradient generator. They are not +taken into account for computing gradients. + +This is useful any time you want to compute a value with TensorFlow but need +to pretend that the value was a constant. Some examples include: + +* The *EM* algorithm where the *M-step* should not involve backpropagation + through the output of the *E-step*. +* Contrastive divergence training of Boltzmann machines where, when + differentiating the energy function, the training must not backpropagate + through the graph that generated the samples from the model. +* Adversarial training, where no backprop should happen through the adversarial + example generation process. + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_StridedSliceOp : TF_Op<"StridedSlice", [NoSideEffect]> { let summary = "Return a strided slice from `input`."; @@ -3143,6 +3723,31 @@ def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> { TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_TensorListPushBackOp : TF_Op<"TensorListPushBack", [NoSideEffect]> { + let summary = [{ +Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. + }]; + + let description = [{ +tensor: The tensor to put on the list. +input_handle: The old list. +output_handle: A list with the elements of the old list followed by tensor. +element_dtype: the type of elements in the list. +element_shape: a shape compatible with that of elements in the list. + }]; + + let arguments = (ins + TF_VariantTensor:$input_handle, + TF_Tensor:$tensor + ); + + let results = (outs + TF_VariantTensor:$output_handle + ); + + TF_DerivedOperandTypeAttr element_dtype = TF_DerivedOperandTypeAttr<1>; +} + def TF_TensorListSetItemOp : TF_Op<"TensorListSetItem", [NoSideEffect]> { let summary = ""; @@ -3187,6 +3792,30 @@ num_elements: optional. If not -1, the number of elements in the list. TF_DerivedResultTypeAttr element_dtype = TF_DerivedResultTypeAttr<0>; } +def TF_TileOp : TF_Op<"Tile", [NoSideEffect]> { + let summary = "Constructs a tensor by tiling a given tensor."; + + let description = [{ +This operation creates a new tensor by replicating `input` `multiples` times. +The output tensor's i'th dimension has `input.dims(i) * multiples[i]` elements, +and the values of `input` are replicated `multiples[i]` times along the 'i'th +dimension. For example, tiling `[a b c d]` by `[2]` produces +`[a b c d a b c d]`. + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$multiples + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tmultiples = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { let summary = [{ Finds values and indices of the `k` largest elements for the last dimension. @@ -3346,6 +3975,82 @@ This is the opposite of `pack`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_WhereOp : TF_Op<"Where", [NoSideEffect]> { + let summary = "Returns locations of nonzero / true values in a tensor."; + + let description = [{ +This operation returns the coordinates of true elements in `condition`. The +coordinates are returned in a 2-D tensor where the first dimension (rows) +represents the number of true elements, and the second dimension (columns) +represents the coordinates of the true elements. Keep in mind, the shape of +the output tensor can vary depending on how many true values there are in +`condition`. Indices are output in row-major order. + +For example: + +``` +# 'input' tensor is [[True, False] +# [True, False]] +# 'input' has two true values, so output has two coordinates. +# 'input' has rank of 2, so coordinates have two indices. +where(input) ==> [[0, 0], + [1, 0]] + +# `condition` tensor is [[[True, False] +# [True, False]] +# [[False, True] +# [False, True]] +# [[False, False] +# [False, True]]] +# 'input' has 5 true values, so output has 5 coordinates. +# 'input' has rank of 3, so coordinates have three indices. +where(input) ==> [[0, 0, 0], + [0, 1, 0], + [1, 0, 1], + [1, 1, 1], + [2, 1, 1]] + +# `condition` tensor is [[[1.5, 0.0] +# [-0.5, 0.0]] +# [[0.0, 0.25] +# [0.0, 0.75]] +# [[0.0, 0.0] +# [0.0, 0.01]]] +# 'input' has 5 nonzero values, so output has 5 coordinates. +# 'input' has rank of 3, so coordinates have three indices. +where(input) ==> [[0, 0, 0], + [0, 1, 0], + [1, 0, 1], + [1, 1, 1], + [2, 1, 1]] + +# `condition` tensor is [[[1.5 + 0.0j, 0.0 + 0.0j] +# [0.0 + 0.5j, 0.0 + 0.0j]] +# [[0.0 + 0.0j, 0.25 + 1.5j] +# [0.0 + 0.0j, 0.75 + 0.0j]] +# [[0.0 + 0.0j, 0.0 + 0.0j] +# [0.0 + 0.0j, 0.01 + 0.0j]]] +# 'input' has 5 nonzero magnitude values, so output has 5 coordinates. +# 'input' has rank of 3, so coordinates have three indices. +where(input) ==> [[0, 0, 0], + [0, 1, 0], + [1, 0, 1], + [1, 1, 1], + [2, 1, 1]] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$input + ); + + let results = (outs + I64Tensor:$index + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XdivyOp : TF_Op<"Xdivy", [Broadcastable, NoSideEffect]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index f374b6b0b77..080e78042a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -16,7 +16,8 @@ limitations under the License. // This is the base operation definition file for TensorFlow. // // This file includes the definition for the TensorFlow dialect, base TensorFlow -// op, and various commonly used TensorFlow types, attributes, and builders. +// op, and various commonly used TensorFlow traits, types, attributes, and +// builders. #ifdef TF_OP_BASE #else @@ -50,6 +51,16 @@ TODO: Make invariants more structured so that we can reference them in ops. let cppNamespace = "TF"; } +//===----------------------------------------------------------------------===// +// TensorFlow traits +//===----------------------------------------------------------------------===// + +// Specify this trait if the op requires all outputs to have the same type and +// the inputs either have the same type as result or a ref type corresponding to +// the result type. +def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< + "TF::OperandsSameAsResultsTypeOrRef">; + //===----------------------------------------------------------------------===// // TensorFlow op definitions //===----------------------------------------------------------------------===// @@ -65,6 +76,12 @@ class TF_Op traits = []> : def TF_TFDialectType : Type()">, "TensorFlow type">; +// Class for any TensorFlow dialect specific type +class TF_TensorFlowType : + Type()">, + "TensorFlow " # description # " type">, + BuildableType<"getType()">; + // Any tensor element type allowed in TensorFlow ops def TF_ElementType : Type, @@ -80,11 +97,34 @@ def TF_I32Or64 : IntOfWidths<[32, 64]>; def TF_I32OrI64Tensor : TensorOf<[TF_I32Or64]>; -def TF_Int : IntOfWidths<[8, 16, 32, 64]>; +def TF_Uint8 : TF_TensorFlowType<"Uint8", "uint8">; +def TF_Uint16 : TF_TensorFlowType<"Uint16", "uint16">; +def TF_Uint32 : TF_TensorFlowType<"Uint32", "uint32">; +def TF_Uint64 : TF_TensorFlowType<"Uint64", "uint64">; + +// Any unsigned integer type +def TF_UInt : AnyTypeOf<[TF_Uint8, TF_Uint16, TF_Uint32, TF_Uint64]>; + +// Any signed integer type +def TF_SInt : IntOfWidths<[8, 16, 32, 64]>; + +// Any integer type +def TF_Int : AnyTypeOf<[TF_SInt, TF_UInt]>; // Any integer tensor types def TF_IntTensor : TensorOf<[TF_Int]>; +//===----------------------------------------------------------------------===// +// Quantized types +def TF_Qint8 : TF_TensorFlowType<"Qint8", "qint8">; +def TF_Qint16 : TF_TensorFlowType<"Qint16", "qint16">; +def TF_Qint32 : TF_TensorFlowType<"Qint32", "qint32">; +def TF_Quint8 : TF_TensorFlowType<"Quint8", "quint8">; +def TF_Quint16 : TF_TensorFlowType<"Quint16", "quint16">; + +// Any quantized type +def TF_AnyQuantized : AnyTypeOf<[TF_Qint8, TF_Qint16, TF_Qint32, TF_Quint8, + TF_Quint16]>; //===----------------------------------------------------------------------===// // Floating-point types @@ -98,12 +138,10 @@ def TF_FpTensor : TensorOf<[AnyFloat]>; //===----------------------------------------------------------------------===// // Complex types -def TF_Complex64 : - Type()">, "complex64 type">; +def TF_Complex64 : TF_TensorFlowType<"Complex64", "complex64">; def TF_Complex64Tensor : TensorOf<[TF_Complex64]>; -def TF_Complex128 : - Type()">, "complex128 type">; +def TF_Complex128 : TF_TensorFlowType<"Complex128", "complex128">; def TF_Complex128Tensor : TensorOf<[TF_Complex128]>; def TF_AnyComplex : AnyTypeOf<[TF_Complex64, TF_Complex128], @@ -114,19 +152,13 @@ def TF_ComplexTensor : TensorOf<[TF_AnyComplex]>; //===----------------------------------------------------------------------===// // String/variant/resource types -def TF_Str : Type()">, - "TensorFlow string type">, - BuildableType<"getType()">; +def TF_Str : TF_TensorFlowType<"String", "string">; def TF_StrTensor : TensorOf<[TF_Str]>; -def TF_Variant : Type()">, - "TensorFlow variant type">, - BuildableType<"getType()">; +def TF_Variant : TF_TensorFlowType<"Variant", "variant">; def TF_VariantTensor : TensorOf<[TF_Variant]>; -def TF_Resource : Type()">, - "TensorFlow variant type">, - BuildableType<"getType()">; +def TF_Resource : TF_TensorFlowType<"Resource", "resource">; def TF_ResourceTensor : TensorOf<[TF_Resource]>; //===----------------------------------------------------------------------===// @@ -141,7 +173,8 @@ def TF_IntOrFpTensor : TensorOf<[TF_Int, AnyFloat]>; def TF_FpOrComplexTensor : TensorOf<[AnyFloat, TF_AnyComplex]>; -def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyComplex], "number">; +def TF_AnyNumber : AnyTypeOf<[TF_Int, AnyFloat, TF_AnyQuantized, TF_AnyComplex], + "number">; def TF_NumberTensor : TensorOf<[TF_AnyNumber]>; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index e39a6768ea4..587849c6a95 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -16,11 +16,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include +#include +#include +#include +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir @@ -29,23 +34,19 @@ limitations under the License. #include "mlir/IR/Matchers.h" // TF:local_config_mlir #include "mlir/IR/OpImplementation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Parser.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "mlir/Support/STLExtras.h" // TF:local_config_mlir -#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/platform/logging.h" namespace mlir { namespace TF { -namespace { -#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" -} // namespace - //===----------------------------------------------------------------------===// // TF op helper functions //===----------------------------------------------------------------------===// @@ -75,10 +76,11 @@ static inline bool HasRankAtLeast(Value *value, int64_t rank) { return ranked_type.getRank() >= rank; return type.isa(); } + // Returns true if the given pair of TensorFlow types can be cast to one // another. In other words, a single run-time value is legal for both the types. // For example, tensor<*xf32> and tensor<3xf32> are cast compatible. -bool AreCastCompatible(Type a, Type b) { +static bool AreCastCompatible(Type a, Type b) { if (TensorCastOp::areCastCompatible(a, b)) return true; // Variant types may optionally contain subtypes information that need not @@ -89,13 +91,21 @@ bool AreCastCompatible(Type a, Type b) { getElementTypeOrSelf(b).getKind() == TensorFlowTypes::VARIANT; } +static bool IsUnknownDimOrRank(int64_t dim_or_rank) { + return dim_or_rank == -1; +} + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -104,7 +114,36 @@ void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// AssertOp +//===----------------------------------------------------------------------===// + +namespace { + +// Removes Assert with constant true predicate. +struct AssertWithTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AssertOp op, + PatternRewriter &rewriter) const override { + ElementsAttr cst; + if (matchPattern(op.condition(), m_Constant(&cst))) { + if (cst.getValue({}).getValue()) { + rewriter.replaceOp(op, llvm::None); + return matchSuccess(); + } + } + return matchFailure(); + } +}; +} // namespace + +void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); } //===----------------------------------------------------------------------===// @@ -113,7 +152,7 @@ void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -134,7 +173,7 @@ static LogicalResult Verify(BroadcastToOp op) { void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -143,7 +182,7 @@ void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -199,7 +238,23 @@ void ConstOp::build(Builder *builder, OperationState *result, Type type, void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// EmptyTensorListOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(EmptyTensorListOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + + if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { + return op.emitOpError("requires max_num_elements operand to be 0D tensor"); + } + return success(); } //===----------------------------------------------------------------------===// @@ -282,45 +337,39 @@ static LogicalResult Verify(FusedBatchNormOp op) { // IfOp //===----------------------------------------------------------------------===// -LogicalResult IfOp::verify() { - auto thenAttr = getAttrOfType("then_branch"); - if (!thenAttr) return emitOpError("requires then_branch attribute"); - - auto elseAttr = getAttrOfType("else_branch"); - if (!elseAttr) return emitOpError("requires else_branch attribute"); - - auto module = getParentOfType(); - auto thenFn = module.lookupSymbol(thenAttr.getValue()); +static LogicalResult Verify(IfOp op) { + auto module = op.getParentOfType(); + auto thenFn = module.lookupSymbol(op.then_branch()); if (!thenFn) - return emitOpError("then_branch refers to an undefined function : ") - << thenAttr; - auto elseFn = module.lookupSymbol(elseAttr.getValue()); + return op.emitOpError("then_branch refers to an undefined function : ") + << op.then_branch(); + auto elseFn = module.lookupSymbol(op.else_branch()); if (!elseFn) - return emitOpError("else_branch refers to an undefined function : ") - << elseAttr; + return op.emitOpError("else_branch refers to an undefined function : ") + << op.else_branch(); auto thenFuncType = thenFn.getType(); auto elseFuncType = elseFn.getType(); // Non-conditional operands starting with the second operand are passed to // branches and should be pair-wise compatible with branches' inputs. - unsigned expectedNumInputs = getNumOperands() - 1; + unsigned expectedNumInputs = op.getNumOperands() - 1; if (thenFuncType.getNumInputs() != expectedNumInputs || elseFuncType.getNumInputs() != expectedNumInputs) - return emitError("branches should have " + Twine(expectedNumInputs) + - " inputs"); + return op.emitError("branches should have " + Twine(expectedNumInputs) + + " inputs"); for (unsigned i = 0; i < expectedNumInputs; ++i) { - auto operandType = getOperand(i + 1)->getType().cast(); + auto operandType = op.getOperand(i + 1)->getType().cast(); auto thenInputType = thenFuncType.getInput(i).cast(); if (!AreCastCompatible(operandType, thenInputType)) - return emitError( + return op.emitError( llvm::formatv("then branch input type {0} is incompatible with " "operand type {1} at index {2}", thenInputType, operandType, i)); auto elseInputType = elseFuncType.getInput(i).cast(); if (!AreCastCompatible(operandType, elseInputType)) - return emitError( + return op.emitError( llvm::formatv("else branch input type {0} is incompatible with " "operand type {1} at index {2}", elseInputType, operandType, i)); @@ -328,30 +377,30 @@ LogicalResult IfOp::verify() { // If branches have incompatible input types that means that no tensor can // serve as input to both the functions. Hence, the op is invalid. if (!AreCastCompatible(thenInputType, elseInputType)) - return emitError(llvm::formatv( + return op.emitError(llvm::formatv( "branches inputs have incompatible types {0} and {1} at index {2}", thenInputType, elseInputType, i)); } // Branches' results should be pair-wise compatible with the op results. - unsigned expectedNumResults = getNumResults(); + unsigned expectedNumResults = op.getNumResults(); if (thenFuncType.getNumResults() != expectedNumResults || elseFuncType.getNumResults() != expectedNumResults) - return emitError("branches should have " + Twine(expectedNumResults) + - " results"); + return op.emitError("branches should have " + Twine(expectedNumResults) + + " results"); for (unsigned i = 0; i < expectedNumResults; ++i) { - auto resultType = getResult(i)->getType().cast(); + auto resultType = op.getResult(i)->getType().cast(); auto thenResultType = thenFuncType.getResult(i).cast(); if (!AreCastCompatible(thenResultType, resultType)) - return emitError( + return op.emitError( llvm::formatv("then branch result type {0} is incompatible with op " "result type {1} at index {2}", thenResultType, resultType, i)); auto elseResultType = elseFuncType.getResult(i).cast(); if (!AreCastCompatible(elseResultType, resultType)) - return emitError( + return op.emitError( llvm::formatv("else branch result type {0} is incompatible with op " "result type {1} at index {2}", elseResultType, resultType, i)); @@ -365,7 +414,7 @@ LogicalResult IfOp::verify() { void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -399,7 +448,7 @@ OpFoldResult LeakyReluOp::fold(ArrayRef operands) { void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -408,10 +457,9 @@ void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void LogicalNotOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, - context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -420,7 +468,7 @@ void LogicalNotOp::getCanonicalizationPatterns( void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -429,7 +477,7 @@ void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void ReciprocalOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -488,7 +536,7 @@ void RankOp::build(Builder *builder, OperationState *result, Value *input) { void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -499,12 +547,13 @@ void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // m_Constant. static LogicalResult Verify(ReshapeOp op) { auto shapeType = op.shape()->getType().cast(); + if (!shapeType.hasRank()) return success(); if (shapeType.getRank() != 1) return op.emitOpError("shape must be 1D tensor"); auto rankByShape = shapeType.getShape()[0]; auto typeOfTensor = op.tensor()->getType().cast(); // No compile time verification for unknown sized shape. - if (rankByShape == -1 || !typeOfTensor.hasRank()) return success(); + if (rankByShape == -1 || !typeOfTensor.hasStaticShape()) return success(); // Check values if constant shape. No compiling time verification for // non-constant shape. auto *shapeOp = op.shape()->getDefiningOp(); @@ -529,7 +578,7 @@ static LogicalResult Verify(ReshapeOp op) { unsigned numByShape = 1; unsigned unknownDimCount = 0; for (int i = 0, e = rankByShape; i != e; ++i) { - auto num = shapeCstAttr.getValue(i).cast().getInt(); + auto num = shapeCstAttr.getValue(i).getInt(); // The dimension size value can be -1, and that the real size needs to // be computed so that the total size remains constant. At most one // component of shape can be -1. @@ -561,53 +610,105 @@ static LogicalResult Verify(ReshapeOp op) { void ReshapeOp::build(Builder *builder, OperationState *result, Value *tensor, Value *shape) { - auto etype = tensor->getType().cast().getElementType(); + auto ttype = tensor->getType().cast(); + auto etype = ttype.getElementType(); + + auto unranked = [builder, etype, result, shape, tensor]() { + return ReshapeOp::build(builder, result, builder->getTensorType(etype), + tensor, shape); + }; + + // If tensor is unranked then we have no info about output of shape. + if (!ttype.hasRank()) return unranked(); + DenseIntElementsAttr attr_shape; if (matchPattern(shape, m_Constant(&attr_shape))) { llvm::SmallVector const_shape; - if (attr_shape.isSplat()) { - const_shape.assign(attr_shape.getType().getNumElements(), - (*attr_shape.begin()).getSExtValue()); - } else { - const_shape.reserve(attr_shape.getType().getNumElements()); - for (auto dim : attr_shape) const_shape.push_back(dim.getSExtValue()); + const_shape.reserve(attr_shape.getNumElements()); + + // Detect if reshape output shape is folded. + bool flatten = false; + int unknown_index = -1; + // The product of constant shape argument excluding unknown dimension. + int64_t product_cshape = 1; + for (auto e : llvm::enumerate(attr_shape)) { + int64_t val = e.value().getSExtValue(); + if (IsUnknownDimOrRank(val)) { + if (flatten) { + mlir::emitError(result->location) + << "only one unknown dimension allowed"; + return; + } + flatten = true; + unknown_index = e.index(); + } else { + product_cshape *= val; + } + const_shape.push_back(val); + } + + // Compute the value of the uknown dimension. + if (flatten) { + // Compute number of elements in tensor shape. + auto tshape = ttype.getShape(); + int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, + std::multiplies()); + // Set the unknown dimension such that total number of elements remain + // constant. + // Note: The case where the ratio is not integral, and so the total size + // of reshape not constant, is checked in verify function. + const_shape[unknown_index] = product_tshape / product_cshape; } return ReshapeOp::build(builder, result, builder->getTensorType(const_shape, etype), tensor, shape); } - return ReshapeOp::build(builder, result, builder->getTensorType(etype), - tensor, shape); + return unranked(); } //===----------------------------------------------------------------------===// // ShapeOp //===----------------------------------------------------------------------===// -static LogicalResult Verify(ShapeOp op) { - auto inputType = op.input()->getType(); - auto resultType = op.getType().dyn_cast(); - if (!resultType || resultType.getShape().size() != 1) - return op.emitOpError("requires 1D result type"); +namespace { +// Validates Shape/ShapeN operand and associated result types. +LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, + Type result_type, + int variadic_idx = -1) { + std::string variadic_idx_str = + variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); - auto rankedTensorType = inputType.dyn_cast(); - if (rankedTensorType) { + auto result_ranked_type = result_type.dyn_cast(); + if (!result_ranked_type || result_ranked_type.getShape().size() != 1) + return op->emitOpError("requires 1D type for result") << variadic_idx_str; + + auto operand_ranked_type = operand_type.dyn_cast(); + if (operand_ranked_type) { // The operand is a ranked tensor. - if (resultType.hasStaticShape()) { - if ((!rankedTensorType.getShape().empty() && - resultType.getDimSize(0) != rankedTensorType.getShape().size())) - return op.emitOpError( - "requires dimension size of result to match rank of operand"); - } - } else { + if (result_ranked_type.hasStaticShape() && + !operand_ranked_type.getShape().empty() && + result_ranked_type.getDimSize(0) != + operand_ranked_type.getShape().size()) + return op->emitOpError("requires dimension size of result") + << variadic_idx_str << " to match rank of operand" + << variadic_idx_str; + } else if (result_ranked_type.hasStaticShape()) { // The operand is an unranked tensor, verify that the result is dynamic. - if (resultType.hasStaticShape()) - return op.emitOpError("requires dynamic shape result for unranked input"); + return op->emitOpError("requires dynamic shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; } - Type elt = op.getType().cast().getElementType(); - if (elt.isInteger(32) || elt.isInteger(64)) return success(); - return op.emitOpError("requires int32 or int64 return type"); + Type element_type = result_ranked_type.getElementType(); + if (!element_type.isInteger(32) && !element_type.isInteger(64)) + return op->emitOpError("requires int32 or int64 return type for result") + << variadic_idx_str; + + return success(); +} +} // anonymous namespace + +static LogicalResult Verify(ShapeOp op) { + return VerifyShapeOperandAndResult(op, op.input()->getType(), op.getType()); } OpFoldResult ShapeOp::fold(ArrayRef operands) { @@ -630,6 +731,30 @@ OpFoldResult ShapeOp::fold(ArrayRef operands) { return b.getDenseElementsAttr(resultType, dimensions); } +//===----------------------------------------------------------------------===// +// ShapeNOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ShapeNOp op) { + const uint64_t n_attr = op.N().getZExtValue(); + + if (op.getNumOperands() != n_attr) + return op.emitOpError() << "requires " << n_attr << " operand(s), got " + << op.getNumOperands() << " operand(s)"; + + if (op.getNumResults() != n_attr) + return op.emitOpError() << "requires " << n_attr << " result(s), got " + << op.getNumResults() << " result(s)"; + + for (auto i : llvm::seq(0, n_attr)) { + auto verification = VerifyShapeOperandAndResult( + op, op.getOperand(i)->getType(), op.getResult(i)->getType(), i); + if (failed(verification)) return verification; + } + + return success(); +} + //===----------------------------------------------------------------------===// // SoftmaxOp //===----------------------------------------------------------------------===// @@ -647,7 +772,7 @@ static LogicalResult Verify(SoftmaxOp op) { void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -656,7 +781,7 @@ void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -707,10 +832,10 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x, llvm::SmallVector const_shape; if (attr_shape.isSplat()) { const_shape.assign( - attr_shape.getType().getNumElements(), + attr_shape.getNumElements(), x_type.getDimSize((*attr_shape.begin()).getSExtValue())); } else { - const_shape.reserve(attr_shape.getType().getNumElements()); + const_shape.reserve(attr_shape.getNumElements()); for (auto dim : attr_shape) const_shape.push_back(x_type.getDimSize(dim.getSExtValue())); } @@ -727,32 +852,35 @@ void TransposeOp::build(Builder *builder, OperationState *result, Value *x, void TruncateDivOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// // WhileOp //===----------------------------------------------------------------------===// -LogicalResult WhileOp::verify() { - auto condAttr = getAttrOfType("cond"); - if (!condAttr) return emitOpError("requires cond attribute"); +static LogicalResult Verify(WhileOp op) { + auto module = op.getParentOfType(); + auto condFn = module.lookupSymbol(op.cond()); + auto bodyFn = module.lookupSymbol(op.body()); + if (!condFn) { + return op.emitOpError("cond refers to an undefined function : ") + << op.cond(); + } + if (!bodyFn) { + return op.emitOpError("body refers to an undefined function : ") + << op.body(); + } - auto module = getParentOfType(); - auto condFn = module.lookupSymbol(condAttr.getValue()); auto condFuncType = condFn.getType(); + auto bodyFuncType = bodyFn.getType(); // Verify that the cond function has exactly one result. if (condFuncType.getNumResults() != 1) - return emitOpError("requires cond function to have exactly one result"); + return op.emitOpError("requires cond function to have exactly one result"); - auto bodyAttr = getAttrOfType("body"); - if (!bodyAttr) return emitOpError("requires body attribute"); - auto bodyFn = module.lookupSymbol(bodyAttr.getValue()); - auto bodyFuncType = bodyFn.getType(); - - SmallVector operands(getOperandTypes()); - SmallVector results(getResultTypes()); + SmallVector operands(op.getOperandTypes()); + SmallVector results(op.getResultTypes()); // Collect all the type lists for the op so that different pairs of type lists // can be compared for the compatibility. @@ -796,7 +924,7 @@ LogicalResult WhileOp::verify() { int aSize = a.second.size(); if (aSize != b.second.size()) - return emitOpError( + return op.emitOpError( llvm::formatv("requires the number of {0}s to be equal to the " "number of {1}s. Found {2} and {3}, respectively", a.first, b.first, aSize, b.second.size())); @@ -806,7 +934,7 @@ LogicalResult WhileOp::verify() { auto bType = b.second[idx]; if (!AreCastCompatible(aType, bType)) - return emitError(llvm::formatv( + return op.emitError(llvm::formatv( "{0} type {1} is incompatible with {2} type {3} at index {4}", a.first, aType, b.first, bType, idx)); } @@ -821,7 +949,7 @@ LogicalResult WhileOp::verify() { void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - RewriteListBuilder::build(results, context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -840,7 +968,7 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc" - , IfOp, WhileOp>(); + >(); addTypes< #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type @@ -954,27 +1082,5 @@ Operation *TensorFlowDialect::materializeConstant(OpBuilder &builder, return nullptr; } -// Verifies that the Op is a well-formed TensorFlow op, checking that all inputs -// and results are Tensor or other TensorFlow types, etc. -LogicalResult verifyTensorFlowOp(Operation *op) { - if (op->getName().getDialect() != "tf") - return op->emitError("TensorFlow op ") - << op->getName() << " should start with 'tf.'"; - - for (Type type : op->getOperandTypes()) { - if (!IsValidTFTensorType(type)) - return op->emitOpError( - "requires operands to have a valid TensorFlow tensor type"); - } - - for (Type type : op->getResultTypes()) { - if (!IsValidTFTensorType(type)) - return op->emitOpError( - "requires results to have a valid TensorFlow tensor type"); - } - - return success(); -} - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index 723aa67c6c4..8a2fa9dd7fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -27,7 +27,8 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/OpDefinition.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -64,20 +65,6 @@ class TensorFlowDialect : public Dialect { Location loc) override; }; -// This verifies that the Op is a well-formed TensorFlow op, checking -// that all inputs and results are Tensor or other TensorFlow types, etc. -static LogicalResult verifyTensorFlowOp(Operation *op); - -// This Trait should be used by all TensorFlow Ops. -// -template -class TensorFlowOp : public OpTrait::TraitBase { - public: - static LogicalResult verifyTrait(Operation *op) { - return verifyTensorFlowOp(op); - } -}; - // TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose // purpose is to catch bug on `tensorflow::mutex_lock`. We don't use // `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and @@ -89,88 +76,6 @@ class TensorFlowOp : public OpTrait::TraitBase { #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" -// The "tf.If" operation takes a condition operand, a list of inputs, and a -// function attribute for the then/else branches. The condition operand -// doesn't have to be a boolean tensor. It is handled according to these -// rules, quoting the TensorFlow op definition: -// -// If the tensor is a scalar of non-boolean type, the scalar is converted to -// a boolean according to the following rule: if the scalar is a numerical -// value, non-zero means True and zero means False; if the scalar is a -// string, non-empty means True and empty means False. If the tensor is not a -// scalar, being empty means False and being non-empty means True. -// -// This is defined in TensorFlow as: -// -// REGISTER_OP("If") -// .Input("cond: Tcond") -// .Input("input: Tin") -// .Output("output: Tout") -// .Attr("Tcond: type") -// .Attr("Tin: list(type) >= 0") -// .Attr("Tout: list(type) >= 0") -// .Attr("then_branch: func") -// .Attr("else_branch: func") -// -// Note: Additional result corresponds to the control output. -class IfOp : public Op::Impl, - OpTrait::VariadicResults> { - public: - using Op::Op; - static StringRef getOperationName() { return "tf.If"; } - - Value *getCondition() { return getOperand(0); } - - // TODO(b/132271680): This is not following Google naming style - StringRef getThen() { - return getAttrOfType("then_branch").getValue(); - } - - StringRef getElse() { - return getAttrOfType("else_branch").getValue(); - } - - LogicalResult verify(); -}; - -// The "tf.While" operation takes a list of inputs and function attributes for -// the loop condition and body. Inputs are updated repeatedly by the body -// function while the loop condition with the tensors evaluates to true. The -// condition result doesn't have to be a boolean tensor. It is handled -// according to these rules, quoting the TensorFlow op definition: -// -// If the tensor is a scalar of non-boolean type, the scalar is converted to -// a boolean according to the following rule: if the scalar is a numerical -// value, non-zero means True and zero means False; if the scalar is a -// string, non-empty means True and empty means False. If the tensor is not a -// scalar, being empty means False and being non-empty means True. -// -// This is defined in TensorFlow as: -// -// REGISTER_OP("While") -// .Input("input: T") -// .Output("output: T") -// .Attr("T: list(type) >= 0") -// .Attr("cond: func") -// .Attr("body: func") -// .Attr("output_shapes: list(shape) = []") -// -class WhileOp : public Op { - public: - using Op::Op; - static StringRef getOperationName() { return "tf.While"; } - - StringRef getCond() { - return getAttrOfType("cond").getValue(); - } - StringRef getBody() { - return getAttrOfType("body").getValue(); - } - - LogicalResult verify(); -}; - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index b2fcb01c2d5..d889a5d038a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -30,6 +30,37 @@ limitations under the License. include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" +class TF_TensorListInitOp : TF_Op { + let results = (outs + TF_VariantTensor:$handle + ); + + TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ + if (handle_dtype().getSubtypes().size() != 1) { + return emitOpError( + "must have exactly one subtype in the result variant type"); + } + + return Verify(*this); + }]; + + DerivedTypeAttr element_dtype = DerivedTypeAttr< + "return getElementTypeOrSelf(element_type());">; + + let extraClassDeclaration = [{ + // Returns type of the TensorList element produced by this op. + TensorType element_type() { return handle_dtype().getSubtypes()[0]; } + + // Returns data type of the result handle. Returned type contains type of + // the TensorList element as a subtype. + VariantType handle_dtype() { + return getElementTypeOrSelf(handle()->getType()).cast(); + } + }]; +} + // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with // its type encoding the tensor's shape and data type. def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> { @@ -55,12 +86,30 @@ def TF_ConstOp : TF_Op<"Const", [NoSideEffect]> { let hasFolder = 1; } +def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> { + let summary = "Creates and returns an empty tensor list."; + + let description = [{ +All list elements must be tensors of dtype element_dtype and shape compatible +with element_shape. + +handle: an empty tensor list. +element_dtype: the type of elements in the list. +element_shape: a shape compatible with that of elements in the list. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$element_shape, + I32Tensor:$max_num_elements + ); +} + // TODO(fengliuai): The tf.Identity is side-effect free and it doesn't change // the status of the system during the execution. However it shouldn't be folded // in general if it used to serve for caching and some other invariant checks, // so we removed the side-effect free property in the op definition. This is a // hack, and we should fix it if we have a better way to model it. -def TF_IdentityOp : TF_Op<"Identity", [SameOperandsAndResultType]> { +def TF_IdentityOp : TF_Op<"Identity", [TF_OperandsSameAsResultsTypeOrRef]> { let summary = "Identity op"; let description = [{ @@ -78,6 +127,50 @@ Returns a tensor with the same shape and contents as input. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_IfOp : TF_Op<"If", []> { + let summary = "output = cond ? then_branch(input) : else_branch(input)"; + + let description = [{ +output = cond ? then_branch(input) : else_branch(input) + +cond: A Tensor. If the tensor is a scalar of non-boolean type, the + scalar is converted to a boolean according to the + following rule: if the scalar is a numerical value, non-zero means + True and zero means False; if the scalar is a string, non-empty + means True and empty means False. If the tensor is not a scalar, + being empty means False and being non-empty means True. +input: A list of input tensors. +then_branch: A function that takes 'inputs' and returns a list of + tensors, whose types are the same as what else_branch returns. +else_branch: A function that takes 'inputs' and returns a list of + tensors. whose types are the same as what then_branch returns. + }]; + + let arguments = (ins + TF_Tensor:$cond, + Variadic:$input, + + SymbolRefAttr:$then_branch, + SymbolRefAttr:$else_branch, + DefaultValuedAttr:$output_shapes, + + // Used to map StatelessIf and If op defined in TensorFlow to a common op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_MeanOp : TF_Op<"Mean", [NoSideEffect]> { let summary = "Computes the mean of elements across dimensions of a tensor."; @@ -147,7 +240,53 @@ Inserts a placeholder for a tensor that will be always fed. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_TensorListReserveOp : TF_Op<"TensorListReserve", [NoSideEffect]> { +def TF_WhileOp : TF_Op<"While", []> { + let summary = [{ +output = input; While (Cond(output)) { output = Body(output) } + }]; + + let description = [{ +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified + by T. + }]; + + let arguments = (ins + Variadic:$input, + + SymbolRefAttr:$cond, + SymbolRefAttr:$body, + DefaultValuedAttr:$output_shapes, + DefaultValuedAttr:$parallel_iterations, + + // Used to map StatelessWhile and While op defined in TensorFlow to a common + // op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> { let summary = "List of the given size with empty elements."; let description = [{ @@ -161,35 +300,6 @@ element_dtype: the desired type of elements in the list. TF_I32OrI64Tensor:$element_shape, I32Tensor:$num_elements ); - - let results = (outs - TF_VariantTensor:$handle - ); - - TF_DerivedOperandTypeAttr shape_type = TF_DerivedOperandTypeAttr<0>; - - let verifier = [{ - if (handle_dtype().getSubtypes().size() != 1) { - return emitOpError( - "must have exactly one subtype in the result variant type"); - } - - return Verify(*this); - }]; - - DerivedTypeAttr element_dtype = DerivedTypeAttr< - "return getElementTypeOrSelf(element_type());">; - - let extraClassDeclaration = [{ - // Returns type of the TensorList element produced by this op. - TensorType element_type() { return handle_dtype().getSubtypes()[0]; } - - // Returns data type of the result handle. Returned type contains type of - // the TensorList element as a subtype. - VariantType handle_dtype() { - return getElementTypeOrSelf(handle()->getType()).cast(); - } - }]; } #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h new file mode 100644 index 00000000000..b96026c8189 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ + +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace OpTrait { +namespace TF { + +// Verifies if 'ref_type' is a REF type corresponding to 'type'. +static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, + mlir::Type ref_type) { + auto ref_type_kind = ref_type.getKind(); + switch (type.getKind()) { + case mlir::StandardTypes::F16: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::HALF_REF); + case mlir::StandardTypes::F32: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::FLOAT_REF); + case mlir::StandardTypes::F64: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::DOUBLE_REF); + case mlir::StandardTypes::BF16: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::BFLOAT16_REF); + case mlir::StandardTypes::Integer: { + const auto& itype = type.cast(); + switch (itype.getWidth()) { + case 1: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::BOOL_REF); + case 8: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT8_REF); + case 16: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT16_REF); + case 32: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT32_REF); + case 64: + return success(ref_type_kind == mlir::TF::TensorFlowTypes::INT64_REF); + default: + return failure(); + } + } +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + case mlir::TF::TensorFlowTypes::enumerant: \ + return success(ref_type_kind == mlir::TF::TensorFlowTypes::enumerant##_REF); + +#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) +// NOLINTNEXTLINE +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" + default: + return failure(); + } +} + +// This class provides verification for ops that are known to have the same +// result types and all operands are either of the same type as result or a REF +// type corresponding to the result type. +template +class OperandsSameAsResultsTypeOrRef + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + LogicalResult shapeMatch = impl::verifySameOperandsAndResultShape(op); + if (failed(shapeMatch)) return shapeMatch; + + auto type = getElementTypeOrSelf(op->getResult(0)->getType()); + + // Verify that the first result type is same as the rest of the results. + // We skip the comparison against itself. + for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) { + resultType = getElementTypeOrSelf(resultType); + if (resultType != type) + return op->emitOpError() << "requires the same type for all results"; + } + + for (auto opType : op->getOperandTypes()) { + opType = getElementTypeOrSelf(opType); + if (opType != type && failed(VerifyRefTypeMatch(type, opType))) { + return op->emitError() << "requires all operands to be either same " + "as or ref type of results"; + } + } + return success(); + } +}; + +} // namespace TF +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_TRAITS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def index 9f1154b84f1..e5041d0ab99 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.def @@ -32,28 +32,33 @@ HANDLE_TF_TYPE(String, STRING, "string") HANDLE_TF_TYPE(Resource, RESOURCE, "resource") HANDLE_TF_TYPE(Complex64, COMPLEX64, "complex64") HANDLE_TF_TYPE(Complex128, COMPLEX128, "complex128") -HANDLE_TF_TYPE(FloatRef, FLOAT_REF, "f32ref") -HANDLE_TF_TYPE(DoubleRef, DOUBLE_REF, "f64ref") -HANDLE_TF_TYPE(Uint8Ref, UINT8_REF, "uint8ref") -HANDLE_TF_TYPE(Int8Ref, INT8_REF, "int8ref") -HANDLE_TF_TYPE(Uint16Ref, UINT16_REF, "uint16ref") -HANDLE_TF_TYPE(Int16Ref, INT16_REF, "int16ref") -HANDLE_TF_TYPE(Uint32Ref, UINT32_REF, "uint32ref") -HANDLE_TF_TYPE(Int32Ref, INT32_REF, "int32ref") -HANDLE_TF_TYPE(Uint64Ref, UINT64_REF, "uint64ref") -HANDLE_TF_TYPE(Int64Ref, INT64_REF, "int64ref") -HANDLE_TF_TYPE(StringRef, STRING_REF, "stringref") -HANDLE_TF_TYPE(BoolRef, BOOL_REF, "boolref") -HANDLE_TF_TYPE(Quint8Ref, QUINT8_REF, "quint8ref") -HANDLE_TF_TYPE(Qint8Ref, QINT8_REF, "qint8ref") -HANDLE_TF_TYPE(Quint16Ref, QUINT16_REF, "quint16ref") -HANDLE_TF_TYPE(Qint16Ref, QINT16_REF, "qint16ref") -HANDLE_TF_TYPE(Qint32Ref, QINT32_REF, "qint32ref") -HANDLE_TF_TYPE(Bfloat16Ref, BFLOAT16_REF, "bfloat16ref") -HANDLE_TF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref") -HANDLE_TF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref") -HANDLE_TF_TYPE(HalfRef, HALF_REF, "halfref") -HANDLE_TF_TYPE(ResourceRef, RESOURCE_REF, "resourceref") + +#ifndef HANDLE_TF_REF_TYPE +#define HANDLE_TF_REF_TYPE(class, enumerant, name) \ + HANDLE_TF_TYPE(class, enumerant, name) +#endif +HANDLE_TF_REF_TYPE(FloatRef, FLOAT_REF, "f32ref") +HANDLE_TF_REF_TYPE(DoubleRef, DOUBLE_REF, "f64ref") +HANDLE_TF_REF_TYPE(Uint8Ref, UINT8_REF, "uint8ref") +HANDLE_TF_REF_TYPE(Int8Ref, INT8_REF, "int8ref") +HANDLE_TF_REF_TYPE(Uint16Ref, UINT16_REF, "uint16ref") +HANDLE_TF_REF_TYPE(Int16Ref, INT16_REF, "int16ref") +HANDLE_TF_REF_TYPE(Uint32Ref, UINT32_REF, "uint32ref") +HANDLE_TF_REF_TYPE(Int32Ref, INT32_REF, "int32ref") +HANDLE_TF_REF_TYPE(Uint64Ref, UINT64_REF, "uint64ref") +HANDLE_TF_REF_TYPE(Int64Ref, INT64_REF, "int64ref") +HANDLE_TF_REF_TYPE(StringRef, STRING_REF, "stringref") +HANDLE_TF_REF_TYPE(BoolRef, BOOL_REF, "boolref") +HANDLE_TF_REF_TYPE(Quint8Ref, QUINT8_REF, "quint8ref") +HANDLE_TF_REF_TYPE(Qint8Ref, QINT8_REF, "qint8ref") +HANDLE_TF_REF_TYPE(Quint16Ref, QUINT16_REF, "quint16ref") +HANDLE_TF_REF_TYPE(Qint16Ref, QINT16_REF, "qint16ref") +HANDLE_TF_REF_TYPE(Qint32Ref, QINT32_REF, "qint32ref") +HANDLE_TF_REF_TYPE(Bfloat16Ref, BFLOAT16_REF, "bfloat16ref") +HANDLE_TF_REF_TYPE(Complex64Ref, COMPLEX64_REF, "complex64ref") +HANDLE_TF_REF_TYPE(Complex128Ref, COMPLEX128_REF, "complex128ref") +HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref") +HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref") #ifndef HANDLE_CUSTOM_TF_TYPE #define HANDLE_CUSTOM_TF_TYPE(class, enumerant, name) \ @@ -64,10 +69,11 @@ HANDLE_CUSTOM_TF_TYPE(Variant, VARIANT, "variant") #ifndef HANDLE_LAST_TF_TYPE #define HANDLE_LAST_TF_TYPE(class, enumerant, name) \ - HANDLE_TF_TYPE(class, enumerant, name) + HANDLE_TF_REF_TYPE(class, enumerant, name) #endif HANDLE_LAST_TF_TYPE(VariantRef, VARIANT_REF, "variantref") #undef HANDLE_LAST_TF_TYPE +#undef HANDLE_TF_REF_TYPE #undef HANDLE_TF_TYPE #endif diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index ffd6bee1e37..65feaa8b84c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -1,5 +1,21 @@ // RUN: tf-opt %s -canonicalize | FileCheck %s +// CHECK-LABEL: func @tfAssertTrue +func @tfAssertTrue(%arg0: tensor<1x1x6x2xf32>) { + %t = constant dense : tensor + // CHECK-NOT: tf.Assert + "tf.Assert"(%t, %arg0) {summarize = 3} : (tensor, tensor<1x1x6x2xf32>) -> () + return +} + +// CHECK-LABEL: func @tfAssertFalse +func @tfAssertFalse(%arg0: tensor<1x1x6x2xf32>) { + %f = constant dense : tensor + // CHECK: tf.Assert + "tf.Assert"(%f, %arg0) {summarize = 3} : (tensor, tensor<1x1x6x2xf32>) -> () + return +} + // CHECK-LABEL: func @testLeakyRelu func @testLeakyRelu(%arg0 : tensor<16xf32>) -> (tensor<16xf32>) { %2 = "tf.LeakyRelu"(%arg0) {alpha = 1.0 : f32} : (tensor<16xf32>) -> tensor<16xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir new file mode 100644 index 00000000000..9e2fdcc1ee5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_formation.mlir @@ -0,0 +1,292 @@ +// RUN: tf-opt %s -split-input-file -tf-device-cluster-formation | FileCheck %s + +// Simple case, single device cluster. + +module { + // CHECK-LABEL: func @singlecluster + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @singlecluster(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor + %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor + %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[C_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[TPU0_OUTPUT]]) + %5 = "tf.D"(%4) : (tensor) -> tensor + tf_executor.yield %5 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Single device cluster, live-in value comes directly from function argument. + +module { + // CHECK-LABEL: func @arglivein + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @arglivein(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) : (tensor) -> tensor + %3 = "tf.A"(%arg0) {device = "tpu0"} : (tensor) -> tensor + + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]], %[[ARG_0]]) : (tensor, tensor) -> tensor + %4 = "tf.B"(%3, %arg0) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[B_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[TPU0_OUTPUT]]) + %5 = "tf.C"(%4) : (tensor) -> tensor + tf_executor.yield %5 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Single device cluster, live-in value comes from other islands. + +module { + // CHECK-LABEL: func @argliveinotherislands + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @argliveinotherislands(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + // CHECK: %[[OTHER_ISLAND_OUTPUT:[0-9]*]]:2 = tf_executor.island { + %1:2 = tf_executor.island { + %3 = "tf.D"(%arg0) : (tensor) -> tensor + tf_executor.yield %3 : tensor + } + + %2:2 = tf_executor.island { + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) : (tensor) -> tensor + %3 = "tf.A"(%arg0) {device = "tpu0"} : (tensor) -> tensor + + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]], %[[OTHER_ISLAND_OUTPUT]]#0) : (tensor, tensor) -> tensor + %4 = "tf.B"(%3, %1#0) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[B_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[TPU0_OUTPUT]]) + %5 = "tf.C"(%4) : (tensor) -> tensor + tf_executor.yield %5 : tensor + } + + tf_executor.fetch %2#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Single device cluster, no live-in values. + +module { + // CHECK-LABEL: func @nolivein + func @nolivein() -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"() : () -> tensor + %3 = "tf.A"() {device = "tpu0"} : () -> tensor + + // CHECK: "tf_device.return"(%[[A_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_OUTPUT]]) + %4 = "tf.B"(%3) : (tensor) -> tensor + tf_executor.yield %4 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Multiple clusters of different devices. Clusters depend on each other. + +module { + // CHECK-LABEL: func @multiplerelatedclusters + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @multiplerelatedclusters(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor + %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor + %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[C_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[TPU0_OUTPUT]]) : (tensor) -> tensor + %5 = "tf.D"(%4) {device = "gpu0"} : (tensor) -> tensor + // CHECK: "tf_device.return"(%[[D_OUTPUT]]) + + // CHECK: tf_executor.yield %[[GPU0_OUTPUT]] + tf_executor.yield %5 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Multiple clusters of different devices. Clusters do not depend on each other. + +module { + // CHECK-LABEL: func @multipleunrelatedclusters + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @multipleunrelatedclusters(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor + %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor + %4 = "tf.C"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[C_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]]) : (tensor) -> tensor + %5 = "tf.D"(%2) {device = "gpu0"} : (tensor) -> tensor + // CHECK: "tf_device.return"(%[[D_OUTPUT]]) + + // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[TPU0_OUTPUT]], %[[GPU0_OUTPUT]]) : (tensor, tensor) -> tensor + %6 = "tf.E"(%4, %5) : (tensor, tensor) -> tensor + + // CHECK: tf_executor.yield %[[E_OUTPUT]] + tf_executor.yield %6 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Single device with non-continous instructions in original block. + +module { + // CHECK-LABEL: func @noncontinoussinglecluster + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @noncontinoussinglecluster(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // Note that tf.C is moved before tf_device.launch. + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) : (tensor) -> tensor + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor + %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor + + %4 = "tf.C"(%arg0) : (tensor) -> tensor + + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor + %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[D_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[C_OUTPUT]], %[[TPU0_OUTPUT]]) : (tensor, tensor) -> tensor + %6 = "tf.E"(%4, %5) : (tensor, tensor) -> tensor + + // CHECK: tf_executor.yield %[[E_OUTPUT]] + tf_executor.yield %6 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + +// ----- + +// Multiple device clusters with intertwined instructions in original block. + +module { + // CHECK-LABEL: func @intertwinedclusters + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @intertwinedclusters(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[GPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[ARG_0]]) : (tensor) -> tensor + // CHECK: "tf_device.return"(%[[C_OUTPUT]]) + // CHECK: {device = "gpu0"} : () -> tensor + + // CHECK: %[[TPU0_OUTPUT:[0-9]*]] = "tf_device.launch" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[A_OUTPUT]]) : (tensor) -> tensor + %3 = "tf.B"(%2) {device = "tpu0"} : (tensor) -> tensor + + %4 = "tf.C"(%arg0) {device = "gpu0"} : (tensor) -> tensor + + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[A_OUTPUT]], %[[B_OUTPUT]]) : (tensor, tensor) -> tensor + %5 = "tf.D"(%2, %3) {device = "tpu0"} : (tensor, tensor) -> tensor + + // CHECK: "tf_device.return"(%[[D_OUTPUT]]) + // CHECK: {device = "tpu0"} : () -> tensor + + // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[GPU0_OUTPUT]], %[[TPU0_OUTPUT]]) : (tensor, tensor) -> tensor + %6 = "tf.E"(%4, %5) : (tensor, tensor) -> tensor + + // CHECK: tf_executor.yield %[[E_OUTPUT]] + tf_executor.yield %6 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir new file mode 100644 index 00000000000..f8797678231 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/cluster_outlining.mlir @@ -0,0 +1,112 @@ +// RUN: tf-opt %s -split-input-file -tf-device-cluster-outlining | FileCheck %s + +// Tests simple case of a single `tf_device.launch`. + +module { + // CHECK-LABEL: func @multiplelaunches + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @multiplelaunches(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func} + %3 = "tf_device.launch"() ( { + %4 = "tf.B"(%2) : (tensor) -> tensor + "tf_device.return"(%4) : (tensor) -> () + }) {device = "tpu0"} : () -> tensor + + // CHECK: tf_executor.yield %[[C_OUTPUT]] + tf_executor.yield %3 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } + +// CHECK-LABEL: func @tpu0_func +// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) +// CHECK: return %[[TPU0_FUNC_B_OUTPUT]] +} + +// ----- + +// Tests that multiple `tf_device.launch` that depend on each other are +// correctly handled. + +module { + // CHECK-LABEL: func @multiplelaunches + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @multiplelaunches(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"(%[[ARG_0]]) + %2 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[A_OUTPUT]]) {device = "tpu0", func = @tpu0_func} + %3 = "tf_device.launch"() ( { + %6 = "tf.B"(%2) : (tensor) -> tensor + "tf_device.return"(%6) : (tensor) -> () + }) {device = "tpu0"} : () -> tensor + + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[C_OUTPUT]]) + %4 = "tf.D"(%3) : (tensor) -> tensor + + // CHECK: %[[E_OUTPUT:[0-9]*]] = "tf_device.launch_func"(%[[C_OUTPUT]], %[[D_OUTPUT]]) {device = "gpu0", func = @gpu0_func} + %5 = "tf_device.launch"() ( { + %6 = "tf.E"(%3) : (tensor) -> tensor + %7 = "tf.F"(%4, %6) : (tensor, tensor) -> tensor + "tf_device.return"(%7) : (tensor) -> () + }) {device = "gpu0"} : () -> tensor + + // CHECK: tf_executor.yield %[[E_OUTPUT]] + tf_executor.yield %5 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } + +// CHECK-LABEL: func @tpu0_func +// CHECK-SAME: (%[[TPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[TPU0_FUNC_B_OUTPUT:[0-9]*]] = "tf.B"(%[[TPU0_FUNC_ARG_0]]) +// CHECK: return %[[TPU0_FUNC_B_OUTPUT]] + +// CHECK-LABEL: func @gpu0_func +// CHECK-SAME: (%[[GPU0_FUNC_ARG_0:[a-z0-9]*]]: tensor, %[[GPU0_FUNC_ARG_1:[a-z0-9]*]]: tensor) -> tensor +// CHECK: %[[GPU0_FUNC_E_OUTPUT:[0-9]*]] = "tf.E"(%[[GPU0_FUNC_ARG_0]]) +// CHECK: %[[GPU0_FUNC_F_OUTPUT:[0-9]*]] = "tf.F"(%[[GPU0_FUNC_ARG_1]], %[[GPU0_FUNC_E_OUTPUT]]) +// CHECK: return %[[GPU0_FUNC_F_OUTPUT]] +} + +// ----- + +// Tests outlining launches with no live-in values. + +module { + // CHECK-LABEL: func @multiplelaunches + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) + func @multiplelaunches(%arg0: tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf_device.launch_func"() {device = "tpu0", func = @tpu0_func} + %2 = "tf_device.launch"() ( { + %3 = "tf.A"() : () -> tensor + "tf_device.return"(%3) : (tensor) -> () + }) {device = "tpu0"} : () -> tensor + + // CHECK: tf_executor.yield %[[A_OUTPUT]] + tf_executor.yield %2 : tensor + } + tf_executor.fetch %1#0 : tensor + } + return %0 : tensor + } + +// CHECK-LABEL: func @tpu0_func +// CHECK-SAME: () -> tensor +// CHECK: %[[TPU0_FUNC_A_OUTPUT:[0-9]*]] = "tf.A"() +// CHECK: return %[[TPU0_FUNC_A_OUTPUT]] +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 51aaf6edad4..115d39d7701 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -42,3 +42,38 @@ func @tfConst() -> (tensor<4xf32>, tensor<1x1x6x2xf32>) { // CHECK-DAG: constant dense<0.242886767> : tensor<1x1x6x2xf32> return %0, %21 : tensor<4xf32>, tensor<1x1x6x2xf32> } + +// CHECK-LABEL: func @testAdd() -> tensor<2x2xi32> +func @testAdd() -> tensor<2x2xi32> { +^bb0: + %0 = constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32> + %1 = constant dense<1> : tensor<2xi32> + %2 = "tf.Add"(%0, %1) {device = "", name = "add"} : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + // CHECK: [[cst:%.*]] = constant dense<{{\[\[}}1, 2], {{\[}}3, 4]]> : tensor<2x2xi32> + // CHECK-NEXT: return [[cst]] : tensor<2x2xi32> + return %2: tensor<2x2xi32> +} + +// Ops with side effects should not get constant folded. +// CHECK-LABEL: func @testSideEffectOp() -> tensor<3xf32> +func @testSideEffectOp() -> tensor<3xf32> { + %0 = constant dense<[3]> : tensor<1xi32> + %1 = "tf.RandomUniform"(%0) {device = "", seed = 3 : i64, seed2 = 5 : i64} : (tensor<1xi32>) -> tensor<3xf32> + // CHECK: %[[random:.*]] = "tf.RandomUniform" + // CHECK: return %[[random]] + return %1: tensor<3xf32> +} + +// Ops with unimplemnted attributes which couldn't be added to the TFE_Op. +// CHECK-LABEL: func @testUnimplementedOp() -> (tensor, tensor) +func @testUnimplementedOp() -> (tensor, tensor) { + %0 = constant dense<1> : tensor + %1 = constant dense<2> : tensor + %2 = "tf.Maximum"(%0, %1) {_output_shapes = ["tfshape$"]} : (tensor, tensor) -> tensor + %3 = "tf.Minimum"(%0, %1) {random_attr = "hello"} : (tensor, tensor) -> tensor + return %2, %3: tensor, tensor + +// CHECK-NEXT: %[[CST:.*]] = constant +// CHECK-NEXT: %[[CST1:.*]] = constant +// CHECK-NEXT: return %[[CST]], %[[CST1]] +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir index b1a9dd71fc7..48f4c8f77df 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/control_to_executor_dialect.mlir @@ -79,7 +79,7 @@ func @LoopTest() { // CHECK-NEXT: %{{[0-9]*}} = "tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor) -> tensor<*xi32> // CHECK-NEXT: tf_executor.yield %{{[0-9]*}} : tensor<*xi32> // CHECK-NEXT: } -// CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} +// CHECK-NEXT: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} // CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ADD]]#0, %[[CT]] : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} // CHECK-NEXT: tf_executor.fetch // CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir new file mode 100644 index 00000000000..4a4aa277067 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/empty-main.mlir @@ -0,0 +1,15 @@ +// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --check-prefix=CONTROL --dump-input=fail +// RUN: tf-opt -tf-control-to-executor-conversion %s | FileCheck %s --check-prefix=EXECUTOR --dump-input=fail + +// CONTROL-LABEL: func @main +// CONTROL-NEXT: return + +// EXECUTOR-LABEL: func @main +// EXECUTOR-NEXT: tf_executor.graph { +// EXECUTOR-NEXT: tf_executor.fetch +// EXECUTOR-NEXT: } +// EXECUTOR-NEXT: return + +func @main() { + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir new file mode 100644 index 00000000000..5b4e8e16cbb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_canonicalize.mlir @@ -0,0 +1,349 @@ +// RUN: tf-opt %s -canonicalize | FileCheck %s --dump-input=fail + + +// Test single graph with no outputs and one island is folded away. +// CHECK-LABEL: func @graph_with_no_outputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @graph_with_no_outputs(%arg0 : tensor) { + tf_executor.graph { + %1:2 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + %4 = "tf.opB"(%3) : (tensor) -> tensor + tf_executor.yield %3 : tensor + } + tf_executor.fetch + } + return +} + +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: return + + +// Test single graph with some outputs and one island is folded away. +// CHECK-LABEL: func @graph_with_outputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @graph_with_outputs(%arg0 : tensor) -> (tensor, tensor) { + %0:3 = tf_executor.graph { + %1:4 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + %4 = "tf.opB"(%3) : (tensor) -> tensor + %5 = "tf.opC"(%4) : (tensor) -> tensor + tf_executor.yield %3, %5, %4 : tensor, tensor, tensor + } + tf_executor.fetch %1#1, %1#0, %1#2 : tensor, tensor, tensor + } + return %0#2, %0#1 : tensor, tensor +} + +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: "tf.opC"(%[[OP_B]]) +// CHECK-NEXT: return %[[OP_B]], %[[OP_A]] : tensor, tensor + + +// Test nested graphs and islands. +// CHECK-LABEL: func @nested_graph +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @nested_graph(%arg0 : tensor) -> (tensor, tensor) { + %0:3 = tf_executor.graph { + %1:4 = tf_executor.island { + %2:3 = tf_executor.graph { + %3:4 = tf_executor.island { + %4 = "tf.opA"(%arg0) : (tensor) -> tensor + %5 = "tf.opB"(%4) : (tensor) -> tensor + %6 = "tf.opC"(%5) : (tensor) -> tensor + tf_executor.yield %4, %6, %5 : tensor, tensor, tensor + } + tf_executor.fetch %3#2, %3#0, %3#1 : tensor, tensor, tensor + } + tf_executor.yield %2#1, %2#1, %2#0 : tensor, tensor, tensor + } + tf_executor.fetch %1#1, %1#0, %1#2 : tensor, tensor, tensor + } + return %0#2, %0#1 : tensor, tensor +} + +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: "tf.opC"(%[[OP_B]]) +// CHECK-NEXT: return %[[OP_B]], %[[OP_A]] : tensor, tensor + + +// Test single graph with multiple islands is unmodified. +// CHECK-LABEL: func @graph_with_multiple_islands +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @graph_with_multiple_islands(%arg0 : tensor) -> (tensor, tensor) { + %0:3 = tf_executor.graph { + %1:4 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + %4 = "tf.opB"(%3) : (tensor) -> tensor + %5 = "tf.opC"(%4) : (tensor) -> tensor + tf_executor.yield %3, %5, %4 : tensor, tensor, tensor + } + %6:3 = tf_executor.island { + %7 = "tf.opD"(%arg0) : (tensor) -> tensor + %8 = "tf.opE"(%7) : (tensor) -> tensor + tf_executor.yield %8, %7 : tensor, tensor + } + tf_executor.fetch %1#1, %1#0, %6#0 : tensor, tensor, tensor + } + return %0#2, %0#1 : tensor, tensor +} + +// CHECK-NEXT: %[[GRAPH:[0-9]*]]:3 = tf_executor.graph { +// CHECK-NEXT: %[[ISLAND_0:[0-9]*]]:4 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]]) +// CHECK-NEXT: tf_executor.yield %[[OP_A]], %[[OP_C]], %[[OP_B]] : tensor, tensor, tensor +// CHECK: %[[ISLAND_1:[0-9]*]]:3 = tf_executor.island { +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]]) +// CHECK-NEXT: tf_executor.yield %[[OP_E]], %[[OP_D]] : tensor, tensor +// CHECK: tf_executor.fetch %[[ISLAND_0]]#1, %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor, tensor, tensor +// CHECK: return %[[GRAPH]]#2, %[[GRAPH]]#1 : tensor, tensor + + +// Test single graph with an island and executor ops is unmodified. +// CHECK-LABEL: func @graph_with_island_and_executor_op +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @graph_with_island_and_executor_op(%arg0 : tensor) -> (tensor, tensor) { + %0:3 = tf_executor.graph { + %1:4 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + %4 = "tf.opB"(%3) : (tensor) -> tensor + %5 = "tf.opC"(%4) : (tensor) -> tensor + tf_executor.yield %3, %5, %4 : tensor, tensor, tensor + } + %6:2 = tf_executor.LoopCond %1#0 : tensor + tf_executor.fetch %1#1, %1#0, %6#0 : tensor, tensor, tensor + } + return %0#2, %0#1 : tensor, tensor +} + +// CHECK-NEXT: %[[GRAPH:[0-9]*]]:3 = tf_executor.graph { +// CHECK-NEXT: %[[ISLAND:[0-9]*]]:4 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]]) +// CHECK-NEXT: tf_executor.yield %[[OP_A]], %[[OP_C]], %[[OP_B]] : tensor, tensor, tensor +// CHECK: %[[LOOP_COND:[0-9]*]]:2 = tf_executor.LoopCond %[[ISLAND]]#0 +// CHECK-NEXT: tf_executor.fetch %[[ISLAND]]#1, %[[ISLAND]]#0, %[[LOOP_COND]]#0 : tensor, tensor, tensor +// CHECK: return %[[GRAPH]]#2, %[[GRAPH]]#1 : tensor, tensor + + +// Test multiple graphs collapsed. +// CHECK-LABEL: func @multiple_graphs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @multiple_graphs(%arg0 : tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) { + %0:4 = tf_executor.graph { + %2:4 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + %4 = "tf.opB"(%3) : (tensor) -> tensor + %5 = "tf.opC"(%4) : (tensor) -> tensor + tf_executor.yield %3, %5, %4 : tensor, tensor, tensor + } + tf_executor.fetch %arg0, %2#0, %2#1, %2#2 : tensor, tensor, tensor, tensor + } + %1:3 = tf_executor.graph { + %6:3 = tf_executor.island { + %7 = "tf.opD"(%arg0) : (tensor) -> tensor + %8 = "tf.opE"(%7) : (tensor) -> tensor + tf_executor.yield %8, %7 : tensor, tensor + } + tf_executor.fetch %arg0, %6#0, %6#1 : tensor, tensor, tensor + } + return %1#1, %1#0, %1#2, %0#1, %0#0, %0#3 : tensor, tensor, tensor, tensor, tensor, tensor +} + +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: "tf.opC"(%[[OP_B]]) +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]]) +// CHECK-NEXT: return %[[OP_E]], %[[ARG_0]], %[[OP_D]], %[[OP_A]], %[[ARG_0]], %[[OP_B]] : tensor, tensor, tensor, tensor, tensor, tensor + + +// Test empty graph with no outputs. +// CHECK-LABEL: func @empty_graph_with_no_outputs +func @empty_graph_with_no_outputs() { + tf_executor.graph { + tf_executor.fetch + } + return +} + +// CHECK-NEXT: return + + +// Test empty graph with some outputs. +// CHECK-LABEL: func @empty_graph_with_outputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @empty_graph_with_outputs(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + tf_executor.fetch %arg1, %arg0 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-NEXT: return %[[ARG_1]], %[[ARG_0]] : tensor, tensor + + +// Test multiple empty graphs. +// CHECK-LABEL: func @empty_graphs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @empty_graphs(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0 = tf_executor.graph { + tf_executor.fetch %arg1 : tensor + } + tf_executor.graph { + tf_executor.fetch + } + %1 = tf_executor.graph { + tf_executor.fetch %arg0 : tensor + } + return %0, %1 : tensor, tensor +} + +// CHECK-NEXT: return %[[ARG_1]], %[[ARG_0]] : tensor, tensor + + +// Test empty graphs and graphs with a single island. +// CHECK-LABEL: func @empty_and_filled_graphs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @empty_and_filled_graphs(%arg0 : tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) { + %0:4 = tf_executor.graph { + %2:4 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + %4 = "tf.opB"(%3) : (tensor) -> tensor + %5 = "tf.opC"(%4) : (tensor) -> tensor + tf_executor.yield %3, %5, %4 : tensor, tensor, tensor + } + tf_executor.fetch %arg0, %2#0, %2#1, %2#2 : tensor, tensor, tensor, tensor + } + tf_executor.graph { + tf_executor.fetch + } + %1:3 = tf_executor.graph { + %6:3 = tf_executor.island { + %7 = "tf.opD"(%arg0) : (tensor) -> tensor + %8 = "tf.opE"(%7) : (tensor) -> tensor + tf_executor.yield %8, %7 : tensor, tensor + } + tf_executor.fetch %arg0, %6#0, %6#1 : tensor, tensor, tensor + } + %9 = tf_executor.graph { + tf_executor.fetch %arg0 : tensor + } + return %1#1, %1#0, %9, %0#1, %0#0, %0#3 : tensor, tensor, tensor, tensor, tensor, tensor +} + +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: "tf.opC"(%[[OP_B]]) +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]]) +// CHECK-NEXT: return %[[OP_E]], %[[ARG_0]], %[[ARG_0]], %[[OP_A]], %[[ARG_0]], %[[OP_B]] : tensor, tensor, tensor, tensor, tensor, tensor + + +// Test single empty island in graph with control output in graph fetch results +// in graph being removed. +// CHECK-LABEL: func @single_empty_island_single_graph_control +func @single_empty_island_single_graph_control() { + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch %0 : !tf_executor.control + } + return +} + +// CHECK-NEXT: return + + +// Test empty island with no operands and no data result user is removed. +// Control result users should also have their respective operands removed. +// CHECK-LABEL: func @empty_island_no_operand_no_data_result +func @empty_island_no_operand_no_data_result() { + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + %1 = tf_executor.island(%0) { + %3 = "tf.opA"() : () -> tensor + tf_executor.yield + } + %2 = tf_executor.island(%0, %1) { + %4 = "tf.opB"() : () -> tensor + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND_0:[0-9]*]] = tf_executor.island { +// CHECK-NEXT: "tf.opA" +// CHECK: tf_executor.island(%[[ISLAND_0]]) { +// CHECK-NEXT: "tf.opB" +// CHECK-NOT: tf_executor.island + + +// Test empty island with one operand and no data results is removed and the +// operand is forwarded to its control result users. +// CHECK-LABEL: func @empty_island_one_operand_no_data_result +func @empty_island_one_operand_no_data_result() { + tf_executor.graph { + %0 = tf_executor.island { + %3 = "tf.opA"() : () -> tensor + tf_executor.yield + } + %1 = tf_executor.island(%0) { + tf_executor.yield + } + %2 = tf_executor.island(%1) { + %4 = "tf.opB"() : () -> tensor + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND_1:[0-9]*]] = tf_executor.island { +// CHECK-NEXT: "tf.opA" +// CHECK: tf_executor.island(%[[ISLAND_1]]) { +// CHECK-NEXT: "tf.opB" +// CHECK-NOT: tf_executor.island + + +// Test empty island with no operands, one data result and no control result +// users is removed and its data result forwarded to its users. +// CHECK-LABEL: func @empty_island_no_operand_one_data_no_control_result +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @empty_island_no_operand_one_data_no_control_result(%arg0 : tensor) { + tf_executor.graph { + %0:2 = tf_executor.island() { + tf_executor.yield %arg0 : tensor + } + %1 = tf_executor.island { + %3 = "tf.opA"(%0#0) : (tensor) -> tensor + tf_executor.yield + } + %2 = tf_executor.island() { + %4 = "tf.opB"(%0#0) : (tensor) -> tensor + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: tf_executor.island { +// CHECK-NEXT: "tf.opA"(%[[ARG_0]]) +// CHECK: tf_executor.island { +// CHECK-NEXT: "tf.opB"(%[[ARG_0]]) +// CHECK-NOT: tf_executor.island diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir new file mode 100644 index 00000000000..a9e83dd006c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_island_coarsening.mlir @@ -0,0 +1,460 @@ +// RUN: tf-opt %s -tf-executor-island-coarsening | FileCheck %s --dump-input=fail + + +// Test that islands linked by a control dependency are merged. +// CHECK-LABEL: func @control_input +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @control_input(%arg0 : tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %3 : tensor + } + %2:2 = tf_executor.island(%1#1) { + %4 = "tf.opB"() : () -> tensor + tf_executor.yield %4 : tensor + } + tf_executor.fetch %2#0 : tensor + } + return %0 : tensor +} + +// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB" +// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor +// CHECK: tf_executor.fetch %[[ISLAND]]#0 : tensor + + +// Test that islands linked by a data dependency are merged. +// CHECK-LABEL: func @data_input +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @data_input(%arg0 : tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %3 : tensor + } + %2:2 = tf_executor.island { + %4 = "tf.opB"(%1#0) : (tensor) -> tensor + tf_executor.yield %4 : tensor + } + tf_executor.fetch %2#0 : tensor + } + return %0 : tensor +} + +// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor +// CHECK: tf_executor.fetch %[[ISLAND]]#0 : tensor + + +// Test empty/trivial islands are merged. +// CHECK-LABEL: func @empty_islands +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @empty_islands(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %1:2 = tf_executor.island { + tf_executor.yield %arg1 : tensor + } + %2:2 = tf_executor.island { + tf_executor.yield %arg0 : tensor + } + %3:2 = tf_executor.island { + tf_executor.yield %1#0 : tensor + } + %4:2 = tf_executor.island { + tf_executor.yield %2#0 : tensor + } + %5:3 = tf_executor.island { + %10:2 = "tf.opA"(%3#0, %4#0) : (tensor, tensor) -> (tensor, tensor) + tf_executor.yield %10#0, %10#1 : tensor, tensor + } + %6:2 = tf_executor.island { + tf_executor.yield %5#0 : tensor + } + %7:2 = tf_executor.island { + tf_executor.yield %5#1 : tensor + } + %8:3 = tf_executor.island { + tf_executor.yield %6#0, %7#0 : tensor, tensor + } + %9 = tf_executor.island(%8#2) { + tf_executor.yield + } + tf_executor.fetch %8#0, %8#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]]:2 = "tf.opA"(%[[ARG_1]], %[[ARG_0]]) +// CHECK-NEXT: tf_executor.yield %[[OP_A]]#0, %[[OP_A]]#1 : tensor, tensor +// CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor, tensor + + +// Test merging islands handle merging results. +// CHECK-LABEL: func @multiple_outputs +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @multiple_outputs(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %1:2 = tf_executor.island { + %3 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %3 : tensor + } + %2:2 = tf_executor.island(%1#1) { + %4 = "tf.opB"(%arg1) : (tensor) -> tensor + tf_executor.yield %4 : tensor + } + tf_executor.fetch %1#0, %2#0 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK: %[[ISLAND:[0-9]*]]:3 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_1]]) +// CHECK-NEXT: tf_executor.yield %[[OP_A]], %[[OP_B]] : tensor, tensor +// CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor, tensor + + +// Test merging islands with multiple inner ops. +// CHECK-LABEL: func @multi_op_regions +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @multi_op_regions(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0, %arg0) : (tensor, tensor) -> tensor + %3 = "tf.opB"(%2, %arg0) : (tensor, tensor) -> tensor + tf_executor.yield %3 : tensor + } + %4:2 = tf_executor.island { + %5 = "tf.opC"(%1#0, %arg1) : (tensor, tensor) -> tensor + %6 = "tf.opD"(%5, %arg0) : (tensor, tensor) -> tensor + tf_executor.yield %6 : tensor + } + tf_executor.fetch %4#0 : tensor + } + return %0 : tensor +} + +// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]], %[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]], %[[ARG_0]]) +// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]], %[[ARG_1]]) +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]], %[[ARG_0]]) +// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor +// CHECK: tf_executor.fetch %[[ISLAND]]#0 : tensor + + +// Test merging multiple islands with multiple inner ops preserves order. +// CHECK-LABEL: func @transitive_preserve_order +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @transitive_preserve_order(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0, %arg0) : (tensor, tensor) -> tensor + %3 = "tf.opB"(%2, %arg0) : (tensor, tensor) -> tensor + tf_executor.yield %3 : tensor + } + %4:2 = tf_executor.island { + %5 = "tf.opC"(%1#0, %arg1) : (tensor, tensor) -> tensor + %6 = "tf.opD"(%5, %arg0) : (tensor, tensor) -> tensor + tf_executor.yield %6 : tensor + } + %7:2 = tf_executor.island { + %8 = "tf.opE"(%4#0, %1#0) : (tensor, tensor) -> tensor + %9 = "tf.opF"(%8, %8) : (tensor, tensor) -> tensor + tf_executor.yield %9 : tensor + } + tf_executor.fetch %7#0 : tensor + } + return %0 : tensor +} + +// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]], %[[ARG_0]]) +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]], %[[ARG_0]]) +// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_B]], %[[ARG_1]]) +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]], %[[ARG_0]]) +// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[OP_D]], %[[OP_B]]) +// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[OP_E]], %[[OP_E]]) +// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor +// CHECK: tf_executor.fetch %[[ISLAND]]#0 : tensor + + +// Test if islands can be merged when non dependent islands are interleaved. +// CHECK-LABEL: func @islands_interleaved +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor, %[[ARG_1:[a-z0-9]*]]: tensor) +func @islands_interleaved(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %1:2 = tf_executor.island { + %7 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %7 : tensor + } + %2:2 = tf_executor.island { + %8 = "tf.opB"(%arg1) : (tensor) -> tensor + tf_executor.yield %8 : tensor + } + %3:2 = tf_executor.island { + %9 = "tf.opC"(%1#0) : (tensor) -> tensor + tf_executor.yield %9 : tensor + } + %4:2 = tf_executor.island { + %10 = "tf.opD"(%2#0) : (tensor) -> tensor + tf_executor.yield %10 : tensor + } + %5:2 = tf_executor.island(%3#1) { + %11 = "tf.opE"(%arg0) : (tensor) -> tensor + tf_executor.yield %11 : tensor + } + %6:2 = tf_executor.island { + %12 = "tf.opF"(%arg1) : (tensor) -> tensor + tf_executor.yield %12 : tensor + } + tf_executor.fetch %4#0, %3#0 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB"(%[[ARG_1]]) +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[OP_B]]) +// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor +// CHECK: %[[ISLAND_1:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA"(%[[ARG_0]]) +// CHECK-NEXT: %[[OP_C:[0-9]*]] = "tf.opC"(%[[OP_A]]) +// CHECK-NEXT: %{{[0-9]*}} = "tf.opE"(%[[ARG_0]]) +// CHECK-NEXT: tf_executor.yield %[[OP_C]] : tensor +// CHECK: tf_executor.island { +// CHECK-NEXT: %[[OP_F:[0-9]*]] = "tf.opF"(%[[ARG_1]]) +// CHECK-NEXT: tf_executor.yield %[[OP_F]] : tensor +// CHECK: tf_executor.fetch %[[ISLAND_0]]#0, %[[ISLAND_1]]#0 : tensor, tensor + + +// Test only islands are merged when other tf_executor ops are interleaved. +// CHECK-LABEL: func @merge_islands_only +func @merge_islands_only() { + tf_executor.graph { + %0:2 = tf_executor.island { + %14 = "tf.opA"() : () -> tensor + tf_executor.yield %14 : tensor + } + %1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor) -> (tensor<*xi32>, !tf_executor.control) + %2 = tf_executor.island { + "tf.opB"() : () -> () + tf_executor.yield + } + %3:3 = tf_executor.NextIteration.Source : tensor<*xi32> + %4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> + %5:2 = tf_executor.island(%4#2) { + %15 = "tf.opC"() : () -> tensor + tf_executor.yield %15 : tensor + } + %6:2 = tf_executor.island { + %16 = "tf.opD"(%4#0, %5#0) : (tensor<*xi32>, tensor) -> tensor<*xi1> + tf_executor.yield %16 : tensor<*xi1> + } + %7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor, !tf_executor.control) + %8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> + %9:2 = tf_executor.Exit %8#0 : tensor<*xi32> + %10:2 = tf_executor.island { + %17 = "tf.opE"(%8#1) : (tensor<*xi32>) -> tensor<*xi32> + tf_executor.yield %17 : tensor<*xi32> + } + %11:2 = tf_executor.island(%10#1) { + %18 = "tf.opF"() : () -> tensor + tf_executor.yield %18 : tensor + } + %12:2 = tf_executor.island { + %19 = "tf.opG"(%10#0, %11#0) : (tensor<*xi32>, tensor) -> tensor<*xi32> + tf_executor.yield %19 : tensor<*xi32> + } + %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 + tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:.*]] = "tf.opA" +// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor +// CHECK: %[[ENTER:[0-9]*]]:2 = tf_executor.Enter %[[ISLAND_0]]#0 +// CHECK-NEXT: %[[ISLAND_1:[0-9]*]] = tf_executor.island { +// CHECK-NEXT: "tf.opB"() +// CHECK-NEXT: tf_executor.yield +// CHECK: %[[NEXTIT_SRC:[0-9]*]]:3 = tf_executor.NextIteration.Source +// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = tf_executor.Merge %[[NEXTIT_SRC]]#0, %[[ENTER]]#0 +// CHECK-NEXT: %[[ISLAND_2:[0-9]*]]:2 = tf_executor.island(%[[MERGE]]#2) { +// CHECK-NEXT: %[[OP_C:.*]] = "tf.opC" +// CHECK-NEXT: %[[OP_D:[0-9]*]] = "tf.opD"(%[[MERGE]]#0, %[[OP_C]]) +// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<*xi1> +// CHECK: %[[COND:[0-9]*]]:2 = tf_executor.LoopCond %[[ISLAND_2:[0-9]*]]#0 +// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = tf_executor.Switch %[[MERGE]]#0, %[[COND]]#0 +// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = tf_executor.Exit %[[SWITCH]]#0 +// CHECK-NEXT: %[[ISLAND_3:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_E:[0-9]*]] = "tf.opE"(%[[SWITCH]]#1) +// CHECK-NEXT: %[[OP_F:.*]] = "tf.opF" +// CHECK-NEXT: %[[OP_G:[0-9]*]] = "tf.opG"(%[[OP_E]], %[[OP_F]]) +// CHECK-NEXT: tf_executor.yield %[[OP_G]] : tensor<*xi32> +// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND_1]], %[[ISLAND_3]]#1, %[[EXIT]]#1 +// CHECK-NEXT: tf_executor.NextIteration.Sink [%[[NEXTIT_SRC]]#1] %[[ISLAND_3]]#0, %[[CT]] + + +// Test no merging took place as cycle would be formed otherwise. +// CHECK-LABEL: func @simple_potential_cycle +func @simple_potential_cycle() { + tf_executor.graph { + %0:2 = tf_executor.island { + %3 = "tf.opA"() : () -> tensor<1xf32> + tf_executor.yield %3 : tensor<1xf32> + } + %1 = tf_executor.ControlTrigger %0#1 + %2:3 = tf_executor.island(%1) { + %4 = "tf.opB"() : () -> tensor<1xf32> + tf_executor.yield %0#0, %4 : tensor<1xf32>, tensor<1xf32> + } + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: %[[OP_A:[0-9]*]] = "tf.opA" +// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32> +// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger %[[ISLAND]]#1 +// CHECK-NEXT: tf_executor.island(%[[CT]]) { +// CHECK-NEXT: %[[OP_B:[0-9]*]] = "tf.opB" +// CHECK-NEXT: tf_executor.yield %[[ISLAND]]#0, %[[OP_B]] : tensor<1xf32>, tensor<1xf32> + + +// Test if island was merged into its result. +// CHECK-LABEL: func @merge_into_result +func @merge_into_result() { + tf_executor.graph { + %0:2 = tf_executor.island { + %3 = "tf.opA"() : () -> tensor<1xf32> + tf_executor.yield %3 : tensor<1xf32> + } + %1 = tf_executor.ControlTrigger {} + %2:3 = tf_executor.island(%1) { + %4 = "tf.opB"() : () -> tensor<1xf32> + tf_executor.yield %0#0, %4 : tensor<1xf32>, tensor<1xf32> + } + tf_executor.fetch + } + return +} + +// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger +// CHECK-NEXT: tf_executor.island(%[[CT]]) { +// CHECK-NEXT: "tf.opA" +// CHECK-NEXT: "tf.opB" +// CHECK-NEXT: tf_executor.yield + + +// Test merging island into data result nested in a graph of another island. +// CHECK-LABEL: func @merge_into_nested_data_result +func @merge_into_nested_data_result() { + tf_executor.graph { + %0:2 = tf_executor.island { + %1 = "tf.opA"() : () -> tensor<1xf32> + tf_executor.yield %1 : tensor<1xf32> + } + %2:2 = tf_executor.island { + %3 = tf_executor.graph { + %4 = tf_executor.ControlTrigger {} + %5:2 = tf_executor.island(%4) { + %6 = "tf.opB"(%0#0) : (tensor<1xf32>) -> tensor<1xf32> + tf_executor.yield %6 : tensor<1xf32> + } + tf_executor.fetch %5#0 : tensor<1xf32> + } + tf_executor.yield %3 : tensor<1xf32> + } + tf_executor.fetch + } + return +} + +// CHECK: tf_executor.island { +// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA" +// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph { +// CHECK-NEXT: [[CT:[0-9]*]] = tf_executor.ControlTrigger +// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island(%[[CT]]) { +// CHECK-NEXT: [[OP_B:[0-9]*]] = "tf.opB"(%[[OP_A]]) +// CHECK-NEXT: tf_executor.yield %[[OP_B]] : tensor<1xf32> +// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32> +// CHECK: tf_executor.yield + + +// Test merging islands in a nested graph. +// CHECK-LABEL: func @merge_islands_inner_graph +func @merge_islands_inner_graph() { + tf_executor.graph { + %0:2 = tf_executor.island { + %1 = "tf.opA"() : () -> tensor<1xf32> + tf_executor.yield %1 : tensor<1xf32> + } + %2:2 = tf_executor.island { + %3 = tf_executor.graph { + %4:2 = tf_executor.island { + %5 = "tf.opB"() : () -> tensor<1xf32> + tf_executor.yield %5 : tensor<1xf32> + } + %6:2 = tf_executor.island { + %7 = "tf.opC"() : () -> tensor<1xf32> + tf_executor.yield %7 : tensor<1xf32> + } + %8:2 = tf_executor.island(%4#1) { + %9 = "tf.opD"(%6#0) : (tensor<1xf32>) -> tensor<1xf32> + tf_executor.yield %9 : tensor<1xf32> + } + tf_executor.fetch %8#0 : tensor<1xf32> + } + tf_executor.yield %3 : tensor<1xf32> + } + tf_executor.fetch + } + return +} + +// CHECK: tf_executor.island { +// CHECK-NEXT: [[OP_A:[0-9*]]] = "tf.opA" +// CHECK-NEXT: tf_executor.yield %[[OP_A]] : tensor<1xf32> +// CHECK: tf_executor.island { +// CHECK-NEXT: [[INNER_GRAPH:[0-9]*]] = tf_executor.graph { +// CHECK-NEXT: [[ISLAND_1:[0-9]*]]:2 = tf_executor.island { +// CHECK-NEXT: "tf.opB" +// CHECK-NEXT: [[OP_C:[0-9]*]] = "tf.opC" +// CHECK-NEXT: [[OP_D:[0-9]*]] = "tf.opD"(%[[OP_C]]) +// CHECK-NEXT: tf_executor.yield %[[OP_D]] : tensor<1xf32> +// CHECK: tf_executor.fetch %[[ISLAND_1]]#0 : tensor<1xf32> +// CHECK: tf_executor.yield %[[INNER_GRAPH]] : tensor<1xf32> + + +// Test merging islands with control island operands and island results only if +// they are the closest ones. +// CHECK-LABEL: func @merge_islands_closest_control +func @merge_islands_closest_control() { + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + %1 = tf_executor.ControlTrigger %0 + %2 = tf_executor.ControlTrigger {} + %3 = tf_executor.island(%0, %2) { + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: %[[ISLAND:[0-9]*]] = tf_executor.island { +// CHECK: tf_executor.ControlTrigger %[[ISLAND]] +// CHECK: %[[CT:[0-9]*]] = tf_executor.ControlTrigger +// CHECK: tf_executor.island(%[[ISLAND]], %[[CT]]) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir new file mode 100644 index 00000000000..11b9b1a564d --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_to_control_dialect.mlir @@ -0,0 +1,99 @@ +// RUN: tf-opt -tf-executor-to-control-conversion %s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @LoopTest() { +func @LoopTest() { + tf_executor.graph { + %0:2 = tf_executor.island { + %cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor} : () -> tensor + tf_executor.yield %cst : tensor + } + %1:2 = tf_executor.Enter %0#0 frame "while/while_context" : (tensor) -> (tensor<*xi32>, !tf_executor.control) {T = "tfdtype$DT_INT32", device = "", name = "while/Enter"} + %2 = tf_executor.island { + "tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> () + tf_executor.yield + } + %3:3 = tf_executor.NextIteration.Source : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} + %4:3 = tf_executor.Merge %3#0, %1#0 : tensor<*xi32> {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} + %5:2 = tf_executor.island(%4#2) { + %cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor} : () -> tensor + tf_executor.yield %cst : tensor + } + %6:2 = tf_executor.island { + %14 = "tf.Less"(%4#0, %5#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor) -> tensor<*xi1> + tf_executor.yield %14 : tensor<*xi1> + } + %7:2 = tf_executor.LoopCond %6#0 : (tensor<*xi1>) -> (tensor, !tf_executor.control) {device = "", name = "while/LoopCond"} + %8:3 = tf_executor.Switch %4#0, %7#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} + %9:2 = tf_executor.Exit %8#0 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} + %10:2 = tf_executor.island { + %14 = "tf.Identity"(%8#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> tensor<*xi32> + tf_executor.yield %14 : tensor<*xi32> + } + %11:2 = tf_executor.island(%10#1) { + %cst = "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor} : () -> tensor + tf_executor.yield %cst : tensor + } + %12:2 = tf_executor.island { + %14 = "tf.Add"(%10#0, %11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor) -> tensor<*xi32> + tf_executor.yield %14 : tensor<*xi32> + } + %13 = tf_executor.ControlTrigger %2, %12#1, %9#1 {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} + tf_executor.NextIteration.Sink [%3#1] %12#0, %13 : tensor<*xi32> {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} + tf_executor.fetch + } + return +} + +// CHECK-NEXT: %[[CONST:[0-9]*]]:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<1> : tensor} : () -> (tensor, !_tf.control) +// CHECK-NEXT: %[[ENTER:[0-9]*]]:2 = "_tf.Enter"(%[[CONST]]#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[NOOP:[0-9]*]] = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control +// CHECK-NEXT: %[[SOURCE:[0-9]*]]:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[MERGE:[0-9]*]]:3 = "_tf.Merge"(%[[SOURCE]]#0, %[[ENTER]]#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor, !_tf.control) +// CHECK-NEXT: %[[CONST_LESS:[0-9]*]]:2 = "_tf.Const"(%[[MERGE]]#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<2> : tensor} : (!_tf.control) -> (tensor, !_tf.control) +// CHECK-NEXT: %[[LESS:[0-9]*]]:2 = "_tf.Less"(%[[MERGE]]#0, %[[CONST_LESS]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor) -> (tensor<*xi1>, !_tf.control) +// CHECK-NEXT: %[[COND:[0-9]*]]:2 = "_tf.LoopCond"(%[[LESS]]#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor, !_tf.control) +// CHECK-NEXT: %[[SWITCH:[0-9]*]]:3 = "_tf.Switch"(%[[MERGE]]#0, %[[COND]]#0) {T = "tfdtype$DT_INT32", _class = ["loc = @while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[EXIT:[0-9]*]]:2 = "_tf.Exit"(%[[SWITCH]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[IDENTITY:[0-9]*]]:2 = "_tf.Identity"(%[[SWITCH]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[CONST_ADD:[0-9]*]]:2 = "_tf.Const"(%[[IDENTITY]]#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<3> : tensor} : (!_tf.control) -> (tensor, !_tf.control) +// CHECK-NEXT: %[[ADD:[0-9]*]]:2 = "_tf.Add"(%[[IDENTITY]]#0, %[[CONST_ADD]]#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[CT:[0-9]*]] = "_tf.ControlTrigger"(%[[NOOP]], %[[ADD]]#1, %[[EXIT]]#1) {_tpu_replicate = "cluster", device = "", name = "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync"} : (!_tf.control, !_tf.control, !_tf.control) -> !_tf.control +// CHECK-NEXT: %[[SINK:[0-9]*]] = "_tf.NextIteration.sink"(%[[ADD]]#0, %[[CT]]) {T = "tfdtype$DT_INT32", device = "", id = 0 : i64, name = "while/NextIteration"} : (tensor<*xi32>, !_tf.control) -> !_tf.control +// CHECK-NEXT: return + +// CHECK-LABEL: func @multiple_ops_region +func @multiple_ops_region(%arg0 : tensor<*xi32>, %arg1 : tensor) { + tf_executor.graph { + %0:2 = tf_executor.island { + // The 4 operations are independent, but the current conversion will add + // control dependencies conservatively. + %1 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %3 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor) -> tensor<*xi32> + %4 = "tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor) -> tensor<*xi32> + tf_executor.yield %4 : tensor<*xi32> + } + tf_executor.fetch + } + return +} + +// CHECK-NEXT: %[[ADD1:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add1"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[ADD2:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD1]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add2"} : (tensor<*xi32>, tensor, !_tf.control) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[ADD3:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD2]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add3"} : (tensor<*xi32>, tensor, !_tf.control) -> (tensor<*xi32>, !_tf.control) +// CHECK-NEXT: %[[ADD4:[0-9]*]]:2 = "_tf.Add"(%arg0, %arg1, %[[ADD3]]#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Add4"} : (tensor<*xi32>, tensor, !_tf.control) -> (tensor<*xi32>, !_tf.control) + +// CHECK-LABEL: func @switchN( +func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %fetches = tf_executor.graph { + +// CHECK: [[S1:%.*]]:6 = "_tf._SwitchN"(%arg1, %arg0) {num_outs = 5 : i64} + %1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32> + +// CHECK: "_tf._SwitchN"(%arg1, %arg0, [[S1]]#5) {num_outs = 12 : i64} + %2:13 = tf_executor.SwitchN %arg1, %arg0 of 12 (%1#5) : tensor<*xf32> + + tf_executor.fetch %2#0 : tensor<*xf32> + } + return %fetches : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir index 82fc0171fa6..2a0434b69e0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-cfg.mlir @@ -7,7 +7,7 @@ func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> func @testIf1Result(tensor, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>): %1 = "tf.If"(%arg0, %arg1, %arg2) { - then_branch = @testIf1Then, else_branch = @testIf1Else + then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false } : (tensor, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> // CHECK: %0 = extract_element %arg0[] : tensor @@ -31,7 +31,7 @@ func @testIf3Else(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16> func @testIf3Result(tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>): %1:3 = "tf.If"(%arg0, %arg1) { - then_branch = @testIf3Then, else_branch = @testIf3Else + then_branch = @testIf3Then, else_branch = @testIf3Else, is_stateless = false } : (tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) // CHECK: %0 = extract_element %arg0[] : tensor @@ -57,7 +57,7 @@ func @testIf1Casts(tensor, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32 ^bb0(%arg0: tensor, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>): %1 = "tf.If"(%arg0, %arg1, %arg2) { - then_branch = @testIf1Then, else_branch = @testIf1Else + then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false } : (tensor, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> // CHECK: %0 = extract_element %arg0[] : tensor @@ -97,7 +97,7 @@ func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> { // expected-error @+1 {{only supports zero-D bool tensors now}} %1 = "tf.If"(%arg0, %arg1, %arg2) { - then_branch = @testIf1Then, else_branch = @testIf1Else + then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false } : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> @@ -113,7 +113,7 @@ func @testWhile2Body(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf func @testWhile2Result(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>): %1:2 = "tf.While"(%arg0, %arg1) { - cond = @testWhile2Cond, body = @testWhile2Body + cond = @testWhile2Cond, body = @testWhile2Body, is_stateless = false } : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) // CHECK: br ^bb1(%arg0, %arg1 : tensor<*xf32>, tensor<*xf32>) @@ -138,7 +138,7 @@ func @testWhile0Body() -> () func @testWhile0Result() { ^bb0: - "tf.While"() { cond = @testWhile0Cond, body = @testWhile0Body } : () -> () + "tf.While"() { cond = @testWhile0Cond, body = @testWhile0Body, is_stateless = false } : () -> () // CHECK: br ^bb1 // CHECK: ^bb1: // CHECK: %0 = call @testWhile0Cond() : () -> tensor @@ -162,7 +162,7 @@ func @testComplexWhile1Result(tensor<*xf32>) -> (tensor<*xf32>) { ^bb1(%0: tensor<*xf32>, %1: tensor<*xf32>): %2 = addf %0, %1 : tensor<*xf32> %3:2 = "tf.While"(%0, %2) { - cond = @testWhile2Cond, body = @testWhile2Body + cond = @testWhile2Cond, body = @testWhile2Body, is_stateless = false } : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) // CHECK: br ^bb2(%0, %2 : tensor<*xf32>, tensor<*xf32>) @@ -194,7 +194,7 @@ func @testWhileBody(tensor<*xf32>) -> (tensor) // CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>) func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor) { %0 = "tf.While"(%arg0) { - cond = @testWhileCond, body = @testWhileBody + cond = @testWhileCond, body = @testWhileBody, is_stateless = false } : (tensor<1x3xf32>) -> (tensor) // CHECK: %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir index 779fe9011ff..e13d5584c7f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if-fail.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 // CHECK: FunctionalizeControlFlowPass: Graph contains node with inputs predicated on incompatible predicates: {s(Cond:0,then)} and {s(Cond:0,else)} // CHECK-NEXT: for node {{[{][{]node Add[}][}]}} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir index d3b2d835c27..0d40a4d383c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functionalize-if.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=FunctionalizeControlFlowPass | FileCheck %s --dump-input-on-failure func @main() { %0 = "_tf._TPUReplicate"() {computation = @foo, Tinputs = [], Tbroadcast_inputs = [], NumVariables = 0, Tguaranteed_constants = [], output_types = []} : () -> !_tf.control loc("_TPUReplicate") @@ -17,18 +17,18 @@ func @foo() { // Match the name of the cloned function with functionalized control-flow at call site // CHECK: func @main() -// CHECK-NEXT: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]] +// CHECK: computation = @[[FUNCTIONALIZE_FUNC:[A-Za-z0-9_]*]] // In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch. // CHECK: func @[[FUNCTIONALIZE_FUNC]] -// CHECK: "_tf.If" +// CHECK: "tf.If" // CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]] // CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]] // We expect the _tf.Add in the else func and the _tf.Mul in the then func // CHECK: func @[[ELSE_FUNC]] -// CHECK: "_tf.Add" +// CHECK: "tf.Add" // CHECK: func @[[THEN_FUNC]] -// CHECK: "_tf.Mul" +// CHECK: "tf.Mul" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir new file mode 100644 index 00000000000..bd10512ff72 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graph_pruning.mlir @@ -0,0 +1,86 @@ +// RUN: tf-opt %s -tf-executor-graph-pruning | FileCheck %s --dump-input=fail + +// Two islands chained by data-flow contributing to the graph return are +// preserved. +// CHECK-LABEL: func @chained_islands( +func @chained_islands(%arg0 : i32) -> i32 { +// CHECK: island +// CHECK: island + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + tf_executor.yield %arg0 : i32 + } + %2:2 = tf_executor.island { + tf_executor.yield %1#0 : i32 + } + tf_executor.fetch %2#0 : i32 + } + return %0 : i32 +} + +// Check that empty islands that don't contribute to the fetch are removed. +// CHECK-LABEL: func @empty_islands( +func @empty_islands() { +// CHECK-NOT: tf_executor.island + tf_executor.graph { + %0 = tf_executor.island { + tf_executor.yield + } + %1 = tf_executor.island { + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// Check that an unused island that doesn't contribute to the fetch is removed. +// CHECK-LABEL: func @dead_island( +func @dead_island(%arg0 : i32) -> i32 { +// CHECK: tf_executor.island +// CHECK-NOT: tf_executor.island + %0 = tf_executor.graph { + %1:2 = tf_executor.island { + %a = "op.A"(%arg0) : (i32) -> i32 + %b = "op.B"(%a) : (i32) -> i32 + tf_executor.yield %b : i32 + } + %2:2 = tf_executor.island { + %a = "op.A"(%1#0) : (i32) -> i32 + tf_executor.yield %a : i32 + } + tf_executor.fetch %1#0 : i32 + } + return %0 : i32 +} + + +// Check that NextIteration.sink node isn't deleted when the source is still +// used, even though it does not have any result. +// CHECK-LABEL: func @nextiteration_sink_preserved( +func @nextiteration_sink_preserved(%arg0 : i32) -> i32 { +// CHECK: tf_executor.NextIteration.Source +// CHECK: tf_executor.NextIteration.Sink + %0 = tf_executor.graph { + %1:3 = tf_executor.NextIteration.Source : i32 + tf_executor.NextIteration.Sink[%1#1] %1#0 : i32 + tf_executor.fetch %1#0 : i32 + } + return %0 : i32 +} + +// Check that NextIteration.sink node is deleted when the source does not have +// any user other than the sink. +// CHECK-LABEL: func @nextiteration_deleted( +func @nextiteration_deleted(%arg0 : i32) -> i32 { +// CHECK-NOT: tf_executor.NextIteration.Source +// CHECK-NOT: tf_executor.NextIteration.Sink + %0 = tf_executor.graph { + %1:3 = tf_executor.NextIteration.Source : i32 + // intentionally take an output dependency on the source here. + tf_executor.NextIteration.Sink[%1#1] %1#0 : i32 + tf_executor.fetch %arg0 : i32 + } + return %0 : i32 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt index ffbd84c7ee7..a2b9efff36b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/add.pbtxt @@ -38,8 +38,14 @@ versions { # CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>) -> tensor<10xi32> # CHECK: attributes {tf.entry_function = {inputs = "input0, input1", outputs = "Add"}} { -# CHECK: %0:2 = "_tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", name = "input0", shape = "tfshape$dim {\0A size: 10\0A}\0A"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK: %1:2 = "_tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", name = "input1", shape = "tfshape$dim {\0A size: 10\0A}\0A"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK: %2:2 = "_tf.Add"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<10xi32>, tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK: return %2#0 : tensor<10xi32> -# CHECK: } + +# CHECK: %[[INPUT0:[0-9]+]]:2 = tf_executor.island +# CHECK-NEXT: "tf.Placeholder.input"(%arg0) + +# CHECK: %[[INPUT1:[0-9]+]]:2 = tf_executor.island +# CHECK-NEXT: "tf.Placeholder.input"(%arg1) + +# CHECK: %[[add:[0-9]+]]:2 = tf_executor.island +# CHECK-NEXT: "tf.Add"(%[[INPUT0]]#0, %[[INPUT1]]#0) + +# CHECK: fetch %[[add]]#0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt index da77c16ca64..74adc38d87d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/arg-control-dep.pbtxt @@ -40,7 +40,9 @@ library { } } # Drop the control dependency on arg for the node "test" - # CHECK: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "test", value = dense<0> : tensor} : () -> (tensor, !_tf.control) + # CHECK-LABEL: func @foo + # CHECK: tf_executor.island { + # CHECK-NEXT: "tf.Const"() node_def { name: "test" op: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt new file mode 100644 index 00000000000..019deaf4df4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/const-values.pbtxt @@ -0,0 +1,90 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s + +node { + name: "x" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 3 + } + } + tensor_content: "\x00\x00\x80\x3F\x00\x00\x00\x40\x00\x00\x40\x40\x00\x00\x80\x40\x00\x00\xA0\x40\x00\x00\xC0\x40" + # CHECK: value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> + } + } + } +} +node { + name: "y" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + dim { + size: 2 + } + dim { + size: 3 + } + } + tensor_content: "\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00\x00\x00\x00\x00" + # CHECK: value = dense<{{\[\[}}1, 3, 2], [5, 4, 7]]> : tensor<2x3xi64> + } + } + } +} +node { + name: "z" + op: "Const" + device: "/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + dim { + size: 3 + } + } + tensor_content: "\x01\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\x05\x00\x00\x00\x04\x00\x00\x00\x07\x00\x00\x00" + # CHECK: value = dense<{{\[\[}}1, 3, 2], [5, 4, 7]]> : tensor<2x3xi32> + } + } + } +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt index 81466e6d937..93a2f602c65 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/empty-value-attr.pbtxt @@ -75,8 +75,8 @@ versions { } # Match partitioned call in main and capture the callee name. -# CHECK: func @main -# CHECK-NEXT: _tf.PartitionedCall +# CHECK-LABEL: func @main +# CHECK: tf.PartitionedCall # CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]] # Verify that callee has the unit attribute tf._input_shapes. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt new file mode 100644 index 00000000000..cbfa973fd64 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-if-ops.pbtxt @@ -0,0 +1,256 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulIf,StatelessIf -o - | FileCheck %s + +# Verify that TensorFlow If and StatelessIf ops are mapped to the +# composite If op in MLIR with is_stateless attribute set accordingly to +# distinguish between them. + +# CHECK-DAG: "tf.If"{{.*}} is_stateless = false, name = "StatefulIf" +# CHECK-DAG: "tf.If"{{.*}} is_stateless = true, name = "StatelessIf" + +node { + name: "tf.Less" + op: "Less" + input: "a" + input: "b" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "StatefulIf" + op: "If" + input: "tf.Less" + input: "a" + input: "b" + attr { + key: "Tcond" + value { + type: DT_BOOL + } + } + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "else_branch" + value { + func { + name: "cond_false" + } + } + } + attr { + key: "then_branch" + value { + func { + name: "cond_true" + } + } + } + experimental_debug_info { + } +} +node { + name: "StatelessIf" + op: "StatelessIf" + input: "tf.Less" + input: "a" + input: "b" + attr { + key: "Tcond" + value { + type: DT_BOOL + } + } + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "else_branch" + value { + func { + name: "cond_false" + } + } + } + attr { + key: "then_branch" + value { + func { + name: "cond_true" + } + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "StatefulIf" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "main1" + op: "_Retval" + input: "StatelessIf" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "a" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +node { + name: "b" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +library { + function { + signature { + name: "cond_true" + input_arg { + name: "cond_true" + type: DT_FLOAT + } + input_arg { + name: "cond_true1" + type: DT_FLOAT + } + output_arg { + name: "cond_true2" + type: DT_FLOAT + } + } + node_def { + name: "tf.Add" + op: "Add" + input: "cond_true" + input: "cond_true1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + original_node_names: "tf.Add" + } + } + ret { + key: "cond_true2" + value: "tf.Add:z:0" + } + } + function { + signature { + name: "cond_false" + input_arg { + name: "cond_false" + type: DT_FLOAT + } + input_arg { + name: "cond_false1" + type: DT_FLOAT + } + output_arg { + name: "cond_false2" + type: DT_FLOAT + } + } + node_def { + name: "tf.Mul" + op: "Mul" + input: "cond_false" + input: "cond_false1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + original_node_names: "tf.Mul" + } + } + ret { + key: "cond_false2" + value: "tf.Mul:z:0" + } + } +} +versions { + producer: 115 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt new file mode 100644 index 00000000000..953f83a9f68 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/functional-while-ops.pbtxt @@ -0,0 +1,283 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=iter,val -tf-input-data-types=DT_INT32,DT_FLOAT -tf-input-shapes=':' -tf-output-arrays=StatefulWhile:1,StatelessWhile:1 -o - | FileCheck %s + +# Verify that TensorFlow While and StatelessWhile ops are mapped to the +# composite While op in MLIR with is_stateless attribute set accordingly to +# distinguish between them. + +# CHECK-DAG: "tf.While"{{.*}} is_stateless = false, name = "StatefulWhile" +# CHECK-DAG: "tf.While"{{.*}} is_stateless = true, name = "StatelessWhile" + +node { + name: "StatefulWhile" + op: "While" + input: "iter" + input: "val" + attr { + key: "T" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "body" + value { + func { + name: "body" + } + } + } + attr { + key: "cond" + value { + func { + name: "cond" + } + } + } + experimental_debug_info { + } +} +node { + name: "StatelessWhile" + op: "StatelessWhile" + input: "iter" + input: "val" + attr { + key: "T" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "body" + value { + func { + name: "body" + } + } + } + attr { + key: "cond" + value { + func { + name: "cond" + } + } + } + experimental_debug_info { + } +} +node { + name: "main" + op: "_Retval" + input: "StatefulWhile:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "main1" + op: "_Retval" + input: "StatelessWhile:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "iter" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + experimental_debug_info { + } +} +node { + name: "val" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + } +} +library { + function { + signature { + name: "cond" + input_arg { + name: "cond" + type: DT_INT32 + } + input_arg { + name: "cond1" + type: DT_FLOAT + } + output_arg { + name: "cond2" + type: DT_BOOL + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + experimental_debug_info { + original_node_names: "Const" + } + } + node_def { + name: "tf.Greater" + op: "Greater" + input: "cond" + input: "Const:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "tf.Greater" + } + } + ret { + key: "cond2" + value: "tf.Greater:z:0" + } + } + function { + signature { + name: "body" + input_arg { + name: "body" + type: DT_INT32 + } + input_arg { + name: "body1" + type: DT_FLOAT + } + output_arg { + name: "body2" + type: DT_INT32 + } + output_arg { + name: "body3" + type: DT_FLOAT + } + } + node_def { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + experimental_debug_info { + original_node_names: "Const" + } + } + node_def { + name: "tf.Sub" + op: "Sub" + input: "body" + input: "Const:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "tf.Sub" + } + } + node_def { + name: "tf.Add" + op: "Add" + input: "body1" + input: "body1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + experimental_debug_info { + original_node_names: "tf.Add" + } + } + ret { + key: "body2" + value: "tf.Sub:z:0" + } + ret { + key: "body3" + value: "tf.Add:z:0" + } + } +} +versions { + producer: 115 + min_consumer: 12 +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-11c8752c150e5643.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-11c8752c150e5643.pbtxt deleted file mode 100644 index ae252ef83dd..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-11c8752c150e5643.pbtxt +++ /dev/null @@ -1,99 +0,0 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s - -node { - name: "Empty/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:TPU:0" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\200\000\000\000\200\000\000\000" - } - } - } -} -node { - name: "Empty" - op: "Empty" - input: "Empty/shape" - device: "/job:localhost/replica:0/task:0/device:TPU:0" - attr { - key: "dtype" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "init" - value { - b: false - } - } -} -node { - name: "Empty/_0" - op: "_Send" - input: "Empty" - device: "/job:localhost/replica:0/task:0/device:TPU:0" - attr { - key: "T" - value { - type: DT_BFLOAT16 - } - } - attr { - key: "client_terminated" - value { - b: false - } - } - attr { - key: "recv_device" - value { - s: "/job:localhost/replica:0/task:0/device:CPU:0" - } - } - attr { - key: "send_device" - value { - s: "/job:localhost/replica:0/task:0/device:TPU:0" - } - } - attr { - key: "send_device_incarnation" - value { - i: 1 - } - } - attr { - key: "tensor_name" - value { - s: "edge_5_Empty" - } - } -} -library { -} -versions { - producer: 26 -} - -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", dtype = "tfdtype$DT_INT32", name = "Empty/shape", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F494E5433320A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20320A20207D0A7D0A74656E736F725F636F6E74656E743A20225C3230305C3030305C3030305C3030305C3230305C3030305C3030305C303030220A"> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Empty"(%0#0) {device = "/job:localhost/replica:0/task:0/device:TPU:0", dtype = "tfdtype$DT_BFLOAT16", init = false, name = "Empty"} : (tensor<2xi32>) -> (tensor<128x128xbf16>, !_tf.control) -# CHECK-NEXT: %2 = "_tf._Send"(%1#0) {T = "tfdtype$DT_BFLOAT16", client_terminated = false, device = "/job:localhost/replica:0/task:0/device:TPU:0", name = "Empty/_0", recv_device = "/job:localhost/replica:0/task:0/device:CPU:0", send_device = "/job:localhost/replica:0/task:0/device:TPU:0", send_device_incarnation = 1 : i64, tensor_name = "edge_5_Empty"} : (tensor<128x128xbf16>) -> !_tf.control -# CHECK-NEXT: return -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-1383300d74bd0b22.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-1383300d74bd0b22.pbtxt deleted file mode 100644 index 0333193be8d..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-1383300d74bd0b22.pbtxt +++ /dev/null @@ -1,1550 +0,0 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s - -node { - name: "placeholder_0_arg" - op: "_Arg" - device: "/device:TPU_REPLICATED_CORE:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index" - value { - i: 0 - } - } -} -node { - name: "tpu/tpu/Shape" - op: "Shape" - input: "placeholder_0_arg" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/strided_slice/stack" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/strided_slice/stack_1" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/strided_slice/stack_2" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/strided_slice" - op: "StridedSlice" - input: "tpu/tpu/Shape" - input: "tpu/tpu/strided_slice/stack" - input: "tpu/tpu/strided_slice/stack_1" - input: "tpu/tpu/strided_slice/stack_2" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "begin_mask" - value { - i: 0 - } - } - attr { - key: "ellipsis_mask" - value { - i: 0 - } - } - attr { - key: "end_mask" - value { - i: 0 - } - } - attr { - key: "new_axis_mask" - value { - i: 0 - } - } - attr { - key: "shrink_axis_mask" - value { - i: 1 - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims/dim" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims" - op: "ExpandDims" - input: "tpu/tpu/strided_slice" - input: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims/dim" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/Const" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/concat/axis" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/concat" - op: "ConcatV2" - input: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims" - input: "tpu/tpu/Plus1RNNCellZeroState/Const" - input: "tpu/tpu/Plus1RNNCellZeroState/concat/axis" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/zeros/Const" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/zeros" - op: "Fill" - input: "tpu/tpu/Plus1RNNCellZeroState/concat" - input: "tpu/tpu/Plus1RNNCellZeroState/zeros/Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims_1/dim" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims_1" - op: "ExpandDims" - input: "tpu/tpu/strided_slice" - input: "tpu/tpu/Plus1RNNCellZeroState/ExpandDims_1/dim" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/Plus1RNNCellZeroState/Const_1" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/sequence_length" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/ExpandDims/dim" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/ExpandDims" - op: "ExpandDims" - input: "tpu/tpu/strided_slice" - input: "tpu/tpu/ExpandDims/dim" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/Const" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/concat/axis" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/concat" - op: "ConcatV2" - input: "tpu/tpu/ExpandDims" - input: "tpu/tpu/Const" - input: "tpu/tpu/concat/axis" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/zeros/Const" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "tpu/tpu/zeros" - op: "Fill" - input: "tpu/tpu/concat" - input: "tpu/tpu/zeros/Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/Const_1" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "tpu/tpu/Const_2" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/Min" - op: "Min" - input: "tpu/tpu/sequence_length" - input: "tpu/tpu/Const_2" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "tpu/tpu/Const_3" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/Max" - op: "Max" - input: "tpu/tpu/sequence_length" - input: "tpu/tpu/Const_3" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "tpu/tpu/LessEqual/y" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/LessEqual" - op: "LessEqual" - input: "tpu/tpu/sequence_length" - input: "tpu/tpu/LessEqual/y" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/LessEqual_1/y" - op: "Const" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/LessEqual_1" - op: "LessEqual" - input: "tpu/tpu/Max" - input: "tpu/tpu/LessEqual_1/y" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/cond/Switch" - op: "Switch" - input: "tpu/tpu/LessEqual_1" - input: "tpu/tpu/LessEqual_1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/switch_t" - op: "Identity" - input: "tpu/tpu/cond/Switch:1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/switch_f" - op: "Identity" - input: "tpu/tpu/cond/Switch" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/pred_id" - op: "Identity" - input: "tpu/tpu/LessEqual_1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/Switch_1" - op: "Switch" - input: "tpu/tpu/zeros" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/zeros" - } - } - } -} -node { - name: "tpu/tpu/cond/Switch_2" - op: "Switch" - input: "tpu/tpu/Plus1RNNCellZeroState/zeros" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/Plus1RNNCellZeroState/zeros" - } - } - } -} -node { - name: "tpu/tpu/cond/add/y" - op: "Const" - input: "^tpu/tpu/cond/switch_f" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "tpu/tpu/cond/add/Switch" - op: "Switch" - input: "placeholder_0_arg" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@Placeholder" - } - } - } -} -node { - name: "tpu/tpu/cond/add" - op: "Add" - input: "tpu/tpu/cond/add/Switch" - input: "tpu/tpu/cond/add/y" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/tpu/cond/add_1/y" - op: "Const" - input: "^tpu/tpu/cond/switch_f" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "tpu/tpu/cond/add_1/Switch" - op: "Switch" - input: "tpu/tpu/Plus1RNNCellZeroState/zeros" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/Plus1RNNCellZeroState/zeros" - } - } - } -} -node { - name: "tpu/tpu/cond/add_1" - op: "Add" - input: "tpu/tpu/cond/add_1/Switch" - input: "tpu/tpu/cond/add_1/y" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/tpu/cond/Greater/y" - op: "Const" - input: "^tpu/tpu/cond/switch_f" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "tpu/tpu/cond/Greater/Switch" - op: "Switch" - input: "tpu/tpu/Min" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/Min" - } - } - } -} -node { - name: "tpu/tpu/cond/Greater" - op: "Greater" - input: "tpu/tpu/cond/Greater/Switch" - input: "tpu/tpu/cond/Greater/y" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_INT32 - } - } -} -node { - name: "tpu/tpu/cond/cond/Switch" - op: "Switch" - input: "tpu/tpu/cond/Greater" - input: "tpu/tpu/cond/Greater" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/cond/switch_t" - op: "Identity" - input: "tpu/tpu/cond/cond/Switch:1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/cond/switch_f" - op: "Identity" - input: "tpu/tpu/cond/cond/Switch" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/cond/pred_id" - op: "Identity" - input: "tpu/tpu/cond/Greater" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } -} -node { - name: "tpu/tpu/cond/cond/Switch_1" - op: "Switch" - input: "tpu/tpu/cond/add" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/cond/add" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Switch_2" - op: "Switch" - input: "tpu/tpu/cond/add_1" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/cond/add_1" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select/Switch" - op: "Switch" - input: "tpu/tpu/LessEqual" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/LessEqual" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select/Switch_1" - op: "Switch" - input: "tpu/tpu/cond/cond/Select/Switch" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_BOOL - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/LessEqual" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select/Switch_2" - op: "Switch" - input: "tpu/tpu/zeros" - input: "tpu/tpu/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/zeros" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select/Switch_3" - op: "Switch" - input: "tpu/tpu/cond/cond/Select/Switch_2" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/zeros" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select/Switch_4" - op: "Switch" - input: "tpu/tpu/cond/add" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/cond/add" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select" - op: "Select" - input: "tpu/tpu/cond/cond/Select/Switch_1" - input: "tpu/tpu/cond/cond/Select/Switch_3" - input: "tpu/tpu/cond/cond/Select/Switch_4" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/cond/add" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select_1/Switch" - op: "Switch" - input: "tpu/tpu/cond/add_1/Switch" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/Plus1RNNCellZeroState/zeros" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select_1/Switch_1" - op: "Switch" - input: "tpu/tpu/cond/add_1" - input: "tpu/tpu/cond/cond/pred_id" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/cond/add_1" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Select_1" - op: "Select" - input: "tpu/tpu/cond/cond/Select/Switch_1" - input: "tpu/tpu/cond/cond/Select_1/Switch" - input: "tpu/tpu/cond/cond/Select_1/Switch_1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@tpu/tpu/cond/add_1" - } - } - } -} -node { - name: "tpu/tpu/cond/cond/Merge" - op: "Merge" - input: "tpu/tpu/cond/cond/Select" - input: "tpu/tpu/cond/cond/Switch_1:1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/tpu/cond/cond/Merge_1" - op: "Merge" - input: "tpu/tpu/cond/cond/Select_1" - input: "tpu/tpu/cond/cond/Switch_2:1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/tpu/cond/Merge" - op: "Merge" - input: "tpu/tpu/cond/cond/Merge" - input: "tpu/tpu/cond/Switch_1:1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/tpu/cond/Merge_1" - op: "Merge" - input: "tpu/tpu/cond/cond/Merge_1" - input: "tpu/tpu/cond/Switch_2:1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/NoOp" - op: "NoOp" - device: "/device:TPU_REPLICATED_CORE" -} -node { - name: "tpu/packed" - op: "Pack" - input: "tpu/tpu/cond/Merge" - device: "/device:TPU_REPLICATED_CORE:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "tpu/Identity" - op: "Identity" - input: "tpu/packed" - device: "/device:TPU_REPLICATED_CORE:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu/Identity_1" - op: "Identity" - input: "tpu/tpu/cond/Merge_1" - device: "/device:TPU_REPLICATED_CORE" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "tpu_identity_0_retval_RetVal" - op: "_Retval" - input: "tpu/Identity" - device: "/device:TPU_REPLICATED_CORE:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index" - value { - i: 0 - } - } -} -node { - name: "tpu_identity_1_0_retval_RetVal" - op: "_Retval" - input: "tpu/Identity_1" - device: "/device:TPU_REPLICATED_CORE:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "index" - value { - i: 1 - } - } -} -library { -} -versions { - producer: 26 -} - -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf._Arg"() {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE:0", index = 0 : i64, name = "placeholder_0_arg"} : () -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Shape"(%0#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/Shape", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> (tensor, !_tf.control) -# CHECK-NEXT: %2 = "_tf.NoOp"() {device = "/device:TPU_REPLICATED_CORE", name = "tpu/NoOp"} : () -> !_tf.control -# CHECK-NEXT: %3:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Const", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %4:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Const_1", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %5:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Const_2", value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %6:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Const_3", value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %7:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/ExpandDims/dim", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %8:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/LessEqual/y", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %9:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/LessEqual_1/y", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %10:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Plus1RNNCellZeroState/Const", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %11:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Plus1RNNCellZeroState/Const_1", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %12:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Plus1RNNCellZeroState/ExpandDims/dim", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %13:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Plus1RNNCellZeroState/ExpandDims_1/dim", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %14:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/Plus1RNNCellZeroState/concat/axis", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %15:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_FLOAT", name = "tpu/tpu/Plus1RNNCellZeroState/zeros/Const", value = dense<0.000000e+00> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %16:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/concat/axis", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %17:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/sequence_length", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %18:2 = "_tf.LessEqual"(%17#0, %8#0) {T = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/LessEqual"} : (tensor<1xi32>, tensor) -> (tensor<1xi1>, !_tf.control) -# CHECK-NEXT: %19:2 = "_tf.Max"(%17#0, %6#0) {T = "tfdtype$DT_INT32", Tidx = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", keep_dims = false, name = "tpu/tpu/Max"} : (tensor<1xi32>, tensor<1xi32>) -> (tensor, !_tf.control) -# CHECK-NEXT: %20:2 = "_tf.LessEqual"(%19#0, %9#0) {T = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/LessEqual_1"} : (tensor, tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %21:3 = "_tf.Switch"(%20#0, %20#0) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Switch"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %22:2 = "_tf.Identity"(%21#0) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/switch_f"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %23:2 = "_tf.Const"(%22#1) {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/cond/Greater/y", value = dense<0> : tensor} : (!_tf.control) -> (tensor, !_tf.control) -# CHECK-NEXT: %24:2 = "_tf.Const"(%22#1) {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_FLOAT", name = "tpu/tpu/cond/add/y", value = dense<1.000000e+00> : tensor} : (!_tf.control) -> (tensor, !_tf.control) -# CHECK-NEXT: %25:2 = "_tf.Const"(%22#1) {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_FLOAT", name = "tpu/tpu/cond/add_1/y", value = dense<1.000000e+00> : tensor} : (!_tf.control) -> (tensor, !_tf.control) -# CHECK-NEXT: %26:2 = "_tf.Identity"(%21#1) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/switch_t"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %27:2 = "_tf.Identity"(%20#0) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/pred_id"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %28:3 = "_tf.Switch"(%0#0, %27#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@Placeholder"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/add/Switch"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %29:2 = "_tf.Add"(%28#0, %24#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/add"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %30:3 = "_tf.Switch"(%18#0, %27#0) {T = "tfdtype$DT_BOOL", _class = ["loc:@tpu/tpu/LessEqual"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select/Switch"} : (tensor<1xi1>, tensor) -> (tensor<1xi1>, tensor<1xi1>, !_tf.control) -# CHECK-NEXT: %31:2 = "_tf.Min"(%17#0, %5#0) {T = "tfdtype$DT_INT32", Tidx = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", keep_dims = false, name = "tpu/tpu/Min"} : (tensor<1xi32>, tensor<1xi32>) -> (tensor, !_tf.control) -# CHECK-NEXT: %32:3 = "_tf.Switch"(%31#0, %27#0) {T = "tfdtype$DT_INT32", _class = ["loc:@tpu/tpu/Min"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Greater/Switch"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %33:2 = "_tf.Greater"(%32#0, %23#0) {T = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Greater"} : (tensor, tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %34:3 = "_tf.Switch"(%33#0, %33#0) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Switch"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %35:2 = "_tf.Identity"(%34#0) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/switch_f"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %36:2 = "_tf.Identity"(%34#1) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/switch_t"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %37:2 = "_tf.Identity"(%33#0) {T = "tfdtype$DT_BOOL", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/pred_id"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %38:3 = "_tf.Switch"(%30#0, %37#0) {T = "tfdtype$DT_BOOL", _class = ["loc:@tpu/tpu/LessEqual"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select/Switch_1"} : (tensor<1xi1>, tensor) -> (tensor<1xi1>, tensor<1xi1>, !_tf.control) -# CHECK-NEXT: %39:3 = "_tf.Switch"(%29#0, %37#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/cond/add"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select/Switch_4"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %40:3 = "_tf.Switch"(%29#0, %37#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/cond/add"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Switch_1"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %41:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/strided_slice/stack", value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %42:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/strided_slice/stack_1", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %43:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_INT32", name = "tpu/tpu/strided_slice/stack_2", value = dense<1> : tensor<1xi32>} : () -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %44:2 = "_tf.StridedSlice"(%1#0, %41#0, %42#0, %43#0) {Index = "tfdtype$DT_INT32", T = "tfdtype$DT_INT32", begin_mask = 0 : i64, device = "/device:TPU_REPLICATED_CORE", ellipsis_mask = 0 : i64, end_mask = 0 : i64, name = "tpu/tpu/strided_slice", new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> (tensor, !_tf.control) -# CHECK-NEXT: %45:2 = "_tf.ExpandDims"(%44#0, %7#0) {T = "tfdtype$DT_INT32", Tdim = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/ExpandDims"} : (tensor, tensor) -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %46:2 = "_tf.ConcatV2"(%45#0, %3#0, %16#0) {N = 2 : i64, T = "tfdtype$DT_INT32", Tidx = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/concat"} : (tensor<1xi32>, tensor<1xi32>, tensor) -> (tensor<2xi32>, !_tf.control) -# CHECK-NEXT: %47:2 = "_tf.ExpandDims"(%44#0, %12#0) {T = "tfdtype$DT_INT32", Tdim = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/Plus1RNNCellZeroState/ExpandDims"} : (tensor, tensor) -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %48:2 = "_tf.ConcatV2"(%47#0, %10#0, %14#0) {N = 2 : i64, T = "tfdtype$DT_INT32", Tidx = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/Plus1RNNCellZeroState/concat"} : (tensor<1xi32>, tensor<1xi32>, tensor) -> (tensor<2xi32>, !_tf.control) -# CHECK-NEXT: %49:2 = "_tf.Fill"(%48#0, %15#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", index_type = "tfdtype$DT_INT32", name = "tpu/tpu/Plus1RNNCellZeroState/zeros"} : (tensor<2xi32>, tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %50:3 = "_tf.Switch"(%49#0, %27#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/Plus1RNNCellZeroState/zeros"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Switch_2"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %51:3 = "_tf.Switch"(%49#0, %27#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/Plus1RNNCellZeroState/zeros"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/add_1/Switch"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %52:2 = "_tf.Add"(%51#0, %25#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/add_1"} : (tensor, tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %53:3 = "_tf.Switch"(%52#0, %37#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/cond/add_1"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select_1/Switch_1"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %54:3 = "_tf.Switch"(%52#0, %37#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/cond/add_1"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Switch_2"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %55:3 = "_tf.Switch"(%51#0, %37#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/Plus1RNNCellZeroState/zeros"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select_1/Switch"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %56:2 = "_tf.Select"(%38#0, %55#0, %53#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/cond/add_1"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select_1"} : (tensor<1xi1>, tensor, tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %57:3 = "_tf.Merge"(%56#0, %54#1) {N = 2 : i64, T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Merge_1"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %58:3 = "_tf.Merge"(%57#0, %50#1) {N = 2 : i64, T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Merge_1"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %59:2 = "_tf.Identity"(%58#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/Identity_1"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %60 = "_tf._Retval"(%59#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE:0", index = 1 : i64, name = "tpu_identity_1_0_retval_RetVal"} : (tensor) -> !_tf.control -# CHECK-NEXT: %61:2 = "_tf.ExpandDims"(%44#0, %13#0) {T = "tfdtype$DT_INT32", Tdim = "tfdtype$DT_INT32", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/Plus1RNNCellZeroState/ExpandDims_1"} : (tensor, tensor) -> (tensor<1xi32>, !_tf.control) -# CHECK-NEXT: %62:2 = "_tf.Const"() {device = "/device:TPU_REPLICATED_CORE", dtype = "tfdtype$DT_FLOAT", name = "tpu/tpu/zeros/Const", value = dense<0.000000e+00> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %63:2 = "_tf.Fill"(%46#0, %62#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", index_type = "tfdtype$DT_INT32", name = "tpu/tpu/zeros"} : (tensor<2xi32>, tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %64:3 = "_tf.Switch"(%63#0, %27#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/zeros"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Switch_1"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %65:3 = "_tf.Switch"(%63#0, %27#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/zeros"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select/Switch_2"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %66:3 = "_tf.Switch"(%65#0, %37#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/zeros"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select/Switch_3"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: %67:2 = "_tf.Select"(%38#0, %66#0, %39#0) {T = "tfdtype$DT_FLOAT", _class = ["loc:@tpu/tpu/cond/add"], device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Select"} : (tensor<1xi1>, tensor, tensor<*xf32>) -> (tensor, !_tf.control) -# CHECK-NEXT: %68:3 = "_tf.Merge"(%67#0, %40#1) {N = 2 : i64, T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/cond/Merge"} : (tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor, !_tf.control) -# CHECK-NEXT: %69:3 = "_tf.Merge"(%68#0, %64#1) {N = 2 : i64, T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE", name = "tpu/tpu/cond/Merge"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor, !_tf.control) -# CHECK-NEXT: %70:2 = "_tf.Pack"(%69#0) {N = 1 : i64, T = "tfdtype$DT_FLOAT", axis = 0 : i64, device = "/device:TPU_REPLICATED_CORE:0", name = "tpu/packed"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %71:2 = "_tf.Identity"(%70#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE:0", name = "tpu/Identity"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %72 = "_tf._Retval"(%71#0) {T = "tfdtype$DT_FLOAT", device = "/device:TPU_REPLICATED_CORE:0", index = 0 : i64, name = "tpu_identity_0_retval_RetVal"} : (tensor<*xf32>) -> !_tf.control -# CHECK-NEXT: return -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt new file mode 100644 index 00000000000..1bf5037a75f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-as-function.pbtxt @@ -0,0 +1,254 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-graph-as-function -o - | FileCheck %s --dump-input=fail + +# Verify main graph was converted to a function, args/rets are mapped correctly, +# and ops in the main graph are retained. In addition, check if subsequent +# functions are converted. + +# CHECK: func @main(%arg0: tensor<*x!tf.resource>, %arg1: tensor<*x!tf.resource>) -> (tensor, tensor) +# CHECK: attributes {tf.entry_function = {inputs = "args_0, args_1", outputs = "rets_0_RetVal, rets_1_RetVal"}} { +# CHECK: %[[ISLAND_0:[0-9]]]:2 = tf_executor.island { +# CHECK: "tf.Const" +# CHECK: %[[ISLAND_1:[0-9]]]:2 = tf_executor.island { +# CHECK: "tf.Identity"(%[[ISLAND_0]]#0) +# CHECK: %[[ISLAND_2:[0-9]]]:2 = tf_executor.island { +# CHECK: "tf.StatefulPartitionedCall" +# CHECK-SAME: f = @[[FUNC:[a-z0-9]*]] +# CHECK: tf_executor.fetch %[[ISLAND_1]]#0, %[[ISLAND_2]]#0 : tensor, tensor +# CHECK: func @[[FUNC]](%arg0: tensor<*xf32>, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32> + +node { + name: "args_0" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "args_1" + op: "_Arg" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "_handle_dtypes" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "_handle_shapes" + value { + list { + shape { + dim { + size: 3 + } + dim { + size: 3 + } + dim { + size: 1 + } + dim { + size: 32 + } + } + } + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +node { + name: "const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "statefulpartitionedcall" + op: "StatefulPartitionedCall" + input: "const" + input: "args_1" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_RESOURCE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "PartitionedCall-1205" + } + } + attr { + key: "config" + value { + s: "" + } + } + attr { + key: "config_proto" + value { + s: "\n\007\n\003GPU\020\000\n\007\n\003CPU\020\0012\002J\0008\001" + } + } + attr { + key: "executor_type" + value { + s: "" + } + } + attr { + key: "f" + value { + func { + name: "function" + } + } + } +} +node { + name: "identity" + op: "Identity" + input: "const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "rets_0" + op: "_Retval" + input: "identity" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +node { + name: "rets_1" + op: "_Retval" + input: "statefulpartitionedcall" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 1 + } + } +} +library { + function { + signature { + name: "function" + input_arg { + name: "inputs" + type: DT_FLOAT + } + input_arg { + name: "statefulpartitionedcall_args_1" + type: DT_RESOURCE + } + output_arg { + name: "identity" + type: DT_FLOAT + } + is_stateful: true + } + node_def { + name: "Identity" + op: "Identity" + input: "inputs" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_user_specified_name" + value { + s: "inputs" + } + } + } + } + arg_attr { + key: 1 + value { + } + } + } +} +versions { + producer: 121 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt index 82146716fff..9ce15315832 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt @@ -1,209 +1,8 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s node { - name: "Placeholder" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - unknown_rank: true - } - } - } -} -node { - name: "Placeholder_1" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - unknown_rank: true - } - } - } -} -node { - name: "input0" - op: "TPUReplicatedInput" - input: "Placeholder" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "input1" - op: "TPUReplicatedInput" - input: "Placeholder_1" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "cluster/pivot" - op: "NoOp" -} -node { - name: "TPUReplicateMetadata" - op: "TPUReplicateMetadata" - input: "^cluster/pivot" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "computation_shape" - value { - list { - } - } - } - attr { - key: "device_assignment" - value { - list { - } - } - } - attr { - key: "host_compute_core" - value { - list { - } - } - } - attr { - key: "num_cores_per_replica" - value { - i: 1 - } - } - attr { - key: "num_replicas" - value { - i: 1 - } - } - attr { - key: "topology" - value { - s: "" - } - } - attr { - key: "use_tpu" - value { - b: true - } - } -} -node { - name: "replicated_input_0" - op: "Identity" - input: "input0" - input: "^TPUReplicateMetadata" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "replicated_input_1" - op: "Identity" - input: "input1" - input: "^TPUReplicateMetadata" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/maximum_iterations" + name: "Constant" op: "Const" - input: "^TPUReplicateMetadata" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } -} -node { - name: "while/iteration_counter" - op: "Const" - input: "^TPUReplicateMetadata" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } attr { key: "dtype" value { @@ -223,1968 +22,37 @@ node { } } node { - name: "while/Enter" - op: "Enter" - input: "while/iteration_counter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while/Enter_1" - op: "Enter" - input: "replicated_input_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while/Enter_2" - op: "Enter" - input: "replicated_input_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while/Merge" - op: "Merge" - input: "while/Enter" - input: "while/NextIteration" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Merge_1" - op: "Merge" - input: "while/Enter_1" - input: "while/NextIteration_1" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Merge_2" - op: "Merge" - input: "while/Enter_2" - input: "while/NextIteration_2" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Less/Enter" - op: "Enter" - input: "while/maximum_iterations" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "while/Less" - op: "Less" - input: "while/Merge" - input: "while/Less/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/less_than_5_If8q4vKg9jA" - op: "less_than_5_If8q4vKg9jA" - input: "while/Merge_1" - input: "^while/Merge" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/LogicalAnd" - op: "LogicalAnd" - input: "while/Less" - input: "while/less_than_5_If8q4vKg9jA" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/LoopCond" - op: "LoopCond" - input: "while/LogicalAnd" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Switch" - op: "Switch" - input: "while/Merge" - input: "while/LoopCond" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while/Merge" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Switch_1" - op: "Switch" - input: "while/Merge_1" - input: "while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while/Merge_1" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Switch_2" - op: "Switch" - input: "while/Merge_2" - input: "while/LoopCond" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while/Merge_2" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Identity" - op: "Identity" - input: "while/Switch:1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Identity_1" - op: "Identity" - input: "while/Switch_1:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Identity_2" - op: "Identity" - input: "while/Switch_2:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/add/y" - op: "Const" - input: "^while/Identity" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "while/add" - op: "Add" - input: "while/Identity" - input: "while/add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/add_1/y" - op: "Const" - input: "^while/Identity" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "while/add_1" - op: "Add" - input: "while/Identity_1" - input: "while/add_1/y" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/mul_2_Da30D05wlPU" - op: "mul_2_Da30D05wlPU" - input: "while/Identity_1" - input: "while/Identity_2" - input: "^while/Identity" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/NextIteration" - op: "NextIteration" - input: "while/add" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/NextIteration_1" - op: "NextIteration" - input: "while/add_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/NextIteration_2" - op: "NextIteration" - input: "while/mul_2_Da30D05wlPU" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Exit" - op: "Exit" - input: "while/Switch" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Exit_1" - op: "Exit" - input: "while/Switch_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "while/Exit_2" - op: "Exit" - input: "while/Switch_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/Shape" - op: "Shape" - input: "while/Exit_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "out_type" - value { - type: DT_INT32 - } - } -} -node { - name: "gradients/grad_ys_0" - op: "Const" - input: "^TPUReplicateMetadata" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/grad_ys_0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "index_type" - value { - type: DT_INT32 - } - } -} -node { - name: "gradients/f_count" - op: "Const" - input: "^TPUReplicateMetadata" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "gradients/f_count_1" - op: "Enter" - input: "gradients/f_count" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/Merge" - op: "Merge" - input: "gradients/f_count_1" - input: "gradients/NextIteration" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/Switch" - op: "Switch" - input: "gradients/Merge" - input: "while/LoopCond" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/Add/y" - op: "Const" - input: "^while/Identity" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Add" - op: "Add" - input: "gradients/Switch:1" - input: "gradients/Add/y" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/f_count_2" - op: "Exit" - input: "gradients/Switch" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/b_count" - op: "Const" - input: "^TPUReplicateMetadata" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/b_count_1" - op: "Enter" - input: "gradients/f_count_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "gradients/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/Merge_1" - op: "Merge" - input: "gradients/b_count_1" - input: "gradients/NextIteration_1" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/GreaterEqual/Enter" - op: "Enter" - input: "gradients/b_count" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "gradients/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/GreaterEqual" - op: "GreaterEqual" - input: "gradients/Merge_1" - input: "gradients/GreaterEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/b_count_2" - op: "LoopCond" - input: "gradients/GreaterEqual" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/Switch_1" - op: "Switch" - input: "gradients/Merge_1" - input: "gradients/b_count_2" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/Sub" - op: "Sub" - input: "gradients/Switch_1:1" - input: "gradients/GreaterEqual/Enter" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/b_count_3" - op: "Exit" - input: "gradients/Switch_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/zeros_like" - op: "ZerosLike" - input: "while/Exit_1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/Exit_2_grad/b_exit" - op: "Enter" - input: "gradients/Fill" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "gradients/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/while/Exit_1_grad/b_exit" - op: "Enter" - input: "gradients/zeros_like" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "gradients/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: false - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/while/Switch_2_grad/b_switch" - op: "Merge" - input: "gradients/while/Exit_2_grad/b_exit" - input: "gradients/while/Switch_2_grad_1/NextIteration" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/Merge_2_grad/Switch" - op: "Switch" - input: "gradients/while/Switch_2_grad/b_switch" - input: "gradients/b_count_2" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/while/Switch_2_grad/b_switch" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/Enter_2_grad/Exit" - op: "Exit" - input: "gradients/while/Merge_2_grad/Switch" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Const" - op: "Const" - input: "^cluster/pivot" - attr { - key: "_class" - value { - list { - s: "loc:@while/Identity_1" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/mul" - op: "Mul" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Const" - input: "while/maximum_iterations" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while/Identity_1" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/f_acc" - op: "StackV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/mul" - attr { - key: "_class" - value { - list { - s: "loc:@while/Identity_1" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "elem_type" - value { - type: DT_FLOAT - } - } - attr { - key: "stack_name" - value { - s: "" - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Enter" - op: "Enter" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/f_acc" - attr { - key: "T" - value { - type: DT_RESOURCE - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPushV2" - op: "StackPushV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Enter" - input: "while/Identity_1" - input: "^gradients/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "swap_memory" - value { - b: false - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2/Enter" - op: "Enter" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/f_acc" - attr { - key: "T" - value { - type: DT_RESOURCE - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "gradients/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2" - op: "StackPopV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2/Enter" - input: "^gradients/Sub" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "elem_type" - value { - type: DT_FLOAT - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Const_1" - op: "Const" - input: "^cluster/pivot" - attr { - key: "_class" - value { - list { - s: "loc:@while/Identity_2" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/mul_1" - op: "Mul" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Const_1" - input: "while/maximum_iterations" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_class" - value { - list { - s: "loc:@while/Identity_2" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/f_acc_1" - op: "StackV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/mul_1" - attr { - key: "_class" - value { - list { - s: "loc:@while/Identity_2" - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "elem_type" - value { - type: DT_FLOAT - } - } - attr { - key: "stack_name" - value { - s: "" - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Enter_1" - op: "Enter" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/f_acc_1" - attr { - key: "T" - value { - type: DT_RESOURCE - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPushV2_1" - op: "StackPushV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/Enter_1" - input: "while/Identity_2" - input: "^gradients/Add" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "swap_memory" - value { - b: false - } - } -} -node { - name: "gradients/NextIteration" - op: "NextIteration" - input: "gradients/Add" - input: "^gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPushV2" - input: "^gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPushV2_1" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2_1/Enter" - op: "Enter" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/f_acc_1" - attr { - key: "T" - value { - type: DT_RESOURCE - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "frame_name" - value { - s: "gradients/while/while_context" - } - } - attr { - key: "is_constant" - value { - b: true - } - } - attr { - key: "parallel_iterations" - value { - i: 10 - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2_1" - op: "StackPopV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2_1/Enter" - input: "^gradients/Sub" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "elem_type" - value { - type: DT_FLOAT - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient" - op: "SymbolicGradient" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2_1" - input: "gradients/while/Merge_2_grad/Switch:1" - input: "^gradients/Sub" - attr { - key: "Tin" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - type: DT_FLOAT - } - } - } - attr { - key: "Tout" - value { - list { - type: DT_FLOAT - type: DT_FLOAT - } - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - attr { - key: "f" - value { - func { - name: "mul_2_Da30D05wlPU" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } - } - } - } -} -node { - name: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync" - op: "ControlTrigger" - input: "^cluster/pivot" - input: "^gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2" - input: "^gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/StackPopV2_1" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/NextIteration_1" - op: "NextIteration" - input: "gradients/Sub" - input: "^gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient/b_sync" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "gradients/while/Switch_2_grad_1/NextIteration" - op: "NextIteration" - input: "gradients/while/mul_2_Da30D05wlPU_grad/SymbolicGradient:1" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "NoOp" - op: "NoOp" - input: "^cluster/pivot" - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "Identity" - op: "Identity" - input: "gradients/while/Enter_2_grad/Exit" - device: "/device:TPU_REPLICATED_CORE:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_tpu_replicate" - value { - s: "cluster" - } - } -} -node { - name: "output0" - op: "TPUReplicatedOutput" - input: "Identity" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "num_replicas" - value { - i: 1 - } - } -} -node { - name: "TPUCompilationResult" - op: "TPUCompilationResult" - input: "^TPUReplicateMetadata" - attr { - key: "_tpu_compilation_status" - value { - s: "cluster" - } - } -} -node { - name: "output_0_shard_0" - op: "Identity" - input: "output0" - input: "^NoOp" - attr { - key: "T" - value { - type: DT_FLOAT - } - } -} -node { - name: "ConfigureDistributedTPU" - op: "ConfigureDistributedTPU" - device: "/device:TPU_SYSTEM:0" - attr { - key: "embedding_config" - value { - s: "" - } - } - attr { - key: "is_global_init" - value { - b: false - } - } - attr { - key: "tpu_embedding_config" - value { - s: "" - } - } + name: "_tf.foo" + op: "foo" + input: "Constant" } library { function { signature { - name: "mul_2_Da30D05wlPU" + name: "foo" input_arg { - name: "mul_2_da30d05wlpu" - type: DT_FLOAT - } - input_arg { - name: "mul_2_da30d05wlpu1" - type: DT_FLOAT + name: "arg" + type: DT_INT32 } output_arg { - name: "mul_2_da30d05wlpu2" - type: DT_FLOAT - } - } - node_def { - name: "mul/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - dim { - size: 1 - } - dim { - size: 1 - } - } - float_val: 2 - } - } - } - } - node_def { - name: "mul_0" - op: "Mul" - input: "mul_2_da30d05wlpu1" - input: "mul/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } + name: "return_value" + type: DT_INT32 } } ret { - key: "mul_2_da30d05wlpu2" - value: "mul_0:z:0" - } - attr { - key: "_noinline" - value { - b: true - } - } - } - function { - signature { - name: "less_than_5_If8q4vKg9jA" - input_arg { - name: "less_than_5_if8q4vkg9ja" - type: DT_FLOAT - } - output_arg { - name: "less_than_5_if8q4vkg9ja1" - type: DT_BOOL - } - } - node_def { - name: "Less/y" - op: "Const" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 5 - } - } - } - } - node_def { - name: "Less" - op: "Less" - input: "less_than_5_if8q4vkg9ja" - input: "Less/y:output:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - } - ret { - key: "less_than_5_if8q4vkg9ja1" - value: "Less:z:0" - } - attr { - key: "_noinline" - value { - b: true - } + key: "return_value" + value: "arg" } } } versions { - producer: 27 + producer: 62 min_consumer: 12 } -# CHECK: func @main() { -# CHECK: %30:2 = "_tf.less_than_5_If8q4vKg9jA0"(%23#0, %29#2) {_tpu_replicate = "cluster", device = "", name = "while/less_than_5_If8q4vKg9jA"} : (tensor<*xf32>, !_tf.control) -> (tensor<*xi1>, !_tf.control) -# CHECK: %73:2 = "_tf.mul_2_Da30D05wlPU0"(%58#0, %72#0, %47#1) {_tpu_replicate = "cluster", device = "", name = "while/mul_2_Da30D05wlPU"} : (tensor<*xf32>, tensor<*xf32>, !_tf.control) -> (tensor<*xf32>, !_tf.control) -# CHECK: return -# CHECK-NEXT: } -# CHECK: func @less_than_5_If8q4vKg9jA0(%arg0: tensor<*xf32>) -> tensor<*xi1> -# CHECK-NEXT: attributes {tf._noinline = true} { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Less/y", value = dense<5.000000e+00> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Less"(%arg0, %0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "Less"} : (tensor<*xf32>, tensor) -> (tensor<*xi1>, !_tf.control) -# CHECK-NEXT: return %1#0 : tensor<*xi1> -# CHECK-NEXT: } -# CHECK: func @mul_2_Da30D05wlPU0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> -# CHECK-NEXT: attributes {tf._noinline = true} { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "mul/y", value = dense<2.000000e+00> : tensor<1x1xf32>} : () -> (tensor<1x1xf32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Mul"(%arg1, %0#0) {T = "tfdtype$DT_FLOAT", device = "", name = "mul_0"} : (tensor<*xf32>, tensor<1x1xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: return %1#0 : tensor<*xf32> -# CHECK-NEXT: } +# Verify that we can import a custom operation that maps to a function and that +# the names are matching between the function definition and the uses / call +# site (a numerical suffix may be appended). + +# CHECK: "tf.foo0"( +# CHECK: func @foo0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt index 46682ab866e..b26d7e7f2ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-default-attr.pbtxt @@ -1,7 +1,15 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s -# CHECK: %3:2 = "_tf.Conv2D"(%2#0, %1#0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], name = "MobilenetV1/MobilenetV1/Conv2d_0/Conv2D", padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} -# CHECK-NEXT: %4:2 = "_tf.MaxPool"(%3#0) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", device = "", ksize = [1, 2, 2, 1], name = "MaxPool", padding = "SAME", strides = [1, 2, 2, 1]} +# Verify that the data_format attributes is pulled from the default value in the +# registry when not present in the GraphDef +# CHECK: tf.Conv2D +# CHECK-SAME: data_format = "NHWC" + +# Verify that we can also pull some attributes that are needed to be able to +# create a Graph in memory, like `T`. +# CHECK: tf.MaxPool +# CHECK-SAME: T = "tfdtype$DT_FLOAT" + node { name: "input" op: "Placeholder" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt index fcd0e62ab63..157db7d5331 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-device-retval.pbtxt @@ -74,6 +74,9 @@ library { } # The attribute "experimental_ints_on_device" and the return type INT32 # ensure that kDeviceRetOp is used instead of kRetOp + # CHECK-LABEL: func @foo + # CHECK: tf.experimental_ints_on_device = true + # CHECK: return %{{.*}} tensor attr { key: "experimental_ints_on_device" value { @@ -87,13 +90,3 @@ versions { min_consumer: 12 } -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.PartitionedCall"() {Tin = [], Tout = ["tfdtype$DT_INT32"], config = "", config_proto = "", device = "", executor_type = "", f = @foo0, name = "PartitionedCall"} : () -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK: func @foo0() -> tensor -# CHECK-NEXT: attributes {tf.experimental_ints_on_device = true} { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<5> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Identity"(%0#0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: return %1#0 : tensor -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt index 441eca84e7e..12d05c1195f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-empty-tensor-content.pbtxt @@ -1,6 +1,9 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s -# CHECK: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F464C4F41540A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20310A20207D0A7D0A"> : tensor<1xf32>} : () -> (tensor<1xf32>, !_tf.control) +# This test is intended to verify the tensor_content field on import of an empty +# tensor. +# CHECK: tf.Const +# CHECK-SAME: value = dense<0.000000e+00> node { name: "Const" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt index e8b9ce86ddb..0176edb4b21 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-func-attr.pbtxt @@ -1,5 +1,13 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# CHECK-LABEL: func @main() { + +# Verify that the NameAttrList is properly turned into reference to functions on import +# CHECK: tf.Case +# CHECK-SAME: branches = [@[[FOO:[a-z0-9]+]], @[[BAR:[a-z0-9]+]]] +# CHECK-DAG: func @[[FOO]]() +# CHECK-DAG: func @[[BAR]]() + node { name: "predicate" op: "Const" @@ -152,16 +160,3 @@ versions { min_consumer: 12 } -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "predicate", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo0, @bar0], device = "", name = "Case", output_shapes = []} : (tensor) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK: func @foo0() -> tensor<10xf32> { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_1", value = dense<1.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control) -# CHECK-NEXT: return %0#0 : tensor<10xf32> -# CHECK-NEXT: } -# CHECK: func @bar0() -> tensor<10xf32> { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "const_2", value = dense<2.000000e+00> : tensor<10xf32>} : () -> (tensor<10xf32>, !_tf.control) -# CHECK-NEXT: return %0#0 : tensor<10xf32> -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt new file mode 100644 index 00000000000..9238ea92a20 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-diff-island.pbtxt @@ -0,0 +1,111 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input=fail + +# Verify for functions with control return values, the island with only a +# consumed control return value has its control output added to the GraphOps +# FetchOp. + +# Match the island containing the "tf.Neg", capture the output +# CHECK: %[[ISLAND_0:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg" + +# Check that the tf.Neg control is passed to the fetch +# CHECK: tf_executor.fetch {{.*}} %[[ISLAND_0]]#1 : tensor<*xf32>, !tf_executor.control + +node { + name: "const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "test_fn_call" + op: "StatefulPartitionedCall" + input: "const" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "f" + value { + func { + name: "test_fn" + } + } + } +} +library { + function { + signature { + name: "test_fn" + input_arg { + name: "a" + type: DT_FLOAT + } + output_arg { + name: "d" + type: DT_FLOAT + } + control_output: "must_execute" + } + node_def { + name: "b" + op: "Neg" + input: "a" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "c" + op: "Identity" + input: "a" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "d" + value: "c:output:0" + } + control_ret { + key: "must_execute" + value: "b" + } + } +} +versions { + producer: 121 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt new file mode 100644 index 00000000000..adad8b109b6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-control-ret-same-island.pbtxt @@ -0,0 +1,100 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s --dump-input=fail + +# Verify for functions with control return values, the island with a consumed +# data output and a consumed control has both its outputs added to the GraphOps +# FetchOp. + +# Match the island containing the "tf.Neg", capture the output +# CHECK: %[[ISLAND:[0-9]*]]:2 = tf_executor.island {{.*[[:space:]].*}} "tf.Neg" + +# Check that the tf.Neg data output and control are passed to the fetch +# CHECK: tf_executor.fetch %[[ISLAND]]#0, %[[ISLAND]]#1 : tensor<*xf32>, !tf_executor.control + +node { + name: "const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0 + } + } + } +} +node { + name: "test_fn_call" + op: "StatefulPartitionedCall" + input: "const" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "f" + value { + func { + name: "test_fn" + } + } + } +} +library { + function { + signature { + name: "test_fn" + input_arg { + name: "a" + type: DT_FLOAT + } + output_arg { + name: "c" + type: DT_FLOAT + } + control_output: "must_execute" + } + node_def { + name: "b" + op: "Neg" + input: "a" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "c" + value: "b:y:0" + } + control_ret { + key: "must_execute" + value: "b" + } + } +} +versions { + producer: 121 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt index 40392a6954a..6a2a411d115 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-defs.pbtxt @@ -1,5 +1,11 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# Verify that we properly import call site function attributes. +# CHECK: tf.If +# CHECK-SAME: then_branch = @ +# CHECK-SAME: then_branch.how_many = 32 +# CHECK-SAME: then_branch.ping = "ack" + node { name: "Placeholder" op: "Placeholder" @@ -503,36 +509,3 @@ versions { producer: 27 min_consumer: 12 } - -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.ConfigureDistributedTPU"() {device = "/device:TPU_SYSTEM:0", embedding_config = "", is_global_init = false, name = "ConfigureDistributedTPU", tpu_embedding_config = ""} : () -> (tensor<*x!tf.string>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.TPUReplicatedInput"(%1#0) {N = 1 : i64, T = "tfdtype$DT_INT32", device = "", name = "input0"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %3:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %4:2 = "_tf.TPUReplicatedInput"(%3#0) {N = 1 : i64, T = "tfdtype$DT_INT32", device = "", name = "input1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %5 = "_tf.NoOp"() {device = "", name = "cluster/pivot"} : () -> !_tf.control -# CHECK-NEXT: %6 = "_tf.NoOp"(%5) {_tpu_replicate = "cluster", device = "", name = "NoOp"} : (!_tf.control) -> !_tf.control -# CHECK-NEXT: %7 = "_tf.TPUReplicateMetadata"(%5) {_tpu_replicate = "cluster", computation_shape = [], device = "", device_assignment = [], host_compute_core = [], name = "TPUReplicateMetadata", num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", use_tpu = true} : (!_tf.control) -> !_tf.control -# CHECK-NEXT: %8:2 = "_tf.TPUCompilationResult"(%7) {_tpu_compilation_status = "cluster", device = "", name = "TPUCompilationResult"} : (!_tf.control) -> (tensor, !_tf.control) -# CHECK-NEXT: %9:2 = "_tf.Identity"(%2#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %10:2 = "_tf.Identity"(%4#0, %7) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "replicated_input_1"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %11:2 = "_tf.Less"(%9#0, %10#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "", name = "Less"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi1>, !_tf.control) -# CHECK-NEXT: %12:3 = "_tf.If"(%11#0, %10#0, %9#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32", "tfdtype$DT_INT32"], _tpu_replicate = "cluster", device = "", else_branch = @cond_false0, name = "cond", output_shapes = ["tfshape$unknown_rank: true\0A", "tfshape$unknown_rank: true\0A"], then_branch = @cond_true0, then_branch.how_many = 32 : i64, then_branch.ping = "ack"} : (tensor<*xi1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %13:2 = "_tf.Identity"(%12#0) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %14:2 = "_tf.TPUReplicatedOutput"(%13#0) {T = "tfdtype$DT_INT32", device = "", name = "output0", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %15:2 = "_tf.Identity"(%14#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_0_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %16:2 = "_tf.Identity"(%12#1) {T = "tfdtype$DT_INT32", _tpu_replicate = "cluster", device = "/device:TPU_REPLICATED_CORE:0", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %17:2 = "_tf.TPUReplicatedOutput"(%16#0) {T = "tfdtype$DT_INT32", device = "", name = "output1", num_replicas = 1 : i64} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %18:2 = "_tf.Identity"(%17#0, %6) {T = "tfdtype$DT_INT32", device = "", name = "output_1_shard_0"} : (tensor<*xi32>, !_tf.control) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK: func @cond_false0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) { -# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: return %1#0, %0#0 : tensor<*xi32>, tensor<*xi32> -# CHECK-NEXT: } -# CHECK: func @cond_true0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) { -# CHECK-NEXT: %0:2 = "_tf.Identity"(%arg0) {T = "tfdtype$DT_INT32", device = "", name = "Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Identity"(%arg1) {T = "tfdtype$DT_INT32", device = "", name = "Identity_1"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: return %0#0, %1#0 : tensor<*xi32>, tensor<*xi32> -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt index 41107cfbff4..e0e60c04865 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-static-output.pbtxt @@ -1,5 +1,9 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# Verify that the return type of the functions is properly inferred +#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> +#CHECK: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> + node { name: "Placeholder" op: "Placeholder" @@ -139,16 +143,3 @@ versions { min_consumer: 12 } -#CHECK: func @main() { -#CHECK-NEXT: %0:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_BOOL", name = "Placeholder", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi1>, !_tf.control) -#CHECK-NEXT: %1:2 = "_tf.Placeholder"() {device = "", dtype = "tfdtype$DT_INT32", name = "Placeholder_1", shape = "tfshape$unknown_rank: true\0A"} : () -> (tensor<*xi32>, !_tf.control) -#CHECK-NEXT: %2:2 = "_tf.If"(%0#0, %1#0) {Tcond = "tfdtype$DT_BOOL", Tin = ["tfdtype$DT_INT32"], Tout = ["tfdtype$DT_INT32"], device = "", else_branch = @get_zeros0, name = "If", output_shapes = [], then_branch = @identity0} : (tensor<*xi1>, tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) -#CHECK-NEXT: return -#CHECK-NEXT: } -#CHECK: func @get_zeros0(%arg0: tensor<*xi32>) -> tensor<2xi32> { -#CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "const", value = dense<[1, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>, !_tf.control) -#CHECK-NEXT: return %0#0 : tensor<2xi32> -#CHECK-NEXT: } -#CHECK: func @identity0(%arg0: tensor<*xi32>) -> tensor<*xi32> { -#CHECK-NEXT: return %arg0 : tensor<*xi32> -#CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt index c1045bf19af..b7179ae1dcc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-gradient-def.pbtxt @@ -1,5 +1,12 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# In GraphDef custom gradient functions are modeled using GradientDef which +# links the function and its gradient. In MLIR a TF ops gradient function is +# added to its list of function attributes. + +# CHECK: func @foo0( +# CHECK-NEXT: tf.gradient = @foo_grad + node { name: "Const" op: "Const" @@ -269,26 +276,3 @@ versions { producer: 29 min_consumer: 12 } - -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "Const", value = dense<2.500000e-01> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.foo0"(%0#0) {_disable_call_shape_inference = true, device = "", name = "foo"} : (tensor) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.Shape"(%1#0) {T = "tfdtype$DT_FLOAT", device = "", name = "gradients/Shape", out_type = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> (tensor, !_tf.control) -# CHECK-NEXT: %3:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "gradients/grad_ys_0", value = dense<1.000000e+00> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %4:2 = "_tf.Fill"(%2#0, %3#0) {T = "tfdtype$DT_FLOAT", device = "", index_type = "tfdtype$DT_INT32", name = "gradients/Fill"} : (tensor, tensor) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %5:2 = "_tf.SymbolicGradient"(%0#0, %4#0) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], device = "", f = @foo0, f._disable_call_shape_inference = true, name = "gradients/foo_grad/SymbolicGradient"} : (tensor, tensor<*xf32>) -> (tensor, !_tf.control) -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK: func @foo_grad0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> -# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true} { -# CHECK-NEXT: %0:2 = "_tf.Mul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", device = "", name = "mul_0"} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: return %0#0 : tensor<*xf32> -# CHECK-NEXT: } -# CHECK: func @foo0(%arg0: tensor<*xf32>) -> tensor<*xf32> -# CHECK-NEXT: attributes {tf._disable_call_shape_inference = true, tf.gradient = @foo_grad0} { -# CHECK-NEXT: %0:2 = "_tf.Exp"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Exp"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Neg"(%arg0) {T = "tfdtype$DT_FLOAT", device = "", name = "Neg"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.Exp"(%1#0) {T = "tfdtype$DT_FLOAT", device = "", name = "Exp_1"} : (tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: %3:2 = "_tf.Sub"(%0#0, %2#0) {T = "tfdtype$DT_FLOAT", device = "", name = "sub_0"} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, !_tf.control) -# CHECK-NEXT: return %3#0 : tensor<*xf32> -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-functional-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt similarity index 64% rename from tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-functional-while-loop.pbtxt rename to tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt index 456bf4951bd..ba94c600cf2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-functional-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-input-func-arg-name-collision.pbtxt @@ -1,5 +1,12 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_INT32 -tf-input-shapes='' -tf-output-arrays=while:2 -o - | FileCheck %s +# This check that we don't error out when importing GraphDef containing +# functions with arg name that are the same as the graph input name + +# CHECK: func @main(%arg0: tensor) -> tensor +# CHECK: func @while_body +# CHECK: func @while_cond + node { name: "input" op: "Placeholder" @@ -295,23 +302,3 @@ versions { min_consumer: 12 } -# CHECK: func @main(%arg0: tensor) -> tensor -# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input", outputs = "while"}} { -# CHECK-NEXT: %0:2 = "_tf.Placeholder.input"(%arg0) {_user_specified_name = "input", device = "", dtype = "tfdtype$DT_INT32", name = "input", shape = "tfshape$"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/loop_counter", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "while/maximum_iterations", value = dense<-1> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %3:4 = "_tf.While"(%1#0, %2#0, %0#0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_INT32", "tfdtype$DT_INT32"], _lower_using_switch_merge = true, body = @while_body_60, cond = @while_cond_50, device = "", name = "while", output_shapes = ["tfshape$", "tfshape$", "tfshape$"], parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor, !_tf.control) -# CHECK-NEXT: return %3#2 : tensor -# CHECK-NEXT: } -# CHECK: func @while_body_60(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Add/y", value = dense<1> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "add_1/y", value = dense<1> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.Add"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %3:2 = "_tf.Add"(%arg0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "add_1"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: return %3#0, %arg1, %2#0 : tensor<*xi32>, tensor<*xi32>, tensor<*xi32> -# CHECK-NEXT: } -# CHECK: func @while_cond_50(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) -> tensor<*xi1> { -# CHECK-NEXT: %0:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Less/y", value = dense<10> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Less"(%arg2, %0#0) {T = "tfdtype$DT_INT32", device = "", name = "Less"} : (tensor<*xi32>, tensor) -> (tensor<*xi1>, !_tf.control) -# CHECK-NEXT: return %1#0 : tensor<*xi1> -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt index 83ca4466869..17b2655aa5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt @@ -36,15 +36,13 @@ versions { min_consumer: 12 } -# CHECK: func @main() { -# CHECK-NEXT: %0 = "_tf.foo0"() {device = "", name = "unnamed"} : () -> !_tf.control -# CHECK-NEXT: %1 = "_tf.bar0"() {device = "", name = "unnamed1"} : () -> !_tf.control -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK: func @foo0() { -# CHECK-NEXT: %0 = "_tf.bar0"() {device = "", name = "unnamed"} : () -> !_tf.control -# CHECK-NEXT: return -# CHECK-NEXT: } -# CHECK: func @bar0() { -# CHECK-NEXT: return -# CHECK-NEXT: } +# Verify that functions from the library are properly imported. + +# CHECK-LABEL: func @main() { +# CHECK: "tf.foo0"() +# CHECK: "tf.bar0"() + +# CHECK-LABEL: func @foo0() { +# CHECK: "tf.bar0"() + +# CHECK-LABEL: func @bar0() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt index 97e22256495..0a5aba285dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-malformed.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +# RUN: tf-mlir-translate -graphdef-to-mlir %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 this is not a valid graph def diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt index daef0054fd6..37f7a876814 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-scalar-input.pbtxt @@ -1,5 +1,16 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=out:1,out -o - | FileCheck %s +# Verify that we match correctly the input / output when they are scalar. + +# CHECK: func @main(%arg0: tensor) -> (tensor, tensor) +# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input", outputs = "out"}} { +# CHECK: "tf.Placeholder.input"(%arg0) + +# CHECK: tf.Relu +# CHECK: %[[IDENTITY:[0-9]+]]:3 = tf_executor.island +# CHECK-NEXT: tf.Identity +# CHECK: fetch %[[IDENTITY]]#1, %[[IDENTITY]]#0 : tensor, tensor + node { name: "input" op: "Placeholder" @@ -52,11 +63,3 @@ node { versions { producer: 27 } - -# CHECK: func @main(%arg0: tensor) -> (tensor, tensor) -# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input", outputs = "out"}} { -# CHECK-NEXT: %0:2 = "_tf.Placeholder.input"(%arg0) {device = "/device:CPU:0", dtype = "tfdtype$DT_FLOAT", name = "input", shape = "tfshape$"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Relu"(%0#0) {T = "tfdtype$DT_FLOAT", device = "/device:CPU:0", name = "Relu"} : (tensor) -> (tensor, !_tf.control) -# CHECK-NEXT: %2:3 = "_tf.IdentityN"(%1#0, %1#0) {T = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], device = "", name = "out"} : (tensor, tensor) -> (tensor, tensor, !_tf.control) -# CHECK-NEXT: return %2#1, %2#0 : tensor, tensor -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt index 32b816f5e39..9ae5601fa57 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-uint8-return.pbtxt @@ -104,8 +104,8 @@ versions { } # CHECK: func @main -# CHECK: "_tf.PartitionedCall"() +# CHECK: "tf.PartitionedCall"() # CHECK-SAME: Tout = ["tfdtype$DT_UINT8"] # CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]] # CHECK: func @[[FUNCTION]]() -> tensor -# CHECK: return {{%[0-9]*#[0-9]*}} : tensor +# CHECK: return {{.*}} : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt index 4fa8407c0dd..6816088322d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-undefined-output.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=NotANodeInTheGraph -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 # CHECK: Graph import failed: Invalid argument: Output NotANodeInTheGraph was not found in graph diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt index 5f8e7854161..20bf33d7fb2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-version-info.pbtxt @@ -29,7 +29,6 @@ node { size: 2 } } - tensor_content: "\350\251\242>\276\335r?" } } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt index ac84234e4ac..4ada2f6f71c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-while-loop.pbtxt @@ -1,5 +1,14 @@ # RUN: tf-mlir-translate -graphdef-to-mlir -mlir-print-debuginfo %s -o - | FileCheck %s +# Verify that importing a Graph with a backedge leads to two NextIteration nodes +# to break the cycle. + +# CHECK-LABEL: func @main() +# CHECK: %[[NEXTITERATION:[0-9]+]]:3 = tf_executor.NextIteration.Source +# CHECK: tf_executor.Merge {{.*}} %[[NEXTITERATION]]#0 + +# CHECK: tf_executor.NextIteration.Sink [%[[NEXTITERATION]]#1] + node { name: "Const" op: "Const" @@ -203,20 +212,3 @@ versions { producer: 27 } -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", name = "while/NextIteration"} : () -> (tensor<*xi32>, !_tf.control) loc("while/NextIteration") -# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<0> : tensor} : () -> (tensor, !_tf.control) loc("Const") -# CHECK-NEXT: %2:2 = "_tf.Enter"(%1#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while/while_context", is_constant = false, name = "while/Enter", parallel_iterations = 10 : i64} : (tensor) -> (tensor<*xi32>, !_tf.control) loc("while/Enter") -# CHECK-NEXT: %3:3 = "_tf.Merge"(%2#0, %0#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "while/Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor, !_tf.control) loc("while/Merge") -# CHECK-NEXT: %4:2 = "_tf.Const"(%3#2) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Less/y", value = dense<10> : tensor} : (!_tf.control) -> (tensor, !_tf.control) loc("while/Less/y") -# CHECK-NEXT: %5:2 = "_tf.Less"(%3#0, %4#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Less"} : (tensor<*xi32>, tensor) -> (tensor<*xi1>, !_tf.control) loc("while/Less") -# CHECK-NEXT: %6:2 = "_tf.LoopCond"(%5#0) {device = "", name = "while/LoopCond"} : (tensor<*xi1>) -> (tensor, !_tf.control) loc("while/LoopCond") -# CHECK-NEXT: %7:3 = "_tf.Switch"(%3#0, %6#0) {T = "tfdtype$DT_INT32", _class = ["loc:@while/Merge"], device = "", name = "while/Switch"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, tensor<*xi32>, !_tf.control) loc("while/Switch") -# CHECK-NEXT: %8:2 = "_tf.Exit"(%7#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Exit"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) loc("while/Exit") -# CHECK-NEXT: %9:2 = "_tf.Identity"(%7#1) {T = "tfdtype$DT_INT32", device = "", name = "while/Identity"} : (tensor<*xi32>) -> (tensor<*xi32>, !_tf.control) loc("while/Identity") -# CHECK-NEXT: %10:2 = "_tf.Const"(%9#1) {device = "", dtype = "tfdtype$DT_INT32", name = "while/Add/y", value = dense<1> : tensor} : (!_tf.control) -> (tensor, !_tf.control) loc("while/Add/y") -# CHECK-NEXT: %11:2 = "_tf.Add"(%9#0, %10#0) {T = "tfdtype$DT_INT32", device = "", name = "while/Add"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, !_tf.control) loc("while/Add") -# CHECK-NEXT: %12 = "_tf.NextIteration.sink"(%11#0) {T = "tfdtype$DT_INT32", device = "", name = "while/NextIteration"} : (tensor<*xi32>) -> !_tf.control loc("while/NextIteration") -# CHECK-NEXT: return loc(unknown) -# CHECK-NEXT: } - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt new file mode 100644 index 00000000000..6fec080be58 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/invalid-output-index.pbtxt @@ -0,0 +1,14 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input:1 -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 + +# CHECK: Graph import failed: Invalid argument: Invalid output index 1 specified for node: input + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt new file mode 100644 index 00000000000..c6d00a6f337 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/legacy-fed-input-without-inputs.pbtxt @@ -0,0 +1,30 @@ +# RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-input-arrays=input -tf-input-data-types=DT_FLOAT -tf-input-shapes='' -tf-output-arrays=input -tf-convert-legacy-fed-inputs -o - | FileCheck %s + +# Verify that invalid LegacyFedInput ops without any inputs are replaced with +# Placeholder ops. + +# CHECK-NOT: LegacyFedInput +# CHECK: tf.Placeholder.input{{.*}}(tensor) -> tensor +# CHECK-NOT: LegacyFedInput + +node { + name: "input" + op: "LegacyFedInput" + attr { + key: "input_def" + value { + s: "name: \"batch_1\"\n[dist_belief.ImageInputDef.ext] {\n num_rows: 128\n num_cols: 128\n mean_value: 128\n std_value: 128\n colorspace: RGB\n}\n" + } + } + attr { + key: "output_types" + value { + list { + type: DT_FLOAT + } + } + } +} +versions { + producer: 27 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt index 6baa4973407..09a900e8917 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/multiple-use-next-iteration.pbtxt @@ -1,5 +1,13 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -o - | FileCheck %s +# Verify that a NextIteration node feeding two different merge nodes is properly +# Imported. + +# CHECK-LABEL: func @main() +# CHECK: %[[NEXTITERATION:[0-9]+]]:3 = tf_executor.NextIteration.Source +# CHECK: tf_executor.Merge {{.*}}, %[[NEXTITERATION]]#0 +# CHECK: tf_executor.Merge {{.*}}, %[[NEXTITERATION]]#0 + node { name: "Const" op: "Const" @@ -137,14 +145,3 @@ versions { producer: 62 } -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.NextIteration.source"() {T = "tfdtype$DT_INT32", device = "", name = "NextIteration"} : () -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Add/y", value = dense<1> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Const", value = dense<0> : tensor} : () -> (tensor, !_tf.control) -# CHECK-NEXT: %3:2 = "_tf.Enter"(%2#0) {T = "tfdtype$DT_INT32", device = "", frame_name = "while_context", is_constant = false, name = "Enter", parallel_iterations = 10 : i64} : (tensor) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %4:3 = "_tf.Merge"(%3#0, %0#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "Merge"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor, !_tf.control) -# CHECK-NEXT: %5:2 = "_tf.Add"(%4#0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<*xi32>, tensor) -> (tensor<*xi32>, !_tf.control) -# CHECK-NEXT: %6 = "_tf.NextIteration.sink"(%5#0) {T = "tfdtype$DT_INT32", device = "", name = "NextIteration"} : (tensor<*xi32>) -> !_tf.control -# CHECK-NEXT: %7:3 = "_tf.Merge"(%3#0, %0#0) {N = 2 : i64, T = "tfdtype$DT_INT32", device = "", name = "Use_NextIteration_Again"} : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor, !_tf.control) -# CHECK-NEXT: return -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt index a745cf302e9..7715a0eb9df 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/prune_unused_nodes.pbtxt @@ -1,5 +1,10 @@ # RUN: tf-mlir-translate -graphdef-to-mlir %s -tf-prune-unused-nodes -tf-input-arrays=input0,input1,unused_input -tf-input-data-types=DT_INT32,DT_INT32,DT_INT32 -tf-input-shapes=10:10:10 -tf-output-arrays=Add -o - | FileCheck %s +# Verify that an unused Node (here named "Prune") isn't converted when we +# request pruning on import. +# CHECK-LABEL: func @main +# CHECK-NOT: Prune + node { name: "Prune" op: "Const" @@ -66,13 +71,3 @@ node { versions { producer: 27 } - -# CHECK: func @main(%arg0: tensor<10xi32>, %arg1: tensor<10xi32>, %arg2: tensor<10xi32>) -> tensor<10xi32> -# CHECK-NEXT: attributes {tf.entry_function = {inputs = "input0, input1, unused_input", outputs = "Add"}} { -# CHECK-NEXT: %0:2 = "_tf.Placeholder.input"(%arg0) {device = "", dtype = "tfdtype$DT_INT32", name = "input0", shape = "tfshape$dim {\0A size: 10\0A}\0A"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK-NEXT: %1:2 = "_tf.Placeholder.input"(%arg1) {device = "", dtype = "tfdtype$DT_INT32", name = "input1", shape = "tfshape$dim {\0A size: 10\0A}\0A"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK-NEXT: %2:2 = "_tf.Add"(%0#0, %1#0) {T = "tfdtype$DT_INT32", device = "", name = "Add"} : (tensor<10xi32>, tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK-NEXT: %3:2 = "_tf.Placeholder.input"(%arg2) {device = "", dtype = "tfdtype$DT_INT32", name = "unused_input", shape = "tfshape$dim {\0A size: 10\0A}\0A"} : (tensor<10xi32>) -> (tensor<10xi32>, !_tf.control) -# CHECK-NEXT: return %2#0 : tensor<10xi32> -# CHECK-NEXT: } - diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt index 096264737da..748bc996f36 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/quint8-const.pbtxt @@ -27,6 +27,6 @@ versions { producer: 70 } -# CHECK: "_tf.Const"() +# CHECK: tf.Const # CHECK-SAME: name = "Quantized_Constant" # CHECK-SAME: value = opaque<"tf", "{{0[xX][0-9a-fA-F]*}}"> : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt index 32007150bcd..54877e873e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/stateful-attribute.pbtxt @@ -82,7 +82,7 @@ versions { # Find PartitionedCall ops in main and match the callee name. # CHECK: func @main -# CHECK: "_tf.PartitionedCall" +# CHECK: "tf.PartitionedCall" # CHECK-SAME: f = @[[FUNCTION_FOO:[a-zA-Z0-9_]*]] # Find callee and verify it has the stateful attribute set. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt index 790fb0c7334..707b04473f3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/string-attr.pbtxt @@ -1,4 +1,9 @@ # RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s + +# CHECK: tf.Const +# CHECK-SAME: _output_shapes = ["tfshape$dim { size: 3 }"] +# CHECK-SAME: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E472074656E736F725F7368617065207B2064696D207B2073697A653A2033207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C30303022"> : tensor<3x!tf.string> + node { name: "save/SaveV2/shape_and_slices" op: "Const" @@ -40,8 +45,3 @@ node { versions { producer: 74 } - -# CHECK: func @main() { -# CHECK-NEXT: %0:2 = "_tf.Const"() {_output_shapes = ["tfshape$dim {\0A size: 3\0A}\0A"], device = "", dtype = "tfdtype$DT_STRING", name = "save/SaveV2/shape_and_slices", value = opaque<"tf", "0x746674656E736F722464747970653A2044545F535452494E470A74656E736F725F7368617065207B0A202064696D207B0A2020202073697A653A20330A20207D0A7D0A737472696E675F76616C3A2022220A737472696E675F76616C3A2022220A737472696E675F76616C3A2022220A"> : tensor<3x!tf.string>} : () -> (tensor<3x!tf.string>, !_tf.control) -# CHECK-NEXT: return -# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt new file mode 100644 index 00000000000..ea3b143d63e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/switch_n.pbtxt @@ -0,0 +1,270 @@ +# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s --dump-input-on-failure + +# CHECK: tf_executor.SwitchN +# CHECK-SAME: of 3 : tensor +# CHECK-SAME: T = "tfdtype$DT_INT32" +# CHECK-SAME: name = "Case/branch_index/_3" + +node { + name: "Case/branch_index" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "Case/input_0" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 1 + } + } + } +} +node { + name: "Case/branch_index/_3" + op: "_SwitchN" + input: "Case/branch_index" + input: "Case/branch_index" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "num_outs" + value { + i: 3 + } + } +} +node { + name: "Case/Case/input_0/_7" + op: "_SwitchN" + input: "Case/input_0" + input: "Case/branch_index" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "Case/input_0" + } + } + } + attr { + key: "num_outs" + value { + i: 3 + } + } +} +node { + name: "Case/pivot_0/_4" + op: "Identity" + input: "Case/branch_index/_3" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "Case/pivot_1/_5" + op: "Identity" + input: "Case/branch_index/_3:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "Case/pivot_2/_6" + op: "Identity" + input: "Case/branch_index/_3:2" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "Case/branch0/_0/mul/y" + op: "Const" + input: "^Case/pivot_0/_4" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 2 + } + } + } +} +node { + name: "Case/branch1/_1/mul/y" + op: "Const" + input: "^Case/pivot_1/_5" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 3 + } + } + } +} +node { + name: "Case/branch2/_2/mul/y" + op: "Const" + input: "^Case/pivot_2/_6" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 4 + } + } + } +} +node { + name: "Case/branch0/_0/mul_0" + op: "Mul" + input: "Case/Case/input_0/_7" + input: "Case/branch0/_0/mul/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Case/branch1/_1/mul_0" + op: "Mul" + input: "Case/Case/input_0/_7:1" + input: "Case/branch1/_1/mul/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Case/branch2/_2/mul_0" + op: "Mul" + input: "Case/Case/input_0/_7:2" + input: "Case/branch2/_2/mul/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Case/merge/_9" + op: "Merge" + input: "Case/branch0/_0/mul_0" + input: "Case/branch1/_1/mul_0" + input: "Case/branch2/_2/mul_0" + attr { + key: "N" + value { + i: 3 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "__inference_run_240_RetVal" + op: "_Retval" + input: "Case/merge/_9" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "index" + value { + i: 0 + } + } +} +library { +} +versions { + producer: 126 +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt index a8802a99456..cc24caae6e8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/tensor-list.pbtxt @@ -209,10 +209,10 @@ versions { } # Verify that list element shape and dtype are expected. -# CHECK: _tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> (tensor>>, !_tf.control) +# CHECK: tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> tensor>> # Nested variant type. -# CHECK: _tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> (tensor>>, !_tf.control) +# CHECK: tf.TensorListReserve{{.*}}(tensor<2xi32>, tensor) -> tensor>> -# CHECK: _tf.TensorListSetItem{{.*}}(tensor>>, tensor, tensor<2x2xf32>) -> (tensor>>, !_tf.control) -# CHECK: _tf.TensorListStack{{.*}}(tensor>>, tensor) -> (tensor, !_tf.control) +# CHECK: tf.TensorListSetItem{{.*}}(tensor>>, tensor, tensor<2x2xf32>) -> tensor>> +# CHECK: tf.TensorListStack{{.*}}(tensor>>, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir index 2259d301dc8..4566ffb507c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/isolate-placer.mlir @@ -13,15 +13,19 @@ func @foo(%arg0: tensor) -> tensor { // The IsolatePlacerInspectionRequiredOpsPass adds Identities for each input/output of function-calling ops. // Capture the result of input to function call. -// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = "_tf.VarHandleOp"() +// CHECK: [[VARIABLE_REG:%[0-9]*]]:2 = tf_executor.island +// CHECK-NEXT: "tf.VarHandleOp"() // Test for the presence of Identity op between input and function call. -// CHECK-NEXT: [[IDENTITY_REG:%[0-9]*]]:2 = "_tf.Identity"([[VARIABLE_REG]]#0) -// CHECK-NEXT: [[CALL_RESULT_REG:%[0-9]*]]:2 = "_tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0) +// CHECK: [[IDENTITY_REG:%[0-9]*]]:2 = tf_executor.island +// CHECK-NEXT: "tf.Identity"([[VARIABLE_REG]]#0) + +// CHECK: [[CALL_RESULT_REG:%[0-9]*]]:2 = tf_executor.island +// CHECK-NEXT: "tf.StatefulPartitionedCall"([[IDENTITY_REG]]#0) // CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]] // Match the inserted Identity op for call output. -// CHECK-NEXT: "_tf.Identity"([[CALL_RESULT_REG]]#0) +// CHECK: "tf.Identity"([[CALL_RESULT_REG]]#0) // Match the function name // CHECK: func @[[FUNCTION]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir new file mode 100644 index 00000000000..52e4c529815 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/convert_tensor.mlir @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main() -> (tensor<1x2xf16>, tensor<2xf16>) { + %0:2 = "_tf.Const"() {device = "", name = "foo", dtype = "tfdtype$DT_HALF", value = dense<1.0> : tensor<1x2xf16>} : () -> (tensor<1x2xf16>, !_tf.control) + %1:2 = "_tf.Const"() {device = "", name = "bar", dtype = "tfdtype$DT_HALF", value = dense<[1.0, 2.0]> : tensor<2xf16>} : () -> (tensor<2xf16>, !_tf.control) + return %0#0, %1#0 : tensor<1x2xf16>, tensor<2xf16> + +// CHECK: node { +// CHECK-NEXT: name: "foo" +// CHECK-NEXT: op: "Const" +// CHECK: half_val: 15360 +// CHECK: name: "bar" +// CHECK-NEXT: op: "Const" +// CHECK: half_val: 15360 +// CHECK: half_val: 16384 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir new file mode 100644 index 00000000000..ccd058842a9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir @@ -0,0 +1,34 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "tf.Placeholder.input"(%arg0) : (tensor) -> tensor + %1 = "tf.Placeholder.input"(%arg1) : (tensor) -> tensor + %2 = "tf.Less"(%0, %1) : (tensor, tensor) -> tensor + %3 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor, tensor, tensor) -> tensor loc("StatefulIf") + %4 = "tf.If"(%2, %0, %1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = true} : (tensor, tensor, tensor) -> tensor loc("StatelessIf") + return %3, %4 : tensor, tensor +} + +func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// Verify that If op is mapped to TensorFlow StatelessIf op if the is_stateless +// attribute is present and otherwise it is mapped to TensorFlow If op. In both +// cases, the additional attribute should be dropped. + +// CHECK: name: "StatefulIf" +// CHECK-NOT: name: +// CHECK: op: "If" +// CHECK-NOT: is_stateless + +// CHECK: name: "StatelessIf" +// CHECK-NOT: name: +// CHECK: op: "StatelessIf" +// CHECK-NOT: is_stateless diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir new file mode 100644 index 00000000000..0009c7a4dc4 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir @@ -0,0 +1,43 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %iter = "tf.Placeholder.input"(%arg0) : (tensor) -> tensor loc("iter") + %val = "tf.Placeholder.input"(%arg1) : (tensor) -> tensor loc("val") + + // Element wise add `val` with itself for `iter` number of times. + %2:2 = "tf.While"(%iter, %val) { + cond = @cond, body = @body, is_stateless = false + } : (tensor, tensor) -> (tensor, tensor) loc("StatefulWhile") + %3:2 = "tf.While"(%iter, %val) { + cond = @cond, body = @body, is_stateless = true + } : (tensor, tensor) -> (tensor, tensor) loc("StatelessWhile") + + return %2#1, %3#1 : tensor, tensor +} + +func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { + %0 = "tf.Const" () {value = dense<0> : tensor} : () -> tensor loc("Const") + %1 = "tf.Greater"(%arg0, %0) : (tensor<*xi32>, tensor) -> tensor + return %1 : tensor +} + +func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor<*xf32>) { + %0 = "tf.Const" () {value = dense<1> : tensor} : () -> tensor loc("Const") + %1 = "tf.Sub"(%arg0, %0) : (tensor<*xi32>, tensor) -> tensor<*xi32> + %2 = "tf.Add"(%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %1, %2 : tensor<*xi32>, tensor<*xf32> +} + +// Verify that While op is mapped to TensorFlow StatelessWhile op if the +// is_stateless attribute is present and otherwise it is mapped to TensorFlow +// While op. In both cases, the additional attribute should be dropped. + +// CHECK: name: "StatefulWhile" +// CHECK-NOT: name: +// CHECK: op: "While" +// CHECK-NOT: is_stateless + +// CHECK: name: "StatelessWhile" +// CHECK-NOT: name: +// CHECK: op: "StatelessWhile" +// CHECK-NOT: is_stateless diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir index 041be4b9fe0..f73e93369d5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/missing-main.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[1]} -eq 0 +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - 2>&1 | FileCheck %s; test ${PIPESTATUS[0]} -ne 0 // CHECK: Graph export failed: Failed precondition: entry function `main` must be present diff --git a/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir new file mode 100644 index 00000000000..a3c2d18c671 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/optimize.mlir @@ -0,0 +1,35 @@ +// RUN: tf-opt -tf-optimize %s | FileCheck %s + +// CHECK-LABEL: convbiasaddmul +func @convbiasaddmul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> { + %filter = constant dense<2.0> : tensor<3x3x3x16xf32> + %bias = constant dense<3.0> : tensor<16xf32> + %value = constant dense<4.0> : tensor<16xf32> + %0 = "tf.Conv2D"(%arg, %filter) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> + %1 = "tf.BiasAdd"(%0, %bias) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"}: (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + %2 = "tf.Mul"(%1, %value) {T = "tfdtype$DT_FLOAT"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + return %2 : tensor<256x30x30x16xf32> + +// CHECK-NEXT: %[[cst:.*]] = constant dense<8.000000e+00> : tensor<3x3x3x16xf32> +// CHECK-NEXT: %[[cst_0:.*]] = constant dense<1.200000e+01> : tensor<16xf32> +// CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) +// CHECK-NEXT: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) +// CHECK-NEXT: return %[[bias]] : tensor<256x30x30x16xf32> +} + +// CHECK-LABEL: convaddv2mul +func @convaddv2mul(%arg: tensor<256x32x32x3xf32>) -> tensor<256x30x30x16xf32> { + %filter = constant dense<2.0> : tensor<3x3x3x16xf32> + %bias = constant dense<3.0> : tensor<16xf32> + %value = constant dense<4.0> : tensor<16xf32> + %0 = "tf.Conv2D"(%arg, %filter) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x30x30x16xf32> + %1 = "tf.AddV2"(%0, %bias) {T = "tfdtype$DT_FLOAT"}: (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + %2 = "tf.Mul"(%1, %value) {T = "tfdtype$DT_FLOAT"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> + return %2 : tensor<256x30x30x16xf32> + +// CHECK-NEXT: %[[cst:.*]] = constant dense<8.000000e+00> : tensor<3x3x3x16xf32> +// CHECK-NEXT: %[[cst_0:.*]] = constant dense<1.200000e+01> : tensor<16xf32> +// CHECK-NEXT: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) +// CHECK-NEXT: %[[add:.*]] = "tf.AddV2"(%[[conv]], %[[cst_0]]) +// CHECK-NEXT: return %[[add]] : tensor<256x30x30x16xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-control.mlir b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-control.mlir new file mode 100644 index 00000000000..271b6ec92f9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-control.mlir @@ -0,0 +1,12 @@ +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=MlirRoundtripPass | FileCheck %s --dump-input-on-failure + +// The test uses the tf_graph_optimization_pass to run the MlirRoundtripPass. +// We convert mlir -> Graph -> mlir -> Graph -> mlir + +func @main() { + "_tf.NoOp"() {} : () -> () loc("X") + return +} + +// Check for the presence of tf.NoOp in the final output. +// CHECK: tf.NoOp \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir new file mode 100644 index 00000000000..6b245236d35 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/roundtrip-tf-executor.mlir @@ -0,0 +1,19 @@ +// RUN: tf-opt %s --run-tf-graph-optimization --graph-passes=MlirRoundtripPass | FileCheck %s --dump-input-on-failure + +module { + func @main() { + tf_executor.graph { + %0 = tf_executor.island { + "tf.NoOp"() {} : () -> () loc("X") + tf_executor.yield + } + tf_executor.fetch + } + return + } +} + +// The test uses the tf_graph_optimization_pass to run the MlirRoundtripPass. +// We convert mlir -> Graph -> mlir -> Graph -> mlir +// Check for the presence of tf.NoOp in the final output. +// CHECK: tf.NoOp diff --git a/tensorflow/compiler/mlir/tensorflow/tests/savedmodel2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/savedmodel2mlir/BUILD new file mode 100644 index 00000000000..ff3b70a22c9 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/savedmodel2mlir/BUILD @@ -0,0 +1,19 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +package(licenses = ["notice"]) + +tf_cc_test( + name = "half_plus_two", + srcs = ["half_plus_two.cc"], + data = [ + "//tensorflow/cc/saved_model:saved_model_half_plus_two", + ], + deps = [ + "//tensorflow/cc/saved_model:tag_constants", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/savedmodel2mlir/half_plus_two.cc b/tensorflow/compiler/mlir/tensorflow/tests/savedmodel2mlir/half_plus_two.cc new file mode 100644 index 00000000000..b18e6c0b188 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/savedmodel2mlir/half_plus_two.cc @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/cc/saved_model/tag_constants.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +// TODO(silvasean): Add a FileCheck based testing harness for SavedModel to +// replace the following. The source should be TensorFlow Python code. Then we +// can generate SavedModel directories on the fly and import them. Check +// directives can be embedded into the same file as the source. +TEST(SavedModel, HalfPlusTwo) { + const char kSavedModel[] = "cc/saved_model/testdata/half_plus_two/00000123"; + const std::string saved_model_dir = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), kSavedModel); + std::unordered_set tags{tensorflow::kSavedModelTagServe}; + + mlir::MLIRContext context; + auto module = tensorflow::SavedModelToMlirImport( + saved_model_dir, tags, /*debug_info_file=*/"", &context); + auto* block = module->getBody(); + + // testdata/half_plus_two does not use any functions. So we only have the + // mandatory module terminator op inside its block. + EXPECT_TRUE(std::next(block->begin()) == block->end()); +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 3b21c528c90..dd6d77f7816 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics | FileCheck %s --dump-input=fail //===--------------------------------------------------------------------===// // Test TF opaque attributes @@ -65,6 +65,32 @@ func @testTFComplex(tensor<*x!tf.complex64>, tensor<*x!tf.complex128>) -> (!tf.c // ----- +// CHECK-LABEL: func @testIdentity +func @testIdentity(%arg0: tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string> { + // CHECK: tf.Identity + %0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.stringref>) -> tensor<4x2x!tf.string> + return %0 : tensor<4x2x!tf.string> +} + +// ----- + +// CHECK-LABEL: func @testBitcast +func @testBitcast(%arg0: tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> { + // CHECK: tf.Bitcast + %0 = "tf.Bitcast"(%arg0) : (tensor<3x4x!tf.uint16>) -> tensor<3x4x!tf.quint16> + return %0 : tensor<3x4x!tf.quint16> +} + +// ----- + +func @testIdentityWrongType(%arg0: tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> { + // expected-error @+1 {{requires all operands to be either same as or ref type of results}} + %0 = "tf.Identity"(%arg0) : (tensor<4x2x!tf.string>) -> tensor<4x2x!tf.stringref> + return %0 : tensor<4x2x!tf.stringref> +} + +// ----- + // TODO(hinsu): Move this to MLIR core once the test dialect have a custom type. // Check that broadcastable trait accepts TF specific element type @@ -133,9 +159,18 @@ func @testLeakyWrongAlphaType(tensor<16xf32>) -> tensor<16xf32> { } // ----- -// CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>) -func @testReshape(tensor<*xf32>, tensor<*xf32>, tensor<10000xf32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>) { -^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>): + +// CHECK-LABEL: func @testMul +func @testMul(%arg0: tensor<2x!tf.uint16>) -> (tensor<2x!tf.uint16>) { + // CHECK: tf.Mul + %0 = "tf.Mul"(%arg0, %arg0) {T = "tfdtype$DT_UINT16", device = "/device:CPU:0", name = "Mul"} : (tensor<2x!tf.uint16>, tensor<2x!tf.uint16>) -> tensor<2x!tf.uint16> + return %0 : tensor<2x!tf.uint16> +} + +// ----- + +// CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) +func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) { // CHECK: %cst = constant dense<100> : tensor<2xi32> %shape1 = constant dense<100> : tensor<2xi32> // CHECK: %0 = "tf.Reshape"(%arg0, %cst) : (tensor<*xf32>, tensor<2xi32>) -> tensor<100x100xf32> @@ -150,7 +185,11 @@ func @testReshape(tensor<*xf32>, tensor<*xf32>, tensor<10000xf32>) -> (tensor<10 %shape3 = constant dense<[-1, 100]> : tensor<2xi32> // CHECK: %4 = "tf.Reshape"(%arg2, %cst_0) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32", device = "", name = "Reshape_1"} : (tensor<10000xf32>, tensor<2xi32>) -> tensor<100x100xf32> %r4 = "tf.Reshape"(%arg2, %shape3) {device = "", name = "Reshape_1", T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<2xi32>) -> (tensor<100x100xf32>) - return %r1, %r2, %r3, %r4: tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32> + // CHECK: "tf.Reshape"(%arg0, %arg3) + %r5 = "tf.Reshape"(%arg0, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<*xf32>, tensor<*xi32>) -> (tensor<*xf32>) + // CHECK: "tf.Reshape"(%arg2, %arg3) + %r6 = "tf.Reshape"(%arg2, %arg3) {T = "tfdtype$DT_FLOAT", Tshape = "tfdtype$DT_INT32"} : (tensor<10000xf32>, tensor<*xi32>) -> (tensor<*xf32>) + return %r1, %r2, %r3, %r4, %r5, %r6: tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32> } // ----- @@ -190,6 +229,14 @@ func @testReshape(%arg0: tensor<10x10x10xf32>) -> tensor<100x100xf32> { return %r1 : tensor<100x100xf32> } +// ----- +// tf.Reshape with a first operand that has non-static shape. +func @testReshape(%arg0: tensor<10x10x?xf32>) -> tensor<10x10xf32> { + %shape1 = constant dense<[10, 10]> : tensor<2xi32> + %r1 = "tf.Reshape" (%arg0, %shape1) : (tensor<10x10x?xf32>, tensor<2xi32>) -> (tensor<10x10xf32>) + return %r1 : tensor<10x10xf32> +} + // ----- // CHECK-LABEL: func @testValidAvgPool @@ -478,7 +525,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32> func @testValidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): %1 = "tf.If"(%arg0, %arg1) { - then_branch = @testIfThen, else_branch = @testIfElse + then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false } : (tensor, tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -492,10 +539,11 @@ func @testIfElse(f32) -> f32 // Test invalid tf.If operation func @testInvalidIfOp(tensor, f32) -> f32 { ^bb0(%arg0: tensor, %arg1: f32): - // expected-error @+1 {{requires operands to have a valid TensorFlow tensor type}} + // expected-error @+1 {{operand #1 must be tensor of tf.dtype values}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, - else_branch = @testIfElse + else_branch = @testIfElse, + is_stateless = false } : (tensor, f32) -> f32 return %1 : f32 @@ -508,9 +556,9 @@ func @testIfElse(tensor<2xf32>) -> tensor<2xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{requires then_branch attribute}} + // expected-error @+1 {{requires attribute 'then_branch'}} %1 = "tf.If"(%arg0, %arg1) { - else_branch = @testIfElse + else_branch = @testIfElse, is_stateless = false } : (tensor, tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -527,7 +575,8 @@ func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{branches should have 1 inputs}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, - else_branch = @testIfElse + else_branch = @testIfElse, + is_stateless = false } : (tensor, tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -544,7 +593,8 @@ func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { // expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, - else_branch = @testIfElse + else_branch = @testIfElse, + is_stateless = false } : (tensor, tensor<2xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -561,7 +611,8 @@ func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { // expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, - else_branch = @testIfElse + else_branch = @testIfElse, + is_stateless = false } : (tensor, tensor<*xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -578,7 +629,8 @@ func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { // expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, - else_branch = @testIfElse + else_branch = @testIfElse, + is_stateless = false } : (tensor, tensor<*xf32>) -> tensor<2xf32> return %1 : tensor<2xf32> @@ -615,12 +667,31 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { ^bb0(%arg0: tensor<*xf32>): %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xf32>) return %1 : tensor<*xf32> } +// ----- +func @testWhileUndefinedCond(%arg0: tensor, %arg1: tensor) -> tensor { + // expected-error @+1 {{cond refers to an undefined function : undefined_func}} + %0 = "tf.While"(%arg0, %arg1) {cond = @undefined_func, body = @body, is_stateless = false} : (tensor, tensor) -> (tensor) + return %0 : tensor +} + +func @body(%arg0: tensor, %arg1: tensor) -> tensor + +// ----- +func @testWhileUndefinedBody(%arg0: tensor, %arg1: tensor) -> tensor { + // expected-error @+1 {{body refers to an undefined function : undefined_func}} + %0 = "tf.While"(%arg0, %arg1) {cond = @cond, body = @undefined_func, is_stateless = false} : (tensor, tensor) -> (tensor) + return %0 : tensor +} + +func @cond(%arg0: tensor, %arg1: tensor) -> tensor + // ----- func @testWhileCond(tensor<*xf32>) -> () @@ -632,7 +703,8 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { // expected-error @+1 {{requires cond function to have exactly one result}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xf32>) return %1 : tensor<*xf32> @@ -649,7 +721,8 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xi32>) { // expected-error @+1 {{operand type tensor<*xf32> is incompatible with result type}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xi32>) return %1 : tensor<*xi32> @@ -666,7 +739,8 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { // expected-error @+1 {{operand type tensor<*xf32> is incompatible with cond function input type}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xf32>) return %1 : tensor<*xf32> @@ -683,7 +757,8 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { // expected-error @+1 {{requires the number of operands to be equal to the number of body function inputs. Found 1 and 2, respectively}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xf32>) return %1 : tensor<*xf32> @@ -700,7 +775,8 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { // expected-error @+1 {{body function result type tensor<*xi32> is incompatible with result type}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xf32>) return %1 : tensor<*xf32> @@ -717,7 +793,8 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { // expected-error @+1 {{cond function input type tensor<3xf32> is incompatible with body function input type}} %1 = "tf.While"(%arg0) { cond = @testWhileCond, - body = @testWhileBody + body = @testWhileBody, + is_stateless = false } : (tensor<*xf32>) -> (tensor<*xf32>) return %1 : tensor<*xf32> @@ -747,7 +824,7 @@ func @testShapeWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf3 func @testShapeWrongResultDim(tensor<1x32x32x16xf32>) -> tensor<*xi32> { ^bb0(%arg0: tensor<1x32x32x16xf32>): - // expected-error @+1 {{requires 1D result type}} + // expected-error @+1 {{requires 1D type for result}} %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<1x32x32x16xf32>) -> tensor<*xi32> return %0 : tensor<*xi32> } @@ -763,15 +840,77 @@ func @testShapeMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { // ----- -func @testShapeWrongResultDim(tensor<*xf32>) -> tensor<2xi32> { +func @testShapeWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { ^bb0(%arg0: tensor<*xf32>): - // expected-error @+1 {{requires dynamic shape result for unranked input}} + // expected-error @+1 {{requires dynamic shape result for unranked operand}} %0 = "tf.Shape"(%arg0) {T = "tfdtype$DT_FLOAT", output = "tfdtype$DT_INT32"} : (tensor<*xf32>) -> tensor<2xi32> return %0 : tensor<2xi32> } // ----- +// CHECK-LABEL: func @testValidShapeN +func @testValidShapeN(%arg0 : tensor<1x32x32x16xf32>, %arg1 : tensor<*xf32>) -> (tensor<4xi32>, tensor) { + // CHECK-NEXT: "tf.ShapeN" + %0:2 = "tf.ShapeN"(%arg0, %arg1) {N = 2 : i64} : (tensor<1x32x32x16xf32>, tensor<*xf32>) -> (tensor<4xi32>, tensor) + return %0#0, %0#1 : tensor<4xi32>, tensor +} + +// ----- + +func @testShapeNWrongResultElemType(%arg0: tensor<1x32x32x16xf32>) -> tensor<4xf32> { + // expected-error @+1 {{result #1 must be tensor of 32/64-bit integer values}} + %0:2 = "tf.ShapeN"(%arg0, %arg0) {N = 2 : i64} : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<4xf32>) + return %0#1 : tensor<4xf32> +} + +// ----- + +func @testShapeNWrongResultDim(tensor<1x32x32x16xf32>) -> tensor<*xi32> { +^bb0(%arg0: tensor<1x32x32x16xf32>): + // expected-error @+1 {{requires 1D type for result #1}} + %0:2 = "tf.ShapeN"(%arg0, %arg0) {N = 2 : i64} : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<*xi32>) + return %0#1 : tensor<*xi32> +} + +// ----- + +func @testShapeNMismatchDim(tensor<1x32x32x16xf32>) -> tensor<2xi32> { +^bb0(%arg0: tensor<1x32x32x16xf32>): + // expected-error @+1 {{requires dimension size of result #1 to match rank of operand #1}} + %0:2 = "tf.ShapeN"(%arg0, %arg0) {N = 2 : i64} : (tensor<1x32x32x16xf32>, tensor<1x32x32x16xf32>) -> (tensor<4xi32>, tensor<2xi32>) + return %0#1 : tensor<2xi32> +} + +// ----- + +func @testShapeNWrongResultDimDynamic(tensor<*xf32>) -> tensor<2xi32> { +^bb0(%arg0: tensor<*xf32>): + // expected-error @+1 {{requires dynamic shape result #1 for unranked operand #1}} + %0:2 = "tf.ShapeN"(%arg0, %arg0) {N = 2 : i64} : (tensor<*xf32>, tensor<*xf32>) -> (tensor, tensor<2xi32>) + return %0#1 : tensor<2xi32> +} + +// ----- + +func @testShapeNWrongNumOperands(tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + // expected-error @+1 {{requires 3 operand(s), got 2 operand(s)}} + %0:3 = "tf.ShapeN"(%arg0, %arg0) {N = 3 : i64} : (tensor<*xf32>, tensor<*xf32>) -> (tensor, tensor, tensor) + return +} + +// ----- + +func @testShapeNWrongNumResults(tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + // expected-error @+1 {{requires 3 result(s), got 2 result(s)}} + %0:2 = "tf.ShapeN"(%arg0, %arg0, %arg0) {N = 3 : i64} : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> (tensor, tensor) + return +} + +// ----- + // Test invalid tf.Const func @testConst() -> tensor { // expected-error @+1 {{attribute 'value' failed to satisfy constraint: constant vector/tensor}} @@ -837,3 +976,4 @@ func @testConcatV2(%arg: tensor<8x16xf32>, %axis: tensor<1xi32>) -> tensor, tensor<1xi32>) -> tensor return %0 : tensor } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir index 510aaccb26a..2890656c013 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops.mlir @@ -68,6 +68,30 @@ func @simpleIsland_with_attributes(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } +// CHECK-LABEL: func @simpleIsland_with_multiple_control_inputs(%arg0: tensor<*xf32>) +func @simpleIsland_with_multiple_control_inputs(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = tf_executor.graph { + %1 = tf_executor.island { + tf_executor.yield + } + %2 = tf_executor.island { + tf_executor.yield + } + %3:2 = tf_executor.island(%1, %2) { + tf_executor.yield %arg0 : tensor<*xf32> + } + tf_executor.fetch %3#0 : tensor<*xf32> + } +// CHECK: %[[ISLAND0:[0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield +// CHECK: %[[ISLAND1:[0-9]*]] = tf_executor.island { +// CHECK-NEXT: tf_executor.yield +// CHECK: %[[ISLAND2:[0-9]*]]:2 = tf_executor.island(%[[ISLAND0]], %[[ISLAND1]]) { +// CHECK: tf_executor.fetch %[[ISLAND2]]#0 : tensor<*xf32> + + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @fetchWithControlDep(%arg0: tensor<*xf32>) func @fetchWithControlDep(%arg0: tensor<*xf32>) -> tensor<*xf32> { %result = tf_executor.graph { @@ -153,8 +177,8 @@ func @switch_with_attributes(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor< return %result : tensor<*xf32> } -// CHECK-LABEL: func @switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { -func @switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-LABEL: func @switchN( +func @switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { // CHECK: %1:6 = tf_executor.SwitchN %arg1, %arg0 of 5 : tensor<*xf32> @@ -210,6 +234,43 @@ func @switch_merge_with_attributes(%arg0: tensor<*xf32>, %arg1: tensor) -> t return %result : tensor<*xf32> } +// Verify that long form printing is used when operand types do not match the +// result type and then it can be parsed again correctly. +// CHECK-LABEL: func @merge_different_operand_types +func @merge_different_operand_types(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %result = tf_executor.graph { + +// CHECK: tf_executor.Merge{{.*}}(tensor<*xf32>, tensor<4xf32>) -> (tensor<4xf32>, tensor, !tf_executor.control) + %value, %idx, %ctlMerge = tf_executor.Merge %arg0, %arg1 : (tensor<*xf32>, tensor<4xf32>) -> (tensor<4xf32>, tensor, !tf_executor.control) + tf_executor.fetch %value : tensor<4xf32> + } + return %result : tensor<4xf32> +} + +// Verify that long form printing is used when there is only one data operand +// and then it can be parsed again correctly. +// CHECK-LABEL: func @merge_one_data_operand +func @merge_one_data_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %result = tf_executor.graph { + +// CHECK: tf_executor.Merge{{.*}}(tensor<*xf32>) -> (tensor<*xf32>, tensor, !tf_executor.control) + %value, %idx, %ctlMerge = tf_executor.Merge %arg0 : (tensor<*xf32>) -> (tensor<*xf32>, tensor, !tf_executor.control) + tf_executor.fetch %value : tensor<*xf32> + } + return %result : tensor<*xf32> +} + +// CHECK-LABEL: func @merge_with_variant_type +func @merge_with_variant_type(%arg0: tensor, %arg1: tensor>>) -> tensor>> { + %result = tf_executor.graph { + +// CHECK: tf_executor.Merge{{.*}}(tensor, tensor>>) -> (tensor>>, tensor, !tf_executor.control) + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor, tensor>>) -> (tensor>>, tensor, !tf_executor.control) + tf_executor.fetch %value : tensor>> + } + return %result : tensor>> +} + // CHECK-LABEL: func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { func @enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %result = tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 366cd825f65..ee3d2b91732 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -27,7 +27,7 @@ func @empty_graph() { // Check that an empty graph is invalid (it needs a region). func @empty_graph() { "tf_executor.graph" () ({ -// expected-error@-1 {{'tf_executor.graph' op expects a non-empty body}} +// expected-error@-1 {{'tf_executor.graph' op expects a non-empty block}} ^entry: }) : () -> () return @@ -47,6 +47,17 @@ func @graph_with_invalid_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { // ----- +// Check that tf_executor.graph can't be nested directly in a tf_executor.graph. +func @nested_graph() { + tf_executor.graph { + tf_executor.graph {} +// expected-error@-1 {{'tf_executor.graph' op unallowed directly inside another tf_executor.graph}} + } + return +} + +// ----- + // Check that a tf_executor.fetch is terminating a tf_executor.graph (custom parser) func @graph_with_invalid_terminator(%arg0: tensor<*xf32>) -> tensor<*xf32> { tf_executor.graph { @@ -58,11 +69,23 @@ func @graph_with_invalid_terminator(%arg0: tensor<*xf32>) -> tensor<*xf32> { // ----- +// Check that a tf_executor.fetch parent is a graph. +func @parent_is_graph() { + "some.op"() ({ + tf_executor.fetch +// expected-error@-1 {{'tf_executor.fetch' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + // Check that a tf_executor.fetch is terminating a tf_executor.graph (verifier) func @graph_with_invalid_terminator(%arg0: tensor<*xf32>) -> tensor<*xf32> { +// expected-error@+2 {{'tf_executor.graph' op expects regions to end with 'tf_executor.fetch', found 'tf_executor.yield'}} +// expected-note@+1 {{in custom textual format, the absence of terminator implies 'tf_executor.fetch'}} "tf_executor.graph" () ({ tf_executor.yield -// expected-error@-1 {{'tf_executor.yield' op invalid tf_executor.graph terminator, fetch expected}} }) : () -> () return %arg0 : tensor<*xf32> } @@ -149,6 +172,17 @@ func @invalid_fetch(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) -> tensor< // ----- +// Check that a tf_executor.island parent is a graph. +func @parent_is_graph() { + "some.op"() ({ + %ctl = tf_executor.island {} +// expected-error@-1 {{'tf_executor.island' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + // Check that an island can't have other operands than controls. func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { tf_executor.graph { @@ -189,7 +223,7 @@ func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { tf_executor.graph { "tf_executor.island"() ({ -// expected-error@-1 {{'tf_executor.island' op expects a non-empty body}} +// expected-error@-1 {{'tf_executor.island' op expects a non-empty block}} ^entry: }) : () -> (!tf_executor.control) } @@ -202,8 +236,9 @@ func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { tf_executor.graph { "tf_executor.island"() ({ +// expected-error@-1 {{'tf_executor.island' op expects regions to end with 'tf_executor.yield', found 'std.return'}} +// expected-note@-2 {{in custom textual format, the absence of terminator implies 'tf_executor.yield'}} return -// expected-error@-1 {{'std.return' op invalid tf_executor.island terminator, yield expected}} }) : () -> (!tf_executor.control) } return @@ -211,6 +246,17 @@ func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { // ----- +// Check that a tf_executor.yield parent is a tf_executor.island. +func @parent_is_island() { + "some.op"() ({ + tf_executor.yield +// expected-error@-1 {{'tf_executor.yield' op expects parent op 'tf_executor.island'}} + }) : () -> () + return +} + +// ----- + // Check that an island yield matches the island results. func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { tf_executor.graph { @@ -276,6 +322,17 @@ func @invalid_yield(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { // ----- +// Check that a tf_executor.Switch parent is a graph. +func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor) { + "some.op"() ({ + %true, %false, %ctlSwitch = tf_executor.Switch %arg0, %arg1 : tensor<*xf32> +// expected-error@-1 {{'tf_executor.Switch' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + // Check that a switch always takes two arguments. func @invalid_switch(%arg0: tensor<*xf32>) { tf_executor.graph { @@ -335,11 +392,22 @@ func @invalid_switch(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { // ----- +// Check that a tf_executor.SwitchN parent is a graph. +func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: tensor) { + "some.op"() ({ + %1:6 = tf_executor.SwitchN %arg0, %arg1 of 5 : tensor<*xf32> +// expected-error@-1 {{'tf_executor.SwitchN' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + // Check that switchN result numbers matches the num_out attribute. -func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 5} : (tensor<*xf32>, i32) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) + %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 5} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor<*xf32>, !tf_executor.control) // expected-error@-1 {{'tf_executor.SwitchN' op expect `num_outs` (5) results but got 2}} tf_executor.fetch %1#0 : tensor<*xf32> @@ -350,10 +418,10 @@ func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { // ----- // Check that switchN result type matches the input type. -func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { - %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, i32) -> (tensor<*xf32>, i32, !tf_executor.control) + %1:3 = "tf_executor.SwitchN"(%arg1, %arg0) {num_outs = 2} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, i32, !tf_executor.control) // expected-error@-1 {{'tf_executor.SwitchN' op type mismatch between data operand and result: 'tensor<*xf32>' vs 'i32'}} tf_executor.fetch %1#0 : tensor<*xf32> @@ -364,7 +432,7 @@ func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { // ----- // Check that switchN custom type has a single entry. -func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @invalid_switchN(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %fetches = tf_executor.graph { %1:3 = tf_executor.SwitchN %arg1, %arg0 of 2 : tensor<*xf32>, i32 @@ -377,6 +445,17 @@ func @invalid_switchN(%arg0: i32, %arg1: tensor<*xf32>) -> tensor<*xf32> { // ----- +// Check that a tf_executor.Merge parent is a graph. +func @parent_is_graph(%arg0: tensor<*xf32>) { + "some.op"() ({ + %value, %idx, %ctlMerge = tf_executor.Merge %arg0, %arg0 : tensor<*xf32> +// expected-error@-1 {{'tf_executor.Merge' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + // Check that merge has at least one operand. func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %result = tf_executor.graph { @@ -431,6 +510,18 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor<4xf32>) -> tensor<8xf32> // ----- +// Check that merge data inputs of variant type are broadcastable to the output +func @invalid_merge(%arg0: tensor<*x!tf.variant>, %arg1: tensor<4x!tf.variant>) -> tensor<8x!tf.variant> { + %result = tf_executor.graph { + %value, %idx, %ctlMerge = "tf_executor.Merge"(%arg0, %arg1) : (tensor<*x!tf.variant>, tensor<4x!tf.variant>) -> (tensor<8x!tf.variant>, tensor, !tf_executor.control) +// expected-error@-1 {{'tf_executor.Merge' op expects all operands to be broadcastable but got 'tensor<8x!tf.variant>' vs 'tensor<4x!tf.variant>'}} + tf_executor.fetch %value : tensor<8x!tf.variant> + } + return %result : tensor<8x!tf.variant> +} + +// ----- + // Check that merge data inputs can't appear after control input. func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { %result = tf_executor.graph { @@ -446,6 +537,17 @@ func @invalid_merge(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { // ----- +// Check that a tf_executor.Enter parent is a graph. +func @parent_is_graph(%arg0: tensor<*xf32>) { + "some.op"() ({ + %res:2 = tf_executor.Enter %arg0 frame "some/fra\"me" : tensor<*xf32> +// expected-error@-1 {{'tf_executor.Enter' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + // Check that Enter return value is the same type as the input. func @invalid_enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { %result = tf_executor.graph { @@ -458,6 +560,28 @@ func @invalid_enter(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { // ----- +// Check that a tf_executor.NextIteration.Sink parent is a graph. +func @parent_is_graph(%arg0: tensor<*xf32>, %arg1: !tf_executor.token) { + "some.op"() ({ + tf_executor.NextIteration.Sink[%arg1] %arg0 : tensor<*xf32> +// expected-error@-1 {{'tf_executor.NextIteration.Sink' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + +// Check that a tf_executor.NextIteration.Source parent is a graph. +func @parent_is_graph() { + "some.op"() ({ + %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> +// expected-error@-1 {{'tf_executor.NextIteration.Source' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + func @invalid_nextiteration(%arg0: tensor<*xf32>, %arg1: !tf_executor.token) -> tensor<*xf32> { %0 = tf_executor.graph { %1:3 = tf_executor.NextIteration.Source : tensor<*xf32> @@ -521,6 +645,17 @@ func @invalid_nextiteration(%arg0: tensor<*xf32>, %arg1: i1) -> tensor<*xf32> { // ----- +// Check that a tf_executor.Exit parent is a graph. +func @parent_is_graph(%arg0: tensor<*xf32>) { + "some.op"() ({ + %1:2 = tf_executor.Exit %arg0 : tensor<*xf32> +// expected-error@-1 {{'tf_executor.Exit' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + func @exit(%arg0: tensor<*xi32>) -> tensor<*xf32> { %0 = tf_executor.graph { %1:2 = "tf_executor.Exit"(%arg0) : (tensor<*xi32>) -> (tensor<*xf32>, !tf_executor.control) @@ -529,3 +664,25 @@ func @exit(%arg0: tensor<*xi32>) -> tensor<*xf32> { } return %0 : tensor<*xf32> } + +// ----- + +// Check that a tf_executor.ControlTrigger parent is a graph. +func @parent_is_graph(%arg0: !tf_executor.control, %arg1: !tf_executor.control) { + "some.op"() ({ + %0 = tf_executor.ControlTrigger %arg0, %arg1 +// expected-error@-1 {{'tf_executor.ControlTrigger' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} + +// ----- + +// Check that a tf_executor.LoopCond parent is a graph. +func @parent_is_graph(%arg0: tensor, %arg1: !tf_executor.control) { + "some.op"() ({ + %1:2 = tf_executor.LoopCond %arg0, %arg1 : tensor +// expected-error@-1 {{'tf_executor.LoopCond' op expects parent op 'tf_executor.graph'}} + }) : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir new file mode 100644 index 00000000000..dc2f60b6441 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -0,0 +1,406 @@ +// RUN: tf-opt %s -split-input-file -tf-tpu-rewrite | FileCheck %s + +// Tests simple case of `tf_device.launch_func` on TPU with single input and +// single output. + +module { + // CHECK-LABEL: func @single_tpu_launch_func + func @single_tpu_launch_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf.C"(%1) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]]) + + return %2 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests that launch_func without _tpu_replicate attribute is ignored. + +module { + // CHECK-LABEL: func @single_gpu_launch_func + func @single_gpu_launch_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + %1 = "tf_device.launch_func"(%0) {device = "gpu0", func = @gpu0_func} : (tensor) -> tensor + // CHECK: tf_device.launch_func + // CHECK-SAME: {device = "gpu0", func = @gpu0_func} + + return %1 : tensor + } + + func @gpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests of `tf_device.launch_func` on TPU with nested function calls. + +module { + // CHECK-LABEL: func @with_nested_func + func @with_nested_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-SAME: func @nested_func + // CHECK-SAME: tf.D + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf.C"(%1) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]]) + + return %2 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + %1 = call @nested_func(%0) : (tensor) -> tensor + return %1 : tensor + } + + func @nested_func(%arg0: tensor) -> tensor { + %0 = "tf.D"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests of `tf_device.launch_func` on TPU with referenced function that's not +// via a standard call op. + +module { + // CHECK-LABEL: func @with_referenced_func + func @with_referenced_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-SAME: func @referenced_func + // CHECK-SAME: tf.D + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf.C"(%1) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]]) + + return %2 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) {body = @referenced_func} : (tensor) -> tensor + return %0 : tensor + } + + func @referenced_func(%arg0: tensor) -> tensor { + %0 = "tf.D"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests rewriting `tf_device.launch_func` on TPU with a chain of referenced +// functions. + +module { + // CHECK-LABEL: func @with_referenced_func_chain + func @with_referenced_func_chain(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-SAME: @referenced_func1 + // CHECK-SAME: tf.D + // CHECK-SAME: @referenced_func2 + // CHECK-SAME: tf.E + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf.C"(%1) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]]) + + return %2 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) {body = @referenced_func1} : (tensor) -> tensor + return %0 : tensor + } + + func @referenced_func1(%arg0: tensor) -> tensor { + %0 = "tf.D"(%arg0) : (tensor) -> tensor + %1 = call @referenced_func2(%0) : (tensor) -> tensor + return %1 : tensor + } + + func @referenced_func2(%arg0: tensor) -> tensor { + %0 = "tf.E"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests rewriting `tf_device.launch_func` on TPU with multiple calls to same +// function. + +module { + // CHECK-LABEL: func @with_multiple_call_same_referenced_func + func @with_multiple_call_same_referenced_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-COUNT-2: call @referenced_func + // CHECK-COUNT-1: func @referenced_func + // CHECK-SAME: tf.D + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf.C"(%1) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]]) + + return %2 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) {body = @referenced_func1} : (tensor) -> tensor + %1 = call @referenced_func(%0) : (tensor) -> tensor + %2 = call @referenced_func(%1) : (tensor) -> tensor + return %2 : tensor + } + + func @referenced_func(%arg0: tensor) -> tensor { + %1 = "tf.D"(%arg0) : (tensor) -> tensor + return %1 : tensor + } +} + +// ----- + +// Tests multiple `tf_device.launch_func` on TPU with different computation. + +module { + // CHECK-LABEL: func @multiple_launch_different_func + func @multiple_launch_different_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func0} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-NOT: func = @tpu0_func0 + // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func1} : (tensor) -> tensor + // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) + // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster1" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.D + // CHECK-NOT: func = @tpu0_func1 + // CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %3 = "tf.C"(%2) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]]) + + return %3 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func0(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } + + func @tpu0_func1(%arg0: tensor) -> tensor { + %0 = "tf.D"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests multiple `tf_device.launch_func` on TPU with same computation. + +module { + // CHECK-LABEL: func @multiple_launch_same_func + func @multiple_launch_same_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE0_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE0_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE0_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf_device.launch_func"(%1) {_tpu_replicate = "cluster1", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[EXECUTE0_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[EXECUTE0_OUTPUT]]) + // CHECK: %[[COMPILE1_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[EXECUTE0_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster1" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-NOT: func = @tpu0_func + // CHECK: %[[EXECUTE1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[EXECUTE0_OUTPUT]], %[[COMPILE1_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %3 = "tf.C"(%2) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE1_OUTPUT]]) + + return %3 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + return %0 : tensor + } +} + +// ----- + +// Tests Functions referenced by TPU function via SymbolRefAttr nested in +// ArrayAttr and DictionaryAttr. + +module { + // CHECK-LABEL: func @single_tpu_launch_func + func @single_tpu_launch_func(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + + %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster0", device = "tpu0", func = @tpu0_func} : (tensor) -> tensor + // CHECK: %[[A_SHAPE_OUTPUT:[0-9]*]] = "tf.Shape"(%[[A_OUTPUT]]) + // CHECK: %[[COMPILE_OUTPUT:[0-9]*]]:2 = "tf.MLIRCompileToTPU"(%[[A_SHAPE_OUTPUT]]) + // CHECK-SAME: _tpu_replicate = "cluster0" + // CHECK-SAME: module + // CHECK-SAME: func @main + // CHECK-SAME: tf.B + // CHECK-SAME: func @referenced_func2 + // CHECK-SAME: tf.H + // CHECK-SAME: func @referenced_func3 + // CHECK-SAME: tf.I + // CHECK-SAME: func @referenced_func0 + // CHECK-SAME: tf.F + // CHECK-SAME: func @referenced_func1 + // CHECK-SAME: tf.G + // CHECK: %[[EXECUTE_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[A_OUTPUT]], %[[COMPILE_OUTPUT]]#1) + // CHECK-SAME: Targs = [tensor] + // CHECK-SAME: Tresults = [tensor] + + %2 = "tf.C"(%1) : (tensor) -> tensor + // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[EXECUTE_OUTPUT]]) + + return %2 : tensor + // CHECK: return %[[C_OUTPUT]] + } + + func @tpu0_func(%arg0: tensor) -> tensor { + %0 = "tf.B"(%arg0) : (tensor) -> tensor + %1 = "tf.D"(%0) {array_attr_funcs = [@referenced_func0, @referenced_func1]} : (tensor) -> tensor + %2 = "tf.E"(%1) {dictionary_attr_funcs = {fn1 = @referenced_func2, fn2 = @referenced_func3}} : (tensor) -> tensor + return %0 : tensor + } + + func @referenced_func0(%arg0: tensor) -> tensor { + %1 = "tf.F"(%arg0) : (tensor) -> tensor + return %1 : tensor + } + + func @referenced_func1(%arg0: tensor) -> tensor { + %1 = "tf.G"(%arg0) : (tensor) -> tensor + return %1 : tensor + } + + func @referenced_func2(%arg0: tensor) -> tensor { + %1 = "tf.H"(%arg0) : (tensor) -> tensor + return %1 : tensor + } + + func @referenced_func3(%arg0: tensor) -> tensor { + %1 = "tf.I"(%arg0) : (tensor) -> tensor + return %1 : tensor + } +} + + diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 473f69f87e7..0653c1d109e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -20,9 +20,7 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" /// TODO(b/130756570): Support OpBase constraints in PatternRewrites. def SingleResultAndOperandHaveSameElementType : Constraint< - CPred<"$0->getResult(0)->getType().cast()" - ".getElementType() == " - "$1->getType().cast().getElementType()">>; + CPred<"getElementTypeOrSelf($0) == getElementTypeOrSelf($1)">>; //===----------------------------------------------------------------------===// // Add op patterns. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc new file mode 100644 index 00000000000..2511ff2fdb3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -0,0 +1,232 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation forms clusters from instructions in same island and +// assigned to save devices. Clusters are represented as regions. +// Note that side-effecting ops are not correctly handled yet. + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/platform/logging.h" + +namespace mlir { +namespace TFDevice { + +namespace { + +struct ClusterFormationPass : public FunctionPass { + void runOnFunction() override; +}; + +// Cluster structure captures all the operations that are assigned to same +// device and can form a legal strict cluster. +// Ops must follow same ordering in their parent block. We rely on this +// assumption to perform analysis. +struct Cluster { + llvm::SmallVector ops; + StringRef device; +}; + +StringRef GetDevice(Operation* op) { + auto device_attr = op->getAttrOfType("device"); + return device_attr ? device_attr.getValue() : ""; +} + +// An op can be merged into cluster if all of its operands are one of the +// following: +// 1) A block argument +// 2) A value produced by other islands +// 1) Defined before the cluster +// 2) Defined by an operation in the cluster +// TODO(ycao): This is not optimal as it doesn't consider the situation of +// defining_op's operands all meet the requirements above. In that case, the +// defining_op can be moved and to_merge op would be legal to absorb. +// TODO(ycao): Take op side-effects into consideration since they can not be +// re-ordered but forming clusters of non-continuous ops is effectively +// re-ordering them.. +bool CanMergeIntoCluster(const Cluster& c, Operation* to_merge) { + return llvm::all_of(to_merge->getOperands(), [&](Value* operand) { + // Block arguments. + if (isa(operand)) return true; + + Operation* defining_op = operand->getDefiningOp(); + + // Operand produced by other islands. + if (defining_op->getBlock() != c.ops.front()->getBlock()) return true; + + // Defining op is before the cluster. + if (defining_op->isBeforeInBlock(c.ops.front())) return true; + + // Defining op is between first and last operation in cluster. Note that + // cluster may contain operations that are non-continuous in their original + // block, thus we also need to check defining_op is also assigned to + // cluster's device to be sure. This is a faster check than linearly + // searching through all ops in cluster. + if (defining_op->isBeforeInBlock(c.ops.back()->getNextNode()) && + GetDevice(defining_op) == c.device) + return true; + + // Other cases, operand is generated after or outside the cluster, this + // means it is illegal to merge operation. + return false; + }); +} + +void ReplaceLiveOutExternalUses(llvm::ArrayRef live_outs, + tf_device::LaunchOp launch_op) { + Region* launch_op_region = &launch_op.body(); + for (const auto& p : llvm::zip(live_outs, launch_op.getResults())) { + Value* from = std::get<0>(p); + for (auto& use : from->getUses()) { + if (launch_op_region->isAncestor(use.getOwner()->getParentRegion())) + continue; + use.set(std::get<1>(p)); + } + } +} + +// Get all escaped live-out values of a region. +void GetLiveOuts(Region* region, llvm::SmallVectorImpl* live_outs) { + live_outs->clear(); + + for (Operation& op : region->front()) { + for (Value* v : op.getResults()) { + // A value is live-out if any of its users are not inside value producer's + // region. + bool is_live_out = llvm::any_of(v->getUsers(), [&](Operation* user) { + return !region->isAncestor(user->getParentRegion()); + }); + + if (is_live_out) live_outs->emplace_back(v); + } + } +} + +// Build a `tf_device.launch` op with a region that contains all the operations +// in given cluster. Then all ops in cluster are replaced by `tf_device.launch`. +void BuildLaunchForCluster(const Cluster& c, OpBuilder* builder) { + // Set insertion point to right after all operations in cluster. + builder->setInsertionPoint(c.ops.back()->getNextNode()); + + // Create a stand-alone region to hold all instructions in the cluster. + Region region; + region.push_back(new Block); + + // Move all operations in cluster to newly created region, stripping their + // "device" attribute since launch op already carries device information. + Block* block = ®ion.front(); + for (Operation* op : c.ops) { + op->moveBefore(block, block->end()); + op->removeAttr(builder->getIdentifier("device")); + } + + // Get all escaped live-out values of region, they are used later to determine + // return values and types of launch op. + llvm::SmallVector live_outs; + GetLiveOuts(®ion, &live_outs); + + // Build a `tf_device.return` op at end of region, with all live-out values + // as operand. + OpBuilder return_builder(builder->getContext()); + return_builder.setInsertionPointToEnd(block); + return_builder.create(return_builder.getUnknownLoc(), + live_outs); + + llvm::SmallVector live_out_types; + live_out_types.reserve(live_outs.size()); + for (Value* v : live_outs) { + live_out_types.emplace_back(v->getType()); + } + + tf_device::LaunchOp launch_op = builder->create( + builder->getUnknownLoc(), builder->getStringAttr(c.device), + live_out_types); + + // Attach the region to launch_op. + launch_op.body().takeBody(region); + + // Replace any external uses of live-out values with return values of launch + // op. So live-out values no longer escape the region. + ReplaceLiveOutExternalUses(live_outs, launch_op); +} + +void ClusterFormationPass::runOnFunction() { + OpBuilder builder(getFunction().getContext()); + getFunction().walk([&](tf_executor::IslandOp island) { + // Iteratively find clusters of different devices within an island. + // Whenever we see an operation that is assigned to an accelerator device + // (ie. device != ""), we try to merge it into the last cluster of same + // device. If that is infeasible (say because of violating def-before-use), + // create a new cluster with that operation and move on. + llvm::MapVector nearest_clusters; + for (Operation& op : llvm::make_early_inc_range(island.GetBody())) { + auto device = GetDevice(&op); + if (device == "") continue; + + // If no cluster of same device has been formed yet, create a new cluster + // with op alone. + auto it = nearest_clusters.find(device); + if (it == nearest_clusters.end()) { + nearest_clusters[device] = Cluster{{&op}, device}; + continue; + } + + // Check if it is legal to merge op into nearest cluster of same device. + // If positive, update cluster and move on to next operation. + Cluster& nearest_cluster = it->second; + if (CanMergeIntoCluster(nearest_cluster, &op)) { + nearest_cluster.ops.emplace_back(&op); + continue; + } + + // If nearest cluster of same device can not absorb `op`, then that + // cluster needs to be finalized by building a `tf_device.launch` op with + // a region that contains all operations in clusters. + BuildLaunchForCluster(nearest_cluster, &builder); + + // Create a new cluster to hold op alone and update nearest_clusters. + nearest_clusters[device] = Cluster{{&op}, device}; + } + + // At the end, there might be left-over found clusters that need to be + // built. + for (auto& device_cluster : nearest_clusters) + BuildLaunchForCluster(device_cluster.second, &builder); + }); +} + +} // namespace + +std::unique_ptr CreateClusterFormationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-device-cluster-formation", + "Form clusters from instructions assigned to same device"); + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc new file mode 100644 index 00000000000..414b4a0d161 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -0,0 +1,140 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This pass outlines regions of `tf_device.launch` into functions and replaces +// `tf_device.launch` with equivalent `tf_device.launch_func` operations. + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TFDevice { + +namespace { + +struct ClusterOutliningPass : public ModulePass { + void runOnModule() override; +}; + +void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op, + OpBuilder* builder) { + llvm::SmallVector operands(launch_return_op.getOperands()); + builder->create(launch_return_op.getLoc(), operands); + launch_return_op.erase(); +} + +// Builds a function that outlines region attached to launch_op and inserts +// built function into given module. +FuncOp BuildFunction(StringRef device, llvm::ArrayRef live_ins, + tf_device::LaunchOp launch_op, + ModuleManager* module_manager, OpBuilder* builder) { + llvm::SmallVector operand_types; + operand_types.reserve(live_ins.size()); + for (Value* v : live_ins) operand_types.emplace_back(v->getType()); + + llvm::SmallVector result_types(launch_op.getResultTypes()); + + auto func_type = + FunctionType::get(operand_types, result_types, builder->getContext()); + + std::string func_name_prefix = Twine(device, "_func").str(); + FuncOp outlined_func = + FuncOp::create(launch_op.getLoc(), func_name_prefix, func_type); + + // Create function body. + Block* outlined_func_block = outlined_func.addEntryBlock(); + + // Replace uses of live-in values within launch_op region with function + // arguments. + Region& launch_op_region = launch_op.body(); + for (const auto& p : + llvm::zip(live_ins, outlined_func_block->getArguments())) { + replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p), + launch_op_region); + } + + // Move all instructions in launch_op into outlined_function's only block. + auto& launch_op_body = launch_op_region.front().getOperations(); + outlined_func_block->getOperations().splice( + outlined_func_block->end(), launch_op_body, launch_op_body.begin(), + launch_op_body.end()); + + // Replace `tf_device.launch_return` terminator with `std.return` in function + // body. + auto launch_return_op = + cast(outlined_func_block->getTerminator()); + builder->setInsertionPoint(launch_return_op); + ReplaceLaunchReturnWithReturn(launch_return_op, builder); + + module_manager->insert(outlined_func); + return outlined_func; +} + +// Outlines body of `tf_device.launch` into a function and create a +// `tf_device.launch_func` to invoke that function. `tf_device.launch` is +// removed afterwards.` +void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager, + OpBuilder* builder) { + llvm::SetVector live_ins; + getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins); + + StringRef device = launch_op.getAttrOfType("device").getValue(); + + FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(), + launch_op, module_manager, builder); + builder->setInsertionPoint(launch_op); + tf_device::LaunchFuncOp launch_func_op = + builder->create( + launch_op.getLoc(), outlined_func.getType().getResults(), + builder->getStringAttr(device), + builder->getSymbolRefAttr(outlined_func.getName()), + live_ins.getArrayRef()); + + launch_op.replaceAllUsesWith(launch_func_op); + launch_op.erase(); +} + +void ClusterOutliningPass::runOnModule() { + ModuleOp m = getModule(); + ModuleManager module_manager(m); + OpBuilder builder(m.getContext()); + m.walk([&](tf_device::LaunchOp launch) { + OutlineLaunch(launch, &module_manager, &builder); + }); +} + +} // namespace + +std::unique_ptr CreateClusterOutliningPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-device-cluster-outlining", + "Outline regions of tf_device.launch operations."); + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc index 6ce5233cb1e..3e6e2a6058e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.cc @@ -64,7 +64,9 @@ struct DecodeConstant : public FunctionPass { } // namespace -FunctionPassBase *CreateDecodeConstantPass() { return new DecodeConstant(); } +std::unique_ptr CreateDecodeConstantPass() { + return std::make_unique(); +} static PassRegistration pass( "tf-decode-constant", "Decode opaque constant into human-readable ones"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h index a0cd77b393f..2e66de0c4d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h @@ -23,7 +23,7 @@ namespace TF { // Creates a pass to decode and reset opaque values in constant ops into // readable values. // Note that this pass assumes RaiseTFControlFlow pass has already been run. -FunctionPassBase *CreateDecodeConstantPass(); +std::unique_ptr CreateDecodeConstantPass(); } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc new file mode 100644 index 00000000000..496e99e4ff7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -0,0 +1,329 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass takes TFExecutor dialect IslandOps and merges them. +// Note, this currently does not handle TensorFlow V1 style control flow/frames +// or side effecting ops yet. + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/platform/logging.h" + +namespace mlir { +namespace tf_executor { + +namespace { + +// IslandType is an enum representing if an island is the island (parent) +// merging another island or is the island (child) being being merged. +enum IslandType { kParentIsland, kChildIsland }; + +// Output is a helper struct holding a result index and island type (parent or +// child). +struct Output { + Output(IslandType island_type, int result_index) + : island_type(island_type), result_index(result_index) {} + + IslandType island_type; + int result_index; +}; + +struct ExecutorIslandCoarsening + : public FunctionPass { + void runOnFunction() override; + + private: + void MergeIslands(IslandOp parent, IslandOp child, + IslandType insert_position); + bool MergeIslandWithOperand(IslandOp child); + bool MergeIslandWithResult(IslandOp parent); +}; + +// Finds the operation leading to an island that the island can be merged with. +// This looks for the operation, either control input or data input to an op, +// that is closest to the island in the graph. If no candidate can be found or +// the op found is not an island, an empty optional is returned. +llvm::Optional GetOperandCandidateToMergeWith(IslandOp island) { + Operation* graph_op = island.getParentOp(); + Operation* candidate = nullptr; + + // Check island control operands. + for (Value* input : island.controlInputs()) { + Operation* def = input->getDefiningOp(); + DCHECK_EQ(def->getParentOp(), graph_op); + if (!candidate || candidate->isBeforeInBlock(def)) candidate = def; + } + + // Check island data operands. + island.walk([graph_op, &candidate](Operation* op) { + for (Value* input : op->getOperands()) { + Operation* def = input->getDefiningOp(); + if (!def || def->getParentOp() != graph_op) continue; + if (!candidate || candidate->isBeforeInBlock(def)) candidate = def; + } + }); + + if (!candidate || !llvm::isa(candidate)) return llvm::None; + + return llvm::Optional(llvm::cast(candidate)); +} + +// Finds the operation leading from an island that the island can be merged +// with. This looks for the operation, either control output or data output to +// an op, that is closest to the island in the graph. If no candidate can be +// found or the op found is not an island, an empty optional is returned. +llvm::Optional GetResultCandidateToMergeWith(IslandOp island) { + Operation* graph_op = island.getParentOp(); + Operation* candidate = nullptr; + + // Check island control results. + for (Operation* user : island.control()->getUsers()) { + DCHECK_EQ(user->getParentOp(), graph_op); + if (!candidate || user->isBeforeInBlock(candidate)) candidate = user; + } + + // Check island data results. + Block& graph_body = llvm::cast(graph_op).GetBody(); + for (Value* result : island.outputs()) { + for (Operation* user : result->getUsers()) { + Operation* def = graph_body.findAncestorInstInBlock(*user); + DCHECK_NE(def, nullptr); + if (!candidate || def->isBeforeInBlock(candidate)) candidate = def; + } + } + + if (!candidate || !llvm::isa(candidate)) return llvm::None; + + return llvm::Optional(llvm::cast(candidate)); +} + +// Collects the operands for the new island by collecting all control inputs of +// the islands being merged. +llvm::SmallSetVector GetNewIslandOperands(IslandOp parent, + IslandOp child) { + llvm::SmallSetVector operands; + operands.insert(parent.getOperands().begin(), parent.getOperands().end()); + operands.insert(child.getOperands().begin(), child.getOperands().end()); + operands.remove(parent.control()); + return operands; +} + +// Collects the results for the new island by going through each data output of +// the islands being merged. Unused results outside of the merged island to be +// formed are pruned. If the child island inner ops consume the parent island +// control output, the child island inner ops will have that respective control +// input pruned. Results of the parent island that are consumed by the child +// island are replaced by the respective inner ops output from the parent +// island. +llvm::SmallVector GetNewIslandResultsAndForwardOutputs( + mlir::MLIRContext* context, IslandOp parent, IslandOp child, + llvm::SmallVector* result_types) { + llvm::SmallVector results; + + YieldOp yield_op = parent.GetYield(); + Block& child_body = child.GetBody(); + for (auto& ret_and_idx : llvm::enumerate(parent.outputs())) { + bool output_captured = false; + Value* yield_input = yield_op.getOperand(ret_and_idx.index()); + for (auto& use : + llvm::make_early_inc_range(ret_and_idx.value()->getUses())) { + if (child_body.findAncestorInstInBlock(*use.getOwner())) { + // Forward output from inner op. + use.set(yield_input); + } else if (!output_captured) { + results.push_back( + Output(IslandType::kParentIsland, ret_and_idx.index())); + result_types->push_back(ret_and_idx.value()->getType()); + output_captured = true; + } + } + } + + for (auto& ret_and_idx : llvm::enumerate(child.outputs())) { + if (!ret_and_idx.value()->use_empty()) { + results.push_back(Output(IslandType::kChildIsland, ret_and_idx.index())); + result_types->push_back(ret_and_idx.value()->getType()); + } + } + + // IslandOps always have a control output. + result_types->push_back(ControlType::get(context)); + + return results; +} + +// Creates the new merged island. +IslandOp CreateNewIsland(Operation* old_island, + llvm::ArrayRef result_types, + llvm::ArrayRef operands) { + OpBuilder builder(old_island); + auto new_island = builder.create( + old_island->getLoc(), result_types, operands, ArrayRef{}); + new_island.body().push_back(new Block); + return new_island; +} + +// Creates respective YieldOp for the new merged island. +YieldOp CreateNewIslandYieldOp(IslandOp new_island, + llvm::ArrayRef results, IslandOp parent, + IslandOp child) { + llvm::SmallVector yield_operands; + yield_operands.reserve(results.size()); + for (auto ret_vals : llvm::zip(results, new_island.outputs())) { + // Get consumed output (island type and result index). + const auto& output = std::get<0>(ret_vals); + IslandOp& output_island = + output.island_type == IslandType::kParentIsland ? parent : child; + Value* result = output_island.getResult(output.result_index); + // Replace original result with new island result. + result->replaceAllUsesWith(std::get<1>(ret_vals)); + // Find YieldOp in original island, grab the associated operand (inner op + // output) and add it as a operand to the YieldOp of the merged island. + yield_operands.push_back( + output_island.GetYield().getOperand(output.result_index)); + } + + // Create YieldOp for the new island. + OpBuilder builder(&new_island.GetBody(), new_island.GetBody().end()); + return builder.create(new_island.getLoc(), yield_operands); +} + +// Moves inner ops (excluding last op/YieldOp) from islands being merged into +// the new merged island. +void MoveInnerOpsToNewIsland(IslandOp parent, IslandOp child, + Operation* new_yield_op) { + Block* block = new_yield_op->getBlock(); + + auto move_inner_ops = [block, new_yield_op](IslandOp island) { + auto& island_body = island.GetBody().getOperations(); + block->getOperations().splice(new_yield_op->getIterator(), island_body, + island_body.begin(), + std::prev(island_body.end())); + }; + + move_inner_ops(parent); + move_inner_ops(child); +} + +// Merges two islands and places new merged island before parent or child. +void ExecutorIslandCoarsening::MergeIslands(IslandOp parent, IslandOp child, + IslandType insert_position) { + // Collect operands for the new merged island. + llvm::SmallSetVector operands = + GetNewIslandOperands(parent, child); + + // Collect results and result types for the new merged island. + llvm::SmallVector result_types; + llvm::SmallVector results = GetNewIslandResultsAndForwardOutputs( + &getContext(), parent, child, &result_types); + + // Create the new merged island. + IslandOp new_island = CreateNewIsland( + insert_position == IslandType::kParentIsland ? parent : child, + result_types, operands.getArrayRef()); + + // Create associated YieldOp for the new merged island. + YieldOp new_yield_op = + CreateNewIslandYieldOp(new_island, results, parent, child); + + // Move inner ops from original islands into the new island. + MoveInnerOpsToNewIsland(parent, child, new_yield_op.getOperation()); + + // Update control inputs to point to the new merged island. + child.control()->replaceAllUsesWith(new_island.control()); + parent.control()->replaceAllUsesWith(new_island.control()); + + // Remove merged islands. + child.erase(); + parent.erase(); +} + +// Merges island with the operand closest to the island in the graph. The +// operand must be another IslandOp for merging to take place. A new island is +// created and the islands being merged are removed if a merge took place. +// Returns true if the island was merged with its operand. +bool ExecutorIslandCoarsening::MergeIslandWithOperand(IslandOp child) { + // Find candidate operand to merge island with. + llvm::Optional candidate = GetOperandCandidateToMergeWith(child); + if (!candidate.hasValue()) return false; + auto& parent = candidate.getValue(); + MergeIslands(parent, child, IslandType::kParentIsland); + return true; +} + +// Merges island with the result closest to the island in the graph. The result +// must be another IslandOp for merging to take place. A new island is created +// and the islands being merged are removed if a merge took place. Returns true +// if the island was merged with its result. +bool ExecutorIslandCoarsening::MergeIslandWithResult(IslandOp parent) { + // Find candidate result to merge island with. + llvm::Optional candidate = GetResultCandidateToMergeWith(parent); + if (!candidate.hasValue()) return false; + auto& child = candidate.getValue(); + MergeIslands(parent, child, IslandType::kChildIsland); + return false; +} + +void ExecutorIslandCoarsening::runOnFunction() { + getFunction().walk([this](GraphOp graph) { + Block& graph_body = graph.GetBody(); + + bool updated = false; + do { + updated = false; + + auto reversed = llvm::reverse(graph_body); + for (Operation& operation : llvm::make_early_inc_range(reversed)) { + auto island = llvm::dyn_cast(operation); + if (!island) continue; + updated |= MergeIslandWithResult(island); + } + + for (Operation& operation : llvm::make_early_inc_range(graph_body)) { + auto island = llvm::dyn_cast(operation); + if (!island) continue; + updated |= MergeIslandWithOperand(island); + } + } while (updated); + }); +} + +} // namespace + +std::unique_ptr CreateTFExecutorIslandCoarseningPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-executor-island-coarsening", "Merges TFExecutor dialect IslandOps"); + +} // namespace tf_executor +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index af3e1e05ade..ade8cc17032 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -16,13 +16,13 @@ limitations under the License. // This transformation pass transforms functional control flow operations in the // standard TensorFlow dialect to MLIR Control Flow Graph (CFG) form. +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir -#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -150,12 +150,12 @@ static LogicalResult LowerIfOp(IfOp op) { OpBuilder builder(op_inst); // Lower the condition to a boolean value (i1). - Value* cond_i1 = LowerCondition(loc, op.getCondition(), &builder); + Value* cond_i1 = LowerCondition(loc, op.cond(), &builder); if (!cond_i1) return failure(); auto module = op_inst->getParentOfType(); - auto then_fn = module.lookupSymbol(op.getThen()); - auto else_fn = module.lookupSymbol(op.getElse()); + auto then_fn = module.lookupSymbol(op.then_branch()); + auto else_fn = module.lookupSymbol(op.else_branch()); // Split the basic block before the 'if'. The new dest will be our merge // point. @@ -211,8 +211,8 @@ static LogicalResult LowerWhileOp(WhileOp op) { OpBuilder builder(op_inst); auto module = op_inst->getParentOfType(); - auto cond_fn = module.lookupSymbol(op.getCond()); - auto body_fn = module.lookupSymbol(op.getBody()); + auto cond_fn = module.lookupSymbol(op.cond()); + auto body_fn = module.lookupSymbol(op.body()); // Split the block containing the While op into two blocks. One containing // operations before the While op and other containing the rest. Create two @@ -331,8 +331,8 @@ void FunctionalControlFlowToCFG::runOnFunction() { } // namespace -FunctionPassBase* CreateTFFunctionalControlFlowToCFG() { - return new FunctionalControlFlowToCFG(); +std::unique_ptr CreateTFFunctionalControlFlowToCFG() { + return std::make_unique(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc new file mode 100644 index 00000000000..5d3c612e5cd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "mlir/IR/Block.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace tf_executor { + +// Prunes a TF graph eliminating dead nodes. +void prune_graph(GraphOp graph) { + // A graph has a single block which forms a DAG: nodes that aren't reachable + // from the `fetch` operands can be eliminated. + + // Delete unreachable node from the graph. We traverse it in reverse order so + // that we just have to check that a node does not have any users to delete + // it. + for (Operation &op : llvm::make_early_inc_range( + llvm::drop_begin(llvm::reverse(graph.GetBody()), 1))) { + // NextIteration.Sink operation are handled specially: they are live if the + // source is live, and removed when the source is processed. + if (auto sinkOp = dyn_cast(op)) continue; + + // For NextIteration.Source, we just check that the source does not have any + // other user than the sink. + if (auto sourceOp = dyn_cast(op)) { + Operation *sink = sourceOp.GetSink().getOperation(); + if (llvm::any_of(sourceOp.getResults(), [sink](Value *result) { + return llvm::any_of(result->getUsers(), [sink](Operation *user) { + return user != sink; + }); + })) + continue; + + // No other users than the sink, erase the pair! + sink->erase(); + sourceOp.erase(); + continue; + } + + // General case. + if (op.use_empty()) op.erase(); + } +} + +namespace { + +// This transformation pass prunes a TF graph eliminating dead-nodes. +struct GraphPruning : public FunctionPass { + void runOnFunction() override { + getFunction().walk( + [](tf_executor::GraphOp graph) { prune_graph(graph); }); + } +}; + +} // namespace + +FunctionPassBase *CreateTFExecutorGraphPruningPass() { + return new GraphPruning(); +} + +static PassRegistration pass( + "tf-executor-graph-pruning", "Prune a TensorFlow Graph from dead nodes."); + +} // namespace tf_executor +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index 72775d078f9..5e0e961cc46 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -35,13 +35,15 @@ struct TFOptimizePass : public FunctionPass { OwningRewritePatternList patterns; auto func = getFunction(); populateWithGenerated(&getContext(), &patterns); - applyPatternsGreedily(func, std::move(patterns)); + applyPatternsGreedily(func, patterns); } }; } // namespace -FunctionPassBase* CreateTFOptimizePass() { return new TFOptimizePass(); } +std::unique_ptr CreateTFOptimizePass() { + return std::make_unique(); +} static PassRegistration pass("tf-optimize", "Optimizes TF."); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index 7dcf7c3819f..49793f43cf3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ include "mlir/IR/OpBase.td" -include "mlir/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def IsDataFormatNHWC : ConstantAttr; @@ -21,6 +21,7 @@ def BroadcastableElements : Constraint>; def F32ElementsAttr : ElementsAttrBase< CPred<"$_self.cast().getType().getElementType().isF32()">, "float constant tensor">; +def DefinedByConv2D : Constraint($0->getDefiningOp())">>; // If we see a Conv2D op followed by Mul, then multiply the filter // with the value in Mul. @@ -41,3 +42,40 @@ def FuseMulAndConv2D : Pat<(TF_MulOp (TF_Conv2DOp $input, $padding, $explicit_padding, $data_format, $dilations), [(BroadcastableElements $filter, $value)]>; + +// This rule does the following pattern match and rewrite: +// +// input bias input value bias value +// | / => \ / \ / +// BiasAdd value Mul Mul +// \ / \ / +// Mul BiasAdd +// This is to enable the FuseMulAndConv2D pattern. +def PassthroughMulAndBiasAdd : + Pat<(TF_MulOp + (TF_BiasAddOp $input, + (ConstantOp F32ElementsAttr:$bias), IsDataFormatNHWC:$same_format), + (ConstantOp F32ElementsAttr:$value)), + (TF_BiasAddOp + (TF_MulOp $input, (ConstantOp $value)), + (TF_MulOp (ConstantOp $bias), (ConstantOp $value)), + $same_format), + [(DefinedByConv2D $input)]>; + + +// This rule does the following pattern match and rewrite: +// +// input bias input value bias value +// | / => \ / \ / +// AddV2 value Mul Mul +// \ / \ / +// Mul AddV2 +// This is to enable the FuseMulAndConv2D pattern. +def PassthroughMulAndAddV2 : + Pat<(TF_MulOp + (TF_AddV2Op $input, (ConstantOp F32ElementsAttr:$bias)), + (ConstantOp F32ElementsAttr:$value)), + (TF_AddV2Op + (TF_MulOp $input, (ConstantOp $value)), + (TF_MulOp (ConstantOp $bias), (ConstantOp $value))), + [(DefinedByConv2D $input)]>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 1202d4d432c..e66fd89eb8b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -22,19 +22,49 @@ namespace mlir { namespace TF { // Transforms functional control flow operations in the standard TensorFlow // dialect to MLIR Control Flow Graph (CFG) form. -FunctionPassBase *CreateTFFunctionalControlFlowToCFG(); +std::unique_ptr CreateTFFunctionalControlFlowToCFG(); // Optimizes Tensorflow graph. -FunctionPassBase *CreateTFOptimizePass(); +std::unique_ptr CreateTFOptimizePass(); } // namespace TF namespace TFControlFlow { // Raises from the "TensorFlow Control Flow" dialect to the standard TensorFlow // dialect. -FunctionPassBase *CreateRaiseTFControlFlowPass(); +std::unique_ptr CreateRaiseTFControlFlowPass(); } // namespace TFControlFlow + +namespace tf_executor { +class GraphOp; + +// Create a pass to merge IslandOps from TFExecutor dialect. +std::unique_ptr CreateTFExecutorIslandCoarseningPass(); + +// Create a pass to prune tf_executor.graph from dead nodes. +FunctionPassBase* CreateTFExecutorGraphPruningPass(); + +// Prune a tf_executor.graph operation from dead nodes. +void prune_graph(GraphOp graph); + +} // namespace tf_executor + +namespace TFDevice { +// Creates a pass that forms clusters from instructions that are assigned to +// same device. +std::unique_ptr CreateClusterFormationPass(); + +// Creates a pass that outlines regions of tf_device.launch operations. +std::unique_ptr CreateClusterOutliningPass(); +} // namespace TFDevice + +namespace TFTPU { +// Creates a pass that rewrites `tf_device.launch_func` on TPUs into TPU runtime +// ops +std::unique_ptr CreateTPURewritePass(); +} // namespace TFTPU + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc index 3e058127fe2..69bfd75e1e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/raise_control_flow.cc @@ -145,8 +145,8 @@ void RaiseTFControlFlow::rewriteOps() { } // namespace -FunctionPassBase *CreateRaiseTFControlFlowPass() { - return new RaiseTFControlFlow(); +std::unique_ptr CreateRaiseTFControlFlowPass() { + return std::make_unique(); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 60f7ed35a0b..c5f21fa3029 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc new file mode 100644 index 00000000000..84d2690f787 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -0,0 +1,275 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TFTPU { + +// Rewrites `tf_device.launch_func` operations assigned to TPU into actual TPU +// jit-compile runtime ops. +// +// For example: +// %1 = "tf_device.launch_func"(%0) {_tpu_replicate = "cluster", func = +// @tpu_func} +// %2 = "tf.SomeOp"(%1) +// +// Would become following ops (unimportant attributes, types are omitted): +// %1 = "tf.Shape"(%0) +// %2:2 = "tf.MLIRCompileToTPU"(%1) {module = ""} +// "tf.TPUCompileSucceededAssert"(%2#0) +// %3 = "tf.TPUExecute"(%0, %2#1) +// %4 = "tf.SomeOp"(%3) + +namespace { +struct TPURewritePass : public ModulePass { + void runOnModule() override; +}; + +// Recursively visits all attributes of `op` to find any Attribute of type +// `SymbolRefAttr`. +llvm::SmallVector GetAllSymbolRefAttrs(Operation* op) { + llvm::SmallVector symbol_ref_attrs; + + llvm::SmallVector worklist; + for (auto named_attr : op->getAttrs()) { + worklist.push_back(named_attr.second); + } + + while (!worklist.empty()) { + Attribute attr = worklist.pop_back_val(); + + if (SymbolRefAttr symbol_ref_attr = attr.dyn_cast()) { + // Found a SymbolRefAttr, add it to result list. + symbol_ref_attrs.push_back(symbol_ref_attr); + } else if (ArrayAttr array_attr = attr.dyn_cast()) { + // Found an ArrayAttr, add its nested Attributes to worklist for further + // inspection. + worklist.append(array_attr.begin(), array_attr.end()); + } else if (DictionaryAttr dict_attr = attr.dyn_cast()) { + // Found a DictionaryAttr, add its nested value Attributes to worklist for + // further inspection. + for (NamedAttribute named_attr : dict_attr.getValue()) { + worklist.push_back(named_attr.second); + } + } + } + + return symbol_ref_attrs; +} + +// Creates a new self-contained module that contains `entry_func` and all +// referenced functions in `entry_func`. entry_func is renamed to "main". +// Return value is serialized text formate of newly-created module. +std::string EncapsulateFuncAndSerialize(FuncOp entry_func) { + ModuleOp module = entry_func.getParentOfType(); + llvm::SmallVector referenced({entry_func}); + + // Create a new module to hold func and all referenced functions. + OwningModuleRef module_for_func = + ModuleOp::create(mlir::UnknownLoc::get(entry_func.getContext())); + ModuleManager module_manager(module_for_func.get()); + + while (!referenced.empty()) { + auto func = referenced.pop_back_val(); + + // Skip functions that have already been cloned into new module. + if (module_manager.lookupSymbol(func.getName())) continue; + + // Find any SymbolRefAttr in func that maps to a FuncOp. We need to clone + // all found FuncOps to new_module to make sure new_module is + // self-contained. + func.walk([&](Operation* op) { + for (auto symbol_ref_attr : GetAllSymbolRefAttrs(op)) { + FuncOp referenced_func = + module.lookupSymbol(symbol_ref_attr.getValue()); + + // Skip Symbols that do not map to a function. + if (!referenced_func) continue; + + referenced.emplace_back(referenced_func); + } + }); + + auto clone = func.clone(); + if (clone.getName() == entry_func.getName()) { + // We can simply change name of TPU program's main function because there + // should be no other reference to it. + clone.setName("main"); + } + module_manager.insert(clone); + } + + // Serialize module and return. + std::string txt_module; + { + llvm::raw_string_ostream os(txt_module); + module_for_func.get().print(os); + } + return txt_module; +} + +// Create a `tf.MLIRCompileToTPU` that contains a MLIR module that is +// functionally equivalent to the function referenced by launch_func. +Operation* BuildCompileOp(tf_device::LaunchFuncOp launch_func, + OpBuilder* builder) { + // TODO(b/139377366): Use tf_tpu.compile build method when it is defined. + OperationState compile_op_state(launch_func.getLoc(), "tf.MLIRCompileToTPU"); + + // Build a shape op for each input to launch_func. + // TODO(b/139377366): When shape inference is ready, we can use compile time + // shape inference to get inputs that have static shapes and only use shape + // ops for the rest. + llvm::SmallVector compile_op_operands; + compile_op_operands.reserve(launch_func.getNumOperands()); + + for (Value* v : launch_func.getOperands()) { + auto shape_op = builder->create( + launch_func.getLoc(), + builder->getTensorType({-1}, builder->getIntegerType(64)), v); + compile_op_operands.emplace_back(shape_op.getResult()); + } + compile_op_state.addOperands(compile_op_operands); + + SymbolRefAttr func_attr = launch_func.getAttrOfType("func"); + if (!func_attr) { + launch_func.emitOpError("does not have `func` attribute"); + return nullptr; + } + FuncOp func = launch_func.getParentOfType().lookupSymbol( + func_attr.getValue()); + + std::string txt_module = EncapsulateFuncAndSerialize(func); + compile_op_state.addAttribute("module", builder->getStringAttr(txt_module)); + + // Copy all launch_func attributes other than `func`. + for (auto attr : launch_func.getAttrs()) { + if (attr.first == "func") continue; + compile_op_state.attributes.emplace_back(attr); + } + + // Result #0 is a string indicating whether compilation is successful or not. + compile_op_state.addTypes( + builder->getTensorType({}, builder->getType())); + + // Result #1 is key to look up executable binary in compilation cache. + compile_op_state.addTypes( + builder->getTensorType({}, builder->getType())); + + return builder->createOperation(compile_op_state); +} + +// Creates a `tf.TPUExecute` op that executes TPU program generated by +// `compile_op`. +Operation* BuildExecuteOp(Operation* compile_op, + tf_device::LaunchFuncOp launch_func, + OpBuilder* builder) { + // TODO(b/139377366): Use tf.TPUExecute build method when it is defined. + OperationState execute_op_state(launch_func.getLoc(), "tf.TPUExecute"); + + // TPUExecute inherits all launch_func inputs. + llvm::SmallVector tensor_inputs(launch_func.getOperands()); + execute_op_state.addOperands(tensor_inputs); + + // TODO(b/139377366): Need to snapshot all resource variable inputs in + // follow-up CLs. + + // Set Targs of TPUExecute according to launch_func input types. + llvm::SmallVector tensor_input_types_attrs; + tensor_input_types_attrs.reserve(tensor_inputs.size()); + for (Value* v : tensor_inputs) { + tensor_input_types_attrs.emplace_back(builder->getTypeAttr(v->getType())); + } + execute_op_state.addAttribute( + "Targs", builder->getArrayAttr(tensor_input_types_attrs)); + + // TPUExecute takes an additional input for compilation cache key. + execute_op_state.addOperands(compile_op->getResult(1)); + + // Set Tresults of TPUExecute according to launch_func results types. + llvm::SmallVector output_types_attrs; + output_types_attrs.reserve(launch_func.getNumResults()); + for (Value* v : launch_func.getResults()) { + output_types_attrs.emplace_back(builder->getTypeAttr(v->getType())); + } + execute_op_state.addAttribute("Tresults", + builder->getArrayAttr(output_types_attrs)); + + // TPUExecute has same output types as launch_func. + llvm::SmallVector output_types(launch_func.getResultTypes()); + execute_op_state.addTypes(output_types); + + return builder->createOperation(execute_op_state); +} + +// Creates a `tf.TPUCompileSucceededAssert` operation that parses compilation +// status of `compile_op` to check whether compilation is successful. +void BuildTPUCompileSucceededAssertOp(Operation* compile_op, + OpBuilder* builder) { + OperationState assert_op_state(compile_op->getLoc(), + "tf.TPUCompileSucceededAssert"); + assert_op_state.addOperands(compile_op->getResult(0)); + builder->createOperation(assert_op_state); +} + +// Rewrites a `tf_device.launch_func` operation into a set of TPU Runtime +// Operations that jit-compiles and executes function in `tf_device.launch_func` +// on TPU. +void Rewrite(tf_device::LaunchFuncOp launch_func, OpBuilder* builder) { + builder->setInsertionPoint(launch_func); + Operation* compile_op = BuildCompileOp(launch_func, builder); + BuildTPUCompileSucceededAssertOp(compile_op, builder); + // TODO(ycao): Right now we only support single-core case. The right thing to + // do is to read from launch_func attributes to determine how many execute + // ops to build. + Operation* execute_op = BuildExecuteOp(compile_op, launch_func, builder); + launch_func.replaceAllUsesWith(execute_op); + launch_func.erase(); +} + +void TPURewritePass::runOnModule() { + OpBuilder builder(&getContext()); + getModule().walk([&](tf_device::LaunchFuncOp op) { + // Skip non-tpu device launch_func. + if (!op.getAttrOfType("_tpu_replicate")) return; + Rewrite(op, &builder); + }); + + // TODO(b/139377366): Remove functions that are no longer needed. +} + +} // namespace + +std::unique_ptr CreateTPURewritePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-rewrite", + "Rewriting `tf_device.launch_func` on TPUs into TPU runtime ops"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc index 4d9b3ca7ab7..1b48d92171e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/control_to_executor_dialect.cc @@ -22,12 +22,12 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/Value.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/LLVM.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -147,7 +147,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { if (op.getName().getStringRef() == "_tf.Switch") { replacement = builder.create( loc, types, operands, ArrayRef{}); - } else if (op.getName().getStringRef() == "_tf.SwitchN") { + } else if (op.getName().getStringRef() == "_tf._SwitchN") { replacement = builder.create( loc, types, operands, ArrayRef{}); } else if (op.getName().getStringRef() == "_tf.Merge") { @@ -155,7 +155,7 @@ void ControlToExecutorDialectConversion::runOnFunction() { loc, types, operands, ArrayRef{}); } else if (op.getName().getStringRef() == "_tf.NextIteration.source") { replacement = builder.create( - loc, op.getResult(0)->getType(), operands); + loc, op.getResult(0)->getType()); // Record a mapping of the name to the nextiteration.source so that when // we convert the sink we can get the token. StringAttr frame = op.getAttrOfType("name"); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc index 0c265da11f2..2b076e3d5f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator_gen.cc @@ -16,7 +16,7 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/InitLLVM.h" #include "llvm/Support/Signals.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" @@ -129,10 +129,7 @@ static bool DerivedAttrWritersMain(raw_ostream &os, RecordKeeper &records) { } int main(int argc, char **argv) { - llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); - llvm::PrettyStackTraceProgram X(argc, argv); - - llvm::llvm_shutdown_obj Y; + llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); return TableGenMain(argv[0], &DerivedAttrWritersMain); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc new file mode 100644 index 00000000000..2d906d84db3 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/executor_to_control_dialect.cc @@ -0,0 +1,210 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass transforms from TF executor dialect to MLIR TF +// contol dialect. + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir +#include "mlir/Support/LLVM.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +#define DEBUG_TYPE "tf-executor-to-ctl" + +namespace mlir { + +namespace { +struct ExecutorToControlDialectConversion + : public FunctionPass { + void runOnFunction() override; +}; +} // end anonymous namespace + +static bool HasSingleGraph(FuncOp function) { + // We expect the function has only one region with one block, + if (function.getBlocks().size() != 1) return false; + auto &block = function.front(); + // and the block contains two ops, + if (std::next(block.begin()) == block.end()) return false; + // one GraphOp, + if (!isa(block.begin())) return false; + // followed by a terminator. + if (!std::next(block.begin())->isKnownTerminator()) return false; + return true; +} + +void ExecutorToControlDialectConversion::runOnFunction() { + if (!HasSingleGraph(getFunction())) { + LLVM_DEBUG(llvm::dbgs() + << "Expect a Function with a single block and a single graph op," + " skip tf_executor dialect conversion\n"); + return; + } + Type control_type = TFControlFlow::TFControlType::get(&getContext()); + + Block &body = getFunction().front(); + OpBuilder builder(&body, body.begin()); + auto graph = cast(body.front()); + SmallString<64> new_op_name; + for (auto &op : llvm::make_early_inc_range(graph.GetBody())) { + LLVM_DEBUG(llvm::dbgs() << "Process: " << op.getName() << "\n"); + if (auto fetch = dyn_cast(op)) { + // Replace all the operands of the fetch op with the uses of the graph + // results, the graph op will then be removed. + for (auto ops_and_ret_vals : + llvm::zip(graph.getResults(), fetch.getOperands())) + std::get<0>(ops_and_ret_vals) + ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + continue; + } + if (auto island = dyn_cast(op)) { + Value *ctl_sequence = nullptr; + Operation *last_replaced_op = nullptr; + for (Operation &wrapped_op : island.GetBody()) { + LLVM_DEBUG(llvm::dbgs() + << " In island: " << wrapped_op.getName() << "\n"); + if (isa(wrapped_op)) { + for (auto ops_and_ret_vals : + llvm::zip(island.getResults(), wrapped_op.getOperands())) + std::get<0>(ops_and_ret_vals) + ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + break; + } + // Add a leading _ off the name. + new_op_name = "_"; + new_op_name += wrapped_op.getName().getStringRef(); + OperationState state(wrapped_op.getLoc(), new_op_name); + + // Add an operand for each non-control input we find. Collect control + // values separately to add them to the island operands + state.operands.append(wrapped_op.getOperands().begin(), + wrapped_op.getOperands().end()); + + // Chain operations through a control dependency, except for the first + // operations in the sequence that carry the control dependencies held + // by the island itself. + if (ctl_sequence) { + state.operands.push_back(ctl_sequence); + } else { + for (Value *ctl_operand : island.getOperands()) + state.operands.push_back(ctl_operand); + } + + // Add a result type for each result + state.types.append(wrapped_op.getResultTypes().begin(), + wrapped_op.getResultTypes().end()); + state.types.push_back(control_type); + + // Create the replacement operation. + auto *replacement = builder.createOperation(state); + replacement->setAttrs(wrapped_op.getAttrList()); + + for (auto ops_and_ret_vals : + llvm::zip(wrapped_op.getResults(), replacement->getResults())) + std::get<0>(ops_and_ret_vals) + ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + + ctl_sequence = replacement->getResult(replacement->getNumResults() - 1); + last_replaced_op = replacement; + } + for (Value *island_ctl : island.getResults()) + island_ctl->replaceAllUsesWith( + last_replaced_op->getResult(last_replaced_op->getNumResults() - 1)); + op.erase(); + continue; + } + + new_op_name.clear(); + if (isa(op)) { + new_op_name = "_tf.Switch"; + } else if (isa(op)) { + new_op_name = "_tf._SwitchN"; + } else if (isa(op)) { + new_op_name = "_tf.Merge"; + } else if (isa(op)) { + new_op_name = "_tf.NextIteration.source"; + } else if (isa(op)) { + new_op_name = "_tf.NextIteration.sink"; + } else if (isa(op)) { + new_op_name = "_tf.LoopCond"; + } else if (isa(op)) { + new_op_name = "_tf.Enter"; + } else if (isa(op)) { + new_op_name = "_tf.Exit"; + } else if (isa(op)) { + new_op_name = "_tf.ControlTrigger"; + } else { + op.emitOpError() << "unhandled op in tf_executor to _tf conversion"; + return signalPassFailure(); + } + OperationState state(op.getLoc(), new_op_name); + // Token results are dropped when we process the source op, the operand + // becomes nullptr by the time we process the sink op, filter it out here. + auto non_null_operands = + llvm::make_filter_range(op.getOperands(), [](Value *v) { return v; }); + state.operands.append(non_null_operands.begin(), non_null_operands.end()); + for (Type result_type : op.getResultTypes()) { + // Filter out TokenType, they don't exist in the control dialect. + if (result_type.isa()) continue; + if (!result_type.isa()) + state.types.push_back(result_type); + else + state.types.push_back(control_type); + } + // The control dialect has a control result for the sink operation. + if (isa(op)) + state.types.push_back(control_type); + + // Create the replacement operation. + auto *replacement = builder.createOperation(state); + replacement->setAttrs(op.getAttrList()); + + if (auto next_iteration = + dyn_cast(op)) { + next_iteration.output()->replaceAllUsesWith(replacement->getResult(0)); + next_iteration.token()->dropAllUses(); + next_iteration.control()->replaceAllUsesWith(replacement->getResult(1)); + } else { + for (auto ops_and_ret_vals : + llvm::zip(op.getResults(), replacement->getResults())) + std::get<0>(ops_and_ret_vals) + ->replaceAllUsesWith(std::get<1>(ops_and_ret_vals)); + } + op.erase(); + } + graph.erase(); +} + +std::unique_ptr CreateTFExecutorToControlDialectConversion() { + return std::make_unique(); +} + +} // namespace mlir + +static mlir::PassRegistration pass( + "tf-executor-to-control-conversion", + "Convert from TF executor dialect to TF control dialect"); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 3d98cdf4ea4..9868c4a4ac5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir @@ -34,8 +35,10 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Pass/PassManager.h" // TF:local_config_mlir #include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" @@ -55,6 +58,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +namespace mlir { +/// Create a pass to convert from the TFExecutor to the TF control dialect. +std::unique_ptr CreateTFExecutorToControlDialectConversion(); +} // namespace mlir + namespace tensorflow { using llvm::cast; using llvm::dyn_cast; @@ -201,10 +209,8 @@ std::string Exporter::UniqueName(mlir::Operation* op) { StatusOr> Exporter::GetArgumentNode( mlir::BlockArgument* arg, unsigned index) { auto node_def = absl::make_unique(); - node_def->set_name(UniqueName(arg->getContainingRegion() - ->getParentOfType() - .getName() - .str())); + node_def->set_name(UniqueName( + arg->getParentRegion()->getParentOfType().getName().str())); node_def->set_op(FunctionLibraryDefinition::kArgOp); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( @@ -294,13 +300,17 @@ Status Exporter::AddInstructionNode(mlir::Operation* inst) { // check is too conservative given we could use a OpDef. if (auto abstract_op = inst->getAbstractOperation()) { if (&abstract_op->dialect == tf_dialect_) { - TF_ASSIGN_OR_RETURN(node_def, ConvertTFDialectOpToNodeDef(inst, name)); + TF_ASSIGN_OR_RETURN( + node_def, ConvertTFDialectOpToNodeDef( + inst, name, /*ignore_unregistered_attrs=*/false)); } } // Convert TF control flow dialect ops. if (!node_def) { - TF_ASSIGN_OR_RETURN(node_def, - GetOperationNodeDef(inst, name.c_str(), getTFOpName)); + absl::flat_hash_set attrs_to_ignore; + TF_ASSIGN_OR_RETURN( + node_def, GetOperationNodeDef(attrs_to_ignore, inst, name.c_str(), + getTFOpName)); } Node* node = graph_->AddNode(*node_def, &status); TF_RETURN_IF_ERROR(status); @@ -326,7 +336,7 @@ Status Exporter::AddArgumentNode(mlir::BlockArgument* arg, unsigned index) { // is an input node. We recover the original input node and skip adding the // argument node. The new input node will be handled as normal in the // following steps. - if (arg->getContainingRegion()->getParentOfType().getName() == + if (arg->getParentRegion()->getParentOfType().getName() == "main") { if (!arg->hasOneUse()) { return errors::FailedPrecondition( @@ -556,7 +566,8 @@ Status Exporter::ConvertLibFunction(const ExporterConfigs& configs, // Ignore the gradient and is_stateful attribute on the function as they have // been handled above. - absl::flat_hash_set attrs_to_ignore = {grad_string, stateful_string}; + absl::flat_hash_set attrs_to_ignore = { + grad_string.data(), stateful_string.data()}; llvm::SmallVector funcAttrs( function.getDialectAttrs()); TF_RETURN_IF_ERROR( @@ -604,6 +615,12 @@ Status Exporter::Convert(mlir::ModuleOp module, const ExporterConfigs& configs, Status ConvertMlirToGraph(mlir::ModuleOp module, const ExporterConfigs& confs, std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + mlir::PassManager pass_manager; + pass_manager.addPass(mlir::CreateTFExecutorToControlDialectConversion()); + if (mlir::failed(pass_manager.run(module))) { + return errors::FailedPrecondition( + "Failed to convert TFExecutor Dialect to Control Dialect."); + } return Exporter::Convert(module, confs, graph, flib_def); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index c2caf3f18f9..993a44452ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringSet.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" @@ -65,7 +68,7 @@ Status SetAttribute(absl::string_view name, ContainerT types, // definitions and isn't a header file. #include "tensorflow/compiler/mlir/tensorflow/translate/derived_attr_populator.inc" -static StatusOr getTensorFlowOpName(llvm::StringRef op_name) { +StatusOr getTensorFlowOpName(llvm::StringRef op_name) { if (!op_name.consume_front("tf.")) { return errors::FailedPrecondition("op name not prefixed with 'tf.': " + op_name.str()); @@ -73,12 +76,54 @@ static StatusOr getTensorFlowOpName(llvm::StringRef op_name) { return op_name.str(); } +// Collect all the unregistered attributes for an TF dialect operation. +// Attributes "name" and "device" are not included because they are not part +// of an TF op attributes. +Status GetUnregisteredAttrs( + mlir::Operation* inst, + absl::flat_hash_set* attrs_to_ignore) { + TF_ASSIGN_OR_RETURN(auto op_name, + getTensorFlowOpName(inst->getName().getStringRef())); + + const tensorflow::OpRegistrationData* op_reg_data; + auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (!status.ok()) { + // This is likely a function call node, so we should continue. + VLOG(1) << status.ToString(); + return Status::OK(); + } + + // Collect all the registered attributes. + llvm::DenseSet registered_attrs; + registered_attrs.insert("name"); + registered_attrs.insert("device"); + for (const auto& attr_def : op_reg_data->op_def.attr()) { + registered_attrs.insert(attr_def.name()); + } + // Attributes are not in the registered attributes set will be ignored. + for (auto& attr : inst->getAttrs()) { + auto attr_name = attr.first.c_str(); + if (registered_attrs.find(attr_name) == registered_attrs.end()) { + attrs_to_ignore->insert(attr_name); + } + } + return Status::OK(); +} + } // namespace StatusOr> ConvertTFDialectOpToNodeDef( - mlir::Operation* inst, llvm::StringRef name) { - TF_ASSIGN_OR_RETURN(auto node_def, - GetOperationNodeDef(inst, name, getTensorFlowOpName)); + mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs) { + // The elements are owned by the MLIRContext. + absl::flat_hash_set attrs_to_ignore; + if (ignore_unregistered_attrs) { + TF_RETURN_IF_ERROR(GetUnregisteredAttrs(inst, &attrs_to_ignore)); + } + + TF_ASSIGN_OR_RETURN( + auto node_def, + GetOperationNodeDef(attrs_to_ignore, inst, name, getTensorFlowOpName)); // Use auto generated function to populate derived attribute. // diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h index 6d32a318a30..26e84d631a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h @@ -24,9 +24,13 @@ limitations under the License. namespace tensorflow { // Converts an MLIR operation to TensorFlow NodeDef with given node name. This -// name should be unique to the graph it is being inserted to. +// name should be unique to the graph it is being inserted to. If the +// `ignore_unregistered_attrs` argument is set to true, the attributes which are +// not in the op registry will be ignored. Set it to true if the returned +// NodeDef will be excuted by the linked TF Eager runtime. stream_executor::port::StatusOr> -ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name); +ConvertTFDialectOpToNodeDef(mlir::Operation* inst, llvm::StringRef name, + bool ignore_unregistered_attrs); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc similarity index 61% rename from tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc rename to tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 2ac09e3540d..34cdc609164 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -28,6 +30,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir @@ -36,9 +39,9 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" @@ -49,8 +52,11 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -66,49 +72,37 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" +static inline absl::string_view StringRefToView(llvm::StringRef ref) { + return {ref.data(), ref.size()}; +} + namespace tensorflow { using stream_executor::port::StatusOr; namespace { -// Stateful helper class to import a GraphDef into an MLIR Module. The nodes -// defined in the graph is converted to a function called "main". All the -// library function definitions are converted to MLIR functions in the module. -class Importer { - public: - // Main entry point: converts the given graph to an MLIR Module. - static StatusOr Convert( - mlir::MLIRContext* context, const Graph& graph, - const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs); - - private: - // Most types with subtypes have only one subtype. - using ElementSubtypes = llvm::SmallVector; - - explicit Importer( +// Stateful helper class to import a TensorFlow model into an MLIR Module. +// +// This is the base class that contains common utilties shared between the +// GraphDef importer and SavedModel importer. +// +// A subclass is expected to call `PrepareConvert` first to perform necessary +// preparation over the graph and also certain internal bookkeeping data. +// Afterwards the other protected methods can be called. +class ImporterBase { + protected: + explicit ImporterBase( const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, const NodeSpecs& specs, mlir::ModuleOp module, std::unordered_map* tf_name_to_mlir_name) - : module_(module), + : builder_(module.getContext()), + module_(module), context_(module.getContext()), tf_name_to_mlir_name_(tf_name_to_mlir_name), graph_flib_(flib), specs_(specs), debug_info_(debug_info) {} - // Prepares converting the graph to an MLIR module. This step removes the - // backedges of the graph, orders the nodes and infers the shapes. - Status PrepareConvert(const Graph& graph); - - // Returns the function signature of the main function of converted MLIR - // module, the input nodes and output nodes. The type and shape information - // for the function arguments are read from the specs_, but the type and shape - // information for the function returns are inferred by the shape_refiner_. - StatusOr InferMainFunctionType( - absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes); - // Returns the inferred function signature of the given function body. Input // types are unranked tensor of the respective datatype in the function and // result types are inferred by the shape_refiner_. Result types need not be @@ -116,25 +110,54 @@ class Importer { // depends on an op with static output shape like tf.Const. StatusOr InferLibFunctionType(const FunctionBody& fbody); + // Extracts arg and ret nodes from FunctionBody. + void GetArgsAndRetsFromFunctionBody( + const FunctionBody& fbody, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes, + absl::InlinedVector* control_ret_nodes); + + // Prepares converting the graph to an MLIR module. This step removes the + // backedges of the graph, orders the nodes and infers the shapes. + Status PrepareConvert(const Graph& graph); + // Converts the prepared graph to a Function and adds it to the module. A set // of nodes from the graph are given to converted to the arguments and returns // of the function. Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type, const absl::InlinedVector& arg_nodes, const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes, llvm::ArrayRef attrs); + // Finds out the function definition for the given function name from the + // graph and converts it to a function of the module. This method is called + // on demand because the graph flib_def does not provide an iterator + // interface. + Status ConvertLibFunction(llvm::StringRef func_name); + + // Returns the list of nodes in the graph. Nodes are presented in the reverse + // order of a post-order depth-first visit starting from the graph's source + // nodes. + llvm::ArrayRef GetOrderedNodes() const { return ordered_nodes_; } + + // Returns the inferred output type at index `idx` of the `node` in the + // context. + StatusOr InferOutputType(const Node& node, int idx, + mlir::Builder builder); + + private: + // Most types with subtypes have only one subtype. + using ElementSubtypes = llvm::SmallVector; + // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all // data type and shape information is maintained by the shape_refiner_. Status AddNodesToShapeRefiner(); - // Returns the inferred input type at index `idx` of the node in the context. - StatusOr InferInputType( - ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder); - - // Returns the inferred output type at index `idx` of the node in the context. - StatusOr InferOutputType( - ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder); + // Returns the inferred input type at index `idx` of the `node` in the + // context. + StatusOr InferInputType(const Node& node, int idx, + mlir::Builder builder); // Converts the inferred shape referred to by 'handle' in 'context', with // given element type, and returns an MLIR tensor type. @@ -157,7 +180,7 @@ class Importer { // Converts the tensor proto into an MLIR elements attribute. StatusOr ConvertTensorProto(const TensorProto& value) { - return ::tensorflow::ConvertTensorProto(value, builder_.get()); + return ::tensorflow::ConvertTensorProto(value, &builder_); } // Converts func name in graphdef to mlir::SymbolRefAttribute. @@ -176,6 +199,13 @@ class Importer { const std::string& base_name, const AttrValue& value, llvm::SmallVector* attributes); + // Helper to create either a tf_executor operation or a TF operation wrapped + // in an island. + mlir::Operation* createOperation( + const Node& node, llvm::StringRef op_name, + const mlir::OperationState& result, + const llvm::SmallVectorImpl& control_operands); + // Converts one NodeDef from the input GraphDef into an Operation and // inserts it into the MLIR module using builder_. Status ConvertNode(const Node& node); @@ -200,25 +230,15 @@ class Importer { Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst, int dst_input); - // Gets the "source" of a NextIteration operation. If it doesn't exist, - // creates and inserts it to the front of the basic block. - mlir::Operation* GetOrCreateNextIterationSource(mlir::Operation* sink, - mlir::Operation* dst); - - // Finds out the function definition for the given function name from the - // graph and converts it to a function of the module. This method is called - // on demand because the graph flib_def does not provide an iterator - // interface. The consequence is that only the referred functions are added to - // the MLIR module. - Status ConvertLibFunction(const std::string& func_name); - // Adds the input arguments and return operation to the function. The // arguments are added as basic block argument. Also the argument types and // the id of the nodes from the input graph needs to be specified. Status ConvertFunctionArgAndRets( - mlir::Block* bb, llvm::ArrayRef arg_types, + mlir::Block* bb, mlir::tf_executor::GraphOp graph_op, + llvm::ArrayRef arg_types, const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes); + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes); // Gets the location information of the given node. It uses the // "original_node_name" in the NodeDef to get the corresponding file location @@ -257,13 +277,12 @@ class Importer { // All nodes and version information about the (copied) imported graph. std::unique_ptr graph_; - const VersionDef* graph_versions_; std::vector ordered_nodes_; // Maps from a Node ID to a MLIR value. using NodeValueMap = absl::flat_hash_map; - std::unique_ptr builder_; + mlir::OpBuilder builder_; mlir::ModuleOp module_; mlir::MLIRContext* context_; std::unordered_map* tf_name_to_mlir_name_; @@ -274,11 +293,67 @@ class Importer { std::unique_ptr shape_refiner_; }; -// Adds the default attributes to each node def if they are missing from the -// GraphDef. -Status AddDefaultsToNodeDef(GraphDef* graph_def) { +// Returns true if the node with given name has a non primary output that is +// used by some other node as an input. Returns false if no outputs are in use +// or only the first output is in use. +bool HasNonPrimaryOutputInUse(const GraphDef& graph_def, + const std::string& node) { + for (const auto& node_def : graph_def.node()) { + for (const auto& input : node_def.input()) { + if (absl::StartsWith(input, node + ":") && input != node + ":0") { + return true; + } + } + } + return false; +} + +// Updates the given LegacyFedInput node with Placeholder node if it is one of +// the inputs. Returns an error if non primary output of the LegacyFedInput node +// is in use and therefore can not be replaced by the Placeholder node that only +// has a single output. +Status UpdateLegacyFedInputNode(const GraphDef& graph_def, + const NodeSpecs::InputArrays& inputs, + NodeDef* node) { + const std::string& node_name = node->name(); + auto it = inputs.find(node_name); + + // Node is not an input. + if (it == inputs.end()) return Status::OK(); + + if (HasNonPrimaryOutputInUse(graph_def, node_name)) { + return errors::InvalidArgument( + "LegacyFedInput node ", node->name(), + " has non primary output in use and can not be replaced with " + "Placeholder node"); + } + + // Update op name, drop inputs and set attributes required by the Placeholder + // op. + *node->mutable_op() = "Placeholder"; + node->clear_attr(); + node->clear_input(); + AddNodeAttr("dtype", it->second.imported_dtype, node); + AddNodeAttr("shape", it->second.shape, node); + return Status::OK(); +} + +// Preprocesses GraphDef before it can be converted to Graph by, +// - Adding the default attributes to each node def if they are missing from +// the GraphDef. +// - Replacing LegacyFedInput nodes with Placeholder nodes if +// convert_legacy_fed_inputs option is enabled. +Status PreprocessGraphDef(const NodeSpecs* specs, GraphDef* graph_def) { const tensorflow::OpRegistrationData* op_reg_data; for (auto& node_def : *graph_def->mutable_node()) { + // TODO(hinsu): Completely deprecate support for LegacyFedInput ops. One + // solution could be have a tool to let users upgrade old serialized graphs. + if (specs && specs->convert_legacy_fed_inputs && + node_def.op() == "LegacyFedInput") { + TF_RETURN_IF_ERROR( + UpdateLegacyFedInputNode(*graph_def, specs->inputs, &node_def)); + } + auto status = tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); if (!status.ok()) { @@ -291,7 +366,7 @@ Status AddDefaultsToNodeDef(GraphDef* graph_def) { return Status::OK(); } -Status Importer::RemoveBackedges(const Graph& graph) { +Status ImporterBase::RemoveBackedges(const Graph& graph) { // TODO(fengliuai): Converting to GraphDef and back is the easiest way to // clone a graph. // TODO(fengliuai): clone the graph without going to graph_def first. @@ -300,8 +375,8 @@ Status Importer::RemoveBackedges(const Graph& graph) { graph_ = absl::make_unique(graph.flib_def()); GraphConstructorOptions opts; opts.allow_internal_ops = true; - TF_RETURN_IF_ERROR( - ::tensorflow::ConvertGraphDefToGraph(opts, graph_def, graph_.get())); + TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph( + opts, std::move(graph_def), graph_.get())); // Remove all the backedges. So the nodes can be added to the shape refiner. TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get())); @@ -330,7 +405,7 @@ Status Importer::RemoveBackedges(const Graph& graph) { return Status::OK(); } -StatusOr Importer::ReplaceWithPlaceholderNode( +StatusOr ImporterBase::ReplaceWithPlaceholderNode( const TensorShapeProto& shape, DataType dtype, Node* input_node) { Node* placeholder_node; NodeBuilder builder(input_node->name(), "Placeholder"); @@ -351,7 +426,8 @@ StatusOr Importer::ReplaceWithPlaceholderNode( return placeholder_node; } -Status Importer::GetInputOutputNodes(std::unordered_set* nodes) { +Status ImporterBase::GetInputOutputNodes( + std::unordered_set* nodes) { auto node_name_map = graph_->BuildNodeNameIndex(); auto add_node = [&](const string& name) { auto it = node_name_map.find(name); @@ -375,9 +451,9 @@ Status Importer::GetInputOutputNodes(std::unordered_set* nodes) { } // TODO(fengliuai): Replace the iterative algorithm by an one pass propagation -Status Importer::AddNodesToShapeRefiner() { - shape_refiner_ = - absl::make_unique(*graph_versions_, graph_->op_registry()); +Status ImporterBase::AddNodesToShapeRefiner() { + shape_refiner_ = absl::make_unique(graph_->versions(), + graph_->op_registry()); // Some operations (for example "TPUExecute") don't have shape inference // function defined, so we should set this to false for adding nodes with // these types of operations. @@ -527,8 +603,11 @@ Status Importer::AddNodesToShapeRefiner() { return Status::OK(); } -StatusOr Importer::InferInputType( - ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder) { +StatusOr ImporterBase::InferInputType(const Node& node, + int idx, + mlir::Builder builder) { + ExtendedInferenceContext* shape_context = + shape_refiner_->GetExtendedContext(&node); DataType dtype = shape_context->input_type(idx); auto* context = shape_context->get_context(); return ConvertDataTypeAndShape(dtype, context->input(idx), @@ -536,8 +615,10 @@ StatusOr Importer::InferInputType( context, builder); } -StatusOr Importer::InferOutputType( - ExtendedInferenceContext* shape_context, int idx, mlir::Builder builder) { +StatusOr ImporterBase::InferOutputType( + const Node& node, int idx, mlir::Builder builder) { + ExtendedInferenceContext* shape_context = + shape_refiner_->GetExtendedContext(&node); DataType dtype = shape_context->output_type(idx); auto* context = shape_context->get_context(); return ConvertDataTypeAndShape(dtype, context->output(idx), @@ -545,7 +626,7 @@ StatusOr Importer::InferOutputType( context, builder); } -StatusOr Importer::ConvertDataTypeAndShape( +StatusOr ImporterBase::ConvertDataTypeAndShape( DataType dtype, const shape_inference::ShapeHandle& handle, const std::vector* handle_subtypes, shape_inference::InferenceContext* context, mlir::Builder builder) { @@ -564,7 +645,7 @@ StatusOr Importer::ConvertDataTypeAndShape( return ConvertElementTypeAndShape(element_type, handle, context, builder); } -StatusOr Importer::ConvertElementTypeAndShape( +StatusOr ImporterBase::ConvertElementTypeAndShape( mlir::Type element_type, const shape_inference::ShapeHandle& handle, shape_inference::InferenceContext* context, mlir::Builder builder) { if (!context->RankKnown(handle)) { @@ -591,7 +672,7 @@ StatusOr Importer::ConvertElementTypeAndShape( llvm::makeArrayRef(dimensions.begin(), dimensions.end()), element_type); } -StatusOr Importer::ConvertSubtypes( +StatusOr ImporterBase::ConvertSubtypes( const std::vector* handle_subtypes, shape_inference::InferenceContext* context, mlir::Builder builder) { ElementSubtypes subtypes; @@ -610,64 +691,64 @@ StatusOr Importer::ConvertSubtypes( return subtypes; } -Status Importer::ConvertFunctionCallAttribute( +Status ImporterBase::ConvertFunctionCallAttribute( const std::string& base_name, const AttrValue& value, llvm::SmallVector* attributes) { TF_ASSIGN_OR_RETURN(auto func_attr, ConvertFunctionCallName(value.func().name())); - attributes->push_back(builder_->getNamedAttr(base_name, func_attr)); + attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); for (const auto& it : value.func().attr()) { auto name = absl::StrCat(base_name, ".", it.first); TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second)); - attributes->push_back(builder_->getNamedAttr(name, value)); + attributes->push_back(builder_.getNamedAttr(name, value)); } return Status::OK(); } -StatusOr Importer::ConvertFunctionCallName( +StatusOr ImporterBase::ConvertFunctionCallName( const std::string& func_name) { TF_RETURN_IF_ERROR(ConvertLibFunction(func_name)); auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name]; auto func = module_.lookupSymbol(mlir_func_name); - return builder_->getSymbolRefAttr(func); + return builder_.getSymbolRefAttr(func); } -StatusOr Importer::ConvertAttributeValue( +StatusOr ImporterBase::ConvertAttributeValue( const AttrValue& value) { switch (value.value_case()) { case AttrValue::kI: - return builder_->getI64IntegerAttr(value.i()); + return builder_.getI64IntegerAttr(value.i()); case AttrValue::kS: - return builder_->getStringAttr(value.s()); + return builder_.getStringAttr(value.s()); case AttrValue::kF: - return builder_->getFloatAttr(builder_->getF32Type(), value.f()); + return builder_.getFloatAttr(builder_.getF32Type(), value.f()); case AttrValue::kB: - return builder_->getBoolAttr(value.b()); + return builder_.getBoolAttr(value.b()); case AttrValue::kType: - return builder_->getStringAttr( + return builder_.getStringAttr( mangling_util::MangleDataType(value.type())); case AttrValue::kShape: - return builder_->getStringAttr(mangling_util::MangleShape(value.shape())); + return builder_.getStringAttr(mangling_util::MangleShape(value.shape())); case AttrValue::kTensor: return ConvertTensorProto(value.tensor()); case AttrValue::kList: { absl::InlinedVector attrs; for (const auto& item : value.list().i()) - attrs.push_back(builder_->getI64IntegerAttr(item)); + attrs.push_back(builder_.getI64IntegerAttr(item)); for (const auto& item : value.list().s()) - attrs.push_back(builder_->getStringAttr(item)); + attrs.push_back(builder_.getStringAttr(item)); for (const auto& item : value.list().f()) - attrs.push_back(builder_->getFloatAttr(builder_->getF32Type(), item)); + attrs.push_back(builder_.getFloatAttr(builder_.getF32Type(), item)); for (const auto& item : value.list().b()) - attrs.push_back(builder_->getBoolAttr(item)); + attrs.push_back(builder_.getBoolAttr(item)); for (const auto& item : value.list().type()) { - attrs.push_back(builder_->getStringAttr( + attrs.push_back(builder_.getStringAttr( mangling_util::MangleDataType(static_cast(item)))); } for (const auto& item : value.list().shape()) { attrs.push_back( - builder_->getStringAttr(mangling_util::MangleShape(item))); + builder_.getStringAttr(mangling_util::MangleShape(item))); } for (const auto& item : value.list().tensor()) { TF_ASSIGN_OR_RETURN(auto attr, ConvertTensorProto(item)); @@ -680,13 +761,13 @@ StatusOr Importer::ConvertAttributeValue( "func attributes with non-zero attr.size()"); attrs.push_back(attr); } - return builder_->getArrayAttr( + return builder_.getArrayAttr( llvm::makeArrayRef(attrs.begin(), attrs.end())); } case AttrValue::kFunc: return errors::Unknown("kFunc type should be handled separately!"); case AttrValue::VALUE_NOT_SET: - return builder_->getUnitAttr(); + return builder_.getUnitAttr(); // kPlaceholder is not implemented. default: return errors::Unimplemented( @@ -694,20 +775,36 @@ StatusOr Importer::ConvertAttributeValue( } } -Status Importer::ConvertLibFunction(const std::string& func_name) { +void ImporterBase::GetArgsAndRetsFromFunctionBody( + const FunctionBody& fbody, absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes, + absl::InlinedVector* control_ret_nodes) { + arg_nodes->reserve(fbody.arg_nodes.size()); + ret_nodes->reserve(fbody.ret_nodes.size()); + for (auto arg : fbody.arg_nodes) { + arg_nodes->emplace_back(arg, 0); + } + for (auto ret : fbody.ret_nodes) { + ret_nodes->emplace_back(ret, 0); + } + *control_ret_nodes = fbody.control_ret_nodes; +} + +Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { // If the library function has been converted already, nothing needs to be // done. if (tf_name_to_mlir_name_->find(func_name) != tf_name_to_mlir_name_->end()) return Status::OK(); - std::string mlir_func_name = graph_flib_.UniqueFunctionName(func_name); + std::string mlir_func_name = + graph_flib_.UniqueFunctionName(StringRefToView(func_name)); (*tf_name_to_mlir_name_)[func_name] = mlir_func_name; const auto& func_lib = graph_flib_; const auto* func_def = func_lib.Find(func_name); if (func_def == nullptr) { return errors::FailedPrecondition( - absl::StrCat("Failed to find function '", func_name, + absl::StrCat("Failed to find function '", StringRefToView(func_name), "'. The imported TensorFlow GraphDef is ill-formed.")); } @@ -726,14 +823,14 @@ Status Importer::ConvertLibFunction(const std::string& func_name) { ConvertAttributeValue(name_and_value.second)); std::string attr_name = mangling_util::MangleAttributeName(name_and_value.first); - attributes.push_back(builder_->getNamedAttr(attr_name, attr)); + attributes.push_back(builder_.getNamedAttr(attr_name, attr)); } // Checks opdef stateful attribute and import that as Function Attribute if (func_def->signature().is_stateful()) { auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName(); attributes.push_back( - builder_->getNamedAttr(stateful_str, builder_->getUnitAttr())); + builder_.getNamedAttr(stateful_str, builder_.getUnitAttr())); } // Checks for an associated custom gradient function. Adds it to the attribute @@ -743,99 +840,135 @@ Status Importer::ConvertLibFunction(const std::string& func_name) { TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name)); auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name]; auto grad_func = module_.lookupSymbol(mlir_grad_func_name); - auto gradient_attr = builder_->getSymbolRefAttr(grad_func); + auto gradient_attr = builder_.getSymbolRefAttr(grad_func); auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); - attributes.push_back(builder_->getNamedAttr(grad_string, gradient_attr)); + attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); } // Converts the graph to a MLIR function and adds it to the module. Uses the // default node spec without any inputs or outputs as the function graph has // special '_Arg' and '_Retval' ops for argument and return values. NodeSpecs specs; - Importer child_importer(graph_flib_, debug_info_, specs, module_, - tf_name_to_mlir_name_); + ImporterBase child_importer(graph_flib_, debug_info_, specs, module_, + tf_name_to_mlir_name_); TF_RETURN_IF_ERROR(child_importer.PrepareConvert(*fbody->graph)); TF_ASSIGN_OR_RETURN(auto func_type, child_importer.InferLibFunctionType(*fbody)); absl::InlinedVector arg_nodes; - arg_nodes.reserve(fbody->arg_nodes.size()); absl::InlinedVector ret_nodes; - ret_nodes.reserve(fbody->ret_nodes.size()); - for (auto arg : fbody->arg_nodes) { - arg_nodes.emplace_back(arg, 0); - } - for (auto ret : fbody->ret_nodes) { - ret_nodes.emplace_back(ret, 0); - } + absl::InlinedVector control_ret_nodes; + GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, + &control_ret_nodes); TF_RETURN_IF_ERROR(child_importer.Convert( - mlir_func_name, func_type, arg_nodes, ret_nodes, + mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, llvm::makeArrayRef(attributes.begin(), attributes.end()))); return Status::OK(); } -Status Importer::PrepareConvert(const Graph& graph) { - graph_versions_ = &graph.versions(); +Status ImporterBase::PrepareConvert(const Graph& graph) { TF_RETURN_IF_ERROR(RemoveBackedges(graph)); TF_RETURN_IF_ERROR(AddNodesToShapeRefiner()); return Status::OK(); } -Status Importer::ConvertFunctionArgAndRets( - mlir::Block* bb, llvm::ArrayRef arg_types, +Status ImporterBase::Convert( + llvm::StringRef func_name, mlir::FunctionType func_type, const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes) { + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes, + llvm::ArrayRef attrs) { + // TODO(b/122040776): Uses debug info for FunctionDef. + auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), + func_name, func_type, attrs); + + module_.push_back(function); + // Seeds the builder with an initial block. + function.addEntryBlock(); + builder_ = mlir::OpBuilder(function.getBody()); + auto* bb = &function.front(); + + // Create the graph operation in which we will convert the individual nodes. + auto graph = builder_.create( + function.getLoc(), func_type.getResults()); + builder_.createBlock(&graph.body()); + + for (const Node* node : ordered_nodes_) { + TF_RETURN_IF_ERROR(ConvertNode(*node)); + } + + // Adds the backedges back to the function by creating the source and sink + // pairs. + TF_RETURN_IF_ERROR(AddBackedges()); + + return ConvertFunctionArgAndRets(bb, graph, func_type.getInputs(), arg_nodes, + ret_nodes, control_ret_nodes); +} + +Status ImporterBase::ConvertFunctionArgAndRets( + mlir::Block* bb, mlir::tf_executor::GraphOp graph_op, + llvm::ArrayRef arg_types, + const absl::InlinedVector& arg_nodes, + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes) { for (int i = 0, e = arg_types.size(); i < e; ++i) { - auto* inst = node_values_[arg_nodes[i].node->id()]; - auto* bb_arg = bb->addArgument(arg_types[i]); + // The lookup can't fail here: otherwise some nodes in the function haven't + // be converted to mlir operations and don't have a mapping. + mlir::Operation* island = + node_values_.find(arg_nodes[i].node->id())->second; + // We are looking for the instruction inside the island + mlir::Block& body = island->getRegion(0).front(); + mlir::Operation* inst = &body.front(); + + auto* bb_arg = bb->getArgument(i); mlir::Value* arg_def = bb_arg; - // If this is an input node add argument to the operation operands by - // creating a new input operation. - if (StringPiece(arg_nodes[i].node->type_string()) != - FunctionLibraryDefinition::kArgOp) { - auto inst_name = inst->getName().getStringRef(); - mlir::OperationState state(inst->getLoc(), - inst_name.str().append(".input")); - state.attributes.append(inst->getAttrs().begin(), inst->getAttrs().end()); - - // If there are quantization specifications, add them as the attributes - auto name = inst->getAttrOfType("name").getValue(); - auto input_spec_it = specs_.inputs.find(name.str()); - if (input_spec_it != specs_.inputs.end()) { - auto input_spec = input_spec_it->second; - if (IsQuantizationType(input_spec.final_dtype)) { - // Uses the MLIR built-in type so it can be handled easily later. - auto final_type = mlir::IntegerType::get( - GetQuantizationTypeWidth(input_spec.final_dtype), context_); - state.attributes.push_back(builder_->getNamedAttr( - "min", builder_->getF32FloatAttr(input_spec.min_value))); - state.attributes.push_back(builder_->getNamedAttr( - "max", builder_->getF32FloatAttr(input_spec.max_value))); - state.attributes.push_back(builder_->getNamedAttr( - "type", builder_->getTypeAttr(final_type))); - inst->getParentOfType().setAttr( - "tf.quantize", builder_->getUnitAttr()); - } - } - - for (auto* r : inst->getResults()) state.types.push_back(r->getType()); - - state.operands.append(inst->getOperands().begin(), - inst->getOperands().end()); - state.operands.push_back(bb_arg); - builder_->setInsertionPoint(inst); - auto* input = builder_->createOperation(state); - arg_def = input->getResult(arg_nodes[i].index); - // Verify on the equivalent TF op would have failed, but catching this - // earlier for now as this exposed a bug. TODO(jpienaar): remove post - // dialect refactoring. - DCHECK(input->getResult(0)->getType() == input->getOperand(0)->getType()) - << "invalid placeholder_input constructed"; + // If this is an arg node, just forward the entry block argument + if (arg_nodes[i].node->IsArg()) { + island->getResult(0)->replaceAllUsesWith(arg_def); + island->dropAllReferences(); + island->erase(); + continue; } + // This is an input node, we'll create a new input operation by suffixing + // the existing one with .input. + auto inst_name = inst->getName().getStringRef(); + mlir::OperationState state(inst->getLoc(), + inst_name.str().append(".input")); + state.attributes.append(inst->getAttrs().begin(), inst->getAttrs().end()); + + // If there are quantization specifications, add them as the attributes + auto name = inst->getAttrOfType("name").getValue(); + auto input_spec_it = specs_.inputs.find(name.str()); + if (input_spec_it != specs_.inputs.end()) { + auto input_spec = input_spec_it->second; + if (IsQuantizationType(input_spec.final_dtype)) { + // Uses the MLIR built-in type so it can be handled easily later. + auto final_type = mlir::IntegerType::get( + GetQuantizationTypeWidth(input_spec.final_dtype), context_); + state.attributes.push_back(builder_.getNamedAttr( + "min", builder_.getF32FloatAttr(input_spec.min_value))); + state.attributes.push_back(builder_.getNamedAttr( + "max", builder_.getF32FloatAttr(input_spec.max_value))); + state.attributes.push_back( + builder_.getNamedAttr("type", builder_.getTypeAttr(final_type))); + inst->getParentOfType().setAttr("tf.quantize", + builder_.getUnitAttr()); + } + } + + for (auto* r : inst->getResults()) state.types.push_back(r->getType()); + + state.operands.append(inst->getOperands().begin(), + inst->getOperands().end()); + state.operands.push_back(bb_arg); + builder_.setInsertionPoint(inst); + auto* input = builder_.createOperation(state); + arg_def = input->getResult(arg_nodes[i].index); + for (auto index = 0; index < inst->getNumResults(); index++) { inst->getResult(index)->replaceAllUsesWith(arg_def); } @@ -843,32 +976,47 @@ Status Importer::ConvertFunctionArgAndRets( inst->erase(); } - absl::InlinedVector inst_to_returned; + llvm::SmallVector inst_to_return; for (const auto& ret : ret_nodes) { auto* inst = node_values_[ret.node->id()]; auto op = absl::string_view(ret.node->type_string()); if (op == FunctionLibraryDefinition::kRetOp || op == FunctionLibraryDefinition::kDeviceRetOp) { + // Lookup the instruction inside the island + auto island_op = llvm::cast(inst); + mlir::Operation* inner_op = &island_op.GetBody().front(); // Remove kRetOp or kDeviceRetOp operation and return its operand. // kRetOp and kDeviceRetOp should have just one operand unless they have // control dependencies. - if (inst->getNumOperands() != 1) + if (inner_op->getNumOperands() != 1) return errors::Unimplemented("Return node with multiple inputs."); - inst_to_returned.push_back(inst->getOperand(0)); - node_values_[ret.node->id()]->dropAllReferences(); - node_values_[ret.node->id()]->erase(); + inst_to_return.push_back(inner_op->getOperand(0)); + inst->dropAllReferences(); + inst->erase(); } else { - inst_to_returned.push_back(inst->getResult(ret.index)); + inst_to_return.push_back(inst->getResult(ret.index)); } } - builder_->setInsertionPointToEnd(bb); - builder_->create( - mlir::UnknownLoc::get(context_), - llvm::makeArrayRef(inst_to_returned.begin(), inst_to_returned.end())); + + for (Node* control_ret : control_ret_nodes) { + auto* inst = node_values_[control_ret->id()]; + inst_to_return.push_back(*std::prev(inst->result_end())); + } + + // Terminate the function by adding a Fetch operation to terminate the graph + // and a return operation to return the Graph results. + builder_.setInsertionPointToEnd(&graph_op.body().front()); + builder_.create(graph_op.getLoc(), + inst_to_return); + inst_to_return.assign(graph_op.getResults().begin(), + graph_op.getResults().end()); + builder_.setInsertionPointToEnd(bb); + builder_.create(mlir::UnknownLoc::get(context_), + inst_to_return); return Status::OK(); } -mlir::Location Importer::GetLocation(const NodeDef& node_def) { +mlir::Location ImporterBase::GetLocation(const NodeDef& node_def) { const auto& debug_info = debug_info_.traces(); // Get the CallSiteLoc for a node name. @@ -900,14 +1048,14 @@ mlir::Location Importer::GetLocation(const NodeDef& node_def) { // Use the front FileLineColLoc to generate a NameLoc. mlir::Location node_name_loc = - mlir::NameLoc::get(name_id, locations.front(), context_); + mlir::NameLoc::get(name_id, locations.front()); // If there are more locations then generate a stack trace, otherwise just // return the name loc. auto callsite_locs = llvm::makeArrayRef(locations).drop_front(); return callsite_locs.empty() ? node_name_loc - : mlir::CallSiteLoc::get(node_name_loc, callsite_locs, context_); + : mlir::CallSiteLoc::get(node_name_loc, callsite_locs); }; // For NextIteration nodes, location is used to pair source and sink nodes. @@ -950,7 +1098,8 @@ mlir::Location Importer::GetLocation(const NodeDef& node_def) { } } -std::string Importer::GetLocationStr(const Node& node, bool includeNodeName) { +std::string ImporterBase::GetLocationStr(const Node& node, + bool includeNodeName) { const auto location = GetLocation(node.def()); std::string s; llvm::raw_string_ostream ss(s); @@ -963,7 +1112,80 @@ std::string Importer::GetLocationStr(const Node& node, bool includeNodeName) { return s; } -Status Importer::ConvertNode(const Node& node) { +mlir::Operation* ImporterBase::createOperation( + const Node& node, llvm::StringRef op_name, + const mlir::OperationState& result, + const llvm::SmallVectorImpl& control_operands) { + // For the tf.executor specific operations (not wrapped in an island), we + // have an extra returned value for the control result, and we concatenate + // control and non-control operands. + mlir::SmallVector types(result.types); + types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext())); + mlir::SmallVector operands(result.operands); + operands.append(control_operands.begin(), control_operands.end()); + + auto loc = result.location; + // Dispatch based on the name and create the appropriate operation. + if (node.IsSwitch()) { + // Switch and _SwitchN both are in switch class, differentiate based on + // number of outputs. + if (node.num_outputs() > 2) { + return builder_.create(loc, types, operands, + result.attributes); + } + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsMerge()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsNextIteration()) { + // NextIteration is a bit special, we create a pair of operations that are + // linked together through a token returned by the source. + // We make use of a separate builder to insert the source at the top of + // the block. + mlir::OpBuilder builder_at_begin(builder_.getBlock(), + builder_.getBlock()->begin()); + auto source_op = + builder_at_begin.create( + loc, operands[0]->getType(), result.attributes); + return builder_.create( + loc, source_op.token(), operands, result.attributes); + } + if (node.IsLoopCond()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsEnter()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsExit()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsControlTrigger()) { + return builder_.create( + loc, operands, result.attributes); + } + // Regular TensorFlow operation are wrapped in a tf_executor.island. + auto island = builder_.create( + result.location, types, control_operands, + mlir::ArrayRef{}); + island.body().push_back(new mlir::Block); + mlir::OpBuilder island_builder(&island.GetBody()); + + // Create the operation inside the island now. + mlir::Operation* inner_op = island_builder.createOperation(result); + + // Add the terminator for the island + mlir::SmallVector ret_vals(inner_op->getResults()); + island_builder.create(result.location, ret_vals); + return island.getOperation(); +} + +Status ImporterBase::ConvertNode(const Node& node) { if (!node.IsOp()) { // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by // Graph and don't exist in GraphDef. @@ -979,9 +1201,12 @@ Status Importer::ConvertNode(const Node& node) { node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; } - const char* kTfControlFlowFormPrefix = "_tf."; - std::string op_name = kTfControlFlowFormPrefix + node_type_name; + auto get_full_op_name = [&](const std::string& op_name) { + const char* kTfPrefix = "tf."; + return kTfPrefix + op_name; + }; + std::string op_name = get_full_op_name(node_type_name); if (back_edge_node_output_.contains(&node)) { op_name = op_name + ".sink"; } @@ -989,7 +1214,6 @@ Status Importer::ConvertNode(const Node& node) { const auto& node_def = node.def(); mlir::OperationState result(GetLocation(node_def), op_name); - ExtendedInferenceContext* context = shape_refiner_->GetExtendedContext(&node); for (int i = 0; i < node.num_outputs(); ++i) { // The backedge has been removed, so we shouldn't count the corresponding // output from the src node when converting to an operation. @@ -997,11 +1221,9 @@ Status Importer::ConvertNode(const Node& node) { back_edge_node_output_[&node] == i) { continue; } - TF_ASSIGN_OR_RETURN(auto type, InferOutputType(context, i, *builder_)); + TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_)); result.types.push_back(type); } - result.types.push_back( - builder_->getType()); // Surprisingly input edges can be nondeterministically ordered. This // particularly seems to be the case for the control edges between _SOURCE @@ -1019,6 +1241,10 @@ Status Importer::ConvertNode(const Node& node) { }); result.operands.reserve(in_edges.size()); + + // Collect the control operands separately, they will be held by the island. + mlir::SmallVector control_operands; + for (const auto* input_edge : in_edges) { const Node& input_node = *input_edge->src(); if (input_node.IsSource()) { @@ -1046,9 +1272,10 @@ Status Importer::ConvertNode(const Node& node) { return errors::FailedPrecondition( "Graph not traversed in reverse post order; use seen before def!"); mlir::Operation* inst = node_values_[input_node.id()]; - result.operands.push_back(inst->getResult(input_edge->IsControlEdge() - ? inst->getNumResults() - 1 - : input_edge->src_output())); + if (input_edge->IsControlEdge()) + control_operands.push_back(inst->getResult(inst->getNumResults() - 1)); + else + result.operands.push_back(inst->getResult(input_edge->src_output())); } using FuncPairType = std::pair; @@ -1064,7 +1291,7 @@ Status Importer::ConvertNode(const Node& node) { funcs.emplace_back(&attr_name, &attr_value); } else { TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value)); - result.attributes.push_back(builder_->getNamedAttr(attr_name, attr)); + result.attributes.push_back(builder_.getNamedAttr(attr_name, attr)); } } @@ -1077,12 +1304,32 @@ Status Importer::ConvertNode(const Node& node) { &result.attributes)); } - result.attributes.push_back(builder_->getNamedAttr( - "name", builder_->getStringAttr(std::string(node.name())))); - result.attributes.push_back(builder_->getNamedAttr( - "device", builder_->getStringAttr(std::string(node_def.device())))); + result.attributes.push_back(builder_.getNamedAttr( + "name", builder_.getStringAttr(std::string(node.name())))); + result.attributes.push_back(builder_.getNamedAttr( + "device", builder_.getStringAttr(std::string(node_def.device())))); + + // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add + // the differentiating attribute. + if (node.IsIfNode()) { + result.name = mlir::OperationName(get_full_op_name("If"), context_); + mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf"); + result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); + } + + // Map While and StatelessWhile op in TensorFlow to the common While op in + // MLIR and add the differentiating attribute. + if (node.IsWhileNode()) { + result.name = mlir::OperationName(get_full_op_name("While"), context_); + mlir::BoolAttr val = + builder_.getBoolAttr(node_type_name == "StatelessWhile"); + result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); + } + + // Register the mapping between the TF node and the newly created operation. + node_values_[node.id()] = + createOperation(node, op_name, result, control_operands); - node_values_[node.id()] = builder_->createOperation(result); return Status::OK(); } @@ -1098,7 +1345,7 @@ Status Importer::ConvertNode(const Node& node) { // operation. // TODO(fengliuai): Preserve the order of the results and operands if // necessary. -Status Importer::AddBackedges() { +Status ImporterBase::AddBackedges() { for (auto it : back_edge_dst_inputs_) { BackEdge& edge = it.second; if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) { @@ -1112,9 +1359,10 @@ Status Importer::AddBackedges() { return Status::OK(); } -Status Importer::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, - int dst_input) { - mlir::Operation* source = GetOrCreateNextIterationSource(sink, dst); +Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, + int dst_input) { + // Get the NextIteration.Source operation from the token operand of the sink. + mlir::Operation* source = sink->getOperand(0)->getDefiningOp(); // Adds the "source" to the operands of the dst by creating a new dst // operation. @@ -1130,12 +1378,11 @@ Status Importer::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, state.operands.push_back(dst->getOperand(input - 1)); } } - state.attributes.append(dst->getAttrs().begin(), dst->getAttrs().end()); - for (auto* result : dst->getResults()) { - state.types.push_back(result->getType()); - } - builder_->setInsertionPoint(dst); - auto* new_dst = builder_->createOperation(state); + state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end()); + state.types.assign(dst->getResultTypes().begin(), + dst->getResultTypes().end()); + builder_.setInsertionPoint(dst); + auto* new_dst = builder_.createOperation(state); // Replaces the output uses of the old operation by the corresponding // result of the new operation, and deletes the old operation. @@ -1148,134 +1395,7 @@ Status Importer::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, return Status::OK(); } -mlir::Operation* Importer::GetOrCreateNextIterationSource( - mlir::Operation* sink, mlir::Operation* dst) { - auto iter = next_iteration_sink_source_.find(sink); - if (iter != next_iteration_sink_source_.end()) return iter->second; - - auto inst_name = sink->getName().getStringRef(); - inst_name.consume_back(".sink"); - mlir::OperationState src_state(sink->getLoc(), - inst_name.str().append(".source")); - src_state.attributes.append(sink->getAttrs().begin(), sink->getAttrs().end()); - src_state.types.push_back(dst->getResult(0)->getType()); - src_state.types.push_back( - builder_->getType()); - builder_->setInsertionPoint(dst->getBlock(), dst->getBlock()->begin()); - mlir::Operation* source = builder_->createOperation(src_state); - next_iteration_sink_source_[sink] = source; - return source; -} - -Status Importer::Convert(llvm::StringRef func_name, - mlir::FunctionType func_type, - const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes, - llvm::ArrayRef attrs) { - // TODO(b/122040776): Uses debug info for FunctionDef. - auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), - func_name, func_type, attrs); - - module_.push_back(function); - builder_ = absl::make_unique(function.getBody()); - // Seeds the builder with an initial block. - auto* bb = builder_->createBlock(&function.getBody()); - - for (const Node* node : ordered_nodes_) { - TF_RETURN_IF_ERROR(ConvertNode(*node)); - } - - // Adds the backedges back to the function by creating the source and sink - // pairs. - TF_RETURN_IF_ERROR(AddBackedges()); - - return ConvertFunctionArgAndRets(bb, func_type.getInputs(), arg_nodes, - ret_nodes); -} - -StatusOr Importer::InferMainFunctionType( - absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes) { - // Finds out all the input nodes and output nodes. - if (!specs_.inputs.empty() || !specs_.output_arrays.empty()) { - arg_nodes->resize(specs_.inputs.size()); - ret_nodes->resize(specs_.output_arrays_order.size()); - - for (Node* n : ordered_nodes_) { - // Handle inputs/arguments. - auto input_it = specs_.inputs.find(n->name()); - if (input_it != specs_.inputs.end()) { - (*arg_nodes)[std::distance(specs_.inputs.begin(), input_it)] = {n, 0}; - } - - // Handle outputs/returns. - if (specs_.output_arrays.find(n->name()) != specs_.output_arrays.end()) { - for (int i = 0, e = specs_.output_arrays_order.size(); i != e; ++i) { - std::pair name_and_port = - absl::StrSplit(specs_.output_arrays_order[i], ':'); - auto name = name_and_port.first; - if (name != n->name()) continue; - int port = 0; - if (!name_and_port.second.empty() && - !absl::SimpleAtoi(name_and_port.second, &port)) { - return errors::InvalidArgument("Invalid port specification: ", - specs_.output_arrays_order[i]); - } - (*ret_nodes)[i] = {n, port}; - } - } - } - } - - int i = 0; - for (auto it : specs_.inputs) { - if (arg_nodes->at(i++).node == nullptr) { - return errors::InvalidArgument("Input ", it.first, - " was not found in graph"); - } - } - for (int i = 0, e = specs_.output_arrays_order.size(); i != e; ++i) { - if (ret_nodes->at(i).node == nullptr) { - return errors::InvalidArgument("Output ", specs_.output_arrays_order[i], - " was not found in graph"); - } - } - - // Starts to construct the function type. - llvm::SmallVector arg_types; - llvm::SmallVector ret_types; - arg_types.reserve(specs_.inputs.size()); - ret_types.reserve(specs_.output_arrays.size()); - mlir::Builder builder(context_); - - // Input nodes as function arguments. - for (const auto& input : specs_.inputs) { - mlir::Type element_type; - const auto& node_info = input.second; - TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype, - builder, &element_type)); - llvm::SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); - arg_types.push_back(builder.getTensorType(shape, element_type)); - } - - // Output nodes as function returns. - for (const auto& ret : *ret_nodes) { - if (ret.node->num_outputs() < 1) { - return errors::FailedPrecondition( - "Invalid output node; should have at least 1 output: " + - ret.node->name()); - } - auto* shape_context = shape_refiner_->GetExtendedContext(ret.node); - TF_ASSIGN_OR_RETURN(auto type, - InferOutputType(shape_context, ret.index, builder)); - ret_types.push_back(type); - } - - return builder.getFunctionType(arg_types, ret_types); -} - -StatusOr Importer::InferLibFunctionType( +StatusOr ImporterBase::InferLibFunctionType( const FunctionBody& fbody) { mlir::Builder builder(context_); @@ -1297,76 +1417,273 @@ StatusOr Importer::InferLibFunctionType( // Find node in the graph using the node id instead of using `ret` directly // because the graph has been cloned. auto* node = graph_->FindNodeId(ret->id()); - auto* shape_context = shape_refiner_->GetExtendedContext(node); // Return type of the function is type of the only input of the respective // return node in the function. - TF_ASSIGN_OR_RETURN(auto type, - InferInputType(shape_context, /*idx=*/0, builder)); + TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); ret_types.push_back(type); } return builder.getFunctionType(arg_types, ret_types); } -StatusOr Importer::Convert( +// Stateful helper class to import a TensorFlow model expressed in GraphDef into +// an MLIR Module. +// +// The nodes defined in the graph is converted to a function called "main". All +// the library function definitions are converted to MLIR functions in the +// module. +class GraphDefImporter : public ImporterBase { + public: + // Main entry point: converts the given graph to an MLIR Module. + static StatusOr Convert( + mlir::MLIRContext* context, const Graph& graph, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs); + + private: + explicit GraphDefImporter( + const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, + const NodeSpecs& specs, mlir::ModuleOp module, + std::unordered_map* tf_name_to_mlir_name) + : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {} + + // Returns the function signature of the main function of converted MLIR + // module, the input nodes and output nodes. The type and shape information + // for the function arguments are read from `specs`, but the type and shape + // information for the function returns are inferred by the shape refiner in + // ImporterBase. + StatusOr InferMainFunctionType( + const NodeSpecs& specs, mlir::MLIRContext* context, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes); +}; + +StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs) { mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; - Importer importer(flib_def, debug_info, specs, module.get(), - &tf_name_to_mlir_name); - TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); - // Collects the argument and return nodes by looking up the node names - // specified by the user. + GraphDefImporter importer(flib_def, debug_info, specs, module.get(), + &tf_name_to_mlir_name); + + mlir::FunctionType func_type; absl::InlinedVector arg_nodes; absl::InlinedVector ret_nodes; - TF_ASSIGN_OR_RETURN(auto func_type, - importer.InferMainFunctionType(&arg_nodes, &ret_nodes)); - - // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function, - // tf.versions) shared by importer and exporter in a centralized place. - // Record the input and output mapping. + absl::InlinedVector control_ret_nodes; llvm::SmallVector attrs; - if (!specs.inputs.empty() || !specs.output_arrays.empty()) { - mlir::Builder b(context); - std::string s; - llvm::raw_string_ostream ss(s); - mlir::interleaveComma( - specs.inputs, ss, - [&](const std::pair& v) { ss << v.first; }); - auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); - s.clear(); - mlir::interleaveComma(specs.output_arrays, ss, - [&](const std::string& v) { ss << v; }); - auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + std::unique_ptr graph_fbody; + if (specs.graph_as_function) { + if (specs.prune_unused_nodes || !specs.inputs.empty() || + !specs.output_arrays.empty() || !specs.output_arrays_order.empty()) + return errors::InvalidArgument( + "Pruning of graph is currently unsupported when the main graph is " + "converted to a function."); + // Converts graph into a FunctionDef. + FunctionDef graph_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(graph, "main", &graph_fdef)); - attrs.push_back(b.getNamedAttr("tf.entry_function", - b.getDictionaryAttr({inputs, outputs}))); + // Converts FunctionDef into a FunctionBody. + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(graph_fdef, AttrSlice(), + &flib_def, &graph_fbody)); + + TF_RETURN_IF_ERROR(importer.PrepareConvert(*graph_fbody->graph)); + TF_ASSIGN_OR_RETURN(func_type, importer.InferLibFunctionType(*graph_fbody)); + importer.GetArgsAndRetsFromFunctionBody(*graph_fbody, &arg_nodes, + &ret_nodes, &control_ret_nodes); + + if (!arg_nodes.empty() || !ret_nodes.empty()) { + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + auto node_name = [&](const Node* node) { ss << node->name(); }; + mlir::interleaveComma(graph_fbody->arg_nodes, ss, node_name); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + mlir::interleaveComma(graph_fbody->ret_nodes, ss, node_name); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + + attrs.push_back(b.getNamedAttr("tf.entry_function", + b.getDictionaryAttr({inputs, outputs}))); + } + } else { + TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); + + // Collects the argument and return nodes by looking up the node names + // specified by the user. + TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType( + specs, context, &arg_nodes, &ret_nodes)); + + // TODO(prakalps): Refactor to keep attribute strings (tf.entry_function, + // tf.versions) shared by importer and exporter in a centralized place. + // Record the input and output mapping. + if (!specs.inputs.empty() || !specs.output_arrays.empty()) { + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + mlir::interleaveComma( + specs.inputs, ss, + [&](const std::pair& v) { ss << v.first; }); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + mlir::interleaveComma(specs.output_arrays, ss); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + + attrs.push_back(b.getNamedAttr("tf.entry_function", + b.getDictionaryAttr({inputs, outputs}))); + } } // Record version info. - if (importer.graph_versions_) { - mlir::Builder b(context); - auto producer = b.getNamedAttr( - "producer", b.getI32IntegerAttr(importer.graph_versions_->producer())); - auto min_consumer = b.getNamedAttr( - "min_consumer", - b.getI32IntegerAttr(importer.graph_versions_->min_consumer())); - auto bad_consumers = b.getNamedAttr( - "bad_consumers", b.getI32ArrayAttr(llvm::ArrayRef( - importer.graph_versions_->bad_consumers().begin(), - importer.graph_versions_->bad_consumers().end()))); - module->setAttr("tf.versions", - b.getDictionaryAttr(llvm::ArrayRef( - {producer, min_consumer, bad_consumers}))); + const auto& graph_versions = graph.versions(); + mlir::Builder b(context); + auto producer = b.getNamedAttr( + "producer", b.getI32IntegerAttr(graph_versions.producer())); + auto min_consumer = b.getNamedAttr( + "min_consumer", b.getI32IntegerAttr(graph_versions.min_consumer())); + auto bad_consumers = b.getNamedAttr( + "bad_consumers", b.getI32ArrayAttr(llvm::ArrayRef( + graph_versions.bad_consumers().begin(), + graph_versions.bad_consumers().end()))); + module->setAttr("tf.versions", + b.getDictionaryAttr(llvm::ArrayRef( + {producer, min_consumer, bad_consumers}))); + + TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( + "main", func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs)); + return module; +} + +StatusOr GraphDefImporter::InferMainFunctionType( + const NodeSpecs& specs, mlir::MLIRContext* context, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes) { + // Finds out all the input nodes and output nodes. + if (!specs.inputs.empty() || !specs.output_arrays.empty()) { + arg_nodes->resize(specs.inputs.size()); + ret_nodes->resize(specs.output_arrays_order.size()); + + for (Node* n : GetOrderedNodes()) { + // Handle inputs/arguments. + auto input_it = specs.inputs.find(n->name()); + if (input_it != specs.inputs.end()) { + (*arg_nodes)[std::distance(specs.inputs.begin(), input_it)] = {n, 0}; + } + + // Handle outputs/returns. + if (specs.output_arrays.find(n->name()) != specs.output_arrays.end()) { + for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) { + std::pair name_and_port = + absl::StrSplit(specs.output_arrays_order[i], ':'); + auto name = name_and_port.first; + if (name != n->name()) continue; + int port = 0; + if (!name_and_port.second.empty() && + !absl::SimpleAtoi(name_and_port.second, &port)) { + return errors::InvalidArgument("Invalid port specification: ", + specs.output_arrays_order[i]); + } + (*ret_nodes)[i] = {n, port}; + } + } + } + } + + int i = 0; + for (auto it : specs.inputs) { + if (arg_nodes->at(i++).node == nullptr) { + return errors::InvalidArgument("Input ", it.first, + " was not found in graph"); + } + } + for (int i = 0, e = specs.output_arrays_order.size(); i != e; ++i) { + if (ret_nodes->at(i).node == nullptr) { + return errors::InvalidArgument("Output ", specs.output_arrays_order[i], + " was not found in graph"); + } + } + + // Starts to construct the function type. + llvm::SmallVector arg_types; + llvm::SmallVector ret_types; + arg_types.reserve(specs.inputs.size()); + ret_types.reserve(specs.output_arrays.size()); + mlir::Builder builder(context); + + // Input nodes as function arguments. + for (const auto& input : specs.inputs) { + mlir::Type element_type; + const auto& node_info = input.second; + TF_RETURN_IF_ERROR(::tensorflow::ConvertDataType(node_info.imported_dtype, + builder, &element_type)); + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); + arg_types.push_back(builder.getTensorType(shape, element_type)); + } + + // Output nodes as function returns. + for (const auto& ret : *ret_nodes) { + if (ret.node->num_outputs() <= ret.index) { + return errors::InvalidArgument("Invalid output index ", ret.index, + " specified for node: ", ret.node->name()); + } + TF_ASSIGN_OR_RETURN(auto type, + InferOutputType(*ret.node, ret.index, builder)); + ret_types.push_back(type); + } + + return builder.getFunctionType(arg_types, ret_types); +} + +// Stateful helper class to import a TensorFlow model expressed in SavedModel +// into an MLIR Module. +class SavedModelImporter : public ImporterBase { + public: + // Main entry point: converts all functions in the given meta graph to an MLIR + // Module. + static StatusOr Convert( + const MetaGraphDef& meta_graph, const GraphDebugInfo& debug_info, + bool add_default_attributes, mlir::MLIRContext* context); + + private: + explicit SavedModelImporter( + const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, + const NodeSpecs& specs, mlir::ModuleOp module, + std::unordered_map* tf_name_to_mlir_name) + : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name) {} +}; + +StatusOr SavedModelImporter::Convert( + const MetaGraphDef& meta_graph, const GraphDebugInfo& debug_info, + bool add_default_attributes, mlir::MLIRContext* context) { + NodeSpecs specs; + mlir::OwningModuleRef module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + std::unordered_map tf_name_to_mlir_name; + + const auto& graphdef = meta_graph.graph_def(); + GraphConstructorOptions options; + options.allow_internal_ops = true; + Graph graph(OpRegistry::Global()); + + GraphDef preprocessed_graphdef(graphdef); + if (add_default_attributes) { + TF_RETURN_IF_ERROR(PreprocessGraphDef(nullptr, &preprocessed_graphdef)); } TF_RETURN_IF_ERROR( - importer.Convert("main", func_type, arg_nodes, ret_nodes, attrs)); + ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); + + SavedModelImporter importer(graph.flib_def(), debug_info, specs, module.get(), + &tf_name_to_mlir_name); + + auto fn_names = graph.flib_def().ListFunctionNames(); + for (const auto& fn_name : fn_names) { + TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); + } return module; } } // namespace @@ -1381,10 +1698,10 @@ StatusOr ConvertGraphdefToMlir( GraphDef preprocessed_graphdef(graphdef); if (add_default_attributes) { - TF_RETURN_IF_ERROR(AddDefaultsToNodeDef(&preprocessed_graphdef)); + TF_RETURN_IF_ERROR(PreprocessGraphDef(&specs, &preprocessed_graphdef)); } - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(options, preprocessed_graphdef, &graph)); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( + options, std::move(preprocessed_graphdef), &graph)); return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, context); @@ -1394,7 +1711,14 @@ StatusOr ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs, mlir::MLIRContext* context) { - return Importer::Convert(context, graph, debug_info, flib_def, specs); + return GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs); +} + +StatusOr ConvertSavedModelToMlir( + const SavedModelBundle& saved_model, const GraphDebugInfo& debug_info, + mlir::MLIRContext* context, bool add_default_attributes) { + return SavedModelImporter::Convert(saved_model.meta_graph_def, debug_info, + add_default_attributes, context); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h similarity index 69% rename from tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h rename to tensorflow/compiler/mlir/tensorflow/translate/import_model.h index c494526bb4d..98bb607fa6a 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_GRAPHDEF_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_GRAPHDEF_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -27,20 +28,26 @@ limitations under the License. namespace tensorflow { -// Given a GraphDef, returns a MLIR module containing the graph in control-flow -// form. +// Given a GraphDef, returns a MLIR module containing the graph, expressed with +// tf_executor dialect. stream_executor::port::StatusOr ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, const NodeSpecs& specs, mlir::MLIRContext* context, bool add_default_attributes = true); -// Given a Graph, returns a MLIR module containing the graph in control-flow -// form. +// Given a Graph, returns a MLIR module containing the graph, expressed with +// tf_executor dialect. stream_executor::port::StatusOr ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const NodeSpecs& specs, mlir::MLIRContext* context); +// Given a SavedModel, returns a MLIR module containing the functions, expressed +// with tf_executor dialect. +stream_executor::port::StatusOr ConvertSavedModelToMlir( + const SavedModelBundle& saved_model, const GraphDebugInfo& debug_info, + mlir::MLIRContext* context, bool add_default_attributes = true); + } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_GRAPHDEF_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index 3fc7ee55b4f..6adf1f07339 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -56,6 +56,13 @@ struct NodeSpecs { // setting prune_unused_nodes to true, would prune unreachable nodes if // output_arrays is specified. bool prune_unused_nodes = false; + // If true, inputs of type LegacyFedInput are replaced with Placeholder ops. + // LegacyFedInput ops have two outputs unlike Placeholder which has only one + // output, so if both outputs of the LegacyFedInput ops are used then returns + // an error. + bool convert_legacy_fed_inputs = false; + // If true, the main graph will be treated as a function. + bool graph_as_function = false; }; struct ExporterConfigs { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc index 3d71910edcd..3ebd722c580 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "mlir/Analysis/Verifier.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" @@ -35,9 +36,15 @@ Status MlirRoundtripPass::Run(const GraphOptimizationPassOptions& options) { TF_ASSIGN_OR_RETURN(auto module, ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def, specs, &context)); - // TODO(jpienaar): Remove, just simple verification that this works. - module->dump(); - return ConvertMlirToGraph(*module, confs, options.graph, options.flib_def); + if (failed(mlir::verify(*module))) { + // TODO(jpienaar): Remove, just simple verification that this works. + module->dump(); + return errors::Internal("Verifier failed on MLIR import for the graph"); + } + auto status = + ConvertMlirToGraph(*module, confs, options.graph, options.flib_def); + if (!status.ok()) module->dump(); + return status; } REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0, diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h index 96a66d4eab3..41417edcecf 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_pass.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_PASS_H_ -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 5c7b1e824fe..604fced24d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/Parser.h" // TF:local_config_mlir -#include "tensorflow/compiler/mlir/tensorflow/translate/import_graphdef.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" @@ -46,6 +46,7 @@ static StatusOr GraphdefToMlirImport( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, mlir::MLIRContext* context) { GraphDef graphdef; TF_RETURN_IF_ERROR(tensorflow::LoadProtoFromFile(input_filename, &graphdef)); @@ -57,6 +58,8 @@ static StatusOr GraphdefToMlirImport( NodeSpecs specs; specs.prune_unused_nodes = prune_unused_nodes; + specs.convert_legacy_fed_inputs = convert_legacy_fed_inputs; + specs.graph_as_function = graph_as_function; TF_RETURN_IF_ERROR(ParseInputArrayInfo( input_arrays, input_dtypes, input_shapes, inference_type, min_values, max_values, &specs.inputs)); @@ -71,15 +74,51 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, inference_type, min_values, max_values, prune_unused_nodes, - context); + convert_legacy_fed_inputs, graph_as_function, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); return nullptr; } + + return module_or.ConsumeValueOrDie(); +} + +mlir::OwningModuleRef SavedModelToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, + absl::string_view debug_info_file, mlir::MLIRContext* context) { + SessionOptions session_options; + RunOptions run_options; + tensorflow::SavedModelBundle bundle; + auto load_status = LoadSavedModel( + session_options, run_options, + std::string(saved_model_dir.data(), saved_model_dir.length()), tags, + &bundle); + if (!load_status.ok()) { + LOG(ERROR) << "Failed to load saved model '" << saved_model_dir + << "': " << load_status; + return nullptr; + } + + GraphDebugInfo debug_info; + if (!debug_info_file.empty()) { + if (!LoadProtoFromFile(debug_info_file, &debug_info).ok()) { + LOG(ERROR) << "Failed to load debug info file: " << debug_info_file; + return nullptr; + } + } + + auto module_or = ConvertSavedModelToMlir(bundle, debug_info, context); + + if (!module_or.status().ok()) { + LOG(ERROR) << "SavedModel import failed: " << module_or.status(); + return nullptr; + } return module_or.ConsumeValueOrDie(); } @@ -89,11 +128,12 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( input_filename, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, inference_type, min_values, max_values, prune_unused_nodes, - context); + convert_legacy_fed_inputs, graph_as_function, context); if (!module_or.status().ok()) { LOG(ERROR) << "Graph import failed: " << module_or.status(); return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index 794a2ef9fcb..290223017b4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ +#include +#include + #include "absl/strings/string_view.h" #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir @@ -33,6 +36,7 @@ mlir::OwningModuleRef GraphdefToMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, mlir::MLIRContext* context); // Similar as the above function, but replaces all constant tensors @@ -43,7 +47,16 @@ mlir::OwningModuleRef GraphdefToSplattedMlirTranslateFunction( absl::string_view input_shapes, absl::string_view output_arrays, absl::string_view inference_type, absl::string_view min_values, absl::string_view max_values, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, mlir::MLIRContext* context); + +// Converts a TensorFlow SavedModel stored in the directory with the given +// `saved_model_dir` into a MLIR module. Creates MLIR entities into the +// given MLIR `context`. +mlir::OwningModuleRef SavedModelToMlirImport( + absl::string_view saved_model_dir, + const std::unordered_set& tags, + absl::string_view debug_info_file, mlir::MLIRContext* context); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc index 8e74296b4fc..80df3665007 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.cc @@ -84,3 +84,21 @@ opt prune_unused_nodes( "tf-prune-unused-nodes", llvm::cl::desc("Prune unused nodes in the input graphdef "), llvm::cl::init(false)); + +// NOLINTNEXTLINE +opt convert_legacy_fed_inputs( + "tf-convert-legacy-fed-inputs", + llvm::cl::desc( + "Eliminate LegacyFedInput nodes by replacing them with Placeholder "), + llvm::cl::init(false)); + +opt graph_as_function("tf-graph-as-function", + llvm::cl::desc("Treat main graph as a function "), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +opt saved_model_tags( + "tf-savedmodel-tags", + llvm::cl::desc("Tags used to indicate which MeataGraphDef to import, " + "separated by ','"), + llvm::cl::init("serve")); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h index 8cf17e3a3f0..c5d609acb95 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h @@ -35,5 +35,9 @@ extern llvm::cl::opt min_values; extern llvm::cl::opt max_values; extern llvm::cl::opt debug_info_file; extern llvm::cl::opt prune_unused_nodes; +extern llvm::cl::opt convert_legacy_fed_inputs; +extern llvm::cl::opt graph_as_function; + +extern llvm::cl::opt saved_model_tags; #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_TF_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 7d7632d7e82..90e305f64aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -45,18 +45,30 @@ static OwningModuleRef GraphdefToMlirTranslateFunction( return tensorflow::GraphdefToMlirTranslateFunction( StringRefToView(input_filename), debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, inference_type, min_values, - max_values, prune_unused_nodes, context); + max_values, prune_unused_nodes, convert_legacy_fed_inputs, + graph_as_function, context); } static TranslateToMLIRRegistration GraphdefToMlirTranslate( "graphdef-to-mlir", GraphdefToMlirTranslateFunction); +static OwningModuleRef SavedModelToMlirTranslateFunction( + llvm::StringRef input_filename, MLIRContext* context) { + std::unordered_set tags = absl::StrSplit(saved_model_tags, ','); + return tensorflow::SavedModelToMlirImport(StringRefToView(input_filename), + tags, debug_info_file, context); +} + +static TranslateToMLIRRegistration SavedModelToMlirTranslate( + "savedmodel-to-mlir", SavedModelToMlirTranslateFunction); + static OwningModuleRef GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input_filename, MLIRContext* context) { return tensorflow::GraphdefToSplattedMlirTranslateFunction( StringRefToView(input_filename), debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, inference_type, min_values, - max_values, prune_unused_nodes, context); + max_values, prune_unused_nodes, convert_legacy_fed_inputs, + graph_as_function, context); } static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate( @@ -67,8 +79,8 @@ static LogicalResult MlirToGraphdefTranslateFunction( if (!module) return failure(); std::error_code error; - auto result = llvm::make_unique(output_filename, error, - llvm::sys::fs::F_None); + auto result = std::make_unique(output_filename, error, + llvm::sys::fs::F_None); if (error) { LOG(ERROR) << error.message(); return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index 9c02ce2278f..ac0f4d2adc0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -59,7 +59,8 @@ static LogicalResult MlirToTfNodeDef(ModuleOp module, return failure(); } - auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef(op, "node_name"); + auto node_def_or = tensorflow::ConvertTFDialectOpToNodeDef( + op, "node_name", /*ignore_unregistered_attrs=*/false); if (!node_def_or.ok()) { op->emitError("failed to convert to TF NodeDef:") << node_def_or.status().ToString(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index f66b07b246a..df19e169d3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -42,20 +42,23 @@ namespace tensorflow { using llvm::ArrayRef; using llvm::SmallVector; -using mlir::Attribute; -using mlir::BoolAttr; using mlir::Builder; using mlir::DenseFPElementsAttr; using mlir::DenseIntElementsAttr; using mlir::ElementsAttr; -using mlir::FloatAttr; -using mlir::IntegerAttr; using mlir::OpaqueElementsAttr; using mlir::ShapedType; -using mlir::SplatElementsAttr; using mlir::Type; using tensorflow::errors::InvalidArgument; +void ConvertToMlirShape(const TensorShape& input_shape, + llvm::SmallVectorImpl* shape) { + shape->reserve(input_shape.dims()); + for (const auto& d : input_shape) { + shape->push_back(d.size); + } +} + Status ConvertToMlirShape(const TensorShapeProto& input_shape, llvm::SmallVectorImpl* shape) { shape->reserve(input_shape.dim_size()); @@ -69,174 +72,72 @@ Status ConvertToMlirShape(const TensorShapeProto& input_shape, return Status::OK(); } -// Converts an TensorFlow tensor proto to an MLIR opaque elements attribute. -StatusOr ConvertToOpaqueElementsAttr( - const TensorProto& input_tensor, ShapedType type, Builder* builder) { - // TODO(shpeisman): restructure code to reuse dialect pointer across calls. - auto* dialect = builder->getContext()->getRegisteredDialect("tf"); - return builder->getOpaqueElementsAttr( - dialect, type, mangling_util::MangleTensor(input_tensor)); +static TensorProto ConvertToProto(const Tensor& input_tensor, + bool use_tensor_content = true) { + TensorProto tensor_proto; + // Using tensor content (mostly*) reduces serialization overhead during RPC + // calls, but is less human reader friendly. People reading protobufs are less + // frequent than serialization, so default to using tensor content + // representation. + // * For scalars and short strings it may be marginally worse and a more + // intelligent decision could be made by caller. + if (use_tensor_content) + input_tensor.AsProtoTensorContent(&tensor_proto); + else + input_tensor.AsProtoField(&tensor_proto); + return tensor_proto; } -// Template predicate that provides a constant member `value` equal to true if -// a sequence of `From` values can be copied wholesale to locations for `To` -// values. - -// Primary template declaration -template -struct IsBatchCopyable; - -// Partial template specialization: allow wholesale copy for the same type -template -struct IsBatchCopyable : std::true_type {}; - -// SFINAE: integral types depend on the bitwidth -template -struct IsBatchCopyable< - From, To, - typename std::enable_if::value && - std::is_integral::value>::type> { - static constexpr bool value = - std::numeric_limits::digits == std::numeric_limits::digits; -}; - -// Converts an TensorFlow tensor proto to an MLIR dense elements attribute. -// To save the memory held by the attribute, the value is casted to the -// specified type. -template -typename std::enable_if::value, - StatusOr>::type -ConvertToDenseElementsAttr( - const tensorflow::protobuf::RepeatedField& values, ShapedType type, - Builder* builder) { - return mlir::DenseElementsAttr::get( - type, llvm::makeArrayRef(values.data(), values.size())); +static std::string MangleTensor(const Tensor& tensor) { + return mangling_util::MangleTensor(ConvertToProto(tensor)); } -template -typename std::enable_if::value, - StatusOr>::type -ConvertToDenseElementsAttr( - const tensorflow::protobuf::RepeatedField& values, ShapedType type, - Builder* builder) { - std::vector buff; - buff.reserve(values.size()); - for (auto value : values) { - buff.push_back(value); - } - return mlir::DenseElementsAttr::get(type, llvm::makeArrayRef(buff)); -} - -// Converts an TensorFlow tensor proto with DT_FLOAT data type into an MLIR -// elements attribute. -StatusOr ConvertFloatTensor(const TensorProto& input_tensor, - ShapedType type, Builder* builder) { - // When the repeated "float_val" field only has one element, it is converted - // to a splat elements attribute; When it has more than one element, it is - // converted to a dense elements attribute; otherwise, convert the whole - // tensor to an opaque elements attribute if the "tensor_content" field is - // set. - auto repeated_val_size = input_tensor.float_val_size(); - if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) { - return ConvertToDenseElementsAttr(input_tensor.float_val(), - type, builder); - } - return ConvertToOpaqueElementsAttr(input_tensor, type, builder); -} - -// Converts an TensorFlow tensor proto with DT_INT32, DT_INT16, DT_INT8, -// DT_UINT8, DT_QUINT8 data type into an MLIR elements attribute. +// Converts a TensorFlow tensor into an MLIR elements attribute. template -StatusOr ConvertIntTensor(const TensorProto& input_tensor, - ShapedType type, Builder* builder) { - // When the repeated "int_val" field only has one element, it is converted to - // a splat elements attribute; When it has more than one element, it is - // converted to a dense elements attribute; otherwise, convert the whole - // tensor to an opaque elements attribute if the "tensor_content" field is - // set. - auto repeated_val_size = input_tensor.int_val_size(); - if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) { - return ConvertToDenseElementsAttr(input_tensor.int_val(), type, - builder); - } - return ConvertToOpaqueElementsAttr(input_tensor, type, builder); -} - -// Converts an TensorFlow tensor proto with DT_INT64 data type into an MLIR -// elements attribute. -StatusOr ConvertInt64Tensor(const TensorProto& input_tensor, - ShapedType type, Builder* builder) { - // When the repeated "int64_val" field only has one element, it is converted - // to a splat elements attribute; When it has more than one element, it is - // converted to a dense elements attribute; otherwise, convert the whole - // tensor to an opaque elements attribute if the "tensor_content" field is - // set. - auto repeated_val_size = input_tensor.int64_val_size(); - if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) { - return ConvertToDenseElementsAttr(input_tensor.int64_val(), type, - builder); - } - return ConvertToOpaqueElementsAttr(input_tensor, type, builder); -} - -// Converts an TensorFlow tensor proto with DT_BOOL data type into an MLIR -// elements attribute. -StatusOr ConvertBoolTensor(const TensorProto& input_tensor, +StatusOr ConvertFlatTensor(const Tensor& input_tensor, ShapedType type, Builder* builder) { - // When the repeated "bool_val" field only has one element, it is converted to - // a splat elements attribute; When it has more than one element, it is - // converted to a dense elements attribute; otherwise, convert the whole - // tensor to an opaque elements attribute if the "tensor_content" field is - // set. - auto repeated_val_size = input_tensor.bool_val_size(); - if (repeated_val_size == 1 || repeated_val_size == type.getNumElements()) { - const auto& proto = input_tensor.bool_val(); - return mlir::DenseElementsAttr::get( - type, llvm::makeArrayRef(proto.data(), proto.size())); + auto arr = input_tensor.flat(); + return mlir::DenseElementsAttr::get( + type, llvm::makeArrayRef(arr.data(), arr.size())); +} + +StatusOr ConvertTensor(const Tensor& input_tensor, + Builder* builder) { + const auto& input_dtype = input_tensor.dtype(); + const auto& input_shape = input_tensor.shape(); + Type elt_type; + TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type)); + SmallVector shape; + ConvertToMlirShape(input_shape, &shape); + auto type = builder->getTensorType(shape, elt_type); + +#define CONVERT_FLAT(DTYPE, CTYPE) \ + case DTYPE: \ + return ConvertFlatTensor(input_tensor, type, builder); + + // TODO(fengliuai): customize the conversions for more types. + switch (input_dtype) { + CONVERT_FLAT(DT_BOOL, bool) + CONVERT_FLAT(DT_FLOAT, float) + CONVERT_FLAT(DT_INT32, int32) + CONVERT_FLAT(DT_INT64, int64) + default: + // TODO(shpeisman): restructure code to reuse dialect pointer across + // calls. + auto* dialect = builder->getContext()->getRegisteredDialect("tf"); + return builder->getOpaqueElementsAttr(dialect, type, + MangleTensor(input_tensor)); } - return ConvertToOpaqueElementsAttr(input_tensor, type, builder); + +#undef CONVERT_FLAT } StatusOr ConvertTensorProto(const TensorProto& input_tensor, Builder* builder) { - const auto& input_dtype = input_tensor.dtype(); - const auto& input_shape = input_tensor.tensor_shape(); - Type elt_type; - TF_RETURN_IF_ERROR(ConvertDataType(input_dtype, *builder, &elt_type)); - SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(input_shape, &shape)); - auto type = builder->getTensorType(shape, elt_type); - - // TODO(fengliuai): customize the conversions for more types. - switch (input_dtype) { - case DT_FLOAT: - return ConvertFloatTensor(input_tensor, type, builder); - case DT_INT32: - return ConvertIntTensor(input_tensor, type, builder); - case DT_INT64: - return ConvertInt64Tensor(input_tensor, type, builder); - case DT_BOOL: - return ConvertBoolTensor(input_tensor, type, builder); - default: - // The value of the opaque elements attribute contains the whole tensor - // proto, not just the tensor content. - - // TODO(shpeisman): restructure code to reuse dialect pointer across - // calls. - auto* dialect = builder->getContext()->getRegisteredDialect("tf"); - - return builder->getOpaqueElementsAttr( - dialect, type, mangling_util::MangleTensor(input_tensor)); - } -} - -StatusOr ConvertTensor(const Tensor& input_tensor, - mlir::Builder* builder) { - TensorProto input_proto; - // This decodes the tensor content into a proper proto field. - input_tensor.AsProtoField(&input_proto); - return ConvertTensorProto(input_proto, builder); + Tensor t; + if (!t.FromProto(input_tensor)) + return InvalidArgument("Failed to parse input_tensor."); + return ConvertTensor(t, builder); } Status ConvertToTensorShapeProto(ArrayRef shape, @@ -247,7 +148,7 @@ Status ConvertToTensorShapeProto(ArrayRef shape, return Status::OK(); } -// Converts an MLIR opaque elements attribute to an TensorFlow tensor proto. +// Converts an MLIR opaque elements attribute to a TensorFlow tensor proto. Status ConvertOpaqueElementsAttr(const ElementsAttr attr, TensorProto* output_tensor) { if (attr.isa()) { @@ -258,49 +159,70 @@ Status ConvertOpaqueElementsAttr(const ElementsAttr attr, return InvalidArgument("Unexpected elements attribute type from MLIR."); } -// Converts an MLIR elements attribute to an TensorFlow tensor proto +// Converts an MLIR elements attribute to a TensorFlow tensor proto // with the float_val field updated. Status ConvertFloatElementsAttr(const ElementsAttr attr, TensorProto* output_tensor) { if (auto elts = attr.dyn_cast()) { - for (auto value : elts.getValues()) { - output_tensor->add_float_val(value); + if (elts.isSplat()) { + output_tensor->add_float_val(elts.getSplatValue()); + } else { + for (auto value : elts.getValues()) + output_tensor->add_float_val(value); } - } else { - return ConvertOpaqueElementsAttr(attr, output_tensor); + return Status::OK(); } - return Status::OK(); + return ConvertOpaqueElementsAttr(attr, output_tensor); } -// Converts an MLIR elements attribute to an TensorFlow tensor proto +// Converts an MLIR elements attribute to a TensorFlow tensor proto +// with the half_val field updated. +Status ConvertHalfElementsAttr(const ElementsAttr attr, + TensorProto* output_tensor) { + if (auto elts = attr.dyn_cast()) { + if (elts.isSplat()) { + output_tensor->add_half_val( + (*elts.begin()).bitcastToAPInt().getSExtValue()); + } else { + for (auto value : elts.getFloatValues()) + output_tensor->add_half_val(value.bitcastToAPInt().getSExtValue()); + } + return Status::OK(); + } + return ConvertOpaqueElementsAttr(attr, output_tensor); +} + +// Converts an MLIR elements attribute to a TensorFlow tensor proto // with the int_val field updated. Status ConvertIntElementsAttr(const mlir::ElementsAttr attr, TensorProto* output_tensor) { if (auto elts = attr.dyn_cast()) { - for (auto val : elts) { - output_tensor->add_int_val(val.getSExtValue()); + if (elts.isSplat()) { + output_tensor->add_int_val((*elts.begin()).getSExtValue()); + } else { + for (auto val : elts) output_tensor->add_int_val(val.getSExtValue()); } - } else { - return ConvertOpaqueElementsAttr(attr, output_tensor); + return Status::OK(); } - return Status::OK(); + return ConvertOpaqueElementsAttr(attr, output_tensor); } -// Converts an MLIR elements attribute to an TensorFlow tensor proto +// Converts an MLIR elements attribute to a TensorFlow tensor proto // with the int64_val field updated. Status ConvertInt64ElementsAttr(const mlir::ElementsAttr attr, TensorProto* output_tensor) { if (auto elts = attr.dyn_cast()) { - for (auto val : elts) { - output_tensor->add_int64_val(val.getSExtValue()); + if (elts.isSplat()) { + output_tensor->add_int64_val((*elts.begin()).getSExtValue()); + } else { + for (auto val : elts) output_tensor->add_int64_val(val.getSExtValue()); } - } else { - return ConvertOpaqueElementsAttr(attr, output_tensor); + return Status::OK(); } - return Status::OK(); + return ConvertOpaqueElementsAttr(attr, output_tensor); } -// Converts an MLIR elements attribute to an TensorFlow tensor proto +// Converts an MLIR elements attribute to a TensorFlow tensor proto // with bool_val field updated. Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, TensorProto* output_tensor) { @@ -308,10 +230,9 @@ Status ConvertBoolElementsAttr(const mlir::ElementsAttr attr, for (auto val : elts) { output_tensor->add_bool_val(val.getBoolValue()); } - } else { - return ConvertOpaqueElementsAttr(attr, output_tensor); + return Status::OK(); } - return Status::OK(); + return ConvertOpaqueElementsAttr(attr, output_tensor); } Status ConvertToTensorProto(const ElementsAttr attr, @@ -327,6 +248,9 @@ Status ConvertToTensorProto(const ElementsAttr attr, switch (output_dtype) { case DT_FLOAT: return ConvertFloatElementsAttr(attr, output_tensor); + case DT_HALF: + // Handles both DenseFPElementsAttr and OpaqueElementsAttr. + return ConvertHalfElementsAttr(attr, output_tensor); case DT_QUINT8: case DT_UINT8: case DT_INT8: diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc index 5d6cd1bb222..4e59cec86ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc @@ -29,12 +29,8 @@ using testing::HasSubstr; TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) { MLIRContext context; - - auto emit_error = [&](const std::string& msg) { - emitError(FileLineColLoc::get(Identifier::get("test.cc", &context), 10, 32, - &context), - msg); - }; + auto id = Identifier::get("test.cc", &context); + auto loc = FileLineColLoc::get(id, 0, 0, &context); // Test OK without diagnostic gets passed through. { @@ -44,7 +40,7 @@ TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) { // Verify diagnostics are captured as Unknown status. { StatusScopedDiagnosticHandler handler(&context); - emit_error("Diagnostic message"); + emitError(loc) << "Diagnostic message"; ASSERT_TRUE(tensorflow::errors::IsUnknown(handler.ConsumeStatus())); } @@ -58,8 +54,8 @@ TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) { // Verify diagnostic reported are append to passed in error. { auto function = [&]() { - emit_error("Diagnostic message reported"); - emit_error("Second diagnostic message reported"); + emitError(loc) << "Diagnostic message reported"; + emitError(loc) << "Second diagnostic message reported"; return tensorflow::errors::Internal("Passed in error"); }; Status s = StatusScopedDiagnosticHandler(&context).Combine(function()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc index a821c868d4a..29a4388de30 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/eval_util.cc @@ -78,7 +78,8 @@ mlir::LogicalResult EvaluateOperation( if (auto attr = inst->getAttrOfType("name")) { node_name = attr.getValue(); } - auto node_def_or = ConvertTFDialectOpToNodeDef(inst, node_name.c_str()); + auto node_def_or = ConvertTFDialectOpToNodeDef( + inst, node_name.c_str(), /*ignore_unregistered_attrs=*/true); RETURN_FAILURE_IF_ERROR(node_def_or.status()); const auto& node_def = node_def_or.ValueOrDie(); TFE_Op* op = TFE_NewOp(context, node_def->op().c_str(), status); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index a2f803c0858..48826520949 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir @@ -30,7 +31,6 @@ limitations under the License. #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -160,9 +160,34 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { return Status::OK(); } +// Updates NodeDef constructed out of an MLIR If op to map it to either +// TensorFlow StatelessIf or If op depending on the additional attribute. +void UpdateCompositeIfOp(NodeDef* node_def) { + auto it = node_def->mutable_attr()->find("is_stateless"); + if (it != node_def->attr().end()) { + if (it->second.b()) { + *node_def->mutable_op() = "StatelessIf"; + } + node_def->mutable_attr()->erase(it); + } +} + +// Updates NodeDef constructed out of an MLIR While op to map it to either +// TensorFlow StatelessWhile or While op depending on the additional attribute. +void UpdateCompositeWhileOp(NodeDef* node_def) { + auto it = node_def->mutable_attr()->find("is_stateless"); + if (it != node_def->attr().end()) { + if (it->second.b()) { + *node_def->mutable_op() = "StatelessWhile"; + } + node_def->mutable_attr()->erase(it); + } +} + } // anonymous namespace StatusOr> GetOperationNodeDef( + const absl::flat_hash_set& attrs_to_ignore, mlir::Operation* inst, llvm::StringRef name, OpNameMappingFunc op_name_func) { auto node_def = absl::make_unique(); @@ -184,7 +209,6 @@ StatusOr> GetOperationNodeDef( } // Add the node attributes. - absl::flat_hash_set attrs_to_ignore; TF_RETURN_WITH_CONTEXT_IF_ERROR( ConvertAttributes(inst->getAttrs(), attrs_to_ignore, node_def->mutable_attr()), @@ -194,12 +218,16 @@ StatusOr> GetOperationNodeDef( TF_RETURN_IF_ERROR(ConvertLocation( inst->getLoc(), node_def->mutable_experimental_debug_info())); + if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get()); + if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get()); + return node_def; } -Status ConvertAttributes(const llvm::ArrayRef attrs, - const absl::flat_hash_set& attrs_to_ignore, - AttrValueMap* values) { +Status ConvertAttributes( + const llvm::ArrayRef attrs, + const absl::flat_hash_set& attrs_to_ignore, + AttrValueMap* values) { AttrValueMap func_call_attrs; for (const mlir::NamedAttribute& named_attr : attrs) { auto name_strref = named_attr.first.str(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h index 4c6d8ade04a..0f1994aca43 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.h @@ -43,17 +43,21 @@ using OpNameMappingFunc = std::function(llvm::StringRef)>; // Converts an MLIR operation to TensorFlow NodeDef with given node name. This // name should be unique to the graph it is being inserted into. `op_name_func` -// is to map the op name of `inst` to its op name in TensorFlow. +// is to map the op name of `inst` to its op name in TensorFlow. "name" and +// "device" attributes are ignored by default. Use attrs_to_ignore to specify +// any other attributes that should be ignored. StatusOr> GetOperationNodeDef( + const absl::flat_hash_set& attrs_to_ignore, mlir::Operation* inst, llvm::StringRef name, OpNameMappingFunc op_name_func); // Converts MLIR attributes with values to their tensorflow equivalent. // "name" and "device" attributes are ignored by default. Use attrs_to_ignore to // specify any other attributes that should be ignored. -Status ConvertAttributes(const llvm::ArrayRef attrs, - const absl::flat_hash_set& attrs_to_ignore, - AttrValueMap* values); +Status ConvertAttributes( + const llvm::ArrayRef attrs, + const absl::flat_hash_set& attrs_to_ignore, + AttrValueMap* values); // Sets type attribute with the given name. If the attribute already exists with // a different value, returns an error. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc index 776a7ac71b2..691caab526a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/mangling_util.cc @@ -69,7 +69,7 @@ MangledKind GetMangledKind(absl::string_view str) { } string MangleShape(const TensorShapeProto& shape) { - return absl::StrCat(kTensorShapePrefix, shape.DebugString()); + return absl::StrCat(kTensorShapePrefix, shape.ShortDebugString()); } Status DemangleShape(absl::string_view str, TensorShapeProto* proto) { @@ -85,7 +85,7 @@ Status DemangleShape(absl::string_view str, TensorShapeProto* proto) { } string MangleTensor(const TensorProto& tensor) { - return absl::StrCat(kTensorPrefix, tensor.DebugString()); + return absl::StrCat(kTensorPrefix, tensor.ShortDebugString()); } Status DemangleTensor(absl::string_view str, TensorProto* proto) { diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index aaf4f68f739..3f649c67abf 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -15,13 +15,13 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" -#include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassManager.h" // TF:local_config_mlir #include "mlir/Support/FileUtilities.h" // TF:local_config_mlir #include "mlir/Support/MlirOptMain.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -58,8 +58,7 @@ static llvm::cl::opt verify_passes( static std::vector *pass_list; int main(int argc, char **argv) { - llvm::PrettyStackTraceProgram x(argc, argv); - llvm::InitLLVM y(argc, argv); + tensorflow::InitMlir y(&argc, &argv); // Register any pass manager command line options. mlir::registerPassManagerCLOptions(); @@ -71,10 +70,6 @@ int main(int argc, char **argv) { llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR modular optimizer driver\n"); - // TODO(jpienaar): Enable command line parsing for both sides. - int fake_argc = 1; - tensorflow::port::InitMain(argv[0], &fake_argc, &argv); - // Set up the input file. std::string error_message; auto file = mlir::openInputFile(input_filename, &error_message); diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc new file mode 100644 index 00000000000..fc61e4bc5d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -0,0 +1,45 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/InitLLVM.h" +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir +#include "mlir/Support/TranslateClParser.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/core/platform/init_main.h" + +// NOLINTNEXTLINE +static llvm::cl::opt input_filename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt output_filename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + // Add flags for all the registered translations. + llvm::cl::opt + requested_translation("", llvm::cl::desc("Translation to perform"), + llvm::cl::Required); + llvm::cl::ParseCommandLineOptions(argc, argv, "TF MLIR translation driver\n"); + + mlir::MLIRContext context; + return failed( + (*requested_translation)(input_filename, output_filename, &context)); +} diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index c36299ee263..35c8d2bd0eb 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -13,49 +13,63 @@ package_group( "//babelfish/device/...", "//learning/brain/experimental/mlir/...", "//tensorflow/compiler/mlir/...", + "//tensorflow/compiler/xla/...", "//third_party/mlir_edge/...", ], ) filegroup( - name = "xla_ops_td_files", + name = "hlo_ops_td_files", srcs = [ - "ir/xla_ops.td", + "ir/hlo_ops.td", + "ir/hlo_ops_base.td", + "ir/lhlo_ops.td", "@local_config_mlir//:OpBaseTdFiles", ], ) gentbl( - name = "xla_ops_inc_gen", + name = "hlo_ops_inc_gen", tbl_outs = [ - ( - "-gen-op-decls", - "ir/xla_ops.h.inc", - ), - ( - "-gen-op-defs", - "ir/xla_ops.cc.inc", - ), + ("-gen-op-decls", "ir/hlo_ops.h.inc"), + ("-gen-op-defs", "ir/hlo_ops.cc.inc"), ], tblgen = "@local_config_mlir//:mlir-tblgen", - td_file = "ir/xla_ops.td", - td_srcs = [ - ":xla_ops_td_files", + td_file = "ir/hlo_ops.td", + td_srcs = [":hlo_ops_td_files"], +) + +gentbl( + name = "hlo_ops_base_inc_gen", + tbl_outs = [ + ("-gen-op-decls", "ir/hlo_ops_base.h.inc"), + ("-gen-op-defs", "ir/hlo_ops_base.cc.inc"), ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "ir/hlo_ops_base.td", + td_srcs = [":hlo_ops_td_files"], +) + +gentbl( + name = "lhlo_ops_inc_gen", + tbl_outs = [ + ("-gen-op-decls", "ir/lhlo_ops.h.inc"), + ("-gen-op-defs", "ir/lhlo_ops.cc.inc"), + ], + tblgen = "@local_config_mlir//:mlir-tblgen", + td_file = "ir/lhlo_ops.td", + td_srcs = [":hlo_ops_td_files"], ) gentbl( name = "xla_legalize_tf_inc_gen", tbl_outs = [ - ( - "-gen-rewriters", - "transforms/generated_legalize_tf.inc", - ), + ("-gen-rewriters", "transforms/generated_legalize_tf.inc"), ], tblgen = "@local_config_mlir//:mlir-tblgen", td_file = "transforms/legalize_tf_patterns.td", td_srcs = [ - ":xla_ops_td_files", + ":hlo_ops_td_files", "@local_config_mlir//:StdOpsTdFiles", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", ], @@ -67,8 +81,9 @@ cc_library( "transforms/generated_legalize_tf.inc", "transforms/legalize_tf.cc", ], + copts = ["-std=c++14"], deps = [ - ":xla", + ":hlo", "//tensorflow/compiler/mlir/tensorflow", "@llvm//:support", "@local_config_mlir//:Analysis", @@ -83,15 +98,12 @@ cc_library( gentbl( name = "xla_legalize_to_standard_inc_gen", tbl_outs = [ - ( - "-gen-rewriters", - "transforms/generated_legalize_to_standard.inc", - ), + ("-gen-rewriters", "transforms/generated_legalize_to_standard.inc"), ], tblgen = "@local_config_mlir//:mlir-tblgen", td_file = "transforms/legalize_to_standard_patterns.td", td_srcs = [ - ":xla_ops_td_files", + ":hlo_ops_td_files", "@local_config_mlir//:StdOpsTdFiles", ], ) @@ -101,8 +113,9 @@ cc_library( srcs = [ "transforms/legalize_control_flow.cc", ], + copts = ["-std=c++14"], deps = [ - ":xla", + ":hlo", "//tensorflow/compiler/mlir/tensorflow", "@llvm//:support", "@local_config_mlir//:Analysis", @@ -115,11 +128,10 @@ cc_library( cc_library( name = "xla_legalize_to_standard", - srcs = [ - "transforms/legalize_to_standard.cc", - ], + srcs = ["transforms/legalize_to_standard.cc"], + copts = ["-std=c++14"], deps = [ - ":xla", + ":hlo", ":xla_legalize_to_standard_inc_gen", "//tensorflow/compiler/mlir/tensorflow", "@llvm//:support", @@ -132,19 +144,47 @@ cc_library( ) cc_library( - name = "xla", + name = "hlo", srcs = [ - "ir/xla_ops.cc", - "ir/xla_ops.cc.inc", - "ir/xla_ops.h.inc", + "ir/hlo_ops.cc", + "ir/hlo_ops.cc.inc", + "ir/hlo_ops.h.inc", ], hdrs = [ - "ir/xla_ops.h", + "ir/hlo_ops.h", "transforms/passes.h", ], + copts = ["-std=c++14"], includes = ["include"], deps = [ - ":xla_ops_inc_gen", + ":hlo_ops_base_inc_gen", + ":hlo_ops_inc_gen", + "@llvm//:support", + "@local_config_mlir//:Analysis", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:Support", + "@local_config_mlir//:TransformUtils", + ], + alwayslink = 1, +) + +cc_library( + name = "lhlo", + srcs = [ + "ir/lhlo_ops.cc", + "ir/lhlo_ops.cc.inc", + "ir/lhlo_ops.h.inc", + ], + hdrs = [ + "ir/lhlo_ops.h", + "transforms/passes.h", + ], + includes = ["include"], + deps = [ + ":hlo_ops_base_inc_gen", + ":lhlo_ops_inc_gen", "@llvm//:support", "@local_config_mlir//:Analysis", "@local_config_mlir//:IR", @@ -152,7 +192,6 @@ cc_library( "@local_config_mlir//:StandardOps", "@local_config_mlir//:Support", "@local_config_mlir//:TransformUtils", - "@local_config_mlir//:TypeUtilities", ], alwayslink = 1, ) @@ -161,8 +200,10 @@ cc_library( cc_library( name = "xla_dialect_registration", srcs = ["ir/dialect_registration.cc"], + copts = ["-std=c++14"], deps = [ - ":xla", + ":hlo", + ":lhlo", "@local_config_mlir//:IR", ], alwayslink = 1, @@ -172,11 +213,11 @@ cc_library( name = "type_to_shape", srcs = ["type_to_shape.cc"], hdrs = ["type_to_shape.h"], + copts = ["-std=c++14"], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "@com_google_absl//absl/base:core_headers", "@local_config_mlir//:IR", "@local_config_mlir//:Support", ], @@ -190,6 +231,7 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", "//tensorflow/core:test_main", "@local_config_mlir//:IR", ], @@ -202,9 +244,10 @@ cc_library( "operator_writers.inc", ], hdrs = ["mlir_hlo_to_hlo.h"], + copts = ["-std=c++14"], deps = [ + ":hlo", ":type_to_shape", - ":xla", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:status_macros", @@ -223,12 +266,9 @@ cc_library( cc_library( name = "hlo_to_mlir_hlo", - srcs = [ - "hlo_to_mlir_hlo.cc", - ], - hdrs = [ - "hlo_to_mlir_hlo.h", - ], + srcs = ["hlo_to_mlir_hlo.cc"], + hdrs = ["hlo_to_mlir_hlo.h"], + copts = ["-std=c++14"], deps = [ ":hlo_module_importer", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -248,8 +288,9 @@ cc_library( "hlo_function_importer.h", "hlo_module_importer.h", ], + copts = ["-std=c++14"], deps = [ - ":xla", + ":hlo", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status", @@ -258,6 +299,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", "@llvm//:support", "@local_config_mlir//:IR", "@local_config_mlir//:StandardOps", @@ -266,12 +308,9 @@ cc_library( cc_library( name = "xla_mlir_translate", - srcs = [ - "xla_mlir_translate.cc", - ], - hdrs = [ - "xla_mlir_translate.h", - ], + srcs = ["xla_mlir_translate.cc"], + hdrs = ["xla_mlir_translate.h"], + copts = ["-std=c++14"], deps = [ ":hlo_to_mlir_hlo", ":mlir_hlo_to_hlo", @@ -290,11 +329,8 @@ cc_library( tf_native_cc_binary( name = "operator_writer_gen", - srcs = [ - "operator_writer_gen.cc", - ], + srcs = ["operator_writer_gen.cc"], deps = [ - "@llvm//:config", "@llvm//:support", "@llvm//:tablegen", "@local_config_mlir//:TableGen", @@ -305,13 +341,13 @@ genrule( name = "operator_writer_inc", srcs = [ "@local_config_mlir//:include/mlir/IR/OpBase.td", - "//tensorflow/compiler/mlir/xla:ir/xla_ops.td", - ], - outs = [ - "operator_writers.inc", + "//tensorflow/compiler/mlir/xla:ir/hlo_ops.td", + "//tensorflow/compiler/mlir/xla:ir/hlo_ops_base.td", ], + outs = ["operator_writers.inc"], cmd = ("$(location :operator_writer_gen) " + "-I external/local_config_mlir/include " + - "$(location //tensorflow/compiler/mlir/xla:ir/xla_ops.td) " + " -o $@"), + "$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " + + " -o $@"), tools = [":operator_writer_gen"], ) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index b9ba5fcb9fb..8a69310ced9 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -19,14 +19,14 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/Identifier.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -77,6 +77,11 @@ StatusOr CreateDenseAttrFromLiteral(ShapedType type, DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S16, int16) DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S32, int32) DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S64, int64) + // TODO(b/130356985): Update once MLIR supports unsigned integers. + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U8, uint8) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U16, uint16) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U32, uint32) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::U64, uint64) default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", @@ -174,18 +179,19 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kIota: { return func_builder - ->create( + ->create( loc, result_type, func_builder->getI64IntegerAttr( static_cast(instruction) ->iota_dimension())) .getOperation(); } -#define MakeAndReturn(mlir_op) \ - { \ - mlir::Operation* new_operation = func_builder->create( \ - loc, result_type, operands, attributes); \ - return new_operation; \ +#define MakeAndReturn(mlir_op) \ + { \ + mlir::Operation* new_operation = \ + func_builder->create(loc, result_type, \ + operands, attributes); \ + return new_operation; \ } case HloOpcode::kBroadcast: { // Note that the HLO broadcast is more powerful than the XLA broadcast op. @@ -237,7 +243,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // TODO(b/132057942): Change to explicitly passing an integer instead of // call getI64IntegerAttr here. return func_builder - ->create( + ->create( loc, result_type, operands[0], operands[1], func_builder->getI64IntegerAttr( gather_dimensions.index_vector_dim()), @@ -247,7 +253,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kDynamicUpdateSlice: { return func_builder - ->create( + ->create( loc, result_type, operands[0], operands[1], llvm::ArrayRef(operands.begin() + 2, operands.end())) .getOperation(); @@ -268,15 +274,15 @@ StatusOr HloFunctionImporter::ImportInstruction( } return func_builder - ->create(loc, result_type, operands[0], operands[1], - Convert(edge_padding_low), - Convert(edge_padding_high), - Convert(interior_padding)) + ->create(loc, result_type, operands[0], + operands[1], Convert(edge_padding_low), + Convert(edge_padding_high), + Convert(interior_padding)) .getOperation(); } case HloOpcode::kSlice: { return func_builder - ->create( + ->create( loc, result_type, operands[0], ConvertDimensions(instruction->slice_starts()), ConvertDimensions(instruction->slice_limits())) @@ -286,7 +292,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr // for concatenate dimension. return func_builder - ->create( + ->create( loc, result_type, operands, builder_->getI64IntegerAttr(instruction->concatenate_dimension())) .getOperation(); @@ -297,7 +303,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // TODO(b/132057942): Make more convenient constructors, e.g. pass // mlir function pointer instead of a function attr. return func_builder - ->create( + ->create( loc, result_type, operands, func_builder->getSymbolRefAttr(reduction), ConvertDimensions(instruction->dimensions())) @@ -305,7 +311,7 @@ StatusOr HloFunctionImporter::ImportInstruction( } case HloOpcode::kReverse: { return func_builder - ->create( + ->create( loc, result_type, operands[0], ConvertDimensions(instruction->dimensions())) .getOperation(); @@ -324,7 +330,7 @@ StatusOr HloFunctionImporter::ImportInstruction( auto cond_attr = func_builder->getSymbolRefAttr(cond); auto body_attr = func_builder->getSymbolRefAttr(body); - Operation* op = func_builder->create( + Operation* op = func_builder->create( loc, types, operands, cond_attr, body_attr); return op; } @@ -350,10 +356,19 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kAdd, AddOp); NoAttributeCase(kAnd, AndOp); NoAttributeCase(kConvert, ConvertOp); + NoAttributeCase(kClamp, ClampOp); NoAttributeCase(kDivide, DivOp); + NoAttributeCase(kExp, ExpOp); + NoAttributeCase(kFloor, FloorOp); + NoAttributeCase(kLog, LogOp); NoAttributeCase(kMaximum, MaxOp); NoAttributeCase(kMinimum, MinOp); NoAttributeCase(kMultiply, MulOp); + // The dimensions attribute is not present on the HLO Reshape instruction. + // If dimensions are non-default, the XLA builder implementes it as a + // separate transpose. + NoAttributeCase(kReshape, ReshapeOp); + NoAttributeCase(kRsqrt, RsqrtOp); NoAttributeCase(kSelect, SelectOp); NoAttributeCase(kSubtract, SubOp); NoAttributeCase(kTanh, TanhOp); @@ -365,7 +380,6 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kCopy, CopyOp); // TODO(b/129422361) Ops below need additional work to handle attributes. NoAttributeCase(kConvolution, ConvOp); - NoAttributeCase(kReshape, ReshapeOp); #undef NoAttributeCase #undef MakeAndReturn case HloOpcode::kAddDependency: @@ -374,7 +388,7 @@ StatusOr HloFunctionImporter::ImportInstruction( // is not mentioned in xla client anywhere or in the hlo of our sample // models. default: { - mlir::OperationState result(loc, "xla.unknown"); + mlir::OperationState result(loc, "xla_hlo.unknown"); result.addOperands(operands); result.addTypes(result_type); for (auto attr : attributes) { @@ -429,6 +443,15 @@ StatusOr HloFunctionImporter::ConvertTensorType( return builder_->getTensorType(array, builder_->getIntegerType(32)); case PrimitiveType::S64: return builder_->getTensorType(array, builder_->getIntegerType(64)); + // TODO(b/130356985): Update once MLIR supports unsigned integers. + case PrimitiveType::U8: + return builder_->getTensorType(array, builder_->getIntegerType(8)); + case PrimitiveType::U16: + return builder_->getTensorType(array, builder_->getIntegerType(16)); + case PrimitiveType::U32: + return builder_->getTensorType(array, builder_->getIntegerType(32)); + case PrimitiveType::U64: + return builder_->getTensorType(array, builder_->getIntegerType(64)); default: return tensorflow::errors::Internal( absl::StrCat("Unsupported type: ", PrimitiveType_Name(type))); diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index ee321432f4d..13671dd0310 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -88,7 +89,8 @@ class HloFunctionImporter { xla::HloInstruction* instruction); // Converts the dimensions of an HLO instruction into an MLIR attribute. - mlir::ElementsAttr ConvertDimensions(llvm::ArrayRef op_dimensions); + mlir::ElementsAttr ConvertDimensions( + llvm::ArrayRef op_dimensions); // Converts Array ref to an ElementsAttr. mlir::ElementsAttr Convert(llvm::ArrayRef op_dimensions); diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index f11e06a56f9..ba6519211ce 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/OperationSupport.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "mlir/IR/Types.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 6603ef8500f..5e8005f9489 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -23,7 +23,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc index 79eda9cd278..f5e5b0ad257 100644 --- a/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc +++ b/tensorflow/compiler/mlir/xla/ir/dialect_registration.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" -using namespace mlir; +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" // Static initialization for XLA dialect registration. -static DialectRegistration XlaOps; +static mlir::DialectRegistration xla_hlo_ops; +static mlir::DialectRegistration xla_lhlo_ops; diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc new file mode 100644 index 00000000000..a5df379d90b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -0,0 +1,239 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the XLA dialect. + +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Dialect.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" + +using namespace mlir; +using namespace mlir::xla_hlo; + +XlaHloDialect::XlaHloDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.cc.inc" + >(); + + // Support unknown operations because not all XLA operations are registered. + // allowUnknownOperations(); +} + +Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, + Attribute value, Type type, + Location loc) { + // If this is an opaque elements attribute, then generate an xla_hlo.constant. + if (value.isa()) + return builder.create(loc, type, + value.cast()); + return nullptr; +} + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.cc.inc" + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +// Builds a constant op with the specified attribute `value`. +void ConstOp::build(Builder* builder, OperationState* result, Attribute value) { + Type type; + if (auto elemAttr = value.dyn_cast()) { + type = elemAttr.getType(); + } else if (value.isa() || value.isa() || + value.isa()) { + // All XLA types must be tensor types. In the build() method, we want to + // provide more flexiblity by allowing attributes of scalar types. But we + // need to wrap it up with ElementsAttr to construct valid XLA constants. + type = RankedTensorType::get(/*shape=*/{}, value.getType()); + value = DenseElementsAttr::get(type.cast(), value); + } + + // TODO: support other XLA specific types. + assert(type && "unsupported attribute type for building xla_hlo.constant"); + result->types.push_back(type); + result->addAttribute("value", value); +} + +//===----------------------------------------------------------------------===// +// ConvertOp +//===----------------------------------------------------------------------===// + +namespace { + +// Converts the values of an ElementsAttr into the corresponding type. +ElementsAttr ConvertElements(const ElementsAttr& elements, Type newType) { + auto oldType = getElementTypeOrSelf(elements); + size_t bitWidth = newType.isBF16() ? 64 : newType.getIntOrFloatBitWidth(); + + if (oldType.isa()) { + // mapValues always takes a function returning APInt, even when the output + // is actually float. + using func_type = APInt(const APFloat&); + if (auto newFloatType = newType.dyn_cast()) { + // Float -> Float + return elements.mapValues( + newType, llvm::function_ref([&newFloatType]( + const APFloat& floatVal) { + APFloat newDouble(FloatAttr::getValueAsDouble(floatVal)); + bool losesInfo = false; + newDouble.convert(newFloatType.getFloatSemantics(), + llvm::APFloat::rmNearestTiesToEven, &losesInfo); + return newDouble.bitcastToAPInt(); + })); + } + // Float -> Int + return elements.mapValues( + newType, + llvm::function_ref([&bitWidth](const APFloat& floatVal) { + return APInt(bitWidth, FloatAttr::getValueAsDouble(floatVal)); + })); + } + + // oldType is Integer + // mapValues always takes a function returning APInt, even when the output + // is actually float. + using func_type = APInt(const APInt&); + if (auto newFloatType = newType.dyn_cast()) { + // Int -> Float + return elements.mapValues( + newType, + llvm::function_ref([&newFloatType](const APInt& intVal) { + APFloat newDouble(static_cast(intVal.getLimitedValue())); + bool losesInfo = false; + newDouble.convert(newFloatType.getFloatSemantics(), + llvm::APFloat::rmNearestTiesToEven, &losesInfo); + return newDouble.bitcastToAPInt(); + })); + } + // newType is Integer + // Int -> Int + return elements.mapValues( + newType, llvm::function_ref([&bitWidth](const APInt& intVal) { + return APInt(bitWidth, intVal.getLimitedValue()); + })); +} + +} // namespace + +OpFoldResult ConvertOp::fold(ArrayRef operands) { + if (getOperand()->getType() == getResult()->getType()) return getOperand(); + + // If the operand is constant, we can do the conversion now. + if (auto elementsAttr = operands.front().dyn_cast_or_null()) { + return ConvertElements(elementsAttr, getElementTypeOrSelf(getResult())); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// IotaOp +//===----------------------------------------------------------------------===// + +OpFoldResult IotaOp::fold(ArrayRef operands) { + const auto output_type = getResult()->getType().cast(); + const auto output_size = output_type.getNumElements(); + const auto dimension = iota_dimension().getLimitedValue(); + const auto max_dim_size = output_type.getDimSize(dimension); + int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); + + llvm::SmallVector values; + values.reserve(output_size); + + int64_t increase_stride = output_size; + for (int i = 0; i <= dimension; i++) { + increase_stride /= output_type.getDimSize(i); + } + + int64_t current_value = 0; + for (int i = 0; i < output_size; i++) { + int64_t value = (current_value / increase_stride) % max_dim_size; + values.push_back(APInt(bitwidth, value)); + ++current_value; + } + + return DenseIntElementsAttr::get(output_type, values); +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + if (getOperand()->getType() == getType()) { + return getOperand(); + } + + if (auto prev_op = + dyn_cast_or_null(getOperand()->getDefiningOp())) { + setOperand(prev_op.getOperand()); + return getResult(); + } + + if (auto elements = operands.front().dyn_cast_or_null()) { + return elements.reshape(getResult()->getType().cast()); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + for (auto it : llvm::enumerate(permutation().getValues())) { + if (it.index() != it.value()) { + return {}; + } + } + return getOperand(); +} diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.h b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h similarity index 66% rename from tensorflow/compiler/mlir/xla/ir/xla_ops.h rename to tensorflow/compiler/mlir/xla/ir/hlo_ops.h index 2be8160d4ec..3260a829734 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_ops.h +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.h @@ -15,24 +15,29 @@ limitations under the License. // This file defines the operations used in the XLA dialect. -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_XLA_OPS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_XLA_OPS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Dialect.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir #include "mlir/Support/Functional.h" // TF:local_config_mlir namespace mlir { -class Builder; +class OpBuilder; -namespace XLA { +namespace xla_hlo { -class XLADialect : public Dialect { +class XlaHloDialect : public Dialect { public: - XLADialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla"; } + explicit XlaHloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_hlo"; } // Registered hook to materialize a constant operation from a given attribute // value with the desired resultant type. @@ -41,9 +46,9 @@ class XLADialect : public Dialect { }; #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h.inc" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc" -} // end namespace XLA +} // end namespace xla_hlo } // end namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_XLA_OPS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td similarity index 50% rename from tensorflow/compiler/mlir/xla/ir/xla_ops.td rename to tensorflow/compiler/mlir/xla/ir/hlo_ops.td index a05dd9b3d1d..7775377c94b 100644 --- a/tensorflow/compiler/mlir/xla/ir/xla_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -15,22 +15,27 @@ limitations under the License. // This is the operation definition file for XLA. -#ifdef XLA_OPS +#ifdef HLO_OPS #else -#define XLA_OPS +#define HLO_OPS #ifdef OP_BASE #else include "mlir/IR/OpBase.td" #endif // OP_BASE -def XLA_Dialect : Dialect { - let name = "xla"; - let cppNamespace = "XLA"; +#ifdef HLO_OPS_BASE +#else +include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" +#endif + +def HLO_Dialect : Dialect { + let name = "xla_hlo"; + let cppNamespace = "xla_hlo"; } -class XLA_Op traits> : - Op { +class HLO_Op traits> : + Op { // Whether this operation has a custom conversion to HLO or not. bit hasCustomHLOConverter = 0b0; } @@ -39,44 +44,34 @@ class XLA_Op traits> : // XLA type definitions. //===----------------------------------------------------------------------===// -def XLA_Int : IntOfWidths<[8, 16, 32, 64]>; - // Any integer tensor types -def XLA_IntTensor : StaticShapeTensorOf<[XLA_Int]>; +def HLO_IntTensor : StaticShapeTensorOf<[HLO_Int]>; // Any floating-point tensor types -def XLA_FpTensor : StaticShapeTensorOf<[AnyFloat]>; +def HLO_FpTensor : StaticShapeTensorOf<[AnyFloat]>; -def XLA_Pred : TypeAlias; - -def XLA_PredTensor : StaticShapeTensorOf<[XLA_Pred]>; +def HLO_PredTensor : StaticShapeTensorOf<[HLO_Pred]>; // Any integer or floating-point tensor types -def XLA_IntOrFpTensor : StaticShapeTensorOf<[XLA_Int, AnyFloat]>; +def HLO_IntOrFpTensor : StaticShapeTensorOf<[HLO_Int, AnyFloat]>; -def XLA_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>; +def HLO_Tensor : StaticShapeTensorOf<[AnyFloat, AnyInteger]>; -def XLA_Tuple : NestedTupleOf<[XLA_Tensor]>; +def HLO_Tuple : NestedTupleOf<[HLO_Tensor]>; -def XLA_TensorOrTuple : AnyTypeOf<[XLA_Tensor, XLA_Tuple]>; +def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; //===----------------------------------------------------------------------===// // XLA nullary op definitions. //===----------------------------------------------------------------------===// -def XLA_ConstOp : XLA_Op<"constant", [NoSideEffect]> { - let summary = "Constant operator"; - - let description = [{ - Represents a constant value. - }]; - +def HLO_ConstOp : BASE_HLO_ConstOp, HLO_Op<"constant", [NoSideEffect]> { let arguments = (ins ElementsAttr:$value ); let results = (outs - XLA_Tensor:$output + HLO_Tensor:$output ); let builders = [OpBuilder< @@ -89,16 +84,10 @@ def XLA_ConstOp : XLA_Op<"constant", [NoSideEffect]> { let hasCustomHLOConverter = 1; } -def XLA_IotaOp : XLA_Op<"iota", [NoSideEffect]> { - let summary = "Iota operator"; - - let description = [{ - Creates a rank 1 array of values starting at zero and incrementing by one. - }]; - +def HLO_IotaOp : BASE_HLO_IotaOp, HLO_Op<"iota", [NoSideEffect]> { let arguments = (ins I64Attr:$iota_dimension); - let results = (outs XLA_Tensor:$output); + let results = (outs HLO_Tensor:$output); let hasFolder = 1; @@ -110,32 +99,17 @@ def XLA_IotaOp : XLA_Op<"iota", [NoSideEffect]> { // XLA unary elementwise op definitions. //===----------------------------------------------------------------------===// // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions -class XLA_UnaryElementwiseOp traits>: - XLA_Op, Arguments<(ins XLA_Tensor:$operand)>, - Results<(outs XLA_Tensor:$res)>; +class HLO_UnaryElementwiseOp traits>: + HLO_Op { -def XLA_AbsOp: XLA_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Absolute value operator"; - - let description = [{ - Returns `abs(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; + let arguments = (ins HLO_Tensor); + let results = (outs HLO_Tensor); } -def XLA_ConvertOp : XLA_UnaryElementwiseOp< - "convert", [NoSideEffect, SameOperandsAndResultShape]> { - let summary = "Convert operator"; - - let description = [{ - Performs element-wise conversion of values from one type to another, e.g. - float to int. - - See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. - }]; +def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AbsOp; +def HLO_ConvertOp : HLO_UnaryElementwiseOp< + "convert", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ConvertOp { let hasFolder = 1; // TODO(b/130357376) Convert has a special constructor. Use a custom @@ -143,153 +117,65 @@ def XLA_ConvertOp : XLA_UnaryElementwiseOp< let hasCustomHLOConverter = 1; } -def XLA_NegOp: XLA_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Negation operator"; +def HLO_ExpOp: HLO_UnaryElementwiseOp<"exp", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ExpOp; - let description = [{ - Returns `-operand` element-wise. +def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_FloorOp; - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} +def HLO_LogOp: HLO_UnaryElementwiseOp<"log", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_LogOp; -def XLA_SignOp: XLA_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]> { - let summary = "Sign operator"; +def HLO_NegOp: HLO_UnaryElementwiseOp<"neg", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_NegOp; - let description = [{ - Returns `sign(operand)` element-wise, where +def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RsqrtOp; - ``` - sign(x) = -1 : x < 0 - = -0 : x = -0 - = NaN : x = NaN - = +0 : x = +0 - = 1 : x > 0 - ``` +def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_SignOp; - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -def XLA_TanhOp: XLA_UnaryElementwiseOp<"tanh", - [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType]> { - let summary = "Tanh operator"; - - let description = [{ - Returns `tanh(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} +def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", + [ResultsAreFloatLike, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_TanhOp; //===----------------------------------------------------------------------===// // XLA binary elementwise op definitions. //===----------------------------------------------------------------------===// -// The broadcasting dimensions correspond to a tuple that describes how a -// smaller rank shape is broadcast into a larger rank shape. For example, -// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means -// matching the matrix to dimensions 1 and 2 of the cuboid. -def BroadcastDimAttr : OptionalAttr; - // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations -class XLA_BinaryElementwiseOp traits, dag args = (ins)> : - XLA_Op, - Arguments<( - ins XLA_Tensor:$lhs, - XLA_Tensor:$rhs, - BroadcastDimAttr:$broadcast_dimensions - )>, - Results<(outs XLA_Tensor:$res)> { +class HLO_BinaryElementwiseOp traits> : + HLO_Op { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + BroadcastDimAttr:$broadcast_dimensions + ); + let results = (outs HLO_Tensor); let parser = [{ return mlir::impl::parseBinaryOp(parser, result); }]; let printer = [{ return mlir::impl::printBinaryOp(getOperation(), p); }]; } -def XLA_AddOp : XLA_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Addition operator"; +def HLO_AddOp : HLO_BinaryElementwiseOp<"add", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_AddOp; - let description = [{ - Returns `lhs + rhs` element-wise. +def HLO_DivOp : HLO_BinaryElementwiseOp<"div", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_DivOp; - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} +def HLO_MaxOp : HLO_BinaryElementwiseOp<"max", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MaxOp; -def XLA_DivOp : XLA_BinaryElementwiseOp<"div", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Division operator"; +def HLO_MinOp : HLO_BinaryElementwiseOp<"min", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MinOp; - let description = [{ - Returns `lhs / rhs` element-wise. +def HLO_MulOp : HLO_BinaryElementwiseOp<"mul", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_MulOp; - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} +def HLO_SubOp : HLO_BinaryElementwiseOp<"sub", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_SubOp; -def XLA_MaxOp : XLA_BinaryElementwiseOp<"max", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Maximum operator"; - - let description = [{ - Returns `max(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -def XLA_MinOp : XLA_BinaryElementwiseOp<"min", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Minimum operator"; - - let description = [{ - Returns `min(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -def XLA_MulOp : XLA_BinaryElementwiseOp<"mul", - [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Multiplication operator"; - - let description = [{ - Returns `lhs * rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -def XLA_SubOp : XLA_BinaryElementwiseOp<"sub", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Subtraction operator"; - - let description = [{ - Returns `lhs - rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -def XLA_AndOp: XLA_BinaryElementwiseOp<"and", [Commutative, NoSideEffect]>; +def HLO_AndOp: HLO_BinaryElementwiseOp<"and", [Commutative, NoSideEffect]>, BASE_HLO_AndOp; //===----------------------------------------------------------------------===// // XLA control flow op definitions. //===----------------------------------------------------------------------===// -def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "While operator"; +def HLO_WhileOp: HLO_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { + string summary = "While operator"; - let description = [{ + string description = [{ Returns the result of executing a body function until the cond body returns true. @@ -297,34 +183,26 @@ def XLA_WhileOp: XLA_Op<"while", [NoSideEffect, SameOperandsAndResultType]> { }]; let arguments = (ins - Variadic:$val, + Variadic:$val, SymbolRefAttr:$cond, SymbolRefAttr:$body ); - let results = (outs Variadic:$res); + let results = (outs Variadic); // TODO(b/129422361): WhileOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; } -def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> { - let summary = "Reduce operator"; - - let description = [{ - Returns the result of executing a reduction function on one or more arrays - in parallel. - - See https://www.tensorflow.org/xla/operation_semantics#reduce. - }]; +def HLO_ReduceOp: HLO_Op<"reduce", [NoSideEffect]>, BASE_HLO_ReduceOp { let arguments = (ins - Variadic:$operands_and_init, + Variadic:$operands_and_init, SymbolRefAttr:$computation, ElementsAttr:$dimensions ); - let results = (outs Variadic:$res); + let results = (outs Variadic); // TODO(b/129422361): ReduceOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; @@ -333,147 +211,67 @@ def XLA_ReduceOp: XLA_Op<"reduce", [NoSideEffect]> { //===----------------------------------------------------------------------===// // XLA tuple op definitions. //===----------------------------------------------------------------------===// -def XLA_GetTupleElementOp: XLA_Op<"get_tuple_element", [NoSideEffect]> { - let summary = "GetTupleElement operator"; - - let description = [{ - Returns a member of a tuple specified by an index. - - See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. - }]; - +def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { let arguments = (ins - XLA_Tuple, + HLO_Tuple, I32Attr:$index ); - let results = (outs XLA_TensorOrTuple); + let results = (outs HLO_TensorOrTuple); // GetTupleElementOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; } -def XLA_TupleOp : XLA_Op<"tuple", [NoSideEffect]> { - let summary = "XLA's tuple op"; - - let description = [{ - Groups a set of tensor inputs into a single tuple object. - - See https://www.tensorflow.org/xla/operation_semantics#tuple. - }]; - - let arguments = (ins Variadic:$val); - let results = (outs XLA_Tuple:$res); +def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { + let arguments = (ins Variadic:$val); + let results = (outs HLO_Tuple); // TupleOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; } -//===----------------------------------------------------------------------===// -// Precision Config enum definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA PrecisionConfig proto enum. -def XLA_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">; -def XLA_PRECISION_HIGH : StrEnumAttrCase<"HIGH">; -def XLA_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">; - -def XLA_PrecisionAttr : StrEnumAttr<"Precision", - "XLA precision for an operand. Has backend specific meaning.", - [XLA_PRECISION_DEFAULT, XLA_PRECISION_HIGH, XLA_PRECISION_HIGHEST]>; - -// TODO(b/129153247) See if it's possible to also validate the size. -def XLA_PrecisionConfigAttr: - OptionalAttr< - TypedArrayAttrBase>; - -//===----------------------------------------------------------------------===// -// Comparison op definitions. -//===----------------------------------------------------------------------===// - -// These mirror the XLA ComparisonDirection enum. -def XLA_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">; -def XLA_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">; -def XLA_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">; -def XLA_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">; -def XLA_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">; -def XLA_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">; - -def XLA_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", - "Which comparison operation to perform.", - [ - XLA_COMPARISON_DIRECTION_EQ, - XLA_COMPARISON_DIRECTION_NE, - XLA_COMPARISON_DIRECTION_GE, - XLA_COMPARISON_DIRECTION_GT, - XLA_COMPARISON_DIRECTION_LE, - XLA_COMPARISON_DIRECTION_LT - ]>; - -def XLA_CompareOp: XLA_Op<"compare", - [NoSideEffect, SameOperandsAndResultShape]> { +def HLO_CompareOp: HLO_Op<"compare", + [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_CompareOp { let arguments = (ins - XLA_Tensor:$lhs, - XLA_Tensor:$rhs, + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, BroadcastDimAttr:$broadcast_dimensions, - XLA_ComparisonDirectionAttr:$comparison_direction + HLO_ComparisonDirectionAttr:$comparison_direction ); - let results = (outs I1Tensor:$res); - let summary = "Comparison operator"; - - let description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction`. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. - }]; + let results = (outs HLO_PredTensor); } //===----------------------------------------------------------------------===// // XLA Slice definitions. //===----------------------------------------------------------------------===// -def XLA_SliceOp: XLA_UnaryElementwiseOp<"slice", - [NoSideEffect, SameOperandsAndResultElementType]> { +def HLO_SliceOp: HLO_Op< + "slice", + [NoSideEffect, SameOperandsAndResultElementType, + AllTypesMatch<["start_indices", "limit_indices"]>]> { let arguments = ( - ins XLA_Tensor:$operand, + ins HLO_Tensor:$operand, ElementsAttr:$start_indices, ElementsAttr:$limit_indices ); - let results = (outs XLA_Tensor:$res); - - let summary = "Slice operator"; - - let description = [{ - Slices a portion of the `operand` into a new configuration. - - See https://www.tensorflow.org/xla/operation_semantics#slice. - }]; + let results = (outs HLO_Tensor); // TODO(b/129422361) Two of the required arguments comes from the start and // limit indices which aren't handled by the codegen. let hasCustomHLOConverter = 1; } -def XLA_DynamicUpdateSliceOp: XLA_UnaryElementwiseOp<"dynamic-update-slice", - [NoSideEffect, AllElementTypesMatch<["operand", "res"]>]> { +def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", + [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> { let arguments = (ins - XLA_Tensor:$operand, - XLA_Tensor:$update, - Variadic:$start_indices + HLO_Tensor:$operand, + HLO_Tensor:$update, + Variadic:$start_indices ); - let results = (outs XLA_Tensor:$res); - - let summary = "Dynamic Update Slice operator"; - - let description = [{ - DynamicUpdateSlice generates a result which is the value of the input array - operand, with a slice update overwritten at start_indices. - - See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. - }]; + let results = (outs HLO_Tensor:$result); // TODO(b/129422361) Requires a custom constructor. let hasCustomHLOConverter = 1; @@ -484,52 +282,30 @@ def XLA_DynamicUpdateSliceOp: XLA_UnaryElementwiseOp<"dynamic-update-slice", // XLA Other op definitions. //===----------------------------------------------------------------------===// -def XLA_BatchNormInferenceOp : XLA_Op<"batch_norm_inference", [NoSideEffect]> { - let summary = "Batch Normalization for Inference"; - - let description = [{ - Normalizes an array across batch and spatial dimensions. - - See https://www.tensorflow.org/xla/operation_semantics#batchnorminference - }]; +def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", [NoSideEffect]>, + BASE_HLO_BatchNormInferenceOp { let arguments = (ins - XLA_Tensor:$operand, - XLA_Tensor:$scale, - XLA_Tensor:$offset, - XLA_Tensor:$mean, - XLA_Tensor:$variance, + HLO_Tensor:$operand, + HLO_Tensor:$scale, + HLO_Tensor:$offset, + HLO_Tensor:$mean, + HLO_Tensor:$variance, F32Attr:$epsilon, I64Attr:$feature_index ); - let results = (outs - XLA_Tensor:$res - ); + let results = (outs HLO_Tensor); } -def XLA_BroadcastOp : XLA_Op<"broadcast", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Broadcast a tensor to a higher rank by prepending dimensions"; - - let description = [{ - Broadcasts the operand tensor to a higher rank by prepending - `broadcast_sizes` to the dimensions. The current values of the operand are - copied into the other dimensions. - - This is a more limited form of broadcasting, that corresponds to the XLA - client Broadcast method. For a more general form of broadcasting, see the - BroadcastInDimOp. - - See https://www.tensorflow.org/xla/operation_semantics#broadcast. - }]; - +def HLO_BroadcastOp : HLO_Op<"broadcast", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp { let arguments = (ins - XLA_Tensor:$operand, + HLO_Tensor:$operand, ElementsAttr:$broadcast_sizes ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); // TODO(b/129012527) These should be expressed as type constraints. let verifier = [{ @@ -546,7 +322,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast", "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); } - auto resultType = res()->getType().cast(); + auto resultType = getResult()->getType().cast(); auto resultRank = resultType.getRank(); auto operandType = operand()->getType().cast(); auto operandRank = operandType.getRank(); @@ -560,12 +336,7 @@ def XLA_BroadcastOp : XLA_Op<"broadcast", resultRank, operandRank, sizesSize)); } - auto raw_sizes = sizes.getValues(); - llvm::SmallVector expectedShape(raw_sizes.begin(), - raw_sizes.end()); - if (sizes.isSplat()) { - expectedShape.resize(sizesSize, raw_sizes.front()); - } + llvm::SmallVector expectedShape(sizes.getValues()); auto operandShape = operandType.getShape(); expectedShape.insert(expectedShape.end(), operandShape.begin(), @@ -585,33 +356,14 @@ def XLA_BroadcastOp : XLA_Op<"broadcast", }]; } -def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Broadcast a tensor into the given shape by adding dimensions."; - - let description = [{ - Broadcasts the `operand` tensor to a higher rank. This is not the limited - form of broadcasting exposed as the XLA client broadcast op, but rather the - more powerful "InDim" broadcasting, which is closer to the HLO broadcast op - and exposed in the XLA client BroadcastInDim method. - - `broadcast_dimensions` maps the operand dimension number to the target shape - dimension number. It must have the same size as the rank of the operand. The - mapped dimensions must either be the same size or the dimension being - broadcast from must be size 1 (degenerate broadcasting). - - For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The - The scalar value will be broadcast to every element in the target shape. - - See https://www.tensorflow.org/xla/broadcasting. - }]; - +def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastInDimOp { let arguments = (ins - XLA_Tensor:$operand, + HLO_Tensor:$operand, BroadcastDimAttr:$broadcast_dimensions ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); // TODO(b/129012527) These should be expressed as type constraints. let verifier = [{ @@ -649,7 +401,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim", dimensionsSize, operandRank)); } - auto resultType = res()->getType().cast(); + auto resultType = getResult()->getType().cast(); auto resultRank = resultType.getRank(); if (resultRank < operandRank) { return emitOpError( @@ -658,7 +410,7 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim", } for (int i = 0; i != dimensionsSize; ++i) { - auto dimIndex = dimensions.getValue(i).cast().getInt(); + auto dimIndex = dimensions.getValue(i); if (dimIndex >= resultRank) { return emitOpError( llvm::formatv("broadcast_dimensions contains invalid value {0} for " @@ -684,29 +436,15 @@ def XLA_BroadcastInDimOp : XLA_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } -def XLA_ClampOp : XLA_Op<"clamp", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Clamp operator"; - - let description = [{ - Clamps an operand to within the range between a minimum and maximum value. - - Note: All three arrays must be the same shape. Alternatively, as a - restricted form of broadcasting, min and/or max can be a scalar (0D - tensor) of the element type of the tensor operand. - - See https://www.tensorflow.org/xla/operation_semantics#clamp. - }]; - +def HLO_ClampOp : HLO_Op<"clamp", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp { let arguments = (ins - XLA_Tensor:$min, - XLA_Tensor:$operand, - XLA_Tensor:$max + HLO_Tensor:$min, + HLO_Tensor:$operand, + HLO_Tensor:$max ); - let results = (outs - XLA_Tensor:$res - ); + let results = (outs HLO_Tensor); // TODO(b/129012527) These should be expressed as type constraints. let verifier = [{ @@ -739,18 +477,11 @@ def XLA_ClampOp : XLA_Op<"clamp", }]; } -def XLA_ConcatenateOp : XLA_Op<"concatenate", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "XLA's concantenate op"; - - let description = [{ - Concatenates a set of tensors along the specified dimension. - - See https://www.tensorflow.org/xla/operation_semantics#concatenate. - }]; +def HLO_ConcatenateOp : HLO_Op<"concatenate", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ConcatenateOp { let arguments = ( - ins Variadic:$val, + ins Variadic:$val, I64Attr: $dimension ); @@ -781,101 +512,72 @@ def XLA_ConcatenateOp : XLA_Op<"concatenate", return success(); }]; - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); // TODO(b/129422361) ConcatOp has special conversion logic to HLO. let hasCustomHLOConverter = 1; } -def XLA_ConvOp : XLA_Op<"conv", [NoSideEffect]> { - let summary = "Convolution operator"; - - let description = [{ - Computes a convolution of the kind used in neural networks. - - See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. - }]; - +def HLO_ConvOp : HLO_Op<"conv", [NoSideEffect]>, BASE_HLO_ConvOp { let arguments = (ins - XLA_Tensor:$lhs, - XLA_Tensor:$rhs + HLO_Tensor:$lhs, + HLO_Tensor:$rhs ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); // TODO(b/129422361) Needs additional work to handle attributes. // Conv has custom handling because its other args are passed as attributes let hasCustomHLOConverter = 1; } -def XLA_CopyOp: XLA_UnaryElementwiseOp<"copy", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Copy operator"; +def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { + string summary = "Copy operator"; - let description = [{ + string description = [{ Returns a copy of `operand`. }]; + let arguments = (ins HLO_Tensor); + let results = (outs HLO_Tensor); + // TODO(b/129422361) Implement special handling. // Copy has an HloOpcode, but is not one of the ops defined in xla_builder. let hasCustomHLOConverter = 1; } -def XLA_DotOp: XLA_Op<"dot", [NoSideEffect]> { +def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let arguments = ( - ins XLA_Tensor:$lhs, - XLA_Tensor:$rhs, - XLA_PrecisionConfigAttr:$precision_config + ins HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + HLO_PrecisionConfigAttr:$precision_config ); - let results = (outs XLA_Tensor:$res); - - let description = [{ - Performs dot products between vectors, vector/matrix and matrix/matrix - multiplication. - - See https://www.tensorflow.org/xla/operation_semantics#dot. - }]; + let results = (outs HLO_Tensor); } -def XLA_GatherOp: XLA_Op<"gather", [NoSideEffect]> { +def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let arguments = ( - ins XLA_Tensor:$operand, - XLA_IntTensor:$start_indices, + ins HLO_Tensor:$operand, + HLO_IntTensor:$start_indices, I64Attr: $index_vector_dim, - ElementsAttr: $offsets_dim, + ElementsAttr: $offset_dims, ElementsAttr: $slice_sizes, - ElementsAttr: $collapsed_slice_sizes, + ElementsAttr: $collapsed_slice_dims, ElementsAttr: $start_index_map ); - let results = (outs XLA_Tensor:$res); - - let summary = "Gather operator"; - - let description = [{ - Stitches together several slices of an input array. - - See https://www.tensorflow.org/xla/operation_semantics#gather. - }]; + let results = (outs HLO_Tensor); // TODO(b/129422361) Attributes are not by the codegen. The optional argument // (dimensions) needs to be added as an attribute. let hasCustomHLOConverter = 1; } -def XLA_ReshapeOp: XLA_Op<"reshape", - [NoSideEffect, SameOperandsAndResultElementType]> { - let arguments = (ins XLA_Tensor:$operand); - - let results = (outs XLA_Tensor:$res); - - let summary = "Reshape operator"; - - let description = [{ - Reshapes the dimensions of `operand` into a new configuration. - - See https://www.tensorflow.org/xla/operation_semantics#reshape. - }]; +def HLO_ReshapeOp: HLO_Op<"reshape", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { + let arguments = (ins HLO_Tensor:$operand); + let results = (outs HLO_Tensor); let hasFolder = 1; // TODO(b/129422361) One of the required arguments comes from the new shape, @@ -885,29 +587,14 @@ def XLA_ReshapeOp: XLA_Op<"reshape", } -def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]> { - let summary = "Select operator"; - - let description = [{ - Constructs an output tensor from the elements of `on_true` and `on_false` - based on the values of `pred`. - - `on_true` and `on_false` must be the same shape. For each element of `pred`, - `res` has the corresponding element of `on_true` or `on_false` depending on - the value in `pred`. `pred` must be the same shape as `on_true` and - `on_false` or a scalar, in which case `res` is equal to either `on_true` or - `on_false`. - - See https://www.tensorflow.org/xla/operation_semantics#select. - }]; - +def HLO_SelectOp: HLO_Op<"select", [NoSideEffect]>, BASE_HLO_SelectOp { let arguments = (ins - XLA_PredTensor:$pred, - XLA_Tensor:$on_true, - XLA_Tensor:$on_false + HLO_PredTensor:$pred, + HLO_Tensor:$on_true, + HLO_Tensor:$on_false ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); // TODO(b/129012527) These should be expressed as type constraints. let verifier = [{ @@ -938,48 +625,30 @@ def XLA_SelectOp: XLA_Op<"select", [NoSideEffect]> { }]; } -def XLA_ReverseOp: XLA_Op<"reverse", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Reverse operator"; - - let description = [{ - Reverses the specified dimensions of `operand` according to the given - `dimensions`. - - See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. - }]; - +def HLO_ReverseOp: HLO_Op<"reverse", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReverseOp { let arguments = (ins - XLA_Tensor:$operand, + HLO_Tensor:$operand, ElementsAttr:$dimensions ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); // TODO(b/129422361): ReverseOp has a custom constructor for HLO. let hasCustomHLOConverter = 1; } -def XLA_PadOp: XLA_Op<"pad", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Pad operator"; - - let description = [{ - Pads the edges of `operand` with the `padding_value` and according to - the passed configuration. - - See https://www.tensorflow.org/xla/operation_semantics#pad. - }]; - +def HLO_PadOp: HLO_Op<"pad", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PadOp { let arguments = (ins - XLA_Tensor:$operand, - XLA_Tensor:$padding_value, + HLO_Tensor:$operand, + HLO_Tensor:$padding_value, ElementsAttr: $edge_padding_low, ElementsAttr: $edge_padding_high, ElementsAttr: $interior_padding ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); let description = [{ Pads the `operand` according to TBD. @@ -1018,8 +687,8 @@ def XLA_PadOp: XLA_Op<"pad", for (int i = 0, e = input_shape.size(); i < e; i++) { int expected_output = input_shape[i] - + padding_low.getValue(i).cast().getInt() - + padding_high.getValue(i).cast().getInt(); + + padding_low.getValue(i).getInt() + + padding_high.getValue(i).getInt(); if (expected_output != output_shape[i]) { return emitOpError(llvm::formatv("Expected output shape ({0}) and " "output shape ({1}) should match.", @@ -1034,23 +703,15 @@ def XLA_PadOp: XLA_Op<"pad", let hasCustomHLOConverter = 1; } -def XLA_TransposeOp: XLA_Op<"transpose", - [NoSideEffect, SameOperandsAndResultElementType]> { - let summary = "Transpose operator"; - - let description = [{ - Permutes the dimensions of `operand` according to the given `permutation`. - - `res_dimensions[i] = operand_dimensions[permutation[i]]` - - See https://www.tensorflow.org/xla/operation_semantics#transpose. - }]; - +def HLO_TransposeOp: HLO_Op<"transpose", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_TransposeOp { let arguments = (ins - XLA_Tensor:$operand, + HLO_Tensor:$operand, ElementsAttr:$permutation ); - let results = (outs XLA_Tensor:$res); + let results = (outs HLO_Tensor); + + let hasFolder = 1; // TODO(b/129012527) These should be expressed as type constraints. let verifier = [{ @@ -1076,7 +737,7 @@ def XLA_TransposeOp: XLA_Op<"transpose", permutationSize, operandRank)); } - auto resultType = res()->getType().cast(); + auto resultType = getResult()->getType().cast(); auto resultRank = resultType.getRank(); if (resultRank != operandRank) { return emitOpError( @@ -1088,7 +749,7 @@ def XLA_TransposeOp: XLA_Op<"transpose", auto expectedShape = SmallVector(operandRank); for (int i = 0; i != operandRank; ++i) { - auto permutedDim = permutation().getValue(i).cast().getInt(); + auto permutedDim = permutation().getValue(i).getInt(); expectedShape[i] = operandType.getDimSize(permutedDim); } @@ -1105,4 +766,4 @@ def XLA_TransposeOp: XLA_Op<"transpose", }]; } -#endif // XLA_OPS +#endif // HLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td new file mode 100644 index 00000000000..28d6efd0aad --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td @@ -0,0 +1,528 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef HLO_OPS_BASE +#else +#define HLO_OPS_BASE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def HLO_Int : IntOfWidths<[8, 16, 32, 64]>; +def HLO_Pred : TypeAlias; + +//===----------------------------------------------------------------------===// +// XLA nullary op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_ConstOp { + string summary = "Constant operator"; + + string description = [{ + Represents a constant value. + }]; +} + +class BASE_HLO_IotaOp { + string summary = "Iota operator"; + + string description = [{ + Creates a rank 1 array of values starting at zero and incrementing by one. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA unary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + +class BASE_HLO_AbsOp { + string summary = "Absolute value operator"; + + string description = [{ + Returns `abs(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_ConvertOp { + string summary = "Convert operator"; + + string description = [{ + Performs element-wise conversion of values from one type to another, e.g. + float to int. + + See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. + }]; +} + +class BASE_HLO_ExpOp { + string summary = "Exponential operator"; + + string description = [{ + Returns `e^(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_FloorOp { + string summary = "Floor operator"; + + string description = [{ + Returns `Floor(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_LogOp { + string summary = "Logarithm operator"; + + string description = [{ + Returns `log(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_NegOp { + string summary = "Negation operator"; + + string description = [{ + Returns `-operand` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_RsqrtOp { + string summary = "Reciprocal Square-root operator"; + + string description = [{ + Returns `1.0 / sqrt(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_SignOp { + string summary = "Sign operator"; + + string description = [{ + Returns `sign(operand)` element-wise, where + + ``` + sign(x) = -1 : x < 0 + = -0 : x = -0 + = NaN : x = NaN + = +0 : x = +0 + = 1 : x > 0 + ``` + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_TanhOp { + string summary = "Tanh operator"; + + string description = [{ + Returns `tanh(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +//===----------------------------------------------------------------------===// + +// The broadcasting dimensions correspond to a tuple that describes how a +// smaller rank shape is broadcast into a larger rank shape. For example, +// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means +// matching the matrix to dimensions 1 and 2 of the cuboid. +def BroadcastDimAttr : OptionalAttr; + +class BASE_HLO_AddOp { + string summary = "Addition operator"; + + string description = [{ + Returns `lhs + rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_DivOp { + string summary = "Division operator"; + + string description = [{ + Returns `lhs / rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_MaxOp { + string summary = "Maximum operator"; + + string description = [{ + Returns `max(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_MinOp { + string summary = "Minimum operator"; + + string description = [{ + Returns `min(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_MulOp { + string summary = "Multiplication operator"; + + string description = [{ + Returns `lhs * rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_SubOp { + string summary = "Subtraction operator"; + + string description = [{ + Returns `lhs - rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_AndOp { + string summary = "Logical and"; + + string description = [{ + Returns `lhs /\ rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA control flow op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_ReduceOp { + string summary = "Reduce operator"; + + string description = [{ + Returns the result of executing a reduction function on one or more arrays + in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reduce. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA tuple op definitions. +//===----------------------------------------------------------------------===// +class BASE_HLO_GetTupleElementOp { + string summary = "GetTupleElement operator"; + + string description = [{ + Returns a member of a tuple specified by an index. + + See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. + }]; +} + +class BASE_HLO_TupleOp { + string summary = "XLA's tuple op"; + + string description = [{ + Groups a set of tensor inputs into a single tuple object. + + See https://www.tensorflow.org/xla/operation_semantics#tuple. + }]; +} + +//===----------------------------------------------------------------------===// +// Precision Config enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA PrecisionConfig proto enum. +def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">; +def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">; +def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">; + +def HLO_PrecisionAttr : StrEnumAttr<"Precision", + "XLA precision for an operand. Has backend specific meaning.", + [HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>; + +// TODO(b/129153247) See if it's possible to also validate the size. +def HLO_PrecisionConfigAttr: + OptionalAttr< + TypedArrayAttrBase>; + +//===----------------------------------------------------------------------===// +// Comparison op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA ComparisonDirection enum. +def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">; +def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">; +def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">; +def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">; +def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">; +def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">; + +def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", + "Which comparison operation to perform.", + [ + HLO_COMPARISON_DIRECTION_EQ, + HLO_COMPARISON_DIRECTION_NE, + HLO_COMPARISON_DIRECTION_GE, + HLO_COMPARISON_DIRECTION_GT, + HLO_COMPARISON_DIRECTION_LE, + HLO_COMPARISON_DIRECTION_LT + ]>; + +class BASE_HLO_CompareOp { + string summary = "Comparison operator"; + + string description = [{ + Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA Slice definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_SliceOp { + string summary = "Slice operator"; + + string description = [{ + Slices a portion of the `operand` into a new configuration. + + See https://www.tensorflow.org/xla/operation_semantics#slice. + }]; +} + +class BASE_HLO_DynamicUpdateSliceOp { + string summary = "Dynamic Update Slice operator"; + + string description = [{ + DynamicUpdateSlice generates a result which is the value of the input array + operand, with a slice update overwritten at start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA Other op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_BatchNormInferenceOp { + string summary = "Batch Normalization for Inference"; + + string description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnorminference + }]; +} + +class BASE_HLO_BroadcastOp { + string summary = "Broadcast a tensor to a higher rank by prepending dimensions"; + + string description = [{ + Broadcasts the operand tensor to a higher rank by prepending + `broadcast_sizes` to the dimensions. The current values of the operand are + copied into the other dimensions. + + This is a more limited form of broadcasting, that corresponds to the XLA + client Broadcast method. For a more general form of broadcasting, see the + BroadcastInDimOp. + + See https://www.tensorflow.org/xla/operation_semantics#broadcast. + }]; +} + +class BASE_HLO_BroadcastInDimOp { + string summary = "Broadcast a tensor into the given shape by adding dimensions."; + + string description = [{ + Broadcasts the `operand` tensor to a higher rank. This is not the limited + form of broadcasting exposed as the XLA client broadcast op, but rather the + more powerful "InDim" broadcasting, which is closer to the HLO broadcast op + and exposed in the XLA client BroadcastInDim method. + + `broadcast_dimensions` maps the operand dimension number to the target shape + dimension number. It must have the same size as the rank of the operand. The + mapped dimensions must either be the same size or the dimension being + broadcast from must be size 1 (degenerate broadcasting). + + For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The + The scalar value will be broadcast to every element in the target shape. + + See https://www.tensorflow.org/xla/broadcasting. + }]; +} + +class BASE_HLO_ClampOp { + string summary = "Clamp operator"; + + string description = [{ + Clamps an operand to within the range between a minimum and maximum value. + + Note: All three arrays must be the same shape. Alternatively, as a + restricted form of broadcasting, min and/or max can be a scalar (0D + tensor) of the element type of the tensor operand. + + See https://www.tensorflow.org/xla/operation_semantics#clamp. + }]; +} + +class BASE_HLO_ConcatenateOp { + string summary = "XLA's concantenate op"; + + string description = [{ + Concatenates a set of tensors along the specified dimension. + + See https://www.tensorflow.org/xla/operation_semantics#concatenate. + }]; +} + +class BASE_HLO_ConvOp { + string summary = "Convolution operator"; + + string description = [{ + Computes a convolution of the kind used in neural networks. + + See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. + }]; +} + +class BASE_HLO_DotOp { + string summary = "Dot operator"; + string description = [{ + Performs dot products between vectors, vector/matrix and matrix/matrix + multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dot. + }]; +} + +class BASE_HLO_GatherOp{ + string summary = "Gather operator"; + + string description = [{ + Stitches together several slices of an input array. + + See https://www.tensorflow.org/xla/operation_semantics#gather. + }]; +} + +class BASE_HLO_ReshapeOp { + string summary = "Reshape operator"; + + string description = [{ + Reshapes the dimensions of `operand` into a new configuration. + + See https://www.tensorflow.org/xla/operation_semantics#reshape. + }]; +} + +class BASE_HLO_SelectOp { + string summary = "Select operator"; + + string description = [{ + Constructs an output tensor from the elements of `on_true` and `on_false` + based on the values of `pred`. + + `on_true` and `on_false` must be the same shape. For each element of `pred`, + `res` has the corresponding element of `on_true` or `on_false` depending on + the value in `pred`. `pred` must be the same shape as `on_true` and + `on_false` or a scalar, in which case `res` is equal to either `on_true` or + `on_false`. + + See https://www.tensorflow.org/xla/operation_semantics#select. + }]; +} + +class BASE_HLO_ReverseOp { + string summary = "Reverse operator"; + + string description = [{ + Reverses the specified dimensions of `operand` according to the given + `dimensions`. + + See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. + }]; +} + +class BASE_HLO_PadOp { + string summary = "Pad operator"; + + string description = [{ + Pads the edges of `operand` with the `padding_value` and according to + the passed configuration. + + See https://www.tensorflow.org/xla/operation_semantics#pad. + }]; +} + +class BASE_HLO_TransposeOp { + string summary = "Transpose operator"; + + string description = [{ + Permutes the dimensions of `operand` according to the given `permutation`. + + `res_dimensions[i] = operand_dimensions[permutation[i]]` + + See https://www.tensorflow.org/xla/operation_semantics#transpose. + }]; +} + +#endif // HLO_OPS_BASE diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc new file mode 100644 index 00000000000..312654ef320 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc @@ -0,0 +1,64 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the XLA dialect. + +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" + +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Dialect.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/IR/Value.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc" + +namespace mlir { +namespace xla_lhlo { + +XlaLhloDialect::XlaLhloDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc.inc" + >(); +} + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.cc.inc" + +// TODO(cheshire): Support folding, reuse code from hlo_ops.cc. + +} // namespace xla_lhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h new file mode 100644 index 00000000000..f73e5026541 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LXLA dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Dialect.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/OpDefinition.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/Support/Functional.h" // TF:local_config_mlir + +namespace mlir { +class OpBuilder; + +namespace xla_lhlo { + +class XlaLhloDialect : public Dialect { + public: + explicit XlaLhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_lhlo"; } +}; + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h.inc" + +} // namespace xla_lhlo +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td new file mode 100644 index 00000000000..003247cca8c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -0,0 +1,323 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for LXLA. + +#ifdef LHLO_OPS +#else +#define LHLO_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +#ifdef HLO_OPS_BASE +#else +include "tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td" +#endif + +def LHLO_Dialect : Dialect { + let name = "xla_lhlo"; + let cppNamespace = "xla_lhlo"; +} + +//===----------------------------------------------------------------------===// +// XLA type definitions. +//===----------------------------------------------------------------------===// + +// Any integer tensor types +def LHLO_IntBuffer : StaticShapeMemRefOf<[HLO_Int]>; + +// Any floating-point tensor types +def LHLO_FpBuffer : StaticShapeMemRefOf<[AnyFloat]>; + + +def LHLO_PredBuffer : StaticShapeMemRefOf<[HLO_Pred]>; + +// Any integer or floating-point tensor types +def LHLO_IntOrFpBuffer : StaticShapeMemRefOf<[HLO_Int, AnyFloat]>; + +def LHLO_Buffer : StaticShapeMemRefOf<[AnyFloat, AnyInteger]>; + +def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>; + +def LHLO_BufferOrTuple : AnyTypeOf<[LHLO_Buffer, LHLO_TupleBuffer]>; + +//===----------------------------------------------------------------------===// +// XLA nullary op definitions. +//===----------------------------------------------------------------------===// + +class LHLO_Op traits> : Op; + +def LHLO_ConstOp : BASE_HLO_ConstOp, LHLO_Op<"constant", []> { + let arguments = (ins + ElementsAttr:$value, + LHLO_Buffer:$output + ); +} + +def LHLO_IotaOp : BASE_HLO_IotaOp, LHLO_Op<"iota", []> { + let arguments = (ins I64Attr:$iota_dimension, + LHLO_Buffer:$output); +} + +//===----------------------------------------------------------------------===// +// XLA unary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + +class LHLO_UnaryElementwiseOp : + LHLO_Op { + let arguments = (ins LHLO_Buffer:$input, + LHLO_Buffer:$output); +} + +def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; + +def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert">, BASE_HLO_ConvertOp; + +def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exp">, BASE_HLO_ExpOp; + +def LHLO_NegOp: LHLO_UnaryElementwiseOp<"neg">, BASE_HLO_NegOp; + +def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; + +def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +//===----------------------------------------------------------------------===// + +class LHLO_BinaryElementwiseOp traits> : + LHLO_Op { + let arguments = (ins + LHLO_Buffer:$lhs, + LHLO_Buffer:$rhs, + LHLO_Buffer:$out, + BroadcastDimAttr:$broadcast_dimensions + ); +} + +def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp; + +def LHLO_DivOp : LHLO_BinaryElementwiseOp<"div", []>, BASE_HLO_DivOp; + +def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"max", []>, BASE_HLO_MaxOp; + +def LHLO_MinOp : LHLO_BinaryElementwiseOp<"min", []>, BASE_HLO_MinOp; + +def LHLO_MulOp : LHLO_BinaryElementwiseOp<"mul", []>, BASE_HLO_MulOp; + +def LHLO_SubOp : LHLO_BinaryElementwiseOp<"sub", []>, BASE_HLO_SubOp; + +def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp; + +//===----------------------------------------------------------------------===// +// XLA control flow op definitions. +//===----------------------------------------------------------------------===// + +// TODO(b/139813999): specify required function signature in a type-safe way. +def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp { + let arguments = (ins + Variadic:$operands_and_init, + Variadic:$out, + SymbolRefAttr:$computation, + ElementsAttr:$dimensions + ); +} +//===----------------------------------------------------------------------===// +// XLA tuple op definitions. +//===----------------------------------------------------------------------===// + +def LHLO_GetTupleElementOp: LHLO_Op<"get_tuple_element", []>, BASE_HLO_GetTupleElementOp { + let arguments = (ins + LHLO_TupleBuffer:$input, + LHLO_BufferOrTuple:$out, + I32Attr:$index + ); +} + +def LHLO_TupleOp : LHLO_Op<"tuple", []>, BASE_HLO_TupleOp { + let arguments = (ins + Variadic:$val, + LHLO_TupleBuffer:$out); +} + +def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { + let arguments = (ins + LHLO_Buffer:$lhs, + LHLO_Buffer:$rhs, + LHLO_PredBuffer:$out, + BroadcastDimAttr:$broadcast_dimensions, + HLO_ComparisonDirectionAttr:$comparison_direction + ); +} + +//===----------------------------------------------------------------------===// +// XLA Slice definitions. +//===----------------------------------------------------------------------===// + +def LHLO_SliceOp: LHLO_Op< + "slice", + [AllTypesMatch<["start_indices", "limit_indices"]>]> { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$output, + ElementsAttr:$start_indices, + ElementsAttr:$limit_indices + ); +} + +def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$update, + LHLO_Buffer:$output, + Variadic:$start_indices + ); +} + +//===----------------------------------------------------------------------===// +// XLA Other op definitions. +//===----------------------------------------------------------------------===// + +def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, + BASE_HLO_BatchNormInferenceOp { + + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$scale, + LHLO_Buffer:$offset, + LHLO_Buffer:$mean, + LHLO_Buffer:$variance, + LHLO_Buffer:$output, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +def LHLO_BroadcastOp : LHLO_Op<"broadcast", + []>, BASE_HLO_BroadcastOp { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$output, + ElementsAttr:$broadcast_sizes + ); +} + +def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", + []>, BASE_HLO_BroadcastInDimOp { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$output, + BroadcastDimAttr:$broadcast_dimensions + ); +} + +def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp { + let arguments = (ins + LHLO_Buffer:$min, + LHLO_Buffer:$operand, + LHLO_Buffer:$max, + LHLO_Buffer:$output + ); +} + +def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { + let arguments = (ins + Variadic:$val, + LHLO_Buffer:$output, + I64Attr: $dimension + ); +} + +def LHLO_ConvOp : LHLO_Op<"conv", []>, BASE_HLO_ConvOp { + let arguments = (ins + LHLO_Buffer:$lhs, + LHLO_Buffer:$rhs, + LHLO_Buffer:$output + ); +} + +def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { + let arguments = (ins + LHLO_Buffer:$lhs, + LHLO_Buffer:$rhs, + HLO_PrecisionConfigAttr:$precision_config, + LHLO_Buffer:$output + ); +} + +def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_IntBuffer:$start_indices, + I64Attr: $index_vector_dim, + ElementsAttr: $offset_dims, + ElementsAttr: $slice_sizes, + ElementsAttr: $collapsed_slice_dims, + ElementsAttr: $start_index_map, + LHLO_Buffer:$output + ); +} + +def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$output + ); +} + + +def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { + let arguments = (ins + LHLO_PredBuffer:$pred, + LHLO_Buffer:$on_true, + LHLO_Buffer:$on_false, + LHLO_Buffer:$output + ); +} + +def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { + let arguments = (ins + LHLO_Buffer:$operand, + ElementsAttr:$dimensions, + LHLO_Buffer:$output + ); +} + +def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp { + let arguments = (ins + LHLO_Buffer:$operand, + LHLO_Buffer:$padding_value, + ElementsAttr: $edge_padding_low, + ElementsAttr: $edge_padding_high, + ElementsAttr: $interior_padding, + LHLO_Buffer: $output + ); +} + +def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp { + let arguments = (ins + LHLO_Buffer:$operand, + ElementsAttr:$permutation, + LHLO_Buffer:$output + ); +} + + +#endif // LHLO_OPS diff --git a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc b/tensorflow/compiler/mlir/xla/ir/xla_ops.cc deleted file mode 100644 index 25da9da3d1d..00000000000 --- a/tensorflow/compiler/mlir/xla/ir/xla_ops.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file defines the operations used in the XLA dialect. - -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" - -#include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // TF:local_config_mlir -#include "mlir/IR/Builders.h" // TF:local_config_mlir -#include "mlir/IR/OpImplementation.h" // TF:local_config_mlir -#include "mlir/Support/TypeUtilities.h" // TF:local_config_mlir - -using namespace mlir; -using namespace mlir::XLA; - -XLADialect::XLADialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { - addOperations< -#define GET_OP_LIST -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.cc.inc" - >(); - - // Support unknown operations because not all XLA operations are registered. - allowUnknownOperations(); -} - -Operation* XLADialect::materializeConstant(OpBuilder& builder, Attribute value, - Type type, Location loc) { - // If this is an opaque elements attribute, then generate an xla.constant. - if (value.isa()) - return builder.create(loc, type, value.cast()); - return nullptr; -} - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.cc.inc" - -//===----------------------------------------------------------------------===// -// ConstOp -//===----------------------------------------------------------------------===// - -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - - // Return the held attribute value. - return value(); -} - -// Builds a constant op with the specified attribute `value`. -void ConstOp::build(Builder* builder, OperationState* result, Attribute value) { - Type type; - if (auto elemAttr = value.dyn_cast()) { - type = elemAttr.getType(); - } else if (value.isa() || value.isa() || - value.isa()) { - // All XLA types must be tensor types. In the build() method, we want to - // provide more flexiblity by allowing attributes of scalar types. But we - // need to wrap it up with ElementsAttr to construct valid XLA constants. - type = RankedTensorType::get(/*shape=*/{}, value.getType()); - value = DenseElementsAttr::get(type.cast(), value); - } - - // TODO: support other XLA specific types. - assert(type && "unsupported attribute type for building xla.constant"); - result->types.push_back(type); - result->addAttribute("value", value); -} - -//===----------------------------------------------------------------------===// -// ConvertOp -//===----------------------------------------------------------------------===// - -OpFoldResult ConvertOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "convert must take one operand"); - auto operand = operands[0]; - - if (!operand) return {}; - - if (auto elementsAttr = operand.dyn_cast()) { - auto inType = elementsAttr.getType(); - auto outType = getResult()->getType().cast(); - - if (inType == outType) { - return operand; - } - - auto inElement = inType.getElementType(); - auto outElement = outType.getElementType(); - size_t bitWidth = - outElement.isBF16() ? 64 : outElement.getIntOrFloatBitWidth(); - - if (inElement.isa()) { - if (outElement.isa()) { - auto func = [&](const APFloat& floatValue) -> APInt { - return APInt(bitWidth, FloatAttr::getValueAsDouble(floatValue)); - }; - llvm::function_ref func_ref = func; - return elementsAttr.mapValues(outType.getElementType(), func_ref); - } - - if (outElement.isa()) { - auto& semantics = outElement.cast().getFloatSemantics(); - auto func = [&](const APFloat& floatValue) -> APInt { - APFloat newDouble(FloatAttr::getValueAsDouble(floatValue)); - bool losesInfo = false; - newDouble.convert(semantics, llvm::APFloat::rmNearestTiesToEven, - &losesInfo); - return newDouble.bitcastToAPInt(); - }; - llvm::function_ref func_ref = func; - return elementsAttr.mapValues(outType.getElementType(), func_ref); - } - } - - if (inElement.isa()) { - if (outElement.isa()) { - auto func = [&](const APInt& val) -> APInt { - return APInt(bitWidth, val.getLimitedValue()); - }; - llvm::function_ref func_ref = func; - return elementsAttr.mapValues(outType.getElementType(), func_ref); - } - - if (outElement.isa()) { - auto& semantics = outElement.cast().getFloatSemantics(); - auto func = [&](const APInt& val) -> APInt { - APFloat newDouble(static_cast(val.getLimitedValue())); - bool losesInfo = false; - newDouble.convert(semantics, llvm::APFloat::rmNearestTiesToEven, - &losesInfo); - return newDouble.bitcastToAPInt(); - }; - llvm::function_ref func_ref = func; - return elementsAttr.mapValues(outType.getElementType(), func_ref); - } - } - } - - return {}; -} - -//===----------------------------------------------------------------------===// -// IotaOp -//===----------------------------------------------------------------------===// - -OpFoldResult IotaOp::fold(ArrayRef operands) { - const auto output_type = getResult()->getType().cast(); - const auto output_size = output_type.getNumElements(); - const auto dimension = iota_dimension().getLimitedValue(); - const auto max_dim_size = output_type.getDimSize(dimension); - int bitwidth = output_type.getElementType().getIntOrFloatBitWidth(); - - llvm::SmallVector values; - values.reserve(output_size); - - int64_t increase_stride = output_size; - for (int i = 0; i <= dimension; i++) { - increase_stride /= output_type.getDimSize(i); - } - - int64_t current_value = 0; - for (int i = 0; i < output_size; i++) { - int64_t value = (current_value / increase_stride) % max_dim_size; - values.push_back(APInt(bitwidth, value)); - ++current_value; - } - - return DenseIntElementsAttr::get(output_type, values); -} - -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -OpFoldResult ReshapeOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "convert must take one operand"); - auto operand = operands[0]; - if (!operand) return {}; - - if (auto elements = operand.dyn_cast()) { - return elements.reshape(getResult()->getType().cast()); - } - - return {}; -} diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 2ec1324a1cf..230044d538b 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -22,14 +22,14 @@ limitations under the License. #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Attributes.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/Module.h" // TF:local_config_mlir #include "mlir/IR/Operation.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/comparison_util.h" @@ -37,11 +37,11 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +using tensorflow::int64; + static std::vector ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) { - llvm::ArrayRef raw_data = attr.getValues(); - if (attr.isSplat()) - return std::vector(attr.getType().getNumElements(), raw_data[0]); - return raw_data; + auto values = attr.getValues(); + return {values.begin(), values.end()}; } // Converts the broadcast_dimensions attribute into a span of dimension numbers @@ -154,7 +154,7 @@ class ConvertToHloModule { // if an error was encountered. LogicalResult RunOnFunction(mlir::FuncOp f); - xla::HloModuleProto ConsumeMainProto() { + ::xla::HloModuleProto ConsumeMainProto() { return lowered_computation_[module_.lookupSymbol("main")] .proto(); } @@ -176,8 +176,8 @@ LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success(); // TODO(riverriddle) We currently don't support lowering constant operations. - if (isa(inst)) { - inst->emitError("unable to lower 'xla.constant' operation"); + if (isa(inst)) { + inst->emitError("unable to lower 'xla_hlo.constant' operation"); return failure(); } diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 0fb315b90f9..6aecf70b385 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -18,7 +18,7 @@ limitations under the License. #include "llvm/ADT/StringExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/InitLLVM.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Main.h" @@ -51,8 +51,8 @@ static std::string GetConversionFunction( return "Convert_" + named_attr.name.str(); } -using ArgumentName = string; -using ArgumentDeclaration = string; +using ArgumentName = std::string; +using ArgumentDeclaration = std::string; using Argument = std::pair; using ArgumentList = std::vector; @@ -63,7 +63,7 @@ static std::string BuildOperator(const Operator& op) { // Signature. os << "static xla::XlaOp " << GetOperatorBuilderName(op_name) - << "(mlir::XLA::" << op_name.str() << " xla_op, " + << "(mlir::xla_hlo::" << op_name.str() << " xla_op, " << "llvm::DenseMap* " "value_lowering) {\n"; @@ -148,7 +148,7 @@ static void EmitBuilder(const std::vector& defs, StringRef op_name = def->getName().drop_front(4); // Try to cast to each op and call the corresponding op builder. - os << " if (auto xla_op = llvm::dyn_cast(op))\n return " << GetOperatorBuilderName(op_name) << "(xla_op, value_lowering);\n"; } @@ -163,17 +163,17 @@ static void EmitBuilder(const std::vector& defs, static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { emitSourceFileHeader("MLIR XLA Builders", os); - // Retrieve all the definitions derived from XLA_Op and sort by record name. - std::vector defs = records.getAllDerivedDefinitions("XLA_Op"); + // Retrieve all the definitions derived from HLO_Op and sort by record name. + std::vector defs = records.getAllDerivedDefinitions("HLO_Op"); llvm::sort(defs, LessRecord()); for (const auto* def : defs) { // XLA ops in the .td file are expected to follow the naming convention: - // XLA_Op. - // The generated XLA op C++ class should be XLA::Op. - if (!def->getName().startswith("XLA_")) + // HLO_Op. + // The generated XLA op C++ class should be HLO::Op. + if (!def->getName().startswith("HLO_")) PrintFatalError(def->getLoc(), - "unexpected op name format: 'XLA_' prefix missing"); + "unexpected op name format: 'HLO_' prefix missing"); if (!def->getName().endswith("Op")) PrintFatalError(def->getLoc(), "unexpected op name format: 'Op' suffix missing"); @@ -187,10 +187,7 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { } int main(int argc, char** argv) { - llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); - llvm::PrettyStackTraceProgram X(argc, argv); - - llvm::llvm_shutdown_obj Y; + llvm::InitLLVM y(argc, argv); llvm::cl::ParseCommandLineOptions(argc, argv); return TableGenMain(argv[0], &OperatorWritersMain); } diff --git a/tensorflow/compiler/mlir/xla/tests/convert.mlir b/tensorflow/compiler/mlir/xla/tests/convert.mlir index 93de3b30ec0..76cdab37a4e 100644 --- a/tensorflow/compiler/mlir/xla/tests/convert.mlir +++ b/tensorflow/compiler/mlir/xla/tests/convert.mlir @@ -1,218 +1,203 @@ -// RUN: tf-opt %s -split-input-file -xla-legalize-to-std | FileCheck %s +// RUN: tf-opt %s -split-input-file -canonicalize | FileCheck %s // ----- -// CHECK-LABEL: func @convert.1(%arg0: tensor) -> tensor { -func @convert.1(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor) -> tensor - %0 = "xla.convert"(%arg0) : (tensor) -> tensor - // CHECK-NEXT: return %0 : tensor +// CHECK-LABEL: func @same_type +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @same_type(%arg: tensor) -> tensor { + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[ARG]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.2(%arg0: tensor) -> tensor { -func @convert.2(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor) -> tensor - %0 = "xla.convert"(%arg0) : (tensor) -> tensor - // CHECK-NEXT: return %0 : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.3(%arg0: tensor) -> tensor { -func @convert.3(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor) -> tensor - %0 = "xla.convert"(%arg0) : (tensor) -> tensor - // CHECK-NEXT: return %0 : tensor +// CHECK-LABEL: func @int_widening +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_widening(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.4(%arg0: tensor) -> tensor { -func @convert.4(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor) -> tensor - %0 = "xla.convert"(%arg0) : (tensor) -> tensor - // CHECK-NEXT: return %0 : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.5(%arg0: tensor) -> tensor { -func @convert.5(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = "xla.convert"(%arg0) : (tensor) -> tensor - %0 = "xla.convert"(%arg0) : (tensor) -> tensor - // CHECK-NEXT: return %0 : tensor - return %0 : tensor -} - -// ----- - - -// CHECK-LABEL: func @convert.const.1() -> tensor { -func @convert.const.1() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor -} - -// ----- - -// check-label: func @convert.const.2() -> tensor { -func @convert.const.2() -> tensor { - // check-next: %cst = constant dense<42> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // check-next: return %cst : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.const.3() -> tensor { -func @convert.const.3() -> tensor { - // CHECK-NEXT: %cst = constant dense<42> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.const.4() -> tensor { -func @convert.const.4() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.const.5() -> tensor { -func @convert.const.5() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.const.6() -> tensor { -func @convert.const.6() -> tensor { - // CHECK-NEXT: %cst = constant dense<42> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor +// CHECK-LABEL: func @int_narrowing +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_narrowing(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.7() -> tensor { -func @convert.const.7() -> tensor { - // CHECK-NEXT: %cst = constant dense<42> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor +// CHECK-LABEL: func @float_int +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @float_int(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.8() -> tensor { -func @convert.const.8() -> tensor { - // CHECK-NEXT: %cst = constant dense<42> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @convert.const.9() -> tensor { -func @convert.const.9() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor +// CHECK-LABEL: func @int_float +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_float(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.9() -> tensor { -func @convert.const.9() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor +// CHECK-LABEL: func @high_rank_tensor +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32> + %0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32> + // CHECK-NEXT: return [[RES]] + return %0 : tensor<2x3xf32> +} + +// ----- + + +// CHECK-LABEL: func @const_same_type +func @const_same_type() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_float_int +func @const_float_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_float +func @const_int_float() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<4.{{0*}}e+00> : tensor + %cst = constant dense<4> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_bf16 +func @const_int_bf16() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<4.{{0*}}e+00> : tensor + %cst = constant dense<4> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.10() -> tensor { -func @convert.const.10() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor +// CHECK-LABEL: func @const_bf16_int +func @const_bf16_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.11() -> tensor { -func @convert.const.11() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor +// CHECK-LABEL: func @const_int_narrowing +func @const_int_narrowing() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor } - // ----- -// CHECK-LABEL: func @convert.const.12() -> tensor { -func @convert.const.12() -> tensor { - // CHECK-NEXT: %cst = constant dense<42> : tensor - %cst = constant dense<42.0> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor +// CHECK-LABEL: func @const_int_widening +func @const_int_widening() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.13() -> tensor { -func @convert.const.13() -> tensor { - // CHECK-NEXT: %cst = constant dense<42> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor +// CHECK-LABEL: func @const_float_narrowing +func @const_float_narrowing() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<4.2{{0*}}e+00> : tensor + %cst = constant dense<4.2> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor } // ----- -// CHECK-LABEL: func @convert.const.14() -> tensor { -func @convert.const.14() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant dense<42> : tensor - %0 = "xla.convert"(%cst) : (tensor) -> tensor - // CHECK-NEXT: return %cst : tensor +// CHECK-LABEL: func @const_f32_bf16 +func @const_f32_bf16() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<4.2{{0*}}e+01> : tensor + %cst = constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_f64 +func @const_bf16_f64() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<4.2{{0*}}e+00> : tensor + %cst = constant dense<4.2> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @const_bf16_int +func @const_bf16_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: func @const_high_rank_tensor +func @const_high_rank_tensor() -> tensor<2x3xi32> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6] + // CHECK-SAME: ]> : tensor<2x3xi32> + %cst = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + %0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2x3xi32> +} + diff --git a/tensorflow/compiler/mlir/xla/tests/iota.mlir b/tensorflow/compiler/mlir/xla/tests/iota.mlir index 10559a4bfe8..46e0984cd77 100644 --- a/tensorflow/compiler/mlir/xla/tests/iota.mlir +++ b/tensorflow/compiler/mlir/xla/tests/iota.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { func @iota.const.1() -> tensor<4xi32> { // CHECK-NEXT: %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32> - %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> // CHECK-NEXT: return %cst : tensor<4xi32> return %0 : tensor<4xi32> } @@ -15,7 +15,7 @@ func @iota.const.1() -> tensor<4xi32> { // CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { func @iota.const.2() -> tensor<2x4xi32> { // CHECK-NEXT: %cst = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> - %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> // CHECK-NEXT: return %cst : tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -25,7 +25,7 @@ func @iota.const.2() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { func @iota.const.3() -> tensor<2x4xi32> { // CHECK-NEXT: %cst = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> - %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> // CHECK-NEXT: return %cst : tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -35,7 +35,7 @@ func @iota.const.3() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-NEXT: %cst = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> - %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %cst : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -45,7 +45,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-NEXT: %cst = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> - %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %cst : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } @@ -55,7 +55,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-NEXT: %cst = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> - %0 = "xla.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> // CHECK-NEXT: return %cst : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir index 74dd0034283..92d9c3530fc 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-control-flow.mlir @@ -2,16 +2,16 @@ // CHECK-LABEL: func @cond(%arg0: tensor) -> tensor { func @cond(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor - %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor // CHECK-NEXT: return %0 : tensor return %0 : tensor } // CHECK-LABEL: func @loop(%arg0: tensor) -> tensor { func @loop(%arg0: tensor) -> tensor { - // CHECK-NEXT: %0 = xla.add %arg0, %arg0 {name = "compare.0"} : tensor - %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor, tensor) -> tensor + // CHECK-NEXT: %0 = xla_hlo.add %arg0, %arg0 {name = "compare.0"} : tensor + %0 = "xla_hlo.add"(%arg0, %arg0) {name = "compare.0"} : (tensor, tensor) -> tensor // CHECK-NEXT: return %0 : tensor return %0 : tensor } @@ -27,7 +27,7 @@ func @main(%arg0: tensor) -> tensor { // CHECK-NEXT: %4 = call @loop(%3) : (tensor) -> tensor // CHECK-NEXT: br ^bb1(%4 : tensor) // CHECK-NEXT: b3(%5: tensor): // pred: ^bb1 - %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor) -> tensor + %0 = "xla_hlo.while"(%arg0) {body = @loop, cond = @cond} : (tensor) -> tensor // CHECK-NEXT: return %5 : tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 69be9789818..5b45862a2b3 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -xla-legalize-tf %s | FileCheck %s +// RUN: tf-opt -xla-legalize-tf %s | FileCheck %s --dump-input-on-failure //===----------------------------------------------------------------------===// // BatchNorm op legalizations. @@ -6,7 +6,7 @@ // CHECK-LABEL: fusedBatchNorm_notraining func @fusedBatchNorm_notraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { - // CHECK-NEXT: "xla.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> %0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) return %0#0 : tensor<8x8x8x8xf32> } @@ -25,14 +25,14 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, // CHECK-LABEL: func @biasAdd_NHWC func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } // CHECK-LABEL: func @biasAdd_NCHW func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } @@ -42,14 +42,14 @@ func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens // CHECK-LABEL: func @biasAdd_NHWC_invalid func @biasAdd_NHWC_invalid(%arg0: tensor<1x32x10x2xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x2xi32> { - // CHECK-NOT: xla.add + // CHECK-NOT: xla_hlo.add %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x2xi32>, tensor<32xi32>) -> tensor<1x32x10x2xi32> return %0 : tensor<1x32x10x2xi32> } // CHECK-LABEL: func @biasAdd_NCHW_invalid func @biasAdd_NCHW_invalid(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x10x10x32xi32> { - // CHECK-NOT: xla.add + // CHECK-NOT: xla_hlo.add %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x10x10x32xi32>, tensor<32xi32>) -> tensor<1x10x10x32xi32> return %0 : tensor<1x10x10x32xi32> } @@ -60,29 +60,31 @@ func @biasAdd_NCHW_invalid(%arg0: tensor<1x10x10x32xi32>, %arg1: tensor<32xi32>) // CHECK-LABEL: func @add func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla.add %arg0, %arg0 : tensor<2xi32> - // CHECK-NEXT: return %0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM0:.*]] = xla_hlo.add %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %[[SUM1:.*]] = xla_hlo.add %[[SUM0]], %arg0 : tensor<2xi32> + // CHECK-NEXT: return %[[SUM1]] : tensor<2xi32> %0 = "tf.Add"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> - return %0: tensor<2xi32> + %1 = "tf.AddV2"(%0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> + return %1: tensor<2xi32> } // CHECK-LABEL: func @broadcast_add func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @broadcast_multi_dim_add func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - // CHECK-NEXT: "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} + // CHECK-NEXT: "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} %0 = "tf.Add"(%arg0, %arg1) : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0: tensor<4x4x4x4xi32> } // CHECK-LABEL: func @div func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla.div %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Div"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -90,14 +92,14 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @broadcast_div func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Div"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @mul func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla.mul %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.mul %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -105,28 +107,28 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @broadcast_mul func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.mul"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @real_div func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla.div %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.div %arg0, %arg0 : tensor<2xi32> %0 = "tf.RealDiv"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } // CHECK-LABEL: func @broadcast_real_div func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.div"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.RealDiv"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } // CHECK-LABEL: func @sub func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { - // CHECK-NEXT: %0 = xla.sub %arg0, %arg0 : tensor<2xi32> + // CHECK-NEXT: %0 = xla_hlo.sub %arg0, %arg0 : tensor<2xi32> // CHECK-NEXT: return %0 : tensor<2xi32> %0 = "tf.Sub"(%arg0, %arg0) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> return %0: tensor<2xi32> @@ -134,7 +136,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK-LABEL: func @broadcast_sub func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - // CHECK-NEXT: "xla.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-NEXT: "xla_hlo.sub"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} %0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0: tensor<1x2xi32> } @@ -156,7 +158,7 @@ func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: @const func @const() -> tensor<2xi32> { - // tf.Const is legalized into xla.constant, which is folded into constant. + // tf.Const is legalized into xla_hlo.constant, which is folded into constant. // CHECK-NEXT: constant dense<0> : tensor<2xi32> %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) @@ -170,7 +172,7 @@ func @const() -> tensor<2xi32> { // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: %cst = constant dense<0> : tensor<1xi32> - // CHECK-NEXT: %0 = xla.max %arg0, %cst : tensor<1xi32> + // CHECK-NEXT: %0 = xla_hlo.max %arg0, %cst : tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -179,7 +181,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-NEXT: %cst = constant dense<0> : tensor<1xi32> // CHECK-NEXT: %cst_0 = constant dense<6> : tensor<1xi32> - // CHECK-NEXT: %0 = "xla.clamp"(%cst, %arg0, %cst_0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK-NEXT: %0 = "xla_hlo.clamp"(%cst, %arg0, %cst_0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -190,7 +192,7 @@ func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: reshape func @reshape(%arg0: tensor<2xf32>, %arg1: tensor<2xi32>) -> tensor<1x1xf32> { - // CHECK: %0 = "xla.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32> + // CHECK: %0 = "xla_hlo.reshape"(%arg0) : (tensor<2xf32>) -> tensor<1x1xf32> %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xi32>) -> tensor<1x1xf32> return %0 : tensor<1x1xf32> } @@ -204,7 +206,7 @@ func @reshape_dynamic(%arg0: tensor<*xf32>, %arg1: tensor<2xi32>) -> tensor) -> tensor<1x10xf32> { - // CHECK-NEXT: %0 = "xla.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> + // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> return %0 : tensor<1x10xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir index d75b283e633..6dad19179f1 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-to-std.mlir @@ -3,16 +3,16 @@ // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32> - %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32> - %1 = "xla.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "xla_hlo.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> - %2 = "xla.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32> - %3 = "xla.div"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %3 : tensor<4xf32> return %3 : tensor<4xf32> @@ -21,16 +21,16 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf // CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32> - %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32> - %1 = "xla.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "xla_hlo.mul"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> - %2 = "xla.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %2 = "xla_hlo.sub"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: %3 = divis %2, %arg1 : tensor<4xi32> - %3 = "xla.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %3 = "xla_hlo.div"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK-NEXT: return %3 : tensor<4xi32> return %3 : tensor<4xi32> @@ -41,23 +41,23 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32 // them to separate broadcast and binary op. // CHECK-LABEL: func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tensor<4x4xf32> { - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %0 = "xla.add"(%arg0, %arg1) { + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "add.3"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %0 = "xla_hlo.add"(%arg0, %arg1) { name = "add.3", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %1 = "xla.mul"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %1 = "xla.mul"(%0, %arg1) { + // CHECK-NEXT: %1 = "xla_hlo.mul"(%0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "mul.4"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %1 = "xla_hlo.mul"(%0, %arg1) { name = "mul.4", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %2 = "xla.sub"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %2 = "xla.sub"(%1, %arg1) { + // CHECK-NEXT: %2 = "xla_hlo.sub"(%1, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "sub.5"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %2 = "xla_hlo.sub"(%1, %arg1) { name = "sub.5", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - // CHECK-NEXT: %3 = "xla.div"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> - %3 = "xla.div"(%2, %arg1) { + // CHECK-NEXT: %3 = "xla_hlo.div"(%2, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "div.6"} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> + %3 = "xla_hlo.div"(%2, %arg1) { name = "div.6", broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4x4xf32> @@ -68,17 +68,17 @@ func @binary_ops_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>) -> tens // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> - %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> - %1 = "xla.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> - %2 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> - %3 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> - %4 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> - %5 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } @@ -86,17 +86,17 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi // CHECK-LABEL: func @compare_float func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> - %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> - %1 = "xla.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> - %2 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> - %3 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> - %4 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> - %5 = "xla.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> } diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir new file mode 100644 index 00000000000..070386a0393 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -0,0 +1,143 @@ +// RUN: tf-opt %s -verify-diagnostics -split-input-file + +// ----- + +func @enforce_static_shapes(%arg0: memref, %arg1: memref) -> () { + // expected-error@+1{{op operand #0 must be statically shaped memref of floating-point or integer values}} + "xla_lhlo.tanh"(%arg0, %arg1) : (memref, memref) -> () + return +} + +// ----- + +func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { + // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} + "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @add_memrefs +func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @abs_memref +func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @convert_memref +func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @exp_memref +func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.exp"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @neg_memref +func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.neg"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sign_memref +func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @tanh_memref +func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @add_memref +func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @div_memref +func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.div"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @max_memref +func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.max"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @min_memref +func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.min"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @mul_memref +func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.mul"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sub_memref +func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.sub"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @and_memref +func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +func @reduce_computation(%sum: memref<1xf32>, %element: memref<1xf32>) -> () { + "xla_lhlo.add"(%element, %sum, %sum) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// CHECK-LABEL: func @reduce_memref +func @reduce_memref(%input: memref<10xf32>, %out: memref<1xf32>) -> () { + "xla_lhlo.reduce"(%input, %out) {computation = @reduce_computation, + dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<1xf32>) -> () + return +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/ops.mlir b/tensorflow/compiler/mlir/xla/tests/ops.mlir index fcd93bb1b97..06c98fb39b0 100644 --- a/tensorflow/compiler/mlir/xla/tests/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/ops.mlir @@ -4,7 +4,7 @@ func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> { // expected-error@+1 {{op operand #0 must be statically shaped tensor}} - %0 = "xla.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + %0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0: tensor<*xf32> } @@ -12,7 +12,7 @@ func @enforce_static_shapes(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: func @add_tensors func @add_tensors(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { - %0 = "xla.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -20,7 +20,7 @@ func @add_tensors(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @add_scalars func @add_scalars(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor return %0: tensor } @@ -28,7 +28,7 @@ func @add_scalars(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-LABEL: func @add_scalar_tensor func @add_scalar_tensor(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<1xi32> { - %0 = "xla.add"(%arg0, %arg1) : (tensor<1xi32>, tensor) -> tensor<1xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1xi32>, tensor) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -36,7 +36,7 @@ func @add_scalar_tensor(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<1xi3 // CHECK-LABEL: func @batch_norm_inference func @batch_norm_inference(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8x8x8x8xf32> { - %0 = "xla.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0 = "xla_hlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> return %0 : tensor<8x8x8x8xf32> } @@ -44,7 +44,7 @@ func @batch_norm_inference(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %ar // CHECK-LABEL: func @broadcast func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -52,7 +52,7 @@ func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_nonint_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -60,7 +60,7 @@ func @broadcast_nonint_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_splat_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<2.0> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<2.0> : tensor<2xf64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -68,7 +68,7 @@ func @broadcast_splat_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_sparse_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -76,7 +76,7 @@ func @broadcast_sparse_sizes(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -84,7 +84,7 @@ func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result rank (3) does not match operand rank (1) plus size of broadcast_sizes (3)}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -92,7 +92,7 @@ func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result has shape [1, 3] instead of [2, 3]}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> return %0 : tensor<1x3xi32> } @@ -100,7 +100,7 @@ func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{result has shape [2, 1] instead of [2, 3]}} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> return %0 : tensor<2x1xi32> } @@ -108,7 +108,7 @@ func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2 // CHECK-LABEL: func @broadcast_in_dim func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> return %0 : tensor<1x2x2xi32> } @@ -116,7 +116,7 @@ func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { // CHECK-LABEL: func @broadcast_in_dim_zero_rank func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { - %0 = "xla.broadcast_in_dim"(%arg0) : (tensor) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) : (tensor) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -124,7 +124,7 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { func @broadcast_in_dim_bad_nonint_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1.0, 2.0]> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -132,7 +132,7 @@ func @broadcast_in_dim_bad_nonint_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1 func @broadcast_in_dim_bad_splat_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2.0> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<2.0> : tensor<2xf64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -140,7 +140,7 @@ func @broadcast_in_dim_bad_splat_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x func @broadcast_in_dim_bad_sparse_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_sizes must be a DenseIntElementsAttr}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -148,7 +148,7 @@ func @broadcast_in_dim_bad_sparse_dimensions(%arg0: tensor<1x2xi32>) -> tensor<1 func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -156,7 +156,7 @@ func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -164,7 +164,7 @@ func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> { // expected-error@+1 {{result rank (1) is less than operand rank (3)}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> return %0 : tensor<3xi32> } @@ -172,7 +172,7 @@ func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions contains invalid value 9 for result result with rank 3}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -180,7 +180,7 @@ func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> ten func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}} - %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> return %0 : tensor<1x2x3xi32> } @@ -188,7 +188,7 @@ func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3x // CHECK-LABEL: func @comp_eq func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { - %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -196,7 +196,7 @@ func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { // expected-error@+1 {{'comparison_direction' failed to satisfy constraint}} - %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -204,7 +204,7 @@ func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3 func @comp_no_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { // expected-error@+1 {{op requires attribute 'comparison_direction'}} - %0 = "xla.compare"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %0 = "xla_hlo.compare"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> return %0 : tensor<3xi1> } @@ -212,7 +212,7 @@ func @comp_no_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3x // CHECK-LABEL: func @conv func @conv(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi32> { - %0 = "xla.conv"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + %0 = "xla_hlo.conv"(%arg0, %arg1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> return %0: tensor<3xi32> } @@ -220,7 +220,7 @@ func @conv(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi32> { // CHECK-LABEL: func @copy func @copy(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = "xla.copy"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + %0 = "xla_hlo.copy"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -228,7 +228,7 @@ func @copy(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @clamp func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = "xla.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -236,39 +236,39 @@ func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @clamp_scalar func @clamp_scalar(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<1xi32> { - %0 = "xla.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> + %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @clamp_invalid_clamp_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> { + // expected-error@+1 {{'xla_hlo.clamp' op requires the same element type for all operands and results}} + %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @clamp_invalid_clamp_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> { + // expected-error@+1 {{min shape [2] is not scalar and does not match operand shape [1]}} + %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } // ----- func @clamp_invalid_min_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> { - // expected-error@+1 {{'xla.clamp' op requires the same element type for all operands and results}} - %0 = "xla.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - return %0: tensor<1xi32> -} - -// ----- - -func @clamp_invalid_min_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> { - // expected-error@+1 {{min shape [2] is not scalar and does not match operand shape [1]}} - %0 = "xla.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // expected-error@+1 {{'xla_hlo.min' op requires the same element type for all operands and results}} + %0 = "xla_hlo.min"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> tensor<1xi32> return %0: tensor<1xi32> } // ----- func @clamp_invalid_max_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> { - // expected-error@+1 {{'xla.clamp' op requires the same element type for all operands and results}} - %0 = "xla.clamp"(%arg0, %arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>, tensor<1xf32>) -> tensor<1xi32> - return %0: tensor<1xi32> -} - -// ----- - -func @clamp_invalid_max_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> { - // expected-error@+1 {{max shape [2] is not scalar and does not match operand shape [1]}} - %0 = "xla.clamp"(%arg0, %arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<1xi32> + // expected-error@+1 {{'xla_hlo.max' op requires the same element type for all operands and results}} + %0 = "xla_hlo.max"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -276,7 +276,7 @@ func @clamp_invalid_max_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> ten // CHECK-LABEL: func @dot_vector func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor { - %0 = "xla.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor + %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor return %0: tensor } @@ -284,7 +284,7 @@ func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor // CHECK-LABEL: func @dot_matrix func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -292,7 +292,7 @@ func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi // CHECK-LABEL: func @dot_precision_config func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { - %0 = "xla.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -300,7 +300,7 @@ func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> te func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { // expected-error@+1 {{'precision_config' failed to satisfy constraint}} - %0 = "xla.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0: tensor<2x2xi32> } @@ -308,15 +308,47 @@ func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) - // CHECK-LABEL: func @tanh func @tanh(%arg0: tensor<1xf32>) -> tensor<1xf32> { - %0 = "xla.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %0 = "xla_hlo.tanh"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> return %0: tensor<1xf32> } // ----- +func @exp_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // expected-error@+1 {{'xla_hlo.exp' op requires the same type for all operands and results}} + %0 = "xla_hlo.exp"(%arg0) : (tensor<1xf32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @floor_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // expected-error@+1 {{'xla_hlo.floor' op requires the same type for all operands and results}} + %0 = "xla_hlo.floor"(%arg0) : (tensor<1xf32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @log_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // expected-error@+1 {{'xla_hlo.log' op requires the same type for all operands and results}} + %0 = "xla_hlo.log"(%arg0) : (tensor<1xf32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @rsqrt_invalid_result_type(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // expected-error@+1 {{'xla_hlo.rsqrt' op requires the same type for all operands and results}} + %0 = "xla_hlo.rsqrt"(%arg0) : (tensor<1xf32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + // CHECK-LABEL: func @reshape_same_shape func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> { - %0 = "xla.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -324,7 +356,7 @@ func @reshape_same_shape(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @reshape_different_shape func @reshape_different_shape(%arg0: tensor<1x16xi32>) -> tensor<4x4xi32> { - %0 = "xla.reshape"(%arg0) : (tensor<1x16xi32>) -> tensor<4x4xi32> + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x16xi32>) -> tensor<4x4xi32> return %0: tensor<4x4xi32> } @@ -332,7 +364,7 @@ func @reshape_different_shape(%arg0: tensor<1x16xi32>) -> tensor<4x4xi32> { // CHECK-LABEL: func @reshape_from_scalar func @reshape_from_scalar(%arg0: tensor) -> tensor<1xi32> { - %0 = "xla.reshape"(%arg0) : (tensor) -> tensor<1xi32> + %0 = "xla_hlo.reshape"(%arg0) : (tensor) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -340,7 +372,7 @@ func @reshape_from_scalar(%arg0: tensor) -> tensor<1xi32> { // CHECK-LABEL: func @reshape_to_scalar func @reshape_to_scalar(%arg0: tensor<1xi32>) -> tensor { - %0 = "xla.reshape"(%arg0) : (tensor<1xi32>) -> tensor + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1xi32>) -> tensor return %0: tensor } @@ -348,7 +380,7 @@ func @reshape_to_scalar(%arg0: tensor<1xi32>) -> tensor { // CHECK-LABEL: func @select func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -356,7 +388,7 @@ func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi3 // CHECK-LABEL: func @select_scalar_pred func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { - %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -364,7 +396,7 @@ func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tenso func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{must be statically shaped tensor of pred (AKA boolean or 1-bit integer)}} - %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -372,7 +404,7 @@ func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{on_true type (tensor<2x4xi32>) does not match on_false type (tensor<2x3xi32>)}} - %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -380,7 +412,7 @@ func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %ar func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{on_true type (tensor<2x3xf32>) does not match on_false type (tensor<2x3xi32>)}} - %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } @@ -388,15 +420,39 @@ func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf3 func @select_bad_pred_shape(%arg0: tensor<3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { // expected-error@+1 {{red shape ([3]) is not scalar and does not match operand shapes ([2, 3])}} - %0 = "xla.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } // ----- +// CHECK-LABEL: func @slice +func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { + // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices} have same type}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { + // expected-error@+1 {{requires the same element type for all operands and results}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- + // CHECK-LABEL: func @transpose -func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { - %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> +func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -404,7 +460,7 @@ func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { func @transpose_bad_permutations_float(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation must be a DenseIntElementsAttr}} - %0 = "xla.transpose"(%arg0) {permutation = dense<[1.0, 0.0, 3.0, 2.0]> : tensor<4xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1.0, 0.0, 3.0, 2.0]> : tensor<4xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -412,7 +468,7 @@ func @transpose_bad_permutations_float(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x func @transpose_bad_permutations_splat(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation must be a DenseIntElementsAttr}} - %0 = "xla.transpose"(%arg0) {permutation = dense<2.0> : tensor<2xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<2.0> : tensor<2xf64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -420,7 +476,7 @@ func @transpose_bad_permutations_splat(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x func @transpose_bad_permutations_sparse(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation must be a DenseIntElementsAttr}} - %0 = "xla.transpose"(%arg0) {permutation = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -428,7 +484,7 @@ func @transpose_bad_permutations_sparse(%arg0: tensor<1x2x3x4xi32>) -> tensor<2 func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation has rank 2 instead of rank 1}} - %0 = "xla.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -436,7 +492,7 @@ func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1 func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // expected-error@+1 {{permutation size (1) does not match operand rank (4)}} - %0 = "xla.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0: tensor<2x1x4x3xi32> } @@ -444,7 +500,7 @@ func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1 func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> tensor<2xi32> { // expected-error@+1 {{result rank (1) does not match operand rank (4)}} - %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> return %0: tensor<2xi32> } @@ -452,7 +508,7 @@ func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> ten func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> { // expected-error@+1 {{result shape is [1, 2, 3, 4] instead of [2, 1, 4, 3]}} - %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> return %0: tensor<1x2x3x4xi32> } @@ -460,6 +516,6 @@ func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x2x3x4xi32>) // CHECK-LABEL: func @tuple func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> { - %0 = "xla.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> return %0: tuple, tensor<1x2xf32>> -} \ No newline at end of file +} diff --git a/tensorflow/compiler/mlir/xla/tests/reshape.mlir b/tensorflow/compiler/mlir/xla/tests/reshape.mlir index ee29a718abf..34cb3cb2729 100644 --- a/tensorflow/compiler/mlir/xla/tests/reshape.mlir +++ b/tensorflow/compiler/mlir/xla/tests/reshape.mlir @@ -1,80 +1,149 @@ -// RUN: tf-opt %s -split-input-file -xla-legalize-to-std | FileCheck %s +// RUN: tf-opt %s -split-input-file -canonicalize | FileCheck %s -// ----- - -// CHECK-LABEL: func @reshape.const.1() -> tensor { -func @reshape.const.1() -> tensor { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor - %cst = constant {name = "constant.1"} dense<42.0> : tensor<1x1xf32> - %0 = "xla.reshape"(%cst) : (tensor<1x1xf32>) -> tensor - // CHECK-NEXT: return %cst : tensor - return %0 : tensor +// CHECK-LABEL: func @const_fold_collapse_to_scalar +func @const_fold_collapse_to_scalar() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor + %cst = constant dense<42> : tensor<1x1xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor } // ----- -// CHECK-LABEL: func @reshape.const.2() -> tensor<2xf32> { -func @reshape.const.2() -> tensor<2xf32> { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<2xf32> - %cst = constant {name = "constant.1"} dense<42.0> : tensor<1x2xf32> - %0 = "xla.reshape"(%cst) : (tensor<1x2xf32>) -> tensor<2xf32> - // CHECK-NEXT: return %cst : tensor<2xf32> - return %0 : tensor<2xf32> +// CHECK-LABEL: func @const_fold_collapse_to_tensor +func @const_fold_collapse_to_tensor() -> tensor<2xi32> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor<2xi32> + %cst = constant dense<42> : tensor<1x2xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2xi32> } // ----- -// CHECK-LABEL: func @reshape.const.3() -> tensor<1xf32> { -func @reshape.const.3() -> tensor<1xf32> { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<1xf32> - %cst = constant {name = "constant.1"} dense<42.0> : tensor - %0 = "xla.reshape"(%cst) : (tensor) -> tensor<1xf32> - // CHECK-NEXT: return %cst : tensor<1xf32> - return %0 : tensor<1xf32> +// CHECK-LABEL: func @const_fold_expand +func @const_fold_expand() -> tensor<1xi32> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor<1xi32> + %cst = constant dense<42> : tensor + %0 = "xla_hlo.reshape"(%cst) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<1xi32> } // ----- -// CHECK-LABEL: func @reshape.const.4() -> tensor<16xi64> { -func @reshape.const.4() -> tensor<16xi64> { - // CHECK-NEXT: %cst = constant dense<42> : tensor<16xi64> - %cst = constant dense<42> : tensor<4x4xi64> - %0 = "xla.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> - // CHECK-NEXT: return %cst : tensor<16xi64> +// CHECK-LABEL: func @const_fold_nontrivial +func @const_fold_nontrivial() -> tensor<16xi64> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor<16xi64> + %cst = constant dense<42> : tensor<4x4xi64> + %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: return [[CST]] return %0 : tensor<16xi64> } // ----- -// CHECK-LABEL: func @reshape.const.5() -> tensor<16xf64> { -func @reshape.const.5() -> tensor<16xf64> { - // CHECK-NEXT: %cst = constant dense<4.200000e+01> : tensor<16xf64> - %cst = constant dense<4.200000e+01> : tensor<4x4xf64> - %0 = "xla.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> - // CHECK-NEXT: return %cst : tensor<16xf64> - return %0 : tensor<16xf64> +// CHECK-LABEL: func @const_fold_flatten +func @const_fold_flatten() -> tensor<16xi64> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<42> : tensor<16xi64> + %cst = constant dense<42> : tensor<4x4xi64> + %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xi64> } - // ----- -// CHECK-LABEL: func @reshape.const.6() -> tensor<6xi32> { -func @reshape.const.6() -> tensor<6xi32> { - // CHECK-NEXT: %cst = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - %cst = constant {name = "constant.1"} dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> - %0 = "xla.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> - // CHECK-NEXT: return %cst : tensor<6xi32> +// CHECK-LABEL: func @const_fold_6 +func @const_fold_6() -> tensor<6xi32> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %cst = constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: return [[CST]] return %0 : tensor<6xi32> } +// ----- + +// CHECK-LABEL: func @const_fold_same_shape +func @const_fold_same_shape() -> tensor<2x3xi32> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6] + // CHECK-SAME: ]> : tensor<2x3xi32> + %cst = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2x3xi32> +} // ----- -// CHECK-LABEL: func @reshape.const.7() -> tensor<2x3xi32> { -func @reshape.const.7() -> tensor<2x3xi32> { - // CHECK-NEXT: %cst = constant dense<{{\[\[}}1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> - %cst = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> - %0 = "xla.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> - // CHECK-NEXT: return %cst : tensor<2x3xi32> +// CHECK-LABEL: func @const_fold_float +func @const_fold_float() -> tensor<16xf64> { + // CHECK-NEXT: [[CST:%.+]] = constant dense<4.2{{0*}}e+00> : tensor<16xf64> + %cst = constant dense<4.2> : tensor<4x4xf64> + %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xf64> +} + +// ----- + +// CHECK-LABEL: func @non_const_same_shape +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: return [[ARG]] + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) { + // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32> + // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape_unused_parent +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: return [[RES]] + return %1 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[ARG]] + return %1 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_many_chained_reshapes +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> + %2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> + %3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> + %4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> + // CHECK-NEXT: return [[RES]] + return %4 : tensor<1x2x4x3xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt index d285df18bc9..96423e0d12b 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt @@ -13,15 +13,15 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) %Arg_3.4 = f32[] parameter(3) // Add two tensors - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) // Add two scalars - // CHECK-NEXT: %1 = "xla.add"(%arg2, %arg3) {name = "add.4"} : (tensor, tensor) -> tensor + // CHECK-NEXT: %1 = "xla_hlo.add"(%arg2, %arg3) {name = "add.4"} : (tensor, tensor) -> tensor %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4) // Add a tensor and scalar - // CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: %2 = "xla_hlo.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK-NEXT: return %2 : tensor<4xf32> ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir index 4009759f3b8..a77b90ca083 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir @@ -6,9 +6,9 @@ func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) - %0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2) - %1 = "xla.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "xla_hlo.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt index 1826809db63..25cf3ecd16a 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %and.3 = f32[4] and(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir index 9aff6393e86..38aa4f04bad 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir @@ -8,19 +8,19 @@ func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi // CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4) // CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1) // CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2) - %0 = "xla.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> // Broadcast up rank // CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2} // CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2) // CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3) - %1 = "xla.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> + %1 = "xla_hlo.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> // Broadcast up rank + degenerate broadcast // CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2} // CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9) // CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2} // CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3) - %2 = "xla.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> + %2 = "xla_hlo.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> return %2 : tensor<2x3x4xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir index 1d231535703..0b64ab23d54 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir @@ -4,6 +4,6 @@ func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} - %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> return %0 : tensor<1x2x3x4xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt index d9c2e9fe094..3d520fc1bc2 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt @@ -6,14 +6,14 @@ HloModule main ENTRY %main { %Arg_0.1 = f32[1, 2] parameter(0) - // CHECK-NEXT: %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + // CHECK-NEXT: %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} // Degenerate broadcast - // CHECK-NEXT: %1 = "xla.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32> + // CHECK-NEXT: %1 = "xla_hlo.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32> broadcast.3 = f32[3,2] broadcast(%Arg_0.1), dimensions={} - // CHECK-NEXT: %2 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: %2 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> // CHECK-NEXT: return %2 : tensor<3x1x2xf32> ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt index c7ea0f9637e..350c372796d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt @@ -4,16 +4,16 @@ HloModule foo // CHECK-LABEL: func @call(%arg0: tensor) -> tensor { %call (arg_1: s64[]) -> s64[] { - %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"} - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.2"} : (tensor, tensor) -> tensor + %arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"} + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg0) {name = "compare.2"} : (tensor, tensor) -> tensor // CHECK-NEXT: return %0 : tensor ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"} } // CHECK-LABEL: func @main(%arg0: tensor) -> tensor { ENTRY %foo (arg0.1: s64[]) -> s64[] { - %arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"} + %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} // CHECK-NEXT: %0 = call @call(%arg0) : (tensor) -> tensor // CHECK-NEXT: return %0 : tensor ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call -} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/clamp.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/clamp.hlotxt new file mode 100644 index 00000000000..ea0ca3c1031 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/clamp.hlotxt @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main( +// CHECK-SAME: [[A0:%.+]]: tensor, [[A1:%.+]]: tensor<4xf32>, [[A2:%.+]]: tensor) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[], Arg_1.2: f32[4], Arg_1.3: f32[]) -> f32[4] { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + %Arg_2.3 = f32[] parameter(2) + + // CHECK-NEXT: [[R0:%.+]] = "xla_hlo.clamp"([[A0]], [[A1]], [[A2]]) {name = "clamp.3"} : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: return [[R0]] : tensor<4xf32> + ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt index ed3019b81cb..637629d9744 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt @@ -8,14 +8,14 @@ ENTRY %main (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] { %Arg_1.2 = f32[3] parameter(1) %Arg_2.3 = f32[1] parameter(2) - // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ - // CHECK-NEXT: %1 = "xla.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + // CHECK-NEXT: %1 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE // Requires broadcast of compatible tensors. - // CHECK-NEXT: %2 = "xla.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> + // CHECK-NEXT: %2 = "xla_hlo.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> // CHECK-NEXT: return %2 : tensor<3xi1> ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt index e73447d768d..b23c22b73c0 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] { %Arg_0.1 = f32[4, 1] parameter(0) %Arg_1.2 = f32[4, 2] parameter(1) - // CHECK-NEXT: %0 = "xla.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> + // CHECK-NEXT: %0 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> // CHECK-NEXT: return %0 : tensor<4x3xf32> ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt index 0de3ac6bffe..35fe1363b2e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt @@ -6,12 +6,12 @@ HloModule tfcompile.7 // implementations with attributes, etc. // CHECK-LABEL: func @main(%arg0: tensor<1x16x16x1xf32>) -> tuple> { ENTRY %tfcompile.7 { - %arg0.1 = f32[1,16,16,1]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"} + %arg0.1 = f32[1,16,16,1]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "xla.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %copy.1 = f32[1,16,16,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="XLA_Args"} + // CHECK-NEXT: %0 = "xla_hlo.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %copy.1 = f32[1,16,16,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %1 = "xla.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + // CHECK-NEXT: %1 = "xla_hlo.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> %reshape.2 = f32[1,16,16,1]{2,1,3,0} reshape(%copy.1) // Note that double brackets "[[" have to be escaped as they denote variables @@ -19,13 +19,13 @@ ENTRY %tfcompile.7 { // CHECK-NEXT: %cst = constant {name = "constant.3"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %2 = "xla.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32> + // CHECK-NEXT: %2 = "xla_hlo.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32> %convolution.4 = f32[1,16,16,1]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} - // CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> - %reshape.5 = f32[1,16,16,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="XLA_Retvals"} + // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %reshape.5 = f32[1,16,16,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"} - // CHECK-NEXT: %4 = "xla.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple> + // CHECK-NEXT: %4 = "xla_hlo.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple> // CHECK-NEXT: return %4 : tuple> - ROOT %tuple.6 = (f32[1,16,16,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="XLA_Retvals"} -} + ROOT %tuple.6 = (f32[1,16,16,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt index 3c0c7a9c1d1..f22646fc23e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt @@ -7,13 +7,13 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: %0 = "xla.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64> + // CHECK-NEXT: %0 = "xla_hlo.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64> %convert.3 = f64[4] convert(f32[4] %Arg_0.1) - // CHECK-NEXT: %1 = "xla.convert"(%arg1) {name = "convert.4"} : (tensor) -> tensor + // CHECK-NEXT: %1 = "xla_hlo.convert"(%arg1) {name = "convert.4"} : (tensor) -> tensor %convert.4 = f64[] convert(f32[] %Arg_1.2) - // CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor) -> tensor<4xf64> + // CHECK-NEXT: %2 = "xla_hlo.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor) -> tensor<4xf64> // CHECK-NEXT: return %2 : tensor<4xf64> ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt index 602ad96b852..772e47a0a35 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt index 5b7d0c6c2ef..88beb2f4803 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt @@ -7,17 +7,17 @@ ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[] { %Arg_0.1 = f32[1, 4] parameter(0) %Arg_1.2 = f32[4, 1] parameter(1) - // CHECK-NEXT: %0 = "xla.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %0 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} - // CHECK-NEXT: %1 = "xla.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %1 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} - // CHECK-NEXT: %2 = "xla.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %2 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: %3 = "xla.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: %3 = "xla_hlo.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor // CHECK-NEXT: return %3 : tensor ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt index d31160cfb21..85369451e2f 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt @@ -9,7 +9,7 @@ HloModule main %Arg_2.3 = f32[] parameter(2) %Arg_3.4 = f32[] parameter(3) - // CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: %0 = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> // CHECK-NEXT: return %0 : tensor<4x4xf32> ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4) } @@ -20,7 +20,7 @@ HloModule main %Arg_1.2 = f32[2] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3) } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/exp.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/exp.hlotxt new file mode 100644 index 00000000000..fb523f9cd16 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/exp.hlotxt @@ -0,0 +1,12 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @main(%arg0: tensor<16xf32>) -> tensor<16xf32> { +ENTRY %foo (arg0.1: f32[16]) -> f32[16] { + %arg0.1 = f32[16] parameter(0) + + // CHECK-NEXT: %0 = "xla_hlo.exp"(%arg0) {name = "exp.2"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: return %0 : tensor<16xf32> + ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/floor.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/floor.hlotxt new file mode 100644 index 00000000000..80e66da5642 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/floor.hlotxt @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @main( +// CHECK-SAME: [[A0:%.+]]: tensor<16xf32>) -> tensor<16xf32> { +ENTRY %foo (arg0.1: f32[16]) -> f32[16] { + %arg0.1 = f32[16] parameter(0) + + // CHECK-NEXT: [[R0:%.+]] = "xla_hlo.floor"([[A0]]) {name = "floor.2"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: return [[R0]] : tensor<16xf32> + ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt index a4e5b19e1e1..fca13d7f0b7 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -9,95 +9,95 @@ ENTRY %tfcompile.48 { %arg0.1 = f32[1,300] parameter(0) %arg1.2 = f32[1,300,3,1] parameter(1) - // CHECK-NEXT: %0 = "xla.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> + // CHECK-NEXT: %0 = "xla_hlo.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> %reshape.3 = f32[1,300] reshape(%arg0.1) - // CHECK-NEXT: %1 = "xla.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %1 = "xla_hlo.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} - // CHECK-NEXT: %2 = "xla.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + // CHECK-NEXT: %2 = "xla_hlo.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> %reshape.28 = f32[300,1,1] reshape(%transpose.27) - // CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + // CHECK-NEXT: %3 = "xla_hlo.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> %reshape.29 = f32[300,1] reshape(%reshape.28) - // CHECK-NEXT: %4 = "xla.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %4 = "xla_hlo.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor %constant.8 = f32[] constant(1) - // CHECK-NEXT: %5 = "xla.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %5 = "xla_hlo.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} - // CHECK-NEXT: %6 = "xla.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %6 = "xla_hlo.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor %constant.32 = f32[] constant(0) - // CHECK-NEXT: %7 = "xla.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %7 = "xla_hlo.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} - // CHECK-NEXT: %8 = "xla.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + // CHECK-NEXT: %8 = "xla_hlo.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor %constant.10 = f32[] constant(0) - // CHECK-NEXT: %9 = "xla.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + // CHECK-NEXT: %9 = "xla_hlo.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor %constant.40 = f32[] constant(0) - // CHECK-NEXT: %10 = "xla.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + // CHECK-NEXT: %10 = "xla_hlo.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} - // CHECK-NEXT: %11 = "xla.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %11 = "xla_hlo.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %copy.1 = f32[1,300,3,1] copy(%arg1.2) - // CHECK-NEXT: %12 = "xla.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + // CHECK-NEXT: %12 = "xla_hlo.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> %reshape.4 = f32[1,300,3,1] reshape(%copy.1) - // CHECK-NEXT: %13 = "xla.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + // CHECK-NEXT: %13 = "xla_hlo.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> %reshape.24 = f32[1,300,3] reshape(%reshape.4) - // CHECK-NEXT: %14 = "xla.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + // CHECK-NEXT: %14 = "xla_hlo.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} - // CHECK-NEXT: %15 = "xla.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + // CHECK-NEXT: %15 = "xla_hlo.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> %reshape.26 = f32[300,3] reshape(%transpose.25) // CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) // TODO(b/129709049) consider making this default precision config implied. - // CHECK-NEXT: %16 = "xla.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %16 = "xla_hlo.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} // CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32> %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) - // CHECK-NEXT: %17 = "xla.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %17 = "xla_hlo.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} - // CHECK-NEXT: %18 = "xla.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %18 = "xla_hlo.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32> %add.39 = f32[300,5] add(%dot.36, %broadcast.38) - // CHECK-NEXT: %19 = "xla.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32> + // CHECK-NEXT: %19 = "xla_hlo.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32> %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) - // CHECK-NEXT: %20 = "xla.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %20 = "xla_hlo.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> %reshape.44 = f32[300,1,5] reshape(%maximum.42) - // CHECK-NEXT: %21 = "xla.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %21 = "xla_hlo.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) - // CHECK-NEXT: %22 = "xla.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + // CHECK-NEXT: %22 = "xla_hlo.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> %reshape.46 = f32[300,1,5] reshape(%select.45) - // CHECK-NEXT: %23 = "xla.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> + // CHECK-NEXT: %23 = "xla_hlo.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> // CHECK-NEXT: return %23 : tuple> ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt index 9a4944d414e..35c762c067c 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt @@ -4,14 +4,14 @@ HloModule main.5 // CHECK-LABEL: func @main() -> tensor<4xf32> { ENTRY %iota.1 () -> f32[4] { - // CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %iota.0 = f32[4] iota(), iota_dimension=0 } // CHECK-LABEL: func @iota.2() -> tensor<4x5xf32> { %iota.2 () -> f32[4, 5] { - // CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32> + // CHECK-NEXT: %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32> // CHECK-NEXT: return %0 : tensor<4x5xf32> ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1 } \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/log.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/log.hlotxt new file mode 100644 index 00000000000..616ad0c0eb4 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/log.hlotxt @@ -0,0 +1,12 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @main(%arg0: tensor<16xf32>) -> tensor<16xf32> { +ENTRY %foo (arg0.1: f32[16]) -> f32[16] { + %arg0.1 = f32[16] parameter(0) + + // CHECK-NEXT: %0 = "xla_hlo.log"(%arg0) {name = "log.2"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: return %0 : tensor<16xf32> + ROOT %log.2 = f32[16] log(f32[16] %arg0.1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt index dd6c0f504f5..f4ba76b4675 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt index 5efe44aa53a..880fc0f76ca 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt index 1bfb6662124..ad7feef19bc 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt index 412f267ce42..84e1fbc9cf6 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt @@ -7,7 +7,7 @@ ENTRY %padding.1 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0 } @@ -17,7 +17,7 @@ ENTRY %padding.1 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] { %Arg_0.1 = f32[4, 4, 4] parameter(0) %Arg_1.2 = f32[] parameter(1) - // CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> + // CHECK-NEXT: %0 = "xla_hlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> // CHECK-NEXT: return %0 : tensor<7x11x15xf32> ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6 } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt index 37e638eb1f7..e4dc4d5e211 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt @@ -33,19 +33,19 @@ ENTRY %foo.5 (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f %Arg_1.2 = f32[4] parameter(1) %Arg_2.3 = f32[] parameter(2) - // CHECK-NEXT: %0 = "xla.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor, tensor) -> tuple, tensor> + // CHECK-NEXT: %0 = "xla_hlo.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor, tensor) -> tuple, tensor> %reduce.1 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.1 - // CHECK-NEXT: %1 = "xla.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %1 = "xla_hlo.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32> %reduce.2 = f32[] reduce(%reduce.1, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.2 - // CHECK-NEXT: %2 = "xla.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor + // CHECK-NEXT: %2 = "xla_hlo.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor %reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.2 - // CHECK-NEXT: %3 = "xla.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor) -> tensor + // CHECK-NEXT: %3 = "xla_hlo.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor) -> tensor %reduce.4 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3 - // CHECK-NEXT: %4 = "xla.sub"(%2, %3) {name = "sub.5"} : (tensor, tensor) -> tensor + // CHECK-NEXT: %4 = "xla_hlo.sub"(%2, %3) {name = "sub.5"} : (tensor, tensor) -> tensor %sub.5 = f32[] subtract(%reduce.2, %reduce.3) ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.4, %sub.5) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt index 7c8303d5966..f89f3eb89bf 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt @@ -6,7 +6,7 @@ HloModule main.5 ENTRY %reverse.1 (Arg_0.1: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) - // CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0} } @@ -15,7 +15,7 @@ ENTRY %reverse.1 (Arg_0.1: f32[4]) -> f32[4] { %reverse.2 (Arg_0.1: f32[4, 4]) -> f32[4, 4] { %Arg_0.1 = f32[4, 4] parameter(0) - // CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32> // CHECK-NEXT: return %0 : tensor<4x4xf32> ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt new file mode 100644 index 00000000000..a7b9b73f239 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/rsqrt.hlotxt @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @main( +// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32> { +ENTRY %foo (arg0.1: f32[16]) -> f32[16] { + %arg0.1 = f32[16] parameter(0) + + // CHECK-NEXT: [[P0:%.+]] = "xla_hlo.rsqrt"([[ARG0]]) {name = "rsqrt.2"} : (tensor<16xf32>) -> tensor<16xf32> + // CHECK-NEXT: return [[P0]] : tensor<16xf32> + ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt index b9ae08d8c8c..d3fe6a51e56 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt @@ -8,7 +8,7 @@ ENTRY %main { %Arg_1.2 = s32[2,3] parameter(1) %Arg_2.3 = s32[2,3] parameter(2) - // CHECK-NEXT: %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> // CHECK-NEXT: return %0 : tensor<2x3xi32> ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir index 4990ae712f8..f00aa0ade15 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir @@ -7,7 +7,7 @@ func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32> // CHECK-NEXT: %Arg_2.3 = s32[2,3] parameter(2) // CHECK-NEXT: ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) - %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %0 : tensor<2x3xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo index 83d85f7d45e..5d358596d54 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo @@ -139,8 +139,8 @@ dynamic_parameter_binding { } # CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { -# CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +# CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> # TODO(b/129709049) consider making this default precision config inferred. -# CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor # CHECK-NEXT: return %1 : tensor # CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt index 09462625bbb..b3f8e977bfe 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt @@ -7,11 +7,11 @@ ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] { %Arg_0.1 = f32[4]{0} parameter(0) %Arg_1.2 = f32[4]{0} parameter(1) - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) // TODO(b/129709049) consider making this default precision config inferred. - // CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + // CHECK-NEXT: %1 = "xla_hlo.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor // CHECK-NEXT: return %1 : tensor ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir index f6e277c97de..e68262ba9ff 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir @@ -2,8 +2,8 @@ func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): - %0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %1 = "xla.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "xla_hlo.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt index 6fc493aa764..24d4dff6270 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt @@ -7,7 +7,7 @@ ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { %Arg_0.1 = f32[4] parameter(0) %Arg_1.2 = f32[4] parameter(1) - // CHECK-NEXT: %0 = "xla.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = "xla_hlo.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt index 54dc0faef09..054e6af355e 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt @@ -4,9 +4,9 @@ HloModule foo // CHECK-LABEL: func @main(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { ENTRY %foo (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { - %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"} + %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"} - // CHECK-NEXT: %0 = "xla.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: %0 = "xla_hlo.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> // CHECK-NEXT: return %0 : tensor<1x16x16x3xf32> ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"} -} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt index 335e54669eb..203152d1ca4 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt @@ -6,7 +6,7 @@ HloModule main ENTRY %main { %Arg_0.1 = s32[1,2,3,4] parameter(0) - // CHECK-NEXT: %0 = "xla.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK-NEXT: %0 = "xla_hlo.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> // CHECK-NEXT: return %0 : tensor<2x1x4x3xi32> ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir index e28d0a37d84..77048e6c902 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir @@ -5,7 +5,7 @@ func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { // CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0) // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} - %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> return %0 : tensor<2x1x4x3xi32> } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt index c98fa93fcd9..bcaf1c81982 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt @@ -7,10 +7,10 @@ ENTRY %main(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) { %Arg_0.1 = s32[1] parameter(0) %Arg_1.2 = f32[1, 2] parameter(1) - // CHECK-NEXT: %0 = "xla.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple> + // CHECK-NEXT: %0 = "xla_hlo.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple> %tuple.3 = (s32[1]) tuple(%Arg_0.1) - // CHECK-NEXT: %1 = "xla.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + // CHECK-NEXT: %1 = "xla_hlo.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> // CHECK-NEXT: return %1 : tuple, tensor<1x2xf32>> ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt new file mode 100644 index 00000000000..2db52dd9023 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/types.hlotxt @@ -0,0 +1,44 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule tfcompile.1 + +// CHECK-LABEL: func @main() -> tensor { +ENTRY %tfcompile.1 { + // CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1.000000e+00> : tensor + %constant.0 = f32[] constant(1) + + // CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<1.000000e+00> : tensor + %constant.1 = f64[] constant(1) + + // CHECK-NEXT: %cst_1 = constant {name = "constant.2"} dense<1> : tensor + %constant.2 = s8[] constant(1) + + // CHECK-NEXT: %cst_2 = constant {name = "constant.3"} dense<1> : tensor + %constant.3 = s16[] constant(1) + + // CHECK-NEXT: %cst_3 = constant {name = "constant.4"} dense<1> : tensor + %constant.4 = s32[] constant(1) + + // CHECK-NEXT: %cst_4 = constant {name = "constant.5"} dense<1> : tensor + %constant.5 = s64[] constant(1) + + // TODO(b/130356985): Update once MLIR supports unsigned integers. + // CHECK-NEXT: %cst_5 = constant {name = "constant.6"} dense<1> : tensor + %constant.6 = u8[] constant(1) + + // TODO(b/130356985): Update once MLIR supports unsigned integers. + // CHECK-NEXT: %cst_6 = constant {name = "constant.7"} dense<1> : tensor + %constant.7 = u16[] constant(1) + + // TODO(b/130356985): Update once MLIR supports unsigned integers. + // CHECK-NEXT: %cst_7 = constant {name = "constant.8"} dense<1> : tensor + %constant.8 = u32[] constant(1) + + // TODO(b/130356985): Update once MLIR supports unsigned integers. + // CHECK-NEXT: %cst_8 = constant {name = "constant.9"} dense<1> : tensor + %constant.9 = u64[] constant(1) + + // CHECK-NEXT: %cst_9 = constant {name = "constant.10"} dense : tensor + // CHECK-NEXT: return %cst_9 : tensor + ROOT %constant.10 = pred[] constant(1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt index 42d52fd78c8..daf7dd8d01d 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt @@ -6,6 +6,6 @@ HloModule main ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[1] { %Arg_0.1 = f32[1] parameter(0) - // CHECK-NEXT: %0 = "xla.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK-NEXT: %0 = "xla_hlo.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> ROOT add-dependency.2 = f32[1] add-dependency(Arg_0.1, Arg_0.1) } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt index a6d2a48797e..784ad891111 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt @@ -4,24 +4,24 @@ HloModule foo // CHECK-LABEL: func @cond(%arg0: tensor) -> tensor { %cond (arg_1: s64[]) -> pred[] { - %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"} - // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + %arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"} + // CHECK-NEXT: %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor // CHECK-NEXT: return %0 : tensor ROOT %compare.2 = pred[] compare(%arg_1, %arg_1), direction=LT, metadata={op_type="Less" op_name="Less"} } // CHECK-LABEL: func @loop(%arg0: tensor) -> tensor { %loop (arg_1: s64[]) -> s64[] { - %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"} - // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor, tensor) -> tensor + %arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"} + // CHECK-NEXT: %0 = "xla_hlo.add"(%arg0, %arg0) {name = "compare.0"} : (tensor, tensor) -> tensor // CHECK-NEXT: return %0 : tensor ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"} } // CHECK-LABEL: func @main(%arg0: tensor) -> tensor { ENTRY %foo (arg0.1: s64[]) -> s64[] { - %arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"} - // CHECK-NEXT: %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor) -> tensor + %arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"} + // CHECK-NEXT: %0 = "xla_hlo.while"(%arg0) {body = @loop, cond = @cond} : (tensor) -> tensor // CHECK-NEXT: return %0 : tensor ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond -} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/transpose.mlir new file mode 100644 index 00000000000..0ed7e709ed4 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/transpose.mlir @@ -0,0 +1,29 @@ +// RUN: tf-opt %s -split-input-file -canonicalize | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @remove_noop +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { + %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<2x3x9x5xi32> +} + +// ----- + +// CHECK-LABEL: func @keep_real_transpose +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { + // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) + %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> + return %0 : tensor<3x2x5x9xi32> +} + +// ----- + +// CHECK-LABEL: func @keep_same_shape_real_transpose +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { + // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) + %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> + return %0 : tensor<4x4xi32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc index cf271f42814..b40c89c1f8c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_control_flow.cc @@ -16,6 +16,7 @@ limitations under the License. // This file implements logic for lowering XLA dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Block.h" // TF:local_config_mlir #include "mlir/IR/BlockAndValueMapping.h" // TF:local_config_mlir #include "mlir/IR/Builders.h" // TF:local_config_mlir @@ -23,28 +24,27 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" using mlir::PassRegistration; namespace mlir { -namespace XLA { +namespace xla_hlo { namespace { struct LegalizeControlFlow : public mlir::FunctionPass { // Perform the lowering to MLIR control flow. void runOnFunction() override; }; -bool LowerWhileOp(mlir::XLA::WhileOp while_op) { +bool LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { // Converts an xla while loop into control flow. This mostly generates the // right MLIR boilerplate for calling the body / condition functions, then // branching on their results appropriately. The operation should look similar // to below: // // - // %0 = "xla.while"(%arg0) {body: @loop, cond: @cond} + // %0 = "xla_hlo.while"(%arg0) {body: @loop, cond: @cond} // auto* opInst = while_op.getOperation(); mlir::OpBuilder builder(while_op); @@ -147,9 +147,14 @@ void LegalizeControlFlow::runOnFunction() { } } } // namespace -} // namespace XLA +} // namespace xla_hlo } // namespace mlir -static PassRegistration legalize_cf_pass( +std::unique_ptr +mlir::xla_hlo::createLegalizeControlFlowPass() { + return std::make_unique(); +} + +static PassRegistration legalize_cf_pass( "xla-legalize-control-flow", "Legalize from XLA control flow to MLIR control flow"); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index a10329cea06..00c9c238f1e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -17,10 +17,11 @@ limitations under the License. #include +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" using namespace mlir; @@ -32,7 +33,9 @@ struct LegalizeTF : public FunctionPass { }; } // end anonymous namespace -FunctionPassBase *mlir::XLA::createLegalizeTFPass() { return new LegalizeTF(); } +std::unique_ptr mlir::xla_hlo::createLegalizeTFPass() { + return std::make_unique(); +} /// Returns if the given TF data format string is the default format. static bool isDefaultDataFormat(StringRef format) { return format == "NHWC"; } @@ -127,11 +130,11 @@ static ElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, Value *y) { } namespace mlir { -namespace XLA { +namespace xla { namespace { #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" } // end anonymous namespace -} // end namespace XLA +} // end namespace xla } // end namespace mlir /// Perform the lowering to XLA dialect. @@ -140,8 +143,8 @@ void LegalizeTF::runOnFunction() { auto func = getFunction(); // Add the generated patterns to the list. - XLA::populateWithGenerated(func.getContext(), &patterns); - applyPatternsGreedily(func, std::move(patterns)); + xla::populateWithGenerated(func.getContext(), &patterns); + applyPatternsGreedily(func, patterns); } static PassRegistration pass( diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 7835fcf9213..1730e5374a4 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -16,9 +16,9 @@ limitations under the License. // This is the legalization pattern definition file for TF to XLA. include "mlir/IR/OpBase.td" -include "mlir/StandardOps/Ops.td" +include "mlir/Dialect/StandardOps/Ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "tensorflow/compiler/mlir/xla/ir/xla_ops.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" def NullElementsAttr : NativeCodeCall<"ElementsAttr()">; @@ -30,17 +30,19 @@ def FeatureDimension : NativeCodeCall< "getFeatureDimensionAttr($_builder, $0, $1)">; def FalseBoolAttr : AttrConstraint>; -def : Pattern<(TF_FusedBatchNormOp F32Tensor:$x, F32Tensor:$scale, - F32Tensor:$offset, F32Tensor:$mean, - F32Tensor:$variance, F32Attr:$epsilon, +def : Pattern< + (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, $data_format, FalseBoolAttr:$is_training), - [(XLA_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance, - $epsilon, (FeatureDimension $data_format, $x)), - /*batch_mean=*/(verifyUnusedValue), - /*batch_variance=*/(verifyUnusedValue), - /*reserve_space_1=*/(verifyUnusedValue), - /*reserve_space_2=*/(verifyUnusedValue) - ]>; + [(HLO_BatchNormInferenceOp $x, $scale, $offset, $mean, $variance, + $epsilon, (FeatureDimension $data_format, $x)), + // We already guaranteed that the last four results has no use so it + // does not matter what value we provide here for replacement. + /*batch_mean=*/(replaceWithValue $x), + /*batch_variance=*/(replaceWithValue $x), + /*reserve_space_1=*/(replaceWithValue $x), + /*reserve_space_2=*/(replaceWithValue $x)], + [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), + (HasNoUseOf:$root__3), (HasNoUseOf:$root__4)]>; //===----------------------------------------------------------------------===// // Bias op patterns. @@ -60,7 +62,7 @@ def ValidBiasAddFeatureDimension : Constraint< def : Pat<(TF_BiasAddOp IsAtleast3DShapeTensor:$input, Is1DShapeTensor:$bias, TF_ConvnetDataFormatAttr:$data_format), - (XLA_AddOp $input, $bias, + (HLO_AddOp $input, $bias, (BiasAddFeatureDimension $data_format, $input)), [(ValidBiasAddFeatureDimension $data_format, $input, $bias)]>; @@ -76,11 +78,12 @@ class DirectBinaryPat : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; -foreach fromToBinPair = [[TF_AddOp, XLA_AddOp], - [TF_DivOp, XLA_DivOp], - [TF_MulOp, XLA_MulOp], - [TF_RealDivOp, XLA_DivOp], - [TF_SubOp, XLA_SubOp]] in +foreach fromToBinPair = [[TF_AddOp, HLO_AddOp], + [TF_AddV2Op, HLO_AddOp], + [TF_DivOp, HLO_DivOp], + [TF_MulOp, HLO_MulOp], + [TF_RealDivOp, HLO_DivOp], + [TF_SubOp, HLO_SubOp]] in def : DirectBinaryPat; //===----------------------------------------------------------------------===// @@ -94,7 +97,7 @@ def : Pat<(TF_IdentityOp $op), (replaceWithValue $op)>; //===----------------------------------------------------------------------===// // TODO(riverriddle) Formalize a policy on converting opaque attributes. -def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (XLA_ConstOp $value), +def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (HLO_ConstOp $value), [(AnyStaticShapeTensor $res)]>; //===----------------------------------------------------------------------===// @@ -105,11 +108,11 @@ class ConstantSplat : NativeCodeCall< "getSplat($_builder, $0, " # value # ")">; def : Pat<(TF_ReluOp AnyTensor:$input), - (XLA_MaxOp (ConstantOp (ConstantSplat<"0"> $input)), $input, + (HLO_MaxOp (ConstantOp (ConstantSplat<"0"> $input)), $input, (NullElementsAttr))>; def : Pat<(TF_Relu6Op AnyTensor:$input), - (XLA_ClampOp (ConstantOp (ConstantSplat<"0"> $input)), $input, + (HLO_ClampOp (ConstantOp (ConstantSplat<"0"> $input)), $input, (ConstantOp (ConstantSplat<"6"> $input)))>; //===----------------------------------------------------------------------===// @@ -117,7 +120,7 @@ def : Pat<(TF_Relu6Op AnyTensor:$input), //===----------------------------------------------------------------------===// def : Pat<(TF_ReshapeOp:$res AnyStaticShapeTensor:$arg, $ignored), - (XLA_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; + (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; def : Pat<(TF_SqueezeOp AnyStaticShapeTensor:$arg, $ignored_dims), - (XLA_ReshapeOp $arg)>; + (HLO_ReshapeOp $arg)>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc index 4ac42d39f06..934e9f91820 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard.cc @@ -16,11 +16,11 @@ limitations under the License. // This file implements logic for lowering XLA dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir #include "mlir/IR/Function.h" // TF:local_config_mlir #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir #include "mlir/Pass/Pass.h" // TF:local_config_mlir -#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir -#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" using mlir::Builder; @@ -30,13 +30,13 @@ using mlir::OwningRewritePatternList; using mlir::PassRegistration; namespace mlir { -namespace XLA { +namespace xla_hlo { namespace { #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_to_standard.inc" struct CompareIConvert : public RewritePattern { explicit CompareIConvert(MLIRContext *context) - : RewritePattern("xla.compare", 1, context) {} + : RewritePattern("xla_hlo.compare", 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -75,7 +75,7 @@ struct CompareIConvert : public RewritePattern { struct CompareFConvert : public RewritePattern { explicit CompareFConvert(MLIRContext *context) - : RewritePattern("xla.compare", 1, context) {} + : RewritePattern("xla_hlo.compare", 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -113,7 +113,7 @@ struct CompareFConvert : public RewritePattern { }; } // end anonymous namespace -} // end namespace XLA +} // end namespace xla_hlo } // end namespace mlir namespace { @@ -123,8 +123,9 @@ struct LegalizeToStandard : public FunctionPass { }; } // end anonymous namespace -FunctionPassBase *mlir::XLA::createLegalizeToStdPass() { - return new LegalizeToStandard(); +std::unique_ptr +mlir::xla_hlo::createLegalizeToStdPass() { + return std::make_unique(); } /// Perform the lowering to standard dialect. @@ -132,12 +133,11 @@ void LegalizeToStandard::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); - mlir::XLA::populateWithGenerated(func.getContext(), &patterns); - patterns.push_back( - llvm::make_unique(&getContext())); - patterns.push_back( - llvm::make_unique(&getContext())); - applyPatternsGreedily(func, std::move(patterns)); + mlir::xla_hlo::populateWithGenerated(func.getContext(), &patterns); + patterns + .insert( + &getContext()); + applyPatternsGreedily(func, patterns); } static PassRegistration legalize_pass( diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td index 5f03ee6e70d..d0925cc9fb7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_to_standard_patterns.td @@ -16,8 +16,8 @@ limitations under the License. // This is the legalization pattern definition file for XLA to StandardOps. include "mlir/IR/OpBase.td" -include "mlir/StandardOps/Ops.td" -include "tensorflow/compiler/mlir/xla/ir/xla_ops.td" +include "mlir/Dialect/StandardOps/Ops.td" +include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td" //===----------------------------------------------------------------------===// // Binary op patterns. @@ -28,37 +28,36 @@ def IsSameSizePred : CPred< "== $1->getType().cast().getShape()">; def IsSameSizeConstraint : Constraint; -def : Pat<(XLA_AddOp XLA_FpTensor:$l, XLA_FpTensor:$r, +def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r, IsNullAttr:$broadcast_dimensions), (AddFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_SubOp XLA_FpTensor:$l, XLA_FpTensor:$r, +def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r, IsNullAttr:$broadcast_dimensions), (SubFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_MulOp XLA_FpTensor:$l, XLA_FpTensor:$r, +def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r, IsNullAttr:$broadcast_dimensions), (MulFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_DivOp XLA_FpTensor:$l, XLA_FpTensor:$r, +def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r, IsNullAttr:$broadcast_dimensions), (DivFOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_AddOp XLA_IntTensor:$l, XLA_IntTensor:$r, +def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r, IsNullAttr:$broadcast_dimensions), (AddIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_SubOp XLA_IntTensor:$l, XLA_IntTensor:$r, +def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r, IsNullAttr:$broadcast_dimensions), (SubIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_MulOp XLA_IntTensor:$l, XLA_IntTensor:$r, +def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r, IsNullAttr:$broadcast_dimensions), (MulIOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; -def : Pat<(XLA_DivOp XLA_IntTensor:$l, XLA_IntTensor:$r, +def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r, IsNullAttr:$broadcast_dimensions), (DivISOp $l, $r), [(IsSameSizeConstraint $l, $r)]>; - diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 2ed045396e7..3eb97dd6a0f 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -16,18 +16,23 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ #define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ +#include + namespace mlir { class FunctionPassBase; -namespace XLA { +namespace xla_hlo { /// Lowers from TF dialect to XLA dialect. -FunctionPassBase *createLegalizeTFPass(); +std::unique_ptr createLegalizeTFPass(); -// Lowers from XLA dialect to Standard dialect. -FunctionPassBase *createLegalizeToStdPass(); +/// Lowers XLA control flow ops to the Standard dialect. +std::unique_ptr createLegalizeControlFlowPass(); -} // end namespace XLA +/// Lowers from XLA dialect to Standard dialect. +std::unique_ptr createLegalizeToStdPass(); + +} // end namespace xla_hlo } // end namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index 40c896fef9c..e64182889cb 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "absl/base/integral_types.h" #include "mlir/IR/AffineMap.h" // TF:local_config_mlir #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir #include "mlir/IR/Location.h" // TF:local_config_mlir @@ -25,11 +24,13 @@ limitations under the License. #include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" using mlir::IntegerType; using mlir::MemRefType; using mlir::RankedTensorType; using mlir::VectorType; +using tensorflow::int64; using xla::PrimitiveType; using xla::ShapeUtil; diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index 9a77be947d5..57922fe1532 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -15,20 +15,48 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include + #include "mlir/IR/Builders.h" // TF:local_config_mlir #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/protobuf.h" using mlir::Builder; using mlir::MLIRContext; -using ::testing::EqualsProto; namespace xla { namespace { +// Simple implementation of a proto matcher comparing string representations. +// Only works as ShapeProto's textual representation is deterministic. +class ProtoStringMatcher { + public: + explicit ProtoStringMatcher(const tensorflow::protobuf::Message& expected) + : expected_(expected.SerializeAsString()) {} + + template + bool MatchAndExplain(const Message& p, testing::MatchResultListener*) const { + return p.SerializeAsString() == expected_; + } + + void DescribeTo(::std::ostream* os) const { *os << expected_; } + void DescribeNegationTo(::std::ostream* os) const { + *os << "not equal to expected message: " << expected_; + } + + private: + const std::string expected_; +}; + +inline ::testing::PolymorphicMatcher EqualsProto( + const tensorflow::protobuf::Message& x) { + return ::testing::MakePolymorphicMatcher(ProtoStringMatcher(x)); +} + TEST(TypeToShapeTest, ConvertPrimitiveTypes) { MLIRContext context; Builder b(&context); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 9804858c084..ad7e4724d90 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h" -#include "google/protobuf/text_format.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/Module.h" // TF:local_config_mlir @@ -26,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/protobuf.h" using stream_executor::port::Status; using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this @@ -34,13 +34,13 @@ namespace xla { namespace { // Error collector that simply ignores errors reported. -class NoOpErrorCollector : public ::proto2::io::ErrorCollector { +class NoOpErrorCollector : public tensorflow::protobuf::io::ErrorCollector { public: void AddError(int line, int column, const string& message) override {} }; bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) { - ::proto2::TextFormat::Parser parser; + tensorflow::protobuf::TextFormat::Parser parser; NoOpErrorCollector collector; parser.RecordErrorsTo(&collector); return hlo_proto->ParseFromString(contents) || @@ -114,8 +114,8 @@ static mlir::LogicalResult MlirHloToHloTranslateFunction( if (!module) return mlir::failure(); std::error_code error; - auto result = llvm::make_unique(output_filename, error, - llvm::sys::fs::F_None); + auto result = std::make_unique(output_filename, error, + llvm::sys::fs::F_None); if (error) { LOG(ERROR) << error.message(); return mlir::failure(); @@ -147,8 +147,8 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction( if (!module) return mlir::failure(); std::error_code error; - auto result = llvm::make_unique(output_filename, error, - llvm::sys::fs::F_None); + auto result = std::make_unique(output_filename, error, + llvm::sys::fs::F_None); if (error) { LOG(ERROR) << error.message(); return mlir::failure(); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 15bb0a863d1..307eb1d3213 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -4,7 +4,7 @@ load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites") load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) @@ -36,7 +36,6 @@ py_library( srcs_version = "PY2AND3", visibility = [":friends"], deps = [ - "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -46,6 +45,7 @@ py_library( "//tensorflow/python:random_seed", "//tensorflow/python:session", "//tensorflow/python:variables", + "//tensorflow/python/compiler/xla:compiler_py", "//third_party/py/numpy", ], ) @@ -665,6 +665,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_diag_ops_test", + size = "medium", + timeout = "long", + srcs = ["matrix_diag_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "momentum_test", size = "small", @@ -1024,7 +1037,10 @@ tf_xla_py_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - tags = ["notap"], # b/136030724 + tags = [ + "noguitar", # TODO(b/140174740): Re-enable when fixed. + "notap", # b/136030724 + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1179,6 +1195,7 @@ cuda_py_test( "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], + xla_enable_strict_auto_jit = False, ) cuda_py_test( @@ -1187,7 +1204,7 @@ cuda_py_test( srcs = ["jit_test.py"], additional_deps = [ ":test_utils", - "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -1204,21 +1221,23 @@ cuda_py_test( "nogpu", "no_cuda_on_cpu_tap", ], + xla_enable_strict_auto_jit = False, ) cuda_py_test( name = "dense_layer_test", - size = "small", + size = "medium", srcs = ["dense_layer_test.py"], additional_deps = [ ":test_utils", - "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:layers", "//tensorflow/python:variables", ], + xla_enable_strict_auto_jit = False, ) cc_library( @@ -1304,6 +1323,7 @@ cuda_py_test( "//tensorflow/python:platform", "//tensorflow/python:variables", ], + xla_enable_strict_auto_jit = False, ) # An example of ahead-of-time compilation using tfcompile. The @@ -1382,3 +1402,20 @@ tf_xla_py_test( "@absl_py//absl/testing:parameterized", ], ) + +tf_xla_py_test( + name = "conv_node_name_test", + size = "medium", + srcs = ["conv_node_name_test.py"], + shard_count = 5, + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:layers", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:nn_ops_gen", + "//tensorflow/python:platform_test", + ], +) diff --git a/tensorflow/compiler/tests/adagrad_da_test.py b/tensorflow/compiler/tests/adagrad_da_test.py index 369d0097a0f..e08435b5713 100644 --- a/tensorflow/compiler/tests/adagrad_da_test.py +++ b/tensorflow/compiler/tests/adagrad_da_test.py @@ -56,9 +56,9 @@ class AdagradDAOptimizerTest(xla_test.XLATestCase): # Run a step of AdagradDA update.run() - # Let g to be gradient accumulator, gg to be gradient squared - # accumulator, T be the global step, lr is the learning rate, and k the - # initial gradient squared accumulator value. + # Let g be the gradient accumulator, gg be the gradient squared + # accumulator, T be the global step, lr be the learning rate, + # and k the initial gradient squared accumulator value. # w = \dfrac{sign(-g)*lr*|g - l1*T|_{+}}{l2*T*lr + \sqrt{k+gg})} # For -0.1*3.0*(0.1 - 0)/(0 + sqrt(0.1 + 0.1*0.1)) = -0.904534 # similarly for others. diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 0171be42148..14af571d62f 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -1464,53 +1463,6 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) - def testMatrixSetDiag(self): - # TODO(penporn): Once XLA supports MatrixSetDiagV2, change the call to - # gen_array_ops.matrix_set_diag (V1) to array_ops.matrix_set_diag (V2). - for dtype in self.numeric_types: - # Square - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], - dtype=dtype), - np.array([1.0, 2.0, 3.0], dtype=dtype), - expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]], - dtype=dtype)) - - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], - [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]], - dtype=dtype), - np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype), - expected=np.array( - [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]], - [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]], - dtype=dtype)) - - # Rectangular - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype), - np.array([3.0, 4.0], dtype=dtype), - expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype)) - - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype), - np.array([3.0, 4.0], dtype=dtype), - expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype)) - - self._testBinary( - gen_array_ops.matrix_set_diag, - np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], - [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype), - np.array([[-1.0, -2.0], [-4.0, -5.0]], - dtype=dtype), - expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], - [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], - dtype=dtype)) - def testBroadcastTo(self): for dtype in self.all_types: x = np.random.randint(0, high=100, size=[2, 3]) diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 96d389a81f2..a3b17e42fb0 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -3,7 +3,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", "tf_exec_compatible_with", ) diff --git a/tensorflow/compiler/tests/conv_node_name_test.py b/tensorflow/compiler/tests/conv_node_name_test.py new file mode 100644 index 00000000000..85e8bce8617 --- /dev/null +++ b/tensorflow/compiler/tests/conv_node_name_test.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Convolution node name match via the XLA JIT. + +The canned results in these tests are created by running each test using the +Tensorflow CPU device and saving the output. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import ops +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.platform import googletest + + +class ConvolutionNodeNameTest(xla_test.XLATestCase): + """Verify convolution node name match. + + Verify convolution node names on TPU and CPU match with dilation > 1. + """ + + def _verifyNodeNameMatch(self, layer, input_sizes, filter_sizes, strides, + dilations): + + def _GetNodeNames(use_xla): + with self.session(): + input_tensor = array_ops.placeholder(np.float32, shape=input_sizes) + + if use_xla: + with self.test_scope(): + # pylint: disable=protected-access + graph = ops.get_default_graph() + graph._set_control_flow_context( + control_flow_ops.XLAControlFlowContext()) + # pylint: enable=protected-access + conv2d_op = layer( + filters=64, + kernel_size=filter_sizes, + dilation_rate=dilations, + padding="same") + _ = conv2d_op(input_tensor) + return [n.name for n in ops.get_default_graph().as_graph_def().node] + else: + with ops.device("CPU"): + conv2d_op = layer( + filters=64, + kernel_size=filter_sizes, + dilation_rate=dilations, + padding="same") + _ = conv2d_op(input_tensor) + names = [ + n.name for n in ops.get_default_graph().as_graph_def().node + ] + # filter out space to depth ops. + return [ + name for name in names + if "space" not in name and "Space" not in name + ] + + xla_names = _GetNodeNames(use_xla=True) + no_xla_names = _GetNodeNames(use_xla=False) + self.assertListEqual( + xla_names, + no_xla_names, + ) + + def testConv1DNodeNameMatch(self): + input_sizes = [8, 16, 3] + filter_sizes = [7] + strides = 1 + dilations = [2] + layer = layers.Conv1D + self._verifyNodeNameMatch(layer, input_sizes, filter_sizes, strides, + dilations) + + def testConv2DNodeNameMatch(self): + input_sizes = [8, 16, 16, 3] + filter_sizes = [7, 7] + strides = 1 + dilations = [2, 2] + layer = layers.Conv2D + self._verifyNodeNameMatch(layer, input_sizes, filter_sizes, strides, + dilations) + + def testConv3DNodeNameMatch(self): + input_sizes = [8, 16, 16, 16, 3] + filter_sizes = [7, 7, 7] + strides = 1 + dilations = [2, 2, 2] + layer = layers.Conv3D + self._verifyNodeNameMatch(layer, input_sizes, filter_sizes, strides, + dilations) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 74f16292334..8020aa28ce4 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -22,8 +22,8 @@ import os import numpy as np from tensorflow.compiler.tests import test_utils -from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.compiler.xla import jit from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index c55bc23cf47..a49985f0446 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -87,6 +88,32 @@ def ConfigsToTest(): yield i, f, o, s, p +def ConfigsWithDilationsToTest(): + """Iterator for different convolution shapes, strides and paddings. + + Yields: + Tuple (input_size, filter_size, out_size, stride, dilation, padding), the + depthwise + convolution parameters. + """ + input_sizes = [[4, 6, 6, 48], [4, 8, 8, 84], [4, 36, 36, 2], [4, 148, 148, 2], + [3, 300, 300, 3]] + filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [5, 5, 2, 1], [4, 4, 2, 8], + [2, 2, 3, 8]] + out_sizes = [[4, 6, 6, 96], [4, 8, 8, 84], [4, 36, 36, 2], [4, 74, 74, 16], + [3, 296, 296, 24]] + strides = [1, 1, 2, 2, 1] + dilations = [2, 2, 4, 2, 4] + # pylint: disable=invalid-name + VALID = "VALID" + SAME = "SAME" + # pylint: enable=invalid-name + paddings = [SAME, SAME, SAME, SAME, VALID] + for i, f, o, s, d, p in zip(input_sizes, filter_sizes, out_sizes, strides, + dilations, paddings): + yield i, f, o, s, d, p + + def CheckGradConfigsToTest(): """Iterator for different convolution shapes, strides and paddings. @@ -315,6 +342,118 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) + # This is testing that depthwise_conv2d with dilation produces + # the same results between CPU and TPU. It also tests that NCHW + # and NWHC formats agree. + def _VerifyValuesWithDilation(self, + tensor_in_sizes, + filter_in_sizes, + stride, + dilation, + padding, + data_type, + data_format="NHWC"): + """Verifies the output values of the convolution function. + + Args: + tensor_in_sizes: Input tensor dimensions in [batch, input_rows, + input_cols, input_depth]. + filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols, + input_depth, depth_multiplier]. + stride: Stride. + dilation: Dilation. + padding: Padding type. + data_type: The data type to use. + data_format: The data_format of the input. "NHWC" or "NCHW". + """ + total_size_1 = 1 + total_size_2 = 1 + for s in tensor_in_sizes: + total_size_1 *= s + for s in filter_in_sizes: + total_size_2 *= s + # Initializes the input and filter tensor with numbers incrementing from 1. + x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)], + dtype=data_type).reshape(tensor_in_sizes) + x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)], + dtype=data_type).reshape(filter_in_sizes) + with self.session() as sess: + if data_type == np.float32: + # TODO(b/64210055): Tolerance for TPU is high. + tolerance = 1e-2 + else: + self.assertEqual(data_type, np.float64) + tolerance = 1e-8 + + t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type) + t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type) + + native_t1 = t1 + strides = [1, stride, stride, 1] + dilations = [dilation, dilation] + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t1 = array_ops.transpose(t1, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] + + with self.test_scope(): + conv_native = nn_impl.depthwise_conv2d( + native_t1, + t2, + strides=strides, + rate=dilations, + data_format=data_format, + padding=padding) + + if data_format == "NCHW": + # Transpose back from NCHW to NHWC + conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1]) + + with ops.device("CPU"): + # CPU only support NHWC format + strides = [1, stride, stride, 1] + conv_interface = nn_impl.depthwise_conv2d( + t1, t2, strides=strides, rate=dilations, padding=padding) + + native_result = sess.run(conv_native, {t1: x1, t2: x2}) + interface_result = sess.run(conv_interface, {t1: x1, t2: x2}) + + print("data_type:", data_type, "max diff = ", + np.amax(np.absolute(native_result - interface_result))) + self.assertAllClose( + np.ravel(native_result), np.ravel(interface_result), rtol=tolerance) + + def testDilationDepthwiseConv2DWith(self): + for index, (input_size, filter_size, _, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2D,", index, "th config:", input_size, + "*", filter_size, "stride:", stride, "dilation: ", dilation, + "padding:", padding) + for data_type in self.float_types: + # TODO(phawkins): the reference implementation only supports float32. + if data_type == np.float32: + self._VerifyValuesWithDilation(input_size, filter_size, stride, + dilation, padding, data_type) + + def testDilationDepthwiseConv2DWithFormat(self): + for index, (input_size, filter_size, _, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2DFormat,", index, "th config:", + input_size, "*", filter_size, "stride:", stride, "dilation:", + dilation, "padding:", padding) + for data_type in self.float_types: + # TODO(phawkins): the reference implementation only supports float32. + if data_type == np.float32: + self._VerifyValuesWithDilation( + input_size, + filter_size, + stride, + dilation, + padding, + data_type, + data_format="NCHW") + def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes, stride, padding): x1 = np.random.rand(*filter_sizes).astype(np.float32) @@ -420,5 +559,139 @@ class DepthwiseConv2DTest(xla_test.XLATestCase): padding, data_format="NCHW") + def _CompareBackpropInputWithDilation(self, input_sizes, filter_sizes, + output_sizes, stride, dilation, + padding): + x1 = np.random.rand(*filter_sizes).astype(np.float32) + x2 = np.random.rand(*output_sizes).astype(np.float32) + + def _GetVal(use_xla): + with self.session(): + t1 = array_ops.placeholder(np.float32, shape=filter_sizes) + t2 = array_ops.placeholder(np.float32, shape=output_sizes) + if use_xla: + with self.test_scope(): + t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)]) + backprop = nn_ops.depthwise_conv2d_native_backprop_input( + t0, + t1, + t2, + strides=[1, stride, stride, 1], + dilations=[1, dilation, dilation, 1], + padding=padding) + else: + # TODO(wangtao): figure out gradient with stride > 1. + # depthwise_conv2d_native_backprop_input on CPU doesn't support + # dilation. + t3 = array_ops.space_to_batch( + t2, block_size=dilation, paddings=[[0, 0], [0, 0]]) + input_sizes_transform = [ + input_sizes[0] * dilation * dilation, input_sizes[1] // dilation, + input_sizes[2] // dilation, input_sizes[3] + ] + t0 = constant_op.constant( + input_sizes_transform, shape=[len(input_sizes)]) + backprop_naive = nn_ops.depthwise_conv2d_native_backprop_input( + t0, t1, t3, strides=[1, stride, stride, 1], padding=padding) + backprop = array_ops.batch_to_space( + backprop_naive, [[0, 0], [0, 0]], block_size=dilation) + + ret = backprop.eval({t1: x1, t2: x2}) + self.assertShapeEqual(ret, backprop) + return ret + + gpu_value = _GetVal(use_xla=True) + cpu_value = _GetVal(use_xla=False) + + # TODO (b/64210055): Tolerance for TPU is high. + self.assertAllClose(cpu_value, gpu_value, rtol=1e-2, atol=1e-3) + + def testDilationDepthwiseConv2DInputGradWithCompare(self): + for index, (input_size, filter_size, output_size, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2DInputGradWithDilationCompare,", + index, "th config:", input_size, "*", filter_size, "stride:", + stride, "dilation:", dilation, "padding:", padding) + # TODO(wangtao): implement CPU grad computation with stride > 1. + if stride == 1: + self._CompareBackpropInputWithDilation(input_size, filter_size, + output_size, stride, dilation, + padding) + + def _CompareBackpropFilterWithDilation(self, + input_sizes, + filter_sizes, + output_sizes, + stride, + dilation, + padding, + data_format="NHWC"): + x0 = np.random.rand(*input_sizes).astype(np.float32) + x2 = np.random.rand(*output_sizes).astype(np.float32) + + def _GetVal(use_xla): + with self.session(): + t0 = array_ops.placeholder(np.float32, shape=input_sizes) + t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)]) + t2 = array_ops.placeholder(np.float32, shape=output_sizes) + native_t0 = t0 + native_t2 = t2 + strides = [1, stride, stride, 1] + dilations = [1, dilation, dilation, 1] + + if use_xla: + if data_format == "NCHW": + # Transpose from NWHC input to NCHW + # Ex. [4, 5, 5, 48] to [4, 48, 5, 5] + native_t0 = array_ops.transpose(t0, [0, 3, 1, 2]) + native_t2 = array_ops.transpose(t2, [0, 3, 1, 2]) + strides = [1, 1, stride, stride] + dilations = [1, 1, dilation, dilation] + with self.test_scope(): + backprop = nn_ops.depthwise_conv2d_native_backprop_filter( + native_t0, + t1, + native_t2, + strides=strides, + padding=padding, + dilations=dilations, + data_format=data_format) + else: + # For CPU, the format NCHW is not supported. Therefore we always use + # NHWC here. + # depthwise_conv2d_native_backprop_filter on CPU doesn't support + # dilation. + native_t3 = array_ops.space_to_batch( + native_t2, block_size=dilation, paddings=[[0, 0], [0, 0]]) + native_t0_transform = array_ops.space_to_batch( + native_t0, block_size=dilation, paddings=[[0, 0], [0, 0]]) + backprop = nn_ops.depthwise_conv2d_native_backprop_filter( + native_t0_transform, + t1, + native_t3, + strides=strides, + padding=padding) + ret = backprop.eval({t0: x0, t2: x2}) + self.assertShapeEqual(ret, backprop) + return ret + + gpu_value = _GetVal(use_xla=True) + cpu_value = _GetVal(use_xla=False) + # TODO(b/64210055): Tolerance for TPU is high. + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-4) + + def testDilationDepthwiseConv2DFilterGradCompare(self): + for index, (input_size, filter_size, output_size, stride, dilation, + padding) in enumerate(ConfigsWithDilationsToTest()): + print("Testing DilationDepthwiseConv2DFilterGradCompare,", index, + "th config:", input_size, "*", filter_size, "producing output", + output_size, "stride:", stride, "dilation:", dilation, "padding:", + padding) + if stride == 1: + # TODO(wangtao): implement CPU grad computation with stride > 1. + self._CompareBackpropFilterWithDilation(input_size, filter_size, + output_size, stride, dilation, + padding) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index d2c459bf1ec..a03980f20ba 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -693,8 +693,7 @@ class EagerFunctionTest(xla_test.XLATestCase): return x, y wholly_compiled_f = def_function.function(f) - op_by_op_f = function.defun_with_attributes( - f, attributes={'_XlaCompile': False}) + op_by_op_f = def_function.function(f, experimental_compile=False) x = constant_op.constant([0.0, 2.0], name='data') diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index fb4b2711905..5889a011296 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -514,6 +514,27 @@ class ResizeNearestNeighborTest(xla_test.XLATestCase): [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], dtype=np.float32)) + def testAlignCorners3x3To12x12_uint8(self): + # TODO(b/72099414): enable the test for TPU when the issue is fixed. + if (self.device not in ["XLA_GPU", "XLA_CPU"]): + return + # Ensure that resize with convolution works on XLA/GPU for integer types + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8), [12, 12], + expected=np.array([[1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9]], + dtype=np.uint8)) + class ResizeBilinearTest(parameterized.TestCase, xla_test.XLATestCase): diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 29444c19014..109a7932c20 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -22,10 +22,10 @@ import os import numpy as np from tensorflow.compiler.tests import test_utils -from tensorflow.contrib.compiler import jit from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as session_lib +from tensorflow.python.compiler.xla import jit from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function diff --git a/tensorflow/compiler/tests/matrix_diag_ops_test.py b/tensorflow/compiler/tests/matrix_diag_ops_test.py new file mode 100644 index 00000000000..6437c2749af --- /dev/null +++ b/tensorflow/compiler/tests/matrix_diag_ops_test.py @@ -0,0 +1,655 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA matrix diag ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.compat import compat +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +# Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2. +# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py +def square_cases(): + # pyformat: disable + mat = np.array([[[1, 2, 3, 4, 5], + [6, 7, 8, 9, 1], + [3, 4, 5, 6, 7], + [8, 9, 1, 2, 3], + [4, 5, 6, 7, 8]], + [[9, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 1], + [2, 3, 4, 5, 6]]]) + tests = dict() + # tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals) + tests[-1, -1] = (np.array([[6, 4, 1, 7], + [5, 2, 8, 5]]), + np.array([[[0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 7, 0]], + [[0, 0, 0, 0, 0], + [5, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + [0, 0, 8, 0, 0], + [0, 0, 0, 5, 0]]])) + tests[-4, -3] = (np.array([[[8, 5], + [4, 0]], + [[6, 3], + [2, 0]]]), + np.array([[[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [8, 0, 0, 0, 0], + [4, 5, 0, 0, 0]], + [[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [2, 3, 0, 0, 0]]])) + tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0], + [1, 7, 5, 2, 8], + [6, 4, 1, 7, 0], + [3, 9, 6, 0, 0]], + [[1, 7, 4, 1, 0], + [9, 6, 3, 9, 6], + [5, 2, 8, 5, 0], + [1, 7, 4, 0, 0]]]), + np.array([[[1, 2, 0, 0, 0], + [6, 7, 8, 0, 0], + [3, 4, 5, 6, 0], + [0, 9, 1, 2, 3], + [0, 0, 6, 7, 8]], + [[9, 1, 0, 0, 0], + [5, 6, 7, 0, 0], + [1, 2, 3, 4, 0], + [0, 7, 8, 9, 1], + [0, 0, 4, 5, 6]]])) + tests[2, 4] = (np.array([[[5, 0, 0], + [4, 1, 0], + [3, 9, 7]], + [[4, 0, 0], + [3, 9, 0], + [2, 8, 5]]]), + np.array([[[0, 0, 3, 4, 5], + [0, 0, 0, 9, 1], + [0, 0, 0, 0, 7], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 0, 2, 3, 4], + [0, 0, 0, 8, 9], + [0, 0, 0, 0, 5], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]])) + # pyformat: enable + return (mat, tests) + + +def tall_cases(): + # pyformat: disable + mat = np.array([[[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [9, 8, 7], + [6, 5, 4]], + [[3, 2, 1], + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [9, 8, 7]]]) + tests = dict() + # tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals) + tests[0, 0] = (np.array([[1, 5, 9], + [3, 2, 6]]), + np.array([[[1, 0, 0], + [0, 5, 0], + [0, 0, 9], + [0, 0, 0]], + [[3, 0, 0], + [0, 2, 0], + [0, 0, 6], + [0, 0, 0]]])) + tests[-4, -3] = (np.array([[[9, 5], + [6, 0]], + [[7, 8], + [9, 0]]]), + np.array([[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [9, 0, 0], + [6, 5, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [7, 0, 0], + [9, 8, 0]]])) + tests[-2, -1] = (np.array([[[4, 8, 7], + [7, 8, 4]], + [[1, 5, 9], + [4, 8, 7]]]), + np.array([[[0, 0, 0], + [4, 0, 0], + [7, 8, 0], + [0, 8, 7], + [0, 0, 4]], + [[0, 0, 0], + [1, 0, 0], + [4, 5, 0], + [0, 8, 9], + [0, 0, 7]]])) + tests[-2, 1] = (np.array([[[2, 6, 0], + [1, 5, 9], + [4, 8, 7], + [7, 8, 4]], + [[2, 3, 0], + [3, 2, 6], + [1, 5, 9], + [4, 8, 7]]]), + np.array([[[1, 2, 0], + [4, 5, 6], + [7, 8, 9], + [0, 8, 7], + [0, 0, 4]], + [[3, 2, 0], + [1, 2, 3], + [4, 5, 6], + [0, 8, 9], + [0, 0, 7]]])) + tests[1, 2] = (np.array([[[3, 0], + [2, 6]], + [[1, 0], + [2, 3]]]), + np.array([[[0, 2, 3], + [0, 0, 6], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 2, 1], + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]])) + # pyformat: enable + return (mat, tests) + + +def fat_cases(): + # pyformat: disable + mat = np.array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 1, 2, 3]], + [[4, 5, 6, 7], + [8, 9, 1, 2], + [3, 4, 5, 6]]]) + tests = dict() + # tests[d_lower, d_upper] = (compact_diagonals, padded_diagnals) + tests[0, 0] = (np.array([[1, 6, 2], + [4, 9, 5]]), + np.array([[[1, 0, 0, 0], + [0, 6, 0, 0], + [0, 0, 2, 0]], + [[4, 0, 0, 0], + [0, 9, 0, 0], + [0, 0, 5, 0]]])) + tests[2, 2] = (np.array([[3, 8], + [6, 2]]), + np.array([[[0, 0, 3, 0], + [0, 0, 0, 8], + [0, 0, 0, 0]], + [[0, 0, 6, 0], + [0, 0, 0, 2], + [0, 0, 0, 0]]])) + tests[-2, 0] = (np.array([[[1, 6, 2], + [5, 1, 0], + [9, 0, 0]], + [[4, 9, 5], + [8, 4, 0], + [3, 0, 0]]]), + np.array([[[1, 0, 0, 0], + [5, 6, 0, 0], + [9, 1, 2, 0]], + [[4, 0, 0, 0], + [8, 9, 0, 0], + [3, 4, 5, 0]]])) + tests[-1, 1] = (np.array([[[2, 7, 3], + [1, 6, 2], + [5, 1, 0]], + [[5, 1, 6], + [4, 9, 5], + [8, 4, 0]]]), + np.array([[[1, 2, 0, 0], + [5, 6, 7, 0], + [0, 1, 2, 3]], + [[4, 5, 0, 0], + [8, 9, 1, 0], + [0, 4, 5, 6]]])) + tests[0, 3] = (np.array([[[4, 0, 0], + [3, 8, 0], + [2, 7, 3], + [1, 6, 2]], + [[7, 0, 0], + [6, 2, 0], + [5, 1, 6], + [4, 9, 5]]]), + np.array([[[1, 2, 3, 4], + [0, 6, 7, 8], + [0, 0, 2, 3]], + [[4, 5, 6, 7], + [0, 9, 1, 2], + [0, 0, 5, 6]]])) + # pyformat: enable + return (mat, tests) + + +class MatrixDiagTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + params, + solution, + rtol=1e-3, + atol=1e-5): + """Verifies that matrix_diag produces `solution` when fed `params`. + + Args: + params: dictionary containing input parameters to matrix_diag. + solution: numpy array representing the expected output of matrix_diag. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + diagonal = params["diagonal"] + with self.session() as session: + for dtype in self.numeric_types - {np.int8, np.uint8}: + expected = solution.astype(dtype) + with self.test_scope(): + params["diagonal"] = array_ops.placeholder( + dtype, diagonal.shape, name="diagonal") + output = array_ops.matrix_diag(**params) + result = session.run(output, + {params["diagonal"]: diagonal.astype(dtype)}) + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + + # Generic tests applicable to both v1 and v2 ops. + # Originally from unary_ops_tests.py. + def testV1(self): + # pyformat: disable + vecs1 = np.array([[1, 2], + [3, 4]]) + solution1 = np.array([[[1, 0], [0, 2]], + [[3, 0], [0, 4]]]) + vecs2 = np.array([1, 2, 3, 4]) + solution2 = np.array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]]) + vecs3 = np.array([[[1, 2, 3], + [4, 5, 6]], + [[7, 8, 9], # pylint: disable=bad-whitespace + [10, 11, 12]]]) + solution3 = np.array([[[[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], + [[4, 0, 0], + [0, 5, 0], + [0, 0, 6]]], + [[[7, 0, 0], + [0, 8, 0], + [0, 0, 9]], + [[10, 0, 0], + [0, 11, 0], + [0, 0, 12]]]]) + # pyformat: enable + self._assertOpOutputMatchesExpected({"diagonal": vecs1}, solution1) + self._assertOpOutputMatchesExpected({"diagonal": vecs2}, solution2) + self._assertOpOutputMatchesExpected({"diagonal": vecs3}, solution3) + + # From here onwards are v2-only tests. + def testSquare(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases()]: + for diag_index, (vecs, solution) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs[0], + "k": diag_index + }, solution[0]) + + def testSquareBatch(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases()]: + for diag_index, (vecs, solution) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index + }, solution) + + def testRectangularBatch(self): + # LINT.IfChange + if not compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + return + + # Stores expected num_rows and num_cols (when the other is given). + # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols) + test_list = list() + + # Square cases: + expected = { + (-1, -1): (5, 4), + (-4, -3): (5, 2), + (-2, 1): (5, 5), + (2, 4): (3, 5), + } + test_list.append((expected, square_cases())) + + # Tall cases + expected = { + (0, 0): (3, 3), + (-4, -3): (5, 2), + (-2, -1): (4, 3), + (-2, 1): (3, 3), + (1, 2): (2, 3) + } + test_list.append((expected, tall_cases())) + + # Fat cases + expected = { + (2, 2): (2, 4), + (-2, 0): (3, 3), + (-1, 1): (3, 3), + (0, 3): (3, 3) + } + test_list.append((expected, fat_cases())) + + # Giving both num_rows and num_cols + for _, tests in [tall_cases(), fat_cases()]: + for diag_index, (vecs, solution) in tests.items(): + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution.shape[-2], + "num_cols": solution.shape[-1] + }, solution) + + # Giving just num_rows or num_cols. + for expected, (_, tests) in test_list: + for diag_index, (new_num_rows, new_num_cols) in expected.items(): + vecs, solution = tests[diag_index] + solution_given_num_rows = solution.take( + indices=range(new_num_cols), axis=-1) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution_given_num_rows.shape[-2] + }, solution_given_num_rows) + solution_given_num_cols = solution.take( + indices=range(new_num_rows), axis=-2) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_cols": solution_given_num_cols.shape[-1] + }, solution_given_num_cols) + + def testPadding(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for padding_value in [555, -11]: + for _, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (vecs, solution) in tests.items(): + mask = (solution == 0) + solution = solution + (mask * padding_value) + self._assertOpOutputMatchesExpected( + { + "diagonal": vecs, + "k": diag_index, + "num_rows": solution.shape[-2], + "num_cols": solution.shape[-1], + "padding_value": padding_value + }, solution) + + +class MatrixSetDiagTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + params, + solution, + rtol=1e-3, + atol=1e-5): + """Verifies that matrix_set_diag produces `solution` when fed `params`. + + Args: + params: dictionary containing input parameters to matrix_set_diag. + solution: numpy array representing the expected output of matrix_set_diag. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + input = params["input"] # pylint: disable=redefined-builtin + diagonal = params["diagonal"] + with self.session() as session: + for dtype in self.numeric_types - {np.int8, np.uint8}: + expected = solution.astype(dtype) + with self.test_scope(): + params["input"] = array_ops.placeholder( + dtype, input.shape, name="input") + params["diagonal"] = array_ops.placeholder( + dtype, diagonal.shape, name="diagonal") + output = array_ops.matrix_set_diag(**params) + result = session.run( + output, { + params["input"]: input.astype(dtype), + params["diagonal"]: diagonal.astype(dtype) + }) + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + + # Generic tests applicable to both v1 and v2 ops. + # Originally from binary_ops_tests.py. + def testV1(self): + test_cases = list() + + # pyformat: disable + # pylint: disable=bad-whitespace + # Square cases. + input = np.array([[0, 1, 0], # pylint: disable=redefined-builtin + [1, 0, 1], + [1, 1, 1]]) + diag = np.array([1, 2, 3]) + solution = np.array([[1, 1, 0], + [1, 2, 1], + [1, 1, 3]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + input = np.array([[[1, 0, 3], + [0, 2, 0], + [1, 0, 3]], + [[4, 0, 4], + [0, 5, 0], + [2, 0, 6]]]) + diag = np.array([[-1, 0, -3], + [-4, -5, -6]]) + solution = np.array([[[-1, 0, 3], + [ 0, 0, 0], + [ 1, 0, -3]], + [[-4, 0, 4], + [ 0, -5, 0], + [ 2, 0, -6]]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + # Rectangular cases. + input = np.array([[0, 1, 0], + [1, 0, 1]]) + diag = np.array([3, 4]) + solution = np.array([[3, 1, 0], + [1, 4, 1]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + input = np.array([[0, 1], + [1, 0], + [1, 1]]) + diag = np.array([3, 4]) + solution = np.array([[3, 1], + [1, 4], + [1, 1]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + + input = np.array([[[1, 0, 3], + [0, 2, 0]], + [[4, 0, 4], + [0, 5, 0]]]) + diag = np.array([[-1, -2], [-4, -5]]) + solution = np.array([[[-1, 0, 3], + [ 0, -2, 0]], + [[-4, 0, 4], + [ 0, -5, 0]]]) + test_cases.append(({"input": input, "diagonal": diag}, solution)) + # pylint: enable=bad-whitespace + # pyformat: enable + + for test in test_cases: + self._assertOpOutputMatchesExpected(test[0], test[1]) + + # From here onwards are v2-only tests. + def testSingleMatrix(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat[0] == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat[0] + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs[0], + "k": diag_index + }, solution) + + def testBatch(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for _, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (vecs, banded_mat) in tests.items(): + mask = (banded_mat == 0) + input_mat = np.random.randint(10, size=mask.shape) + solution = input_mat * mask + banded_mat + self._assertOpOutputMatchesExpected( + { + "input": input_mat, + "diagonal": vecs, + "k": diag_index + }, solution) + + +class MatrixDiagPartTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, + params, + solution, + rtol=1e-3, + atol=1e-5): + """Verifies that matrix_diag_part produces `solution` when fed `params`. + + Args: + params: dictionary containing input parameters to matrix_diag_part. + solution: numpy array representing the expected output. + rtol: relative tolerance for equality test. + atol: absolute tolerance for equality test. + """ + input = params["input"] # pylint: disable=redefined-builtin + with self.session() as session: + for dtype in self.numeric_types - {np.int8, np.uint8}: + expected = solution.astype(dtype) + with self.test_scope(): + params["input"] = array_ops.placeholder( + dtype, input.shape, name="input") + output = array_ops.matrix_diag_part(**params) + result = session.run(output, { + params["input"]: input.astype(dtype), + }) + self.assertEqual(output.dtype, expected.dtype) + self.assertAllCloseAccordingToType( + expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + + # Generic tests applicable to both v1 and v2 ops. + # Originally from unary_ops_tests.py. + def testV1(self): + matrices = np.arange(3 * 2 * 4).reshape([3, 2, 4]) + solution = np.array([[0, 5], [8, 13], [16, 21]]) + self._assertOpOutputMatchesExpected({"input": matrices}, solution) + + # From here onwards are v2-only tests. + def testSingleMatrix(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected({ + "input": mat[0], + "k": diag_index + }, solution[0]) + + def testBatch(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (solution, _) in tests.items(): + self._assertOpOutputMatchesExpected({ + "input": mat, + "k": diag_index + }, solution) + + def testPadding(self): + # LINT.IfChange + if compat.forward_compatible(2019, 8, 31): + # LINT.ThenChange(//tensorflow/python/ops/array_ops.py) + for padding_value in [555, -11]: + for mat, tests in [square_cases(), tall_cases(), fat_cases()]: + for diag_index, (solution, _) in tests.items(): + mask = (solution == 0) + solution = solution + (mask * padding_value) + self._assertOpOutputMatchesExpected( + { + "input": mat, + "k": diag_index, + "padding_value": padding_value + }, solution) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/stateful_random_ops_test.py b/tensorflow/compiler/tests/stateful_random_ops_test.py index a54cd60cfd7..343969c40d7 100644 --- a/tensorflow/compiler/tests/stateful_random_ops_test.py +++ b/tensorflow/compiler/tests/stateful_random_ops_test.py @@ -278,10 +278,11 @@ class StatefulRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): maxval = 1 if dtype.is_integer: maxval = 100 - x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy() + t = gen.uniform(shape=[n], maxval=maxval, dtype=dtype) + x = t.numpy().astype(float) if maxval > 1: # Normalize y to range [0, 1). - x = x.astype(float) / maxval + x = x / maxval # Tests that the values are distributed amongst 10 bins with equal # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 8eba83e285d..6576e274300 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,9 +86,9 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): x = stateless.stateless_random_uniform( shape=[n], seed=seed_t, maxval=maxval, dtype=dtype) y = sess.run(x, {seed_t: [565656, 121212]}) - if maxval > 1: - # Normalize y to range [0, 1). - y = y.astype(float) / maxval + # Convert y to float and normalize its value to range [0, 1) when + # maxval != 1. + y = y.astype(float) / maxval # Tests that the values are distributed amongst 10 bins with equal # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with # p=0.05. This test is probabilistic and would be flaky if the random diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py index b24e807b034..7d2425ee205 100644 --- a/tensorflow/compiler/tests/tensor_list_ops_test.py +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os +from absl.testing import parameterized import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op @@ -29,7 +30,7 @@ from tensorflow.python.ops import list_ops from tensorflow.python.platform import test -class ListOpsTest(xla_test.XLATestCase): +class ListOpsTest(parameterized.TestCase, xla_test.XLATestCase): def testElementShape(self): with self.session() as sess, self.test_scope(): @@ -204,6 +205,20 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(t.shape.as_list(), [None]) self.assertAllEqual(t, [1.0, 2.0]) + @parameterized.named_parameters( + ("FlatList", [1.0, 2.0, 3.0], [], [0, 2], [1.0, 3.0]), + ("NestedList", [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0] + ], [2], [1], [[3.0, 4.0]]), + ("EmptyIndices", [1.0, 2.0, 3.0], [], [], []), + ) + def testGather(self, input_list, element_shape, indices, output): + with self.session(), self.test_scope(): + tensor_list = list_ops.tensor_list_from_tensor( + input_list, element_shape=element_shape) + gather_t = list_ops.tensor_list_gather( + tensor_list, indices, element_dtype=dtypes.float32) + self.assertAllEqual(gather_t, output) + def testStackWithUninitializedTensors(self): with self.session(), self.test_scope(): l = list_ops.tensor_list_reserve( @@ -224,6 +239,6 @@ class ListOpsTest(xla_test.XLATestCase): self.assertAllEqual(z, [0.0, 0.0]) if __name__ == "__main__": - os.environ['TF_XLA_FLAGS'] = ('--tf_xla_min_cluster_size=2 ' + - os.environ.get('TF_XLA_FLAGS', '')) + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_min_cluster_size=2 " + + os.environ.get("TF_XLA_FLAGS", "")) test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index bac30b63bf8..349dabbb393 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -27,7 +27,6 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -108,31 +107,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([[-1, 1]], dtype=dtype), expected=np.array([[-1, 1]], dtype=dtype)) - # TODO(penporn): Once XLA supports MatrixDiagV2, change the call to - # gen_array_ops.matrix_diag* (V1) to array_ops.matrix_diag* (V2). - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), - np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), - np.array( - [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], - dtype=dtype)) - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag, - np.array( - [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), - np.array( - [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [ - 0, 0, 6 - ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0], - [0, 0, 12]]]], - dtype=dtype)) - self._assertOpOutputMatchesExpected( - gen_array_ops.matrix_diag_part, - np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), - np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype)) - self._assertOpOutputMatchesExpected( array_ops.prevent_gradient, np.array([[-1, 1]], dtype=dtype), @@ -323,11 +297,12 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.tanh, - np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype), - expected=np.array( - [[0.76159418, 0.76159418, 0.76159418, 0.76159418], - [0.76159418, 0.96402758, 0.99505478, 0.99932933]], - dtype=dtype)) + np.array( + [[1, 2, 3, 4], [np.inf, -np.inf, np.nan, 20], [19, -19, 22, -22]], + dtype=dtype), + expected=np.array([[0.76159418, 0.96402758, 0.99505478, 0.99932933], + [1.0, -1.0, np.nan, 1.0], [1.0, -1.0, 1.0, -1.0]], + dtype=dtype)) self._assertOpOutputMatchesExpected( nn_ops.log_softmax, diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 59e46b06d68..d6e02ecc827 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -25,12 +25,12 @@ import re import numpy as np -from tensorflow.python.eager import context -from tensorflow.contrib.compiler import jit -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session +from tensorflow.python.compiler.xla import jit +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index bfaae215709..79afa0b82dd 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -17,19 +17,13 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_proto_library", ) load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") # Placeholder for Google-internal load statements. -# NOTE: we always assume that if_static returns "otherwise" list in open source. -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "if_static", -) - package( default_visibility = ["//visibility:public"], licenses = ["notice"], # Apache 2.0 @@ -53,10 +47,10 @@ cc_library( alias( name = "tensorrt_lib", - actual = if_static( - "@local_config_tensorrt//:tensorrt", - ":tensorrt_stub", - ), + actual = select({ + "//tensorflow:oss": ":tensorrt_stub", + "//conditions:default": "@local_config_tensorrt//:tensorrt", + }), visibility = ["//visibility:private"], ) @@ -97,10 +91,17 @@ cc_library( ":utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@local_config_cuda//cuda:cuda_headers", + "//tensorflow/core:core_cpu_lib_no_ops", + "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor", "//tensorflow/core:stream_executor_headers_lib", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/stream_executor/lib", ] + if_tensorrt([":tensorrt_lib"]) + tf_custom_op_library_additional_deps(), alwayslink = 1, ) @@ -168,8 +169,12 @@ tf_cuda_cc_test( ":trt_op_kernels", ":trt_op_libs", ":trt_resources", + ":trt_conversion", + ":utils", "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/core:framework", @@ -235,12 +240,10 @@ tf_custom_op_py_library( tf_cuda_library( name = "trt_resources", srcs = [ - "utils/calibration_resource.cc", "utils/trt_int8_calibrator.cc", "utils/trt_lru_cache.cc", ], hdrs = [ - "utils/calibration_resource.h", "utils/trt_int8_calibrator.h", "utils/trt_lru_cache.h", ], @@ -250,6 +253,9 @@ tf_cuda_library( ":utils", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_lite", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core:graph", + "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib_proto_parsing", ] + if_tensorrt([":tensorrt_lib"]), ) @@ -320,11 +326,13 @@ tf_cuda_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//tensorflow/core:gpu_runtime", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -489,7 +497,10 @@ cc_library( srcs = ["utils/py_utils.cc"], hdrs = ["utils/py_utils.h"], copts = tf_copts(), - deps = if_tensorrt([":tensorrt_lib"]), + deps = if_tensorrt([ + ":tensorrt_lib", + "//tensorflow/stream_executor/platform:dso_loader", + ]), ) tf_py_wrap_cc( diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index fb5dda9953e..cd5c7d126c6 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/segment/segment.h" -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" @@ -125,46 +124,37 @@ Status GetEngineInfo(const Graph* g, ++it) { const Node* node = *it; if (segment_nodes.count(node) == 0) continue; - auto node_device = node->requested_device(); - if (!node_device.empty()) { - // If device is set, it means device placement may have been done before, - // so we need to assign a device for the TRTEngineOp to maintain the - // invariance. - // If the device is CPU in this case, it tries to find the first available - // GPU and use it as the device. - DeviceNameUtils::ParsedName parsed_name; - const bool parse_succeeded = - DeviceNameUtils::ParseFullName(node_device, &parsed_name); - if (!parse_succeeded || (parse_succeeded && parsed_name.type == "CPU")) { - string msg; - if (!parse_succeeded) { - msg = StrCat("Failed to parse assigned device of node ", node->name(), - ". "); - } else { - msg = StrCat("Node ", node->name(), " was assigned to the CPU. "); - } - VLOG(1) << msg << "Attempting to place on GPU."; - TfGpuId tf_gpu_id; - PlatformGpuId platform_gpu_id; - std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId(); - if (tf_gpu_id.value() >= 0) { - parsed_name.type = "GPU"; - parsed_name.id = tf_gpu_id.value(); - segment_devices.insert(DeviceNameUtils::FullName( - parsed_name.job, parsed_name.replica, parsed_name.task, - parsed_name.type, parsed_name.id)); - } - } else { - segment_devices.insert(node_device); - } + + std::string device_name; + if (!node->requested_device().empty()) { + device_name = node->requested_device(); } else if (node->has_assigned_device_name()) { // It appears that nodes will not have assigned devices at this point in // execution. - segment_devices.insert(node->assigned_device_name()); + device_name = node->assigned_device_name(); } else { VLOG(2) << "Node " << node->name() << " neither have requested device nor assigned device"; } + + if (!device_name.empty()) { + // If device is set, it means device placement may have been done before, + // so we need to assign a device for the TRTEngineOp if the assigned + // device is a GPU device. + DeviceNameUtils::ParsedName parsed_name; + const bool parse_succeeded = + DeviceNameUtils::ParseFullName(device_name, &parsed_name); + if (!parse_succeeded) { + VLOG(1) << "Failed to parse " + << (node->requested_device().empty() ? "assigned" : "requested") + << " device " << device_name << " of node " << node->name(); + } else if (parsed_name.type != "GPU") { + VLOG(1) << "Node " << node->name() + << " was assigned to a non-GPU device " << device_name; + } else { + segment_devices.insert(device_name); + } + } subgraph_nodes.push_back(node); const int node_id = node->id(); @@ -269,8 +259,20 @@ Status GetEngineInfo(const Graph* g, << ") devices for the segment. Picking first one to continue."; info->device = *segment_devices.begin(); } else { - VLOG(1) << "No device is assigned to the segment. " - << "A device will be assigned during graph execution (inference)."; + TfGpuId tf_gpu_id; + PlatformGpuId platform_gpu_id; + std::tie(tf_gpu_id, platform_gpu_id) = GetFirstValidDeviceId(); + if (tf_gpu_id.value() >= 0) { + DeviceNameUtils::ParsedName parsed_name; + parsed_name.type = "GPU"; + parsed_name.has_type = true; + parsed_name.id = tf_gpu_id.value(); + parsed_name.has_id = true; + info->device = DeviceNameUtils::ParsedNameToString(parsed_name); + } else { + VLOG(1) << "No device is assigned to the segment. A device will be " + "assigned during graph execution (inference)."; + } } return Status::OK(); } @@ -325,8 +327,6 @@ Status CreateTRTNode(const ConversionParams& params, nvinfer1::IGpuAllocator* alloc, std::vector* engine_nodes) { const auto& info = infos.at(pos); - std::vector output_shape_protos; - std::vector input_shape_protos; std::vector input_shapes; std::vector inputs; std::vector input_nodes; @@ -360,25 +360,16 @@ Status CreateTRTNode(const ConversionParams& params, } else { // Data edges if (!conn.is_input_edge) { - // Set the shapes and data types of output edge. - TensorShapeProto out_shape; - // shape of the output node inside segment - conn.inside_shape.AsProto(&out_shape); - if (output_shape_protos.size() <= conn.port_number) { - output_shape_protos.resize(conn.port_number + 1); + // Set the data types of output edge. + if (out_types.size() <= conn.port_number) { out_types.resize(conn.port_number + 1); } - output_shape_protos.at(conn.port_number) = out_shape; out_types.at(conn.port_number) = conn.connection_type; } else { // Set the shapes and data types of input edge. - TensorShapeProto in_shape; - conn.outside_shape.AsProto(&in_shape); - if (input_shape_protos.size() <= conn.port_number) { - input_shape_protos.resize(conn.port_number + 1); + if (input_shapes.size() <= conn.port_number) { input_shapes.resize(conn.port_number + 1); } - input_shape_protos.at(conn.port_number) = in_shape; input_shapes.at(conn.port_number) = conn.outside_shape; // Shape must be fully defined (excluding batch dimension) for static // mode. @@ -440,8 +431,6 @@ Status CreateTRTNode(const ConversionParams& params, TrtUniquePtrType engine_data(engine->serialize()); segment_string = string(static_cast(engine_data->data()), engine_data->size()); - } else { - segment_string = info.segment_graph_def.SerializeAsString(); } string prec_string; @@ -461,15 +450,13 @@ Status CreateTRTNode(const ConversionParams& params, } NodeDef trt_node; + NameAttrList function; + function.set_name(StrCat(info.engine_name, "_native_segment")); Status status = - node_builder.Attr("input_shapes", input_shape_protos) - .Attr("output_shapes", output_shape_protos) + node_builder .Attr("static_engine", info.engine_type == EngineInfo::EngineType::TRTStatic) - .Attr("segment_funcdef_name", - params.use_function_backup - ? StrCat(info.engine_name, "_native_segment") - : "") + .Attr("segment_func", function) .Attr("serialized_segment", segment_string) .Attr("calibration_data", "") .Attr("max_cached_engines_count", info.maximum_cached_engines) @@ -538,103 +525,27 @@ Status CreateTRTNode(const ConversionParams& params, return Status::OK(); } -// Function to construct a funcdef from the segment and add it to the graph. -Status RegisterSegmentFunctionToFunctionLibrary(Graph* graph, - const GraphDef& segment, - const string& engine_name) { - Graph sgraph(graph->flib_def()); - GraphConstructorOptions gcopts; - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(gcopts, segment, &sgraph)); - std::map io_nodes; - int num_inputs = 0; - for (auto n : sgraph.op_nodes()) { - if (absl::StartsWith(n->name(), kInputPHName)) { - num_inputs++; - io_nodes.insert({n->name(), n}); - } else if (absl::StartsWith(n->name(), kOutputPHName)) { - io_nodes.insert({n->name(), n}); - } - } - - for (int i = 0; i < num_inputs; ++i) { - auto name = StrCat(kInputPHName, i); - auto node = io_nodes[name]; - NodeDef nd; - NodeDefBuilder node_builder(StrCat(name, "_Arg"), - FunctionLibraryDefinition::kArgOp); - VLOG(1) << "Adding " << StrCat(name, "_Arg"); - TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) - .Attr("index", i) - .Finalize(&nd)); - Status s; - auto node_arg = sgraph.AddNode(nd, &s); - if (!s.ok()) { - LOG(ERROR) << "Couldn't add _Arg node for " << name; - } - for (auto edge : node->out_edges()) { - sgraph.AddEdge(node_arg, 0, edge->dst(), edge->dst_input()); - VLOG(1) << "Updating funcdef input " << node_arg->name() << ":" << 0 - << " - > " << edge->dst()->name() << ":" << edge->dst_input(); - if (!s.ok()) { - LOG(ERROR) << "Failed to update edge from " << node_arg->name() - << " to " << edge->dst()->name() << ":" << edge->dst_input(); - } - } - sgraph.RemoveNode(node); - } - - for (int i = 0; i < io_nodes.size() - num_inputs; ++i) { - auto name = StrCat(kOutputPHName, i); - auto node = io_nodes[name]; - NodeDef nd; - NodeDefBuilder node_builder(StrCat(name, "_Ret"), - FunctionLibraryDefinition::kRetOp); - auto edge = *(node->in_edges().begin()); - NodeDefBuilder::NodeOut nout(edge->src()->name(), edge->src_output(), - edge->src()->output_type(edge->src_output())); - VLOG(1) << " input " << nout.node << ":" << nout.index - << " dtype=" << DataTypeString(nout.data_type); - // nvcc complains that Input() is - // ambiguous, so do not use Input({nout}). - node_builder.Input(nout); - TF_RETURN_IF_ERROR(node_builder.Attr("T", node->output_type(0)) - .Attr("index", i) - .Finalize(&nd)); - if (VLOG_IS_ON(3)) { - VLOG(3) << nd.DebugString(); - } - Status s; - auto node_ret = sgraph.AddNode(nd, &s); - if (!s.ok()) { - LOG(ERROR) << "Couldn't add _Ret node for " << name; - } - VLOG(1) << "Update edge from " << edge->src()->name() << ":" - << edge->src_output() << " - > " << node_ret->name() << ":" << 0; - sgraph.AddEdge(edge->src(), edge->src_output(), node_ret, 0); - s = sgraph.UpdateEdge(edge->src(), edge->src_output(), node_ret, 0); - if (!s.ok()) { - LOG(ERROR) << "Failed to update edge from " << edge->src()->name() << ":" - << edge->src_output() << " - > " << node_ret->name() << ":" - << 0; - } - sgraph.RemoveNode(node); - } - FunctionDefLibrary fdeflib; - auto native_segment = fdeflib.add_function(); +Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, + Graph* graph, const string& engine_name) { + Graph segment_graph(graph->flib_def()); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), + segment_graph_def, &segment_graph)); + FunctionDefLibrary library; + auto segment_func = library.add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( - sgraph, StrCat(engine_name, "_native_segment"), native_segment)); + segment_graph, StrCat(engine_name, "_native_segment"), segment_func)); // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32 // would be on host if the op generating the tensor has host memory tag set. - (*native_segment - ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr] + (*segment_func->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr] .set_b(true); if (VLOG_IS_ON(7)) { VLOG(7) << engine_name << " Function_Def "; - VLOG(7) << native_segment->DebugString(); + VLOG(7) << segment_func->DebugString(); } - VLOG(1) << "Adding funcdef to graphlib"; - TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdeflib)); + VLOG(1) << "Adding funcdef " << segment_func->signature().name() + << " to graphlib"; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(library)); return Status::OK(); } @@ -691,16 +602,10 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, // Entry function from optimization pass. Status ConvertAfterShapes(const ConversionParams& params) { // Sanity checks. - if (params.precision_mode == TrtPrecisionMode::INT8) { - if (params.use_calibration && !params.use_function_backup) { - return errors::InvalidArgument( - "Calibration requires enabling fallback to TF function execution."); - } - } else { - if (params.use_calibration) { - return errors::InvalidArgument( - "Calibration with FP32 or FP16 is not supported."); - } + if (params.precision_mode != TrtPrecisionMode::INT8 && + params.use_calibration) { + return errors::InvalidArgument( + "Calibration with FP32 or FP16 is not supported."); } // Convert graphdef to graph. @@ -761,14 +666,14 @@ Status ConvertAfterShapes(const ConversionParams& params) { : EngineInfo::EngineType::TRTStatic); curr_engine.use_calibration = params.use_calibration; curr_engine.maximum_cached_engines = params.max_cached_engines; - if (params.use_function_backup) { - status = RegisterSegmentFunctionToFunctionLibrary( - &graph, curr_engine.segment_graph_def, curr_engine.engine_name); - if (!status.ok()) { - LOG(WARNING) << "Failed to register segment graphdef as a function " - << t << ": " << status; - continue; - } + + status = RegisterGraphToFunctionLibrary(curr_engine.segment_graph_def, + &graph, curr_engine.engine_name); + + if (!status.ok()) { + LOG(WARNING) << "Failed to register segment graphdef to the library " << t + << ": " << status; + continue; } engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong()); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index d7f1df5a102..9288829574e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" @@ -46,8 +47,6 @@ struct ConversionParams { // maximum number of cached engines int max_cached_engines = 1; bool use_calibration = true; - // Whether to use function fallback for TRTEngineOp - bool use_function_backup = true; }; // Method to call from optimization pass @@ -57,6 +56,11 @@ Status ConvertAfterShapes(const ConversionParams& params); std::pair GetDeviceAndAllocator(const ConversionParams& params, const EngineInfo& engine); +// Helper method that registers `segment_graph` as a function to the function +// library in `graph`. +Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, + Graph* graph, const string& engine_name); + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index c068c4cc06c..43f920b9ccc 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -29,7 +29,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/node_def.pb.h" // NOLINT #include "tensorflow/core/framework/node_def_builder.h" @@ -40,6 +39,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -76,18 +76,15 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -// TODO(aaroey): put these constants into some class. -const char* const kInputPHName = "TensorRTInputPH_"; -const char* const kOutputPHName = "TensorRTOutputPH_"; +namespace convert { bool IsEngineInput(absl::string_view name) { - return absl::StartsWith(name, kInputPHName); + return absl::StartsWith(name, IONamePrefixes::kInputPHName); } bool IsEngineOutput(absl::string_view name) { - return absl::StartsWith(name, kOutputPHName); + return absl::StartsWith(name, IONamePrefixes::kOutputPHName); } -namespace convert { using absl::StrAppend; using absl::StrCat; @@ -620,7 +617,7 @@ bool AreDimsStaticWithDifferentSize(const nvinfer1::Dims& lhs, } static std::vector> CreateSamePadding( - const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, + const nvinfer1::Dims& stride, const nvinfer1::Dims& kernel, const std::vector& input_dims) { std::vector> padding(input_dims.size()); CHECK_EQ(stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? @@ -779,7 +776,9 @@ class TRT_TensorOrWeights::SimpleITensor : public nvinfer1::ITensor { nvinfer1::TensorFormats getAllowedFormats() const override { return 1; } - bool isShape() const override { return false; } + bool isShapeTensor() const override { return false; } + + bool isExecutionTensor() const override { return true; } #endif private: @@ -847,6 +846,30 @@ string TRT_TensorOrWeights::DebugString() const { return output; } +// Perform 5 dimensional reorder of data on CPU +// This is done once at convert time and does not affect GPU inference perf +// Example: reorder NDHWC (Tensorflow) -> NCDHW (TensorRT) +template +void Reorder5(const nvinfer1::Dims& shape, const T* idata, + const nvinfer1::Dims& istrides, T* odata, + const nvinfer1::Dims& ostrides) { + for (int k = 0; k < shape.d[0]; ++k) { + for (int c = 0; c < shape.d[1]; ++c) { + for (int d = 0; d < shape.d[2]; ++d) { + for (int r = 0; r < shape.d[3]; ++r) { + for (int s = 0; s < shape.d[4]; ++s) { + odata[k * ostrides.d[0] + c * ostrides.d[1] + d * ostrides.d[2] + + r * ostrides.d[3] + s * ostrides.d[4]] = + idata[k * istrides.d[0] + c * istrides.d[1] + + d * istrides.d[2] + r * istrides.d[3] + + s * istrides.d[4]]; + } + } + } + } + } +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -945,6 +968,67 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } +// Initialize a Dims object with arbitrary dimension +nvinfer1::Dims InitDimsN(std::initializer_list list) { + nvinfer1::Dims dim; + dim.nbDims = list.size(); + std::copy(list.begin(), list.end(), dim.d); + return dim; +} + +// Reorder 3D convolution weights from TF to TRT +void ReorderDRSCKToKCDRS(const TRT_ShapedWeights& iweights, + TRT_ShapedWeights* oweights, const int num_groups) { + DCHECK(iweights.TrtDType() == oweights->TrtDType()); + CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); + // K indexes over output channels, C over input channels, and R, S, D over the + // height, width, depth + const int d = iweights.shape_.d[0]; + const int r = iweights.shape_.d[1]; + const int s = iweights.shape_.d[2]; + // TRT requires GKcRS, while TF depthwise has RSCK where c=1, C=G + const int c = iweights.shape_.d[3] / num_groups; + const int k = iweights.shape_.d[4] * num_groups; + + VLOG(2) << "num_groups: " << num_groups << ", c: " << iweights.shape_.d[3] + << " becomes " << c << ", k: " << iweights.shape_.d[4] << " becomes " + << k << ", d: " << d << ", r: " << r << ", s: " << s; + + oweights->shape_.d[0] = iweights.shape_.d[4]; // k / num_groups; + oweights->shape_.d[1] = iweights.shape_.d[3]; // c * num_groups; + oweights->shape_.d[2] = d; + oweights->shape_.d[3] = r; + oweights->shape_.d[4] = s; + + nvinfer1::Dims shape = + InitDimsN({k, c, d, r, s}); // KCDRS shape (same as output) + + nvinfer1::Dims ostrides = + InitDimsN({c * d * r * s, d * r * s, r * s, s, + 1}); // Output = KCDRS = k*CDRS + c*DRS + d*RS + r*S + s + + nvinfer1::Dims istrides = + InitDimsN({1, k, r * s * c * k, s * c * k, + c * k}); // Input = DRSCK = k*1 + c*K + d*RSCK + r*SCK + s*CK + + switch (iweights.TrtDType()) { + case nvinfer1::DataType::kFLOAT: { + Reorder5(shape, static_cast(iweights.GetValues()), istrides, + static_cast(oweights->GetValues()), ostrides); + break; + } + case nvinfer1::DataType::kHALF: { + Reorder5(shape, static_cast(iweights.GetValues()), + istrides, static_cast(oweights->GetValues()), + ostrides); + break; + } + default: + LOG(FATAL) << "Unsupported type, expected fp32 or fp16 but got " + << DebugString(iweights.TrtDType()); + } +} + TRT_ShapedWeights TrtWeightStore::GetTempWeights(nvinfer1::DataType trt_dtype, const nvinfer1::Dims& dims) { TensorShape shape; @@ -1453,6 +1537,15 @@ bool IsClipOrRelu(const nvinfer1::ILayer* layer) { #endif } +bool IsAdd(const nvinfer1::ILayer* layer) { + if (layer->getType() != nvinfer1::LayerType::kELEMENTWISE) { + return false; + } + auto operation = + static_cast(layer)->getOperation(); + return operation == nvinfer1::ElementWiseOperation::kSUM; +} + } // namespace void Converter::MaybeApplyQuantizationRanges() { @@ -1508,11 +1601,25 @@ void Converter::MaybeApplyQuantizationRanges() { } } // Identify fused tensors. + // Conv+BiasAdd+Add+Activation(Clip or Relu), Conv+BiasAdd+Add, // Conv+BiasAdd+Activation(Clip or Relu), Conv+BiasAdd, // Conv+Activation(Clip or Relu) are fused. std::set fused_tensors; typedef std::function matcher; const std::vector>> fused_patterns = { + {"Fused Conv+Bias+Add+Activation", + { + IsConvolution, + IsScale, + IsAdd, + IsClipOrRelu, + }}, + {"Fused Conv+Bias+Add", + { + IsConvolution, + IsScale, + IsAdd, + }}, {"Fused Conv+Bias+Activation", { IsConvolution, @@ -2600,6 +2707,203 @@ Status ConvertConv2DBackpropInput(OpConverterParams* params) { return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true); } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +Status ConvertConv3DHelper(OpConverterParams* params, int group, + bool is_conv3d_backprop_input = false) { + const int kNumDims = 5; + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + TRT_TensorOrWeights backprop_output_size; + nvinfer1::ITensor* tensor = nullptr; + if (is_conv3d_backprop_input) { + // In the case when Conv3dBackpropInput is used for conv3d_transpose, these + // inputs correspond to: output size, filter, and input. + TF_RETURN_IF_ERROR(CheckInputsWeights( + *params, + {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}})); + backprop_output_size = inputs.at(0); + tensor = inputs.at(2).tensor(); + } else { + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", false}, {"filter", true}})); + tensor = inputs.at(0).tensor(); + } + TF_RETURN_IF_ERROR( + AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); + const TRT_ShapedWeights weights_drsck = inputs.at(1).weights(); + if (weights_drsck.shape_.nbDims != kNumDims) { + return errors::InvalidArgument("Conv3D expects kernel of dimension 5, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + auto data_format = attrs.get("data_format"); + const bool is_ndhwc = (data_format == "NDHWC"); // Or NCDHW 01234 - > 02341 + const int d_index = is_ndhwc ? 1 : 2; + const int h_index = is_ndhwc ? 2 : 3; + const int w_index = is_ndhwc ? 3 : 4; + const int c_index = is_ndhwc ? 4 : 1; + auto tf_dilations = attrs.get>("dilations"); + if (tf_dilations.size() != kNumDims) { + return errors::InvalidArgument( + "Convolution dilations field must specify 5 dimensions, at ", + node_def.name()); + } + if (tf_dilations[0] != 1 || tf_dilations[c_index] != 1) { + return errors::Unimplemented( + "Dilation rate must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + + const nvinfer1::Dims3 dilation_dhw( + tf_dilations[d_index], tf_dilations[h_index], tf_dilations[w_index]); + if (is_conv3d_backprop_input && + (dilation_dhw.d[0] != 1 || dilation_dhw.d[1] != 1 || + dilation_dhw.d[2] != 1)) { + return errors::Unimplemented( + "Dilation with Conv3DBackpropInputV2 (conv3d_transpose) is not " + "supported", + ", at ", node_def.name()); + } + + const auto tf_stride = attrs.get>("strides"); + if (tf_stride.size() != kNumDims) { + return errors::InvalidArgument( + "Convolution strides field must specify 5 dimensions, at ", + node_def.name()); + } + if (tf_stride[0] != 1 || tf_stride[c_index] != 1) { + return errors::Unimplemented( + "Stride must be 1 for batch and channel dimensions, at ", + node_def.name()); + } + + const nvinfer1::Dims3 stride_dhw(tf_stride[d_index], tf_stride[h_index], + tf_stride[w_index]); + const auto tensor_dim = tensor->getDimensions(); + + // Asymmetric padding on Deconv not supported for now + if (is_conv3d_backprop_input && attrs.get("padding") == "SAME") { + const int tensor_c_idx = c_index - 1; + const int num_groups = (group == 0) ? tensor_dim.d[tensor_c_idx] : group; + + TRT_ShapedWeights weights = + params->weight_store->GetTempWeights(weights_drsck); + + nvinfer1::Dims3 effective_kernel_size( + weights.shape_.d[0] + + (weights.shape_.d[0] - 1) * (dilation_dhw.d[0] - 1), // D + weights.shape_.d[1] + + (weights.shape_.d[1] - 1) * (dilation_dhw.d[1] - 1), // R + weights.shape_.d[2] + + (weights.shape_.d[2] - 1) * (dilation_dhw.d[2] - 1) // S + ); + + const auto output_size_weights = + static_cast(backprop_output_size.weights().GetValues()); + const std::vector input_dims = {output_size_weights[d_index], + output_size_weights[h_index], + output_size_weights[w_index]}; + + const std::vector> padding = + CreateSamePadding(stride_dhw, effective_kernel_size, input_dims); + + if (padding[0].first != padding[0].second || + padding[1].first != padding[1].second || + padding[2].first != padding[2].second) { + return errors::Unimplemented( + "Asymmetric padding with Conv3DBackpropInputV2 (conv3d_transpose) is " + "not supported, at ", + node_def.name()); + } + } + + if (params->validation_only) + return Status::OK(); // Finished validation checks + + // Transpose to NCDHW (NCDHW is required for IConvLayer). + const bool need_transpose = is_ndhwc; + if (need_transpose) { + TF_RETURN_IF_ERROR( + params->converter->TransposeTensor(tensor, {0, 4, 1, 2, 3}, &tensor)); + } + + // group == 0 signifies that this is a depthwise convolution, so set + // num_groups to size of input's channel dim. For a non-depthwise conv, + // num_groups will be 1. + const int num_groups = (group == 0) ? tensor_dim.d[0] : group; + + // For conv, TF weights are DRSCK, and TRT expects KCDRS. + // For backprop, TF weights are DRSKC, and TRT expects KCDRS. + // Therefore, this reorder will work for both cases. + TRT_ShapedWeights weights = + params->weight_store->GetTempWeights(weights_drsck); + ReorderDRSCKToKCDRS(weights_drsck, &weights, num_groups); + TRT_ShapedWeights biases(weights.TrtDType()); + const int output_axis = is_conv3d_backprop_input ? 1 : 0; + const int noutput = weights.shape_.d[output_axis] * num_groups; + nvinfer1::Dims3 kernel_size_drs(weights.shape_.d[2], // D + weights.shape_.d[3], // R + weights.shape_.d[4] // S + ); + + // Add convolution. + nvinfer1::ILayer* conv_layer = nullptr; + if (is_conv3d_backprop_input) { + nvinfer1::IDeconvolutionLayer* layer = + params->converter->network()->addDeconvolutionNd( + *tensor, noutput, kernel_size_drs, weights.GetTrtWeights(), + biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStrideNd(stride_dhw); // change to nd set stride + + // TensorRT 5.1.3 added support for padding modes. + if (attrs.get("padding") == "SAME") { + VLOG(2) << "Using SAME padding"; + // SAME_UPPER means that post padding is preferred. + layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } + + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + conv_layer = layer; + } else { + nvinfer1::IConvolutionLayer* layer = + params->converter->network()->addConvolutionNd( + *tensor, noutput, kernel_size_drs, weights.GetTrtWeights(), + biases.GetTrtWeights()); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setStrideNd(stride_dhw); + + if (attrs.get("padding") == "SAME") { + VLOG(2) << "Using SAME padding"; + layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); + } + + layer->setName(node_def.name().c_str()); + layer->setNbGroups(num_groups); + layer->setDilationNd(dilation_dhw); + conv_layer = layer; + } + nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0); + + // Restore transpose. + if (need_transpose) { + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + output_tensor, {0, 2, 3, 4, 1}, &output_tensor)); + } + params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); + return Status::OK(); +} + +Status ConvertConv3D(OpConverterParams* params) { + return ConvertConv3DHelper(params, 1, /*is_conv3d_backprop_input=*/false); +} + +Status ConvertConv3DBackpropInputV2(OpConverterParams* params) { + return ConvertConv3DHelper(params, 1, /*is_conv3d_backprop_input=*/true); +} +#endif // #if IS_TRT_VERSION_GE(6, 0, 0, 0) + Status ConvertFusedConv2DBiasActivation(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; @@ -3908,6 +4212,7 @@ Status ConvertPad(OpConverterParams* params) { *tensor, pre_padding, post_padding); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); + params->converter->MarkQuantizationRangesAsInferrable(tensor, output_tensor); if (!legit_pad) { TF_RETURN_IF_ERROR(params->converter->TransposeTensor( @@ -5093,6 +5398,8 @@ static void RegisterValidatableOpConverters( (*registration)["Relu6"] = ConvertRelu6; (*registration)["Reshape"] = ConvertReshape; #if IS_TRT_VERSION_GE(6, 0, 0, 0) + (*registration)["Conv3D"] = ConvertConv3D; + (*registration)["Conv3DBackpropInputV2"] = ConvertConv3DBackpropInputV2; for (auto resize_mode : {"ResizeBilinear", "ResizeNearestNeighbor"}) { (*registration)[resize_mode] = ConvertResize; } @@ -5194,26 +5501,44 @@ Status ConvertGraphDefToEngine( } // Build the network - VLOG(1) << "Starting engine conversion "; + if (VLOG_IS_ON(1)) { + string mode_str; + TF_RETURN_IF_ERROR(TrtPrecisionModeToName(precision_mode, &mode_str)); + VLOG(1) << "Starting engine conversion, precision mode: " << mode_str; + } Converter converter(trt_network.get(), precision_mode, use_calibration); std::vector output_tensors; // Graph nodes are already topologically sorted during construction for (const auto& node_def : gdef.node()) { - string node_name = node_def.name(); - VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op(); - if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) { + const string& node_name = node_def.name(); + VLOG(2) << "Converting node " << node_name << ", op=" << node_def.op(); + if (IsEngineInput(node_name)) { int32 slot_number = -1; - if (!strings::safe_strto32( // non-absl ok - node_name.c_str() + strlen(kInputPHName), &slot_number)) { - return errors::InvalidArgument("Failed to parse slot number from ", - node_name); + string type_key; + if (node_def.op() == "Placeholder") { + if (!strings::safe_strto32( // non-absl ok + node_name.c_str() + strlen(IONamePrefixes::kInputPHName), + &slot_number)) { + return errors::InvalidArgument("Failed to parse slot number from ", + node_name); + } + type_key = "dtype"; + } else if (tensorflow::grappler::IsArg(node_def)) { + // Maybe remove the dependence on grappler and re-implement IsArg, + // which is pretty simple (but could change if new Arg nodes are added) + slot_number = node_def.attr().at("index").i(); + type_key = "T"; + } else { + return errors::InvalidArgument( + "Node ", node_name, + " with is neither Placeholder nor Arg, instead ", node_def.op()); } nvinfer1::DataType trt_dtype; nvinfer1::Dims trt_dims; int batch_size = -1; auto shape = input_shapes.at(slot_number); auto status = ValidateTensorProperties( - node_def.op(), node_def.attr().at("dtype").type(), shape, + node_def.op(), node_def.attr().at(type_key).type(), shape, /*validation_only=*/false, &trt_dtype, &trt_dims, &batch_size); if (!status.ok()) { const string error_message = @@ -5229,12 +5554,23 @@ Status ConvertGraphDefToEngine( // engines offline, by calling sess.run() and cache/serialize the engines. TF_RETURN_IF_ERROR( converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size)); - } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) { + } else if (IsEngineOutput(node_name)) { int32 slot_number = -1; - if (!strings::safe_strto32( // non-absl ok - node_name.c_str() + strlen(kOutputPHName), &slot_number)) { - return errors::InvalidArgument("Failed to parse slot number from ", - node_name); + if (node_def.op() == "Identity") { + if (!strings::safe_strto32( // non-absl ok + node_name.c_str() + strlen(IONamePrefixes::kOutputPHName), + &slot_number)) { + return errors::InvalidArgument("Failed to parse slot number from ", + node_name); + } + } else if (tensorflow::grappler::IsRetval(node_def)) { + slot_number = node_def.attr().at("index").i(); + } else { + return errors::InvalidArgument( + "Node with name ", node_name, + " starting with IONamePrefixes::kOutputPHName is " + "neither Identity nor Retval, instead ", + node_def.op()); } // Get output type that TensorFlow expects TFAttrs attrs(node_def); @@ -5247,8 +5583,6 @@ Status ConvertGraphDefToEngine( output_tensors.at(slot_number) = {node_def.input(0), node_name, trt_dtype}; } else { - VLOG(2) << "Converting node: " << node_def.name() << " , " - << node_def.op(); TF_RETURN_IF_ERROR(converter.ConvertNode(node_def)); } } @@ -5303,7 +5637,8 @@ Status ConvertSegmentToGraphDef( // Add dummy input/output nodes to the segment graphdef. if (connection.is_input_edge) { - const string node_name = StrCat(kInputPHName, connection.port_number); + const string node_name = + StrCat(IONamePrefixes::kInputPHName, connection.port_number); if (marker_nodes.count(node_name)) { VLOG(1) << "Reusing input " << node_name << " for the edge " << connection.outside_node_name << ":" @@ -5313,16 +5648,18 @@ Status ConvertSegmentToGraphDef( } marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); - NodeDefBuilder builder(node_name, "Placeholder"); + NodeDefBuilder builder(node_name, "_Arg"); auto status = builder.Attr("shape", partial_shape) - .Attr("dtype", dtype) + .Attr("T", dtype) + .Attr("index", connection.port_number) .Finalize(seg_node); VLOG(1) << "Constructing input " << node_name << " for the edge " << connection.outside_node_name << ":" << connection.outside_port << " -> " << connection.inside_node_name << ":" << connection.inside_port; } else { - const string node_name = StrCat(kOutputPHName, connection.port_number); + const string node_name = + StrCat(IONamePrefixes::kOutputPHName, connection.port_number); if (marker_nodes.count(node_name)) { VLOG(1) << "Reusing output " << node_name << " for the edge " << connection.inside_node_name << ":" << connection.inside_port @@ -5332,9 +5669,10 @@ Status ConvertSegmentToGraphDef( } marker_nodes.insert(node_name); auto seg_node = segment_def->add_node(); - NodeDefBuilder builder(node_name, "Identity"); + NodeDefBuilder builder(node_name, "_Retval"); auto status = - builder + builder.Attr("T", dtype) + .Attr("index", connection.port_number) .Input(connection.inside_node_name, connection.inside_port, dtype) .Finalize(seg_node); VLOG(1) << "Constructing output " << node_name << " for the edge " @@ -5360,12 +5698,12 @@ Status ConvertSegmentToGraphDef( if (connection.is_control_edge() || !connection.is_input_edge) continue; auto snode = segment_def->mutable_node(old_to_new_id_map[connection.inside_id]); - const string placeholder_name = - StrCat(kInputPHName, connection.port_number); + const string arg_name = + StrCat(IONamePrefixes::kInputPHName, connection.port_number); VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port << " from " << snode->input(connection.inside_port) << " to " - << placeholder_name; - snode->set_input(connection.inside_port, placeholder_name); + << arg_name; + snode->set_input(connection.inside_port, arg_name); } std::set subgraph_node_names; for (const Node* node : subgraph_nodes) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index a6a7afe121e..9d475e25ff7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -23,7 +23,6 @@ limitations under the License. #include #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" @@ -38,8 +37,6 @@ limitations under the License. namespace tensorflow { namespace tensorrt { -extern const char* const kInputPHName; -extern const char* const kOutputPHName; namespace convert { @@ -120,8 +117,8 @@ struct EngineInfo { bool use_calibration; }; -// Constructs a graphdef from the segment in the given graph. Adds placeholder -// nodes for input edges (InputPH_*) and identity nodes for output edges +// Constructs a graphdef from the segment in the given graph. Adds _Arg +// nodes for input edges (InputPH_*) and _Retval nodes for output edges // (OutputPH_*). This function needs to be called before TensorRT nodes // inserted in order to correctly get sizes from the original graph. // diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index b6a3587005c..84898108a4d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -307,7 +307,9 @@ class FakeITensor : public nvinfer1::ITensor { nvinfer1::TensorFormats getAllowedFormats() const override { return 1; } - bool isShape() const override { return false; } + bool isShapeTensor() const override { return false; } + bool isExecutionTensor() const override { return true; } + #endif private: @@ -1158,7 +1160,7 @@ class ConvertGraphDefToEngineTest : public ::testing::Test { int batch_size = -1; for (const NodeDef& node : gdef.node()) { absl::string_view node_name(node.name()); - if (absl::ConsumePrefix(&node_name, kInputPHName)) { + if (absl::ConsumePrefix(&node_name, IONamePrefixes::kInputPHName)) { int port = -1; EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name(); if (input_shapes.size() < port + 1) input_shapes.resize(port + 1); @@ -1188,11 +1190,13 @@ class ConvertGraphDefToEngineTest : public ::testing::Test { TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) { Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT, - ops::Placeholder::Shape({1, 1})); + auto input = + ops::Placeholder(s.WithOpName(StrCat(IONamePrefixes::kInputPHName, 0)), + DT_FLOAT, ops::Placeholder::Shape({1, 1})); auto output = ops::Identity(s.WithOpName("identity1"), input); output = ops::Identity(s.WithOpName("identity2"), output); - output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output); + output = ops::Identity(s.WithOpName(StrCat(IONamePrefixes::kOutputPHName, 0)), + output); // If the converter marks the input tensor as output tensor, the conversion // below will fail with: // > TensorRTOutputPH_0 cannot be both input and output @@ -1453,6 +1457,9 @@ class OpConverterTest : public ::testing::Test { return converter_->quantization_ranges_; } + void PropagateQuantizationRanges() { + converter_->PropagateQuantizationRanges(); + } std::unique_ptr converter_; protected: @@ -3971,6 +3978,340 @@ TEST_F(OpConverterTest, ConvertConv2D) { } } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) +TEST_F(OpConverterTest, ConvertConv3D) { + // Get nodedef for Conv3D layer. + auto get_conv3d_nodedef = + [](std::vector strides = {1, 1, 1, 1, 1}, string padding = "SAME", + string data_format = "NCDHW", + std::vector dilations = {1, 1, 1, 1, 1}, + bool is_conv3d_backprop_input = false) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT); + + if (is_conv3d_backprop_input) { + auto input_sizes = + ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32); + ops::Conv3DBackpropInputV2::Attrs attrs = + ops::Conv3DBackpropInputV2::Attrs() + .DataFormat(data_format) + .Dilations(dilations); + auto conv3d = + ops::Conv3DBackpropInputV2(s.WithOpName("my_conv3d"), input_sizes, + filter, input, strides, padding, attrs); + return conv3d.operation.node()->def(); + } else { + ops::Conv3D::Attrs attrs = + ops::Conv3D::Attrs().DataFormat(data_format).Dilations(dilations); + auto conv3d = ops::Conv3D(s.WithOpName("my_conv3d"), input, filter, + strides, padding, attrs); + return conv3d.operation.node()->def(); + } + }; + + { + // Input is weights, should fail. + Reset(); + NodeDef node_def = get_conv3d_nodedef(); + + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for Conv3D must be a tensor, at my_conv3d"); + } + { + // Filter is tensor, should fail. + Reset(); + NodeDef node_def = get_conv3d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("weights", {3, 3, 1, 1, 3, 3, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"filter\" for Conv3D must be a constant, at my_conv3d"); + } + { + // Filter is not 5D, should fail. + Reset(); + NodeDef node_def = get_conv3d_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Conv3D expects kernel of dimension 5, at my_conv3d"); + } + { + // Dilations is not 5D, should fail. + Reset(); + NodeDef node_def = + get_conv3d_nodedef({1, 1, 1, 1, 1}, "SAME", "NCDHW", {1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights( + "weights", {3, 3, 1, 1, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); // Dimensions, then values + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution dilations field must specify 5 dimensions, at my_conv3d"); + } + { + // Dilation value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv3d_nodedef({1, 1, 1, 1, 1}, "SAME", "NCDHW", {1, 2, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv3d"); + } + { + // Dilation value is not 1 for channel (NDHWC), should fail. + Reset(); + NodeDef node_def = + get_conv3d_nodedef({1, 1, 1, 1, 1}, "SAME", "NDHWC", {1, 1, 1, 1, 2}); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation rate must be 1 for batch and channel " + "dimensions, at my_conv3d"); + } + { + // Dilation + Conv3DBackpropInputV2, should fail. + Reset(); + NodeDef node_def = get_conv3d_nodedef({1, 1, 1, 1, 1}, "SAME", "NDHWC", + {1, 1, 2, 1, 1}, true); + AddTestTensor("input", {2, 3, 1}); + AddTestWeights("weights", {3, 3, 1, 1, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + AddTestWeights("input_sizes", {4}, {1, 2, 3, 1}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Dilation with Conv3DBackpropInputV2 " + "(conv3d_transpose) is not supported, " + "at my_conv3d"); + } + { + // Asymmetric+ Conv3DBackpropInputV2, should fail. + Reset(); + NodeDef node_def = get_conv3d_nodedef({1, 1, 1, 1, 1}, "SAME", "NDHWC", + {1, 1, 1, 1, 1}, true); + AddTestTensor("input", {1, 2, 2, 2}); + AddTestWeights("weights", {1, 1, 2, 1, 1}, {1, 1}); + AddTestWeights("input_sizes", {8}, {1, 2, 3, 4, 5, 6, 7, 8}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "Asymmetric padding with Conv3DBackpropInputV2 " + "(conv3d_transpose) is not supported, at " + "my_conv3d"); + } + { + // Strides is not 5D, should fail. + Reset(); + NodeDef node_def = get_conv3d_nodedef({1, 1, 1, 1, 1, 1}, "SAME", "NCDHW", + {1, 1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 2, 2}); + AddTestWeights("weights", {1, 1, 2, 1, 1}, {1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "Convolution strides field must specify 5 dimensions, at my_conv3d"); + } + { + // Stride value is not 1 for channel, should fail. + Reset(); + NodeDef node_def = + get_conv3d_nodedef({1, 2, 1, 1, 1}, "SAME", "NCDHW", {1, 1, 1, 1, 1}); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("weights", {3, 3, 1, 1, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "Stride must be 1 for batch and channel dimensions, at my_conv3d"); + } + struct TestParams { + std::vector input_dims; + std::vector input; + std::vector filter_dims; + std::vector filter; + std::vector strides; + string padding; + string data_format; + std::vector dilations; + bool is_conv3d_backprop_input; + std::vector expected_output_dims; + std::vector expected_output; + }; + + // Start here + const int kConv3DOKCases = 8; + TestParams ok_params[kConv3DOKCases] = { + // Basic - just 1x1 conv - input = output + TestParams{ + /*input_dims=*/{1, 3, 3, 3}, // CDHW + /*input=*/{1, 2, 15, 3, 6, -3, 22, 1, 88, 56, 36, 1, 1, 105, + 1, 16, -28, 1, 42, 9, 3, 1, 7, 1, 11, 61, 5}, + /*filter_dims=*/{1, 1, 1, 1, 1}, // DRSCK + /*filter=*/{1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 3, 3, 3}, + /*expected_output=*/{1, 2, 15, 3, 6, -3, 22, 1, 88, + 56, 36, 1, 1, 105, 1, 16, -28, 1, + 42, 9, 3, 1, 7, 1, 11, 61, 5}}, + // Basic - 2x1 filter + TestParams{/*input_dims=*/{1, 3, 3, 3}, // CDHW + /*input=*/{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6}, + /*filter_dims=*/{2, 1, 1, 1, 1}, // DRSCK + /*filter=*/{1, 1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3, 3}, + /*expected_output=*/ + {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 7}}, + // SAME padding (Asymmetric) + TestParams{ + /*input_dims=*/{1, 2, 3, 2}, // CDHW + /*input=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + /*filter_dims=*/{2, 1, 1, 1, 1}, // DRSCK + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3, 2}, + /*expected_output=*/ + {6, 6, 6, 6, 6, 6, -6, -7, -8, -9, -10, + -11} // Diff in first 2 depths is const 6 + }, + // SAME padding (Symmetric) + TestParams{ + /*input_dims=*/{1, 2, 3, 2}, // CDHW + /*input=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + /*filter_dims=*/{3, 1, 1, 1, 1}, // DRSCK + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 3, 2}, + /*expected_output=*/ + {6, 7, 8, 9, 10, 11, 0, -1, -2, -3, -4, + -5} // Swaps front two depths, negates + }, + + // NDHWC (multi-channel) + TestParams{ + /*input_dims=*/{2, 3, 2, 2}, // DHWC + /*input=*/{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + /*filter_dims=*/{2, 1, 1, 2, 1}, // DRSCK + /*filter=*/{-1, 1, 1, -1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NDHWC", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 3, 2, 1}, + /*expected_output=*/{0, 0, 0, 0, 0, 0} // Each filter opposes the + // other + }, + + // Dilated + TestParams{ + /*input_dims=*/{1, 3, 3, 3}, // CDHW + /*input=*/{1, 1, 1, 1, 1, 1, 1, 1, 1, -10, -10, -10, -10, -10, + -10, -10, -10, -10, 7, 7, 7, 7, 7, 7, 7, 7, 7}, + /*filter_dims=*/{2, 1, 1, 1, 1}, // DRSCK + /*filter=*/{1, 1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 2, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 1, 3, 3}, + /*expected_output=*/{8, 8, 8, 8, 8, 8, 8, 8, 8} // Only front depth + // is valid, skips + // neg values + }, + // Strided + TestParams{ + /*input_dims=*/{1, 3, 3, 3}, + /*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8}, + /*filter_dims=*/{1, 1, 1, 1, 1}, + /*filter=*/{1}, + /*strides=*/{1, 1, 2, 2, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/false, + /*expected_output_dims=*/{1, 2, 2, 2}, + /*expected_output=*/{1, 2, 3, 4, 5, 6, 7, 8} // Should only pick up + // the corners + }, + // Transpose Strided + TestParams{/*input_dims=*/{1, 2, 2, 2}, // CDHW + /*input=*/{1, 2, 3, 4, 5, 6, 7, 8}, + /*filter_dims=*/{1, 1, 1, 1, 1}, + /*filter=*/{1}, + /*strides=*/{1, 1, 2, 2, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCDHW", + /*dilations=*/{1, 1, 1, 1, 1}, + /*is_conv3d_backprop_input=*/true, + /*expected_output_dims=*/{1, 3, 3, 3}, + /*expected_output=*/ + {1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8}}, // Cube + // expands and + // fills + // center with + // zeroes + + }; + + for (int i = 0; i < kConv3DOKCases; i++) { + Reset(); + NodeDef node_def = get_conv3d_nodedef( + ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format, + ok_params[i].dilations, ok_params[i].is_conv3d_backprop_input); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("weights", ok_params[i].filter_dims, + ok_params[i].filter); + if (ok_params[i].is_conv3d_backprop_input) { + AddTestWeights( + "input_sizes", + {static_cast(ok_params[i].expected_output.size())}, + ok_params[i].expected_output); + } + RunValidationAndConversion(node_def); + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_conv3d", &output)); + ASSERT_TRUE(output.is_tensor()); + ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, + output.tensor()->getDimensions()); + + const DataVec input_data{ + {"input", test::AsTensor(ok_params[i].input)}}; + DataVec output_data{ + {"my_conv3d", + ConstructTensor(ok_params[i].expected_output.size())}}; + BuildAndRun(input_data, &output_data); + EXPECT_THAT(GetSpanForData(output_data[0]), + ElementsAreArray(ok_params[i].expected_output)); + } +} +#endif // IS_TRT_VERSION_GE(6, 0, 0, 0) + TEST_F(OpConverterTest, ConvertTopK) { // TODO(tmorris): This test isn't setting the input dtype properly. TopK with // int32 is unsupported by TRT. @@ -5847,6 +6188,111 @@ TEST_F(OpConverterTest, ConvertResize) { } #endif // IS_TRT_VERSION_GE(6, 0, 0, 0) +NodeDef MakePadNodeDef(std::string name, DataType dtype) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), dtype); + auto padding = ops::Placeholder(s.WithOpName("padding"), DT_INT32); + auto pad = ops::Pad(s.WithOpName(name), input, padding); + return pad.operation.node()->def(); +} + +template +struct PadTestParams { + std::vector input_dims; + std::vector pad_dims; + std::vector input_values; + std::vector expected_output_dims; + std::vector expected_output_values; +}; + +template +void TestConvertPad(OpConverterTest* test) { + typedef typename EnumToDataType::Type CType; + + std::vector> params{ + { + /*input_dims=*/{1, 2, 1}, // H, W, C + /*pad_dims=*/{4, 2}, // #dims, {pad_before, pad_after} + /*input_values=*/CastTestVector({2.0f, -1.0f}), + /*expected_output_dims=*/{2, 3, 1}, // H, W, C + /*expected_output_values=*/ + CastTestVector({0.0, 0.0, 0.0, 2.0f, -1.0f, 0.0}), + }, + }; + + for (int i = 0; i < params.size(); ++i) { + test->Reset(); + // Create pad node. + NodeDef node_def = MakePadNodeDef("my_pad", dtype); + // Create input tensor + test->AddTestTensor("input", params[i].input_dims, /*batch_size=*/1, + /*trt_dtype=*/TfDataTypeToTrt(dtype)); + // Create output size. + test->AddTestWeights("padding", params[i].pad_dims, + {0, 0, 1, 0, 0, 1, 0, 0}); + test->RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(test->GetTensorOrWeights("padding", &output)); + + // Create input data for tensors. + const DataVec input_data{ + {"input", test::AsTensor(params[i].input_values)}}; + DataVec output_data{ + {"my_pad", + ConstructTensor(params[i].expected_output_values.size())}}; + + test->BuildAndRun( + input_data, &output_data, + dtype == DT_HALF ? TrtPrecisionMode::FP16 : TrtPrecisionMode::FP32); + ExpectArrayAlmostEqual(params[i].expected_output_values, + GetSpanForData(output_data[0]), CType(1e-5)); + } +} + +TEST_F(OpConverterTest, ConvertPad) { + { + // First input is weight, should fail. + Reset(); + NodeDef node_def = MakePadNodeDef("my_pad", DT_FLOAT); + AddTestWeights("input", {1, 2}, {1, 2}); + AddTestWeights("padding", {1, 2}, {1, 2}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"tensor\" for Pad must be a " + "tensor"); + } + { + // padding is a tensor, should fail. + Reset(); + NodeDef node_def = MakePadNodeDef("my_pad", DT_FLOAT); + AddTestTensor("input", {1, 2}); + AddTestTensor("padding", {1, 2}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + "The input \"paddings\" for Pad must be a " + "constant"); + } + TestConvertPad(this); + TestConvertPad(this); + { + // Make sure that ranges are inferred across a Pad. + Reset(); + NodeDef node_def = MakePadNodeDef("my_pad", DT_FLOAT); + AddTestTensor("input", {1, 2, 1}); + AddTestWeights("padding", {4, 2}, {0, 0, 1, 0, 0, 1, 0, 0}); + TRT_TensorOrWeights input; + TRT_TensorOrWeights output; + RunValidationAndConversion(node_def); + TF_EXPECT_OK(GetTensorOrWeights("input", &input)); + TF_EXPECT_OK(GetTensorOrWeights("my_pad", &output)); + converter_->ProvideQuantizationRange(input.tensor(), -5.0f, 5.0f); + // Input range should be inferred across pad. + PropagateQuantizationRanges(); + auto ranges = quantization_ranges(); + EXPECT_EQ(5.0f, ranges[input.tensor()]); + EXPECT_EQ(5.0f, ranges[output.tensor()]); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 6af483d37cf..35a8c6340f8 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -67,9 +67,6 @@ Status TRTOptimizationPass::Init( if (params.count("use_calibration")) { use_calibration_ = params.at("use_calibration").b(); } - if (params.count("use_function_backup")) { - use_function_backup_ = params.at("use_function_backup").b(); - } return Status::OK(); } @@ -193,31 +190,30 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, LOG(INFO) << CurrentStackTrace(); PrintDebugInfo(cluster, item); } - int max_dim = -1; - if (!item.feed.empty()) { - for (const auto& f : item.feed) { - const auto& shape = f.second.shape(); - if (shape.dims() > 0) { - if (shape.dim_size(0) > max_dim) max_dim = shape.dim_size(0); + if (!is_dynamic_op_) { + int max_batch_dim = -1; + if (!item.feed.empty()) { + for (const auto& f : item.feed) { + const auto& shape = f.second.shape(); + if (shape.dims() > 0) { + if (shape.dim_size(0) > max_batch_dim) + max_batch_dim = shape.dim_size(0); + VLOG(2) << "Setting max_batch_dim to " << max_batch_dim + << " using batch dimension of " << f.first << " with shape " + << shape; + } } } - } - if (maximum_batch_size_ < 0) { // automatic batch size from input - if (max_dim > 0) { - maximum_batch_size_ = max_dim; - VLOG(1) << "Setting maximum batch size to " << max_dim; - } else { - maximum_batch_size_ = 128; - LOG(WARNING) << "Maximum batch size is not set" - " and can't be deduced from inputs setting it to" - << maximum_batch_size_ - << ". Suggest configuring it from configuration parameters"; - } - } else { - if (max_dim > maximum_batch_size_) { - LOG(WARNING) << "Configured batch size " << maximum_batch_size_ - << " is less than input batch size " << max_dim - << " adjusting maximum batch size to match input batch size"; + if (max_batch_dim > maximum_batch_size_) { + return errors::InvalidArgument( + "Specified max_batch_size=", maximum_batch_size_, + " is less than maximum batch dimension of inputs (", max_batch_dim, + "). ", "To continue, set max_batch_size to >= ", max_batch_dim); + } else if (max_batch_dim < maximum_batch_size_) { + LOG(INFO) << "Specified max_batch_size=" << maximum_batch_size_ + << " is larger than maximum batch dimension of inputs (" + << max_batch_dim << "). " + << "This can result in poor performance."; } } grappler::GraphProperties static_graph_properties(item); @@ -259,7 +255,6 @@ Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, cp.is_dyn_op = is_dynamic_op_; cp.max_cached_engines = max_cached_batches_; cp.use_calibration = use_calibration_; - cp.use_function_backup = use_function_backup_; auto status = ConvertAfterShapes(cp); VLOG(1) << "Returning from " << name_; return status; diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index d3fd914b302..35a92341ee9 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -40,13 +40,14 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { is_dynamic_op_(false), max_cached_batches_(1), max_workspace_size_bytes_(256LL << 20), - use_calibration_(true), - use_function_backup_(true) { + use_calibration_(true) { VLOG(1) << "Constructing " << name_; } string name() const override { return name_; }; + bool UsesFunctionLibrary() const override { return true; } + Status Init( const RewriterConfig_CustomGraphOptimizer* config = nullptr) override; @@ -71,8 +72,6 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { int64_t max_workspace_size_bytes_; bool use_calibration_; - // Whether to allow TF function fallback path in TRTEngineOp. - bool use_function_backup_; }; } // namespace convert diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 91c8c660f85..eb60829d31d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -23,6 +23,12 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +class IONamePrefixes { + public: + static constexpr const char* const kInputPHName = "TensorRTInputPH_"; + static constexpr const char* const kOutputPHName = "TensorRTOutputPH_"; +}; + template struct TrtDestroyer { void operator()(T* t) { diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc index 2898602b879..3143b06817e 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -39,28 +39,25 @@ class GetCalibrationDataOp : public OpKernel { // TODO(laigd): it will allocate the tensor on the device and copy the // serialized string to that tensor, and later sess.run() will copy it back // to host. We need to optimize this. - const string& resource_name = context->input(0).scalar()(); + const string& resource_name = context->input(0).scalar()(); // Get the resource. - TRTCalibrationResource* resource = nullptr; + TRTEngineCacheResource* resource = nullptr; OP_REQUIRES_OK(context, context->resource_manager()->Lookup( - std::string(kCalibrationContainerName), - resource_name, &resource)); + std::string(kTfTrtContainerName), resource_name, + &resource)); core::ScopedUnref sc(resource); // Serialize the resource as output. - string serialized_resource; - OP_REQUIRES_OK(context, resource->SerializeToString(&serialized_resource)); + string serialized_resource = resource->calib_ctx_->TerminateCalibration(); + OP_REQUIRES(context, !serialized_resource.empty(), + errors::Unknown("Calibration table is empty.")); Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output)); - // Destroy the resource. - OP_REQUIRES_OK(context, - context->resource_manager()->Delete( - std::string(kCalibrationContainerName), resource_name)); - output->scalar()() = serialized_resource; + output->scalar()() = serialized_resource; } }; diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index ab0b21edc41..646a44f1405 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -17,18 +17,23 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -54,6 +59,7 @@ using ::stream_executor::port::StatusOr; // A helper class to call done() when destructed for asynchronous execution. // Helps simultaneous execution of native and TRT engines. + class AsyncHelper : public core::RefCounted { public: AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {} @@ -87,10 +93,15 @@ class TRTEngineOp : public AsyncOpKernel { VectorTensorShapeHasher>; // Execute calibration - void ExecuteCalibration(OpKernelContext* ctx, AsyncHelper* helper); + void ExecuteCalibration(OpKernelContext* ctx, + TRTEngineCacheResource* cache_res, + AsyncHelper* helper); // Construct a function handle for executing native funcdef graph - Status ConstructFunctionHandle(OpKernelContext* ctx); + // These are the exact same function. + + Status ConstructFunctionHandle(FunctionLibraryRuntime* lib, + const string& device_name); // Execute replaced native segment as function Op. void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper); @@ -101,15 +112,15 @@ class TRTEngineOp : public AsyncOpKernel { // Allocate necessary resources for calibration Status AllocateCalibrationResources(OpKernelContext* ctx, - TRTEngineCacheResource* cache_res, - TRTCalibrationResource** cr); + TRTEngineCacheResource* cache_res); Status GetEngineCacheResource(OpKernelContext* ctx, TRTEngineCacheResource** cache_res); // Get engine for the input shape StatusOr GetEngine( - const std::vector& input_shapes, OpKernelContext* ctx); + const std::vector& input_shapes, OpKernelContext* ctx, + TRTEngineCacheResource* cache_res); // Verify that the input shapes are consistent and can be handled by this op. Status VerifyInputShapes(const std::vector& shapes); @@ -127,10 +138,8 @@ class TRTEngineOp : public AsyncOpKernel { // serialized protobuf segment or trt engine depending on static_engine_ flag. string serialized_segment_; - // Name of the function for TF native execution of the segment. If empty, it - // means TF native execution is not allowed, and if TRT engine fails to run - // an error will be returned. - string funcdef_name_; + // The function for TF native execution of the segment. + NameAttrList func_; // GraphDef representation of the segment. GraphDef segment_graph_; @@ -150,7 +159,7 @@ class TRTEngineOp : public AsyncOpKernel { int64 workspace_size_; mutex engine_mutex_; - FunctionLibraryRuntime::Handle native_func_; + FunctionLibraryRuntime::Handle func_handle_; // The finalized calibrator for inference. std::unique_ptr calibrator_; @@ -179,23 +188,61 @@ void* GetTensorAddress(const Tensor* tensor_ptr) { } } -Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) { +static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle, + FunctionLibraryRuntime* flib_runtime, + GraphDef* graph_def) { + const FunctionLibraryDefinition* flib_def = + flib_runtime->GetFunctionLibraryDefinition(); + const FunctionBody* fbody; + fbody = flib_runtime->GetFunctionBody(handle); + if (!fbody) { + return errors::Internal( + "Function body is null when converting from FuncDef to GraphDef."); + } + std::unique_ptr graph(new Graph(flib_def)); + CopyGraph(*fbody->graph, graph.get()); + + auto replace_name = [](const char* const prefix, string* name) { + if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) { + name->replace(0, strlen(prefix), prefix); + return true; + } + return false; + }; + graph->ToGraphDef(graph_def); + // GraphToFunctionDef() will convert all the node names to lowercase. + for (auto& node : *graph_def->mutable_node()) { + if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) { + if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) { + // Instantiation of the function will append _RetVal to the node name, + // need to remove it for backward compatibility. + const char* const suffix_to_remove = "_RetVal"; + if (absl::EndsWith(node.name(), suffix_to_remove)) { + node.mutable_name()->erase(node.name().size() - + strlen(suffix_to_remove)); + } + } + } + for (auto& input : *node.mutable_input()) { + if (!replace_name(IONamePrefixes::kInputPHName, &input)) { + replace_name(IONamePrefixes::kOutputPHName, &input); + } + } + } + return Status::OK(); +} + +Status TRTEngineOp::ConstructFunctionHandle(FunctionLibraryRuntime* lib, + const string& device_name) { VLOG(1) << "Constructing function handle"; - auto lib = ctx->function_library(); if (lib == nullptr) { return errors::Internal("Context function library is null"); } - auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_); - if (fdef == nullptr) { - return errors::Internal("Native FunctionDef ", funcdef_name_, - " can't be found in function library"); - } FunctionLibraryRuntime::InstantiateOptions inst_ops; inst_ops.state_handle = ""; - inst_ops.target = ctx->device()->name(); - native_func_ = 0; - return lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()), inst_ops, - &native_func_); + inst_ops.target = device_name; + return lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_ops, + &func_handle_); } TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) @@ -206,15 +253,7 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("workspace_size_bytes", &workspace_size_)); OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_)); - if (!static_engine_) { - OP_REQUIRES(context, segment_graph_.ParseFromString(serialized_segment_), - errors::InvalidArgument("Failed to parse segment graphdef!")); - VLOG(1) << "Size of serialized GraphDef: " - << serialized_segment_.capacity(); - string tmp; - // Swap with temporary empty string to deallocate the CPU memory. - serialized_segment_.swap(tmp); - } + VLOG(1) << "Constructing " << name(); string precision_string; OP_REQUIRES_OK(context, @@ -222,12 +261,25 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) string calibration_data; OP_REQUIRES_OK(context, context->GetAttr("calibration_data", &calibration_data)); - OP_REQUIRES_OK(context, - context->GetAttr("segment_funcdef_name", &funcdef_name_)); + OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_)); + OP_REQUIRES(context, !func_.name().empty(), + errors::InvalidArgument( + "The TF function for the TRT segment could not be empty")); OP_REQUIRES_OK(context, TrtPrecisionModeFromName(precision_string, &precision_mode_)); OP_REQUIRES_OK(context, context->GetAttr("use_calibration", &use_calibration_)); + func_handle_ = kInvalidHandle; + if (!static_engine_) { + FunctionLibraryRuntime* lib = context->function_library(); + OP_REQUIRES_OK(context, + ConstructFunctionHandle(lib, context->device()->name())); + OP_REQUIRES_OK(context, + FunctionDefToGraphDef(func_handle_, lib, &segment_graph_)); + } + // TODO(laigd): calibration_data is used in TF v1.x and we keep it only for + // backward compatibility reasons. Remove it once all known users switch to + // 2.0. calibration_mode_ = (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 && calibration_data.empty()); @@ -235,20 +287,19 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) calibrator_.reset(new TRTInt8Calibrator(calibration_data)); calibration_data.resize(0); } - native_func_ = kInvalidHandle; OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", &max_cached_engines_)); } void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* helper) { - OP_REQUIRES_ASYNC(ctx, !funcdef_name_.empty(), - errors::Internal("Fallback path is disabled, for ", name()), - *helper); std::vector inputs; std::vector* outputs = new std::vector(); - if (native_func_ == kInvalidHandle) { - OP_REQUIRES_OK_ASYNC(ctx, ConstructFunctionHandle(ctx), *helper); + if (func_handle_ == kInvalidHandle) { + OP_REQUIRES_OK_ASYNC( + ctx, + ConstructFunctionHandle(ctx->function_library(), ctx->device()->name()), + *helper); } auto lib = ctx->function_library(); FunctionLibraryRuntime::Options opts; @@ -261,7 +312,7 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, } helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment: " << name(); - lib->Run(opts, native_func_, inputs, outputs, + lib->Run(opts, func_handle_, inputs, outputs, [this, ctx, outputs, helper](const Status& s) { core::ScopedUnref sc(helper); OP_REQUIRES_OK_ASYNC(ctx, s, *helper); @@ -274,27 +325,14 @@ void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx, } void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, + TRTEngineCacheResource* cache_res, AsyncHelper* helper) { VLOG(1) << "Executing TRT calibration: " << name(); helper->Ref(); core::ScopedUnref sc(helper); - // Get the cache resource outside the LookupOrCreate() below to avoid - // deadlock. - TRTEngineCacheResource* cache_res = nullptr; - OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res), *helper); - core::ScopedUnref unref_cache_res(cache_res); - TRTCalibrationResource* calib_res = nullptr; - OP_REQUIRES_OK_ASYNC( - ctx, - ctx->resource_manager()->LookupOrCreate( - std::string(kCalibrationContainerName), name(), - reinterpret_cast(&calib_res), - {[ctx, cache_res, this](TRTCalibrationResource** cr) -> Status { - return this->AllocateCalibrationResources(ctx, cache_res, cr); - }}), - *helper); - core::ScopedUnref calib_sc(calib_res); - int num_inputs = ctx->num_inputs(); + + CalibrationContext* calib_ctx = cache_res->calib_ctx_.get(); + const int num_inputs = ctx->num_inputs(); // TODO(laigd): need to check that input shape matches. // Pass input data to calibrator std::unordered_map input_data; @@ -307,9 +345,9 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, *helper); // Check the allocated buffer is sufficient for input const auto device_tensor = - calib_res->device_tensors_.at(i).AccessTensor(ctx); + calib_ctx->device_tensors_.at(i).AccessTensor(ctx); CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); - input_data.emplace(StrCat(kInputPHName, i), data_address); + input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address); } VLOG(2) << "Filled map for sending"; // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files @@ -326,7 +364,7 @@ void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx, // until setDone() is called later by the calibration thread in // AllocateCalibrationResources(). In that case, this setBatch() will always // be able to detect the error and return false. - OP_REQUIRES_ASYNC(ctx, calib_res->calibrator_->setBatch(input_data, *stream), + OP_REQUIRES_ASYNC(ctx, calib_ctx->calibrator_->setBatch(input_data, *stream), errors::Internal("Failed to feed calibration data"), *helper); VLOG(2) << "Passed calibration data"; @@ -354,9 +392,8 @@ Status TRTEngineOp::VerifyInputShapes(const std::vector& shapes) { return Status::OK(); } -Status TRTEngineOp::GetEngineInputShapes( - const CacheType& cache, const std::vector& actual_input_shapes, - std::vector* engine_input_shapes) { +bool AreShapesCompatible(const std::vector& actual_shapes, + const std::vector& cached_shapes) { auto match_shape = [](const TensorShape& actual_shape, const TensorShape& cached_shape) { // Match the rank. @@ -369,16 +406,17 @@ Status TRTEngineOp::GetEngineInputShapes( } return true; }; - auto match_shapes = [&](const std::vector& actual_shapes, - const std::vector& cached_shapes) { - for (int i = 0; i < actual_shapes.size(); ++i) { - if (!match_shape(actual_shapes[i], cached_shapes[i])) { - return false; - } + for (int i = 0; i < actual_shapes.size(); ++i) { + if (!match_shape(actual_shapes[i], cached_shapes[i])) { + return false; } - return true; - }; + } + return true; +} +Status TRTEngineOp::GetEngineInputShapes( + const CacheType& cache, const std::vector& actual_input_shapes, + std::vector* engine_input_shapes) { // VerifyInputShapes() already ensured that all input shapes have same // batch size, and are not scalars. *engine_input_shapes = actual_input_shapes; @@ -392,7 +430,7 @@ Status TRTEngineOp::GetEngineInputShapes( ", cached size: ", cached_input_shapes.size(), " vs. actual size: ", actual_input_shapes.size()); } - if (match_shapes(actual_input_shapes, cached_input_shapes)) { + if (AreShapesCompatible(actual_input_shapes, cached_input_shapes)) { const int cached_batch_size = cached_input_shapes[0].dim_size(0); if (min_matched_batch_size > cached_batch_size) { min_matched_batch_size = cached_batch_size; @@ -407,10 +445,44 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, AsyncOpKernel::DoneCallback done) { auto helper = new AsyncHelper(done); core::ScopedUnref sc(helper); - if (calibration_mode_) { - ExecuteCalibration(ctx, helper); + + // Get TRT resource. + TRTEngineCacheResource* cache_res = nullptr; + OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res), *helper); + core::ScopedUnref unref_cache_res(cache_res); + + // Run calibration if in int8+calibration mode. + // * Logic in TF 1.x: + // - During conversion: calibration_mode_ is true and cache size is 0, so it + // will run calibration. + // - During inference: calibration_data will be set, so calibration_mode_ is + // false and it won't trigger calibration. + // * Logic in TF 2.0: + // - During conversion: similar to 1.x. + // - During inference: calibration_data will still be empty, but cache will + // contain the the calibrated engine, so it won't trigger calibration. + // + // TODO(laigd): consider the following alternatives: + // 1. Serialize the state (calibration or inference) using + // TRTEngineInstance proto (or a new proto), so we know which mode we're + // in and don't run calibration during inference (which is invalid). + // 2. Reuse the calibration_data attribute or use a new attribute in the + // NodeDef to indicate whether it's in calibration mode. + if (calibration_mode_ && cache_res->cache_.size() == 0) { + if (!cache_res->calib_ctx_) { + // TODO(laigd): better encapsulation. + mutex_lock lock(engine_mutex_); + if (!cache_res->calib_ctx_) { + OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res), + *helper); + } + } + // TODO(laigd): check that the input shapes match the shapes of the + // persistent tensor in the calibration resource. + ExecuteCalibration(ctx, cache_res, helper); return; } + // Get shapes of inputs to engine. std::vector input_shapes; input_shapes.reserve(ctx->num_inputs()); @@ -418,8 +490,9 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, input_shapes.push_back(ctx->input(i).shape()); } OP_REQUIRES_OK_ASYNC(ctx, VerifyInputShapes(input_shapes), *helper); - StatusOr status = GetEngine(input_shapes, ctx); + StatusOr status = GetEngine(input_shapes, ctx, cache_res); OP_REQUIRES_OK_ASYNC(ctx, status.status(), *helper); + EngineContext* engine_context = status.ValueOrDie(); if (!engine_context->cuda_engine) { VLOG(1) << "Engine retrieval for input shapes: " @@ -446,9 +519,11 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, // input. const int num_batch = ctx->input(0).shape().dim_size(0); const int num_binding = ctx->num_inputs() + ctx->num_outputs(); + std::vector buffers(num_binding); + for (int i = 0; i < ctx->num_inputs(); i++) { - const string input_name = StrCat(kInputPHName, i); + const string input_name = StrCat(IONamePrefixes::kInputPHName, i); const int binding_index = cuda_engine->getBindingIndex(input_name.c_str()); if (binding_index == -1) { const string msg = @@ -490,7 +565,7 @@ bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx, for (int i = 0; i < ctx->num_outputs(); i++) { // Create an output tensor - const string output_name = StrCat(kOutputPHName, i); + const string output_name = StrCat(IONamePrefixes::kOutputPHName, i); const int binding_index = cuda_engine->getBindingIndex(output_name.c_str()); Tensor* output_tensor = nullptr; @@ -580,7 +655,7 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx, // Get engine cache. return ctx->resource_manager()->LookupOrCreate( - "TF-TRT-Engine-Cache", string(resource_name), cache_res, + std::string(kTfTrtContainerName), std::string(resource_name), cache_res, {[this, ctx](TRTEngineCacheResource** cr) -> Status { *cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_); return Status::OK(); @@ -588,14 +663,13 @@ Status TRTEngineOp::GetEngineCacheResource(OpKernelContext* ctx, } StatusOr TRTEngineOp::GetEngine( - const std::vector& input_shapes, OpKernelContext* ctx) { + const std::vector& input_shapes, OpKernelContext* ctx, + TRTEngineCacheResource* cache_res) { static EngineContext empty_context; - TRTEngineCacheResource* cache_res = nullptr; - TF_RETURN_IF_ERROR(GetEngineCacheResource(ctx, &cache_res)); - core::ScopedUnref sc(cache_res); mutex_lock lock(engine_mutex_); - // TODO(tmorris): using first input to get batch size - is this reliable? + // Using first input to get batch size is reliable - VerifyInputShapes() has + // verified that. const int batch_size = input_shapes[0].dim_size(0); auto& cache = cache_res->cache_; auto allocator = cache_res->allocator_.get(); @@ -607,9 +681,7 @@ StatusOr TRTEngineOp::GetEngine( // single element containing the only engine. if (static_engine_) { if (cache.size()) { - // Batch size of engine must be >= the input batch size - // TODO(tmorris): use match compatible function? - if (cache.begin()->first[0].dim_size(0) >= batch_size) { + if (AreShapesCompatible(input_shapes, cache.begin()->first)) { return cache.begin()->second.get(); } return &empty_context; @@ -648,9 +720,7 @@ StatusOr TRTEngineOp::GetEngine( return cache.at(engine_input_shapes).get(); } // static_engine_ - // Handle the dynamic engine case. - // See if there is a compatible engine cached. The batch size should be <= the - // cached batch size. + // Handle the dynamic engine case. See if there is a compatible engine cached. std::vector engine_input_shapes; TF_RETURN_IF_ERROR( GetEngineInputShapes(cache, input_shapes, &engine_input_shapes)); @@ -694,17 +764,19 @@ StatusOr TRTEngineOp::GetEngine( return cache.at(engine_input_shapes).get(); } +// TODO(hinsu): Move this allocation to CalibrationContext constructor, if +// possible. Status TRTEngineOp::AllocateCalibrationResources( - OpKernelContext* ctx, TRTEngineCacheResource* cache_res, - TRTCalibrationResource** cr) { - auto cres = new TRTCalibrationResource(); - *cr = cres; + OpKernelContext* ctx, TRTEngineCacheResource* cache_res) { + cache_res->calib_ctx_ = absl::make_unique(); + auto* cres = cache_res->calib_ctx_.get(); + // Get the input shapes. const int batch_size = ctx->input(0).dim_size(0); const int num_inputs = ctx->num_inputs(); std::vector shapes; cres->device_tensors_.resize(num_inputs); - VLOG(1) << " Constructing calibrator"; + VLOG(1) << "Constructing calibrator"; for (int i = 0; i < num_inputs; i++) { // allocate workspace on device for inputs const Tensor& t = ctx->input(i); @@ -719,7 +791,7 @@ Status TRTEngineOp::AllocateCalibrationResources( "Unsupported data type encountered in input ", i); } cres->device_buffers_.emplace( - StrCat(kInputPHName, i), + StrCat(IONamePrefixes::kInputPHName, i), std::pair(device_address, device_tensor->TotalBytes())); } cres->calibrator_.reset( @@ -733,55 +805,52 @@ Status TRTEngineOp::AllocateCalibrationResources( } cache_res->Ref(); - cres->thr_.reset( - new std::thread([this, cres, shapes, platform_gpu_id, cache_res]() { - core::ScopedUnref sc(cache_res); + cres->thr_.reset(new std::thread([this, cres, shapes, platform_gpu_id, + cache_res]() { + core::ScopedUnref sc(cache_res); - LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id - << ", Calibration Resource @ " << cres; - auto err = cudaSetDevice(platform_gpu_id); - if (err != cudaSuccess) { - // TODO(aaroey): should return error here. - LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id - << " in calibration thread"; - } - std::vector partial_shapes(shapes.begin(), - shapes.end()); - // ConvertGraphDefToEngine() will try to build the engine. This thread - // will loop inside buildCudaEngine() consuming the calibration data - // that is set by the TF op, and drive the builder until calibrator - // returns false. Engine is discarded after calibration table is - // generated - // - // TODO(aaroey): maybe setting the max batch size using the python - // calibration wrapper class. - auto s = convert::ConvertGraphDefToEngine( - this->segment_graph_, TrtPrecisionMode::INT8, - cres->calibrator_->getBatchSize(), this->workspace_size_, - partial_shapes, &cres->logger_, cache_res->allocator_.get(), - cres->calibrator_.get(), &cres->engine_, - /*use_calibration=*/true, - /*convert_successfully=*/nullptr); - if (!s.ok()) { - LOG(ERROR) << "Calibration failed: " << s; - cres->calibrator_->setDone(); // Ignore further pushes - } + LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id + << ", Calibration Resource @ " << cres; + auto err = cudaSetDevice(platform_gpu_id); + if (err != cudaSuccess) { + // TODO(aaroey): should return error here. + LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id + << " in calibration thread"; + } + std::vector partial_shapes(shapes.begin(), + shapes.end()); + // ConvertGraphDefToEngine() will try to build the engine. This thread + // will loop inside buildCudaEngine() consuming the calibration data + // that is set by the TF op, and drive the builder until calibrator + // returns false. Engine is discarded after calibration table is + // generated + // + // TODO(aaroey): maybe setting the max batch size using the python + // calibration wrapper class. + auto s = convert::ConvertGraphDefToEngine( + this->segment_graph_, TrtPrecisionMode::INT8, + cres->calibrator_->getBatchSize(), this->workspace_size_, + partial_shapes, &cache_res->GetLogger(), cache_res->allocator_.get(), + cres->calibrator_.get(), &cres->engine_, + /*use_calibration=*/true, + /*convert_successfully=*/nullptr); + if (!s.ok()) { + LOG(ERROR) << "Calibration failed: " << s; + cres->calibrator_->setDone(); // Ignore further pushes + } else { + // Transfer the ownership of the engine to the engine cache, so we can + // dump it out during conversion for TF 2.0. + mutex_lock lock(this->engine_mutex_); + this->calibrator_ = std::move(cres->calibrator_); + TrtUniquePtrType exec_context( + cres->engine_->createExecutionContext()); + cache_res->cache_.emplace( + shapes, absl::make_unique(std::move(cres->engine_), + std::move(exec_context))); + } - // Transfer the ownership of the engine to the engine cache, so we can - // dump it out during conversion for TF 2.0. - if (cache_res) { - mutex_lock lock(this->engine_mutex_); - cres->SetCalibrationTable(); - this->calibrator_ = std::move(cres->calibrator_); - TrtUniquePtrType exec_context( - cres->engine_->createExecutionContext()); - cache_res->cache_.emplace( - shapes, absl::make_unique( - std::move(cres->engine_), std::move(exec_context))); - } - - VLOG(1) << "Calibration loop terminated " << this->name(); - })); + VLOG(1) << "Calibration loop terminated " << this->name(); + })); VLOG(1) << "initialized calibrator resource"; return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index d859d5f957f..4228136e0c8 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -22,11 +22,17 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" +#include "tensorflow/compiler/tf2tensorrt/convert/convert_graph.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h" #include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/ops_testutil.h" @@ -39,6 +45,7 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +using ::absl::StrCat; using ::testing::ElementsAre; class TRTEngineOpTestBase : public OpsTestBase { @@ -50,25 +57,32 @@ class TRTEngineOpTestBase : public OpsTestBase { // Create simple TF graph. Scope s = Scope::NewRootScope(); - auto feed = ops::Placeholder(s.WithOpName("TensorRTInputPH_0"), dtype, - ops::Placeholder::Shape({-1, -1})); + auto feed = ops::_Arg(s.WithOpName("TensorRTInputPH_0"), dtype, 0); auto add = ops::Add(s.WithOpName("add"), feed, feed); - ops::Identity(s.WithOpName("TensorRTOutputPH_0"), add); + ops::_Retval(s.WithOpName("TensorRTOutputPH_0"), add, 0); // Serialize the graph. TRTEngineOp will convert it using dynamic mode. GraphDef graph_def; TF_ASSERT_OK(s.ToGraphDef(&graph_def)); + Graph* graph = s.graph(); + const char* op_name = "myop"; + TF_ASSERT_OK( + convert::RegisterGraphToFunctionLibrary(graph_def, graph, op_name)); + TF_ASSERT_OK(flib_def_->AddLibrary(graph->flib_def())); + PartialTensorShape shape({-1, -1}); // Create the op. OpsTestBase::SetDevice(DEVICE_GPU, std::move(device)); - TF_ASSERT_OK(NodeDefBuilder("myop", "TRTEngineOp") + NameAttrList function; + function.set_name(StrCat(op_name, "_native_segment")); + TF_ASSERT_OK(NodeDefBuilder(op_name, "TRTEngineOp") .Input(FakeInput(1, dtype)) .Attr("input_shapes", {shape}) .Attr("output_shapes", {shape}) .Attr("static_engine", false) - .Attr("segment_funcdef_name", "") // no native fallback - .Attr("serialized_segment", graph_def.SerializeAsString()) + .Attr("segment_func", function) + .Attr("serialized_segment", "") .Attr("calibration_data", "") .Attr("max_cached_engines_count", max_cached_engines_count) .Attr("workspace_size_bytes", 1 << 20) @@ -76,7 +90,7 @@ class TRTEngineOpTestBase : public OpsTestBase { .Attr("use_calibration", false) .Attr("OutT", {dtype}) .Finalize(OpsTestBase::node_def())); - TF_ASSERT_OK(OpsTestBase::InitOp()); + TF_ASSERT_OK(InitOpWithFunctionLibrary()); } template @@ -90,9 +104,20 @@ class TRTEngineOpTestBase : public OpsTestBase { inputs_.clear(); gtl::STLDeleteElements(&tensors_); } + + private: + Status InitOpWithFunctionLibrary() { + OpKernel* kernel = nullptr; + Status status = CreateOpKernel(device_type_, device_, allocator(), + pflr_->GetFLR(device_->name()), node_def_, + TF_GRAPH_DEF_VERSION, &kernel); + kernel_ = std::unique_ptr(kernel); + if (kernel_ != nullptr) input_types_ = kernel_->input_types(); + return status; + } }; -TEST_F(TRTEngineOpTestBase, dynamic_shapes) { +TEST_F(TRTEngineOpTestBase, DynamicShapes) { TRTEngineOpTestBase::AddSimpleTrtOp(DT_FLOAT, /*max_cached_engines_count=*/4); // Execute the op with batch size > 1. @@ -101,8 +126,8 @@ TEST_F(TRTEngineOpTestBase, dynamic_shapes) { // Get the engine cache. TRTEngineCacheResource* cache_resource = nullptr; - TF_ASSERT_OK(device_->resource_manager()->Lookup("TF-TRT-Engine-Cache", - "myop", &cache_resource)); + TF_ASSERT_OK( + device_->resource_manager()->Lookup("TF-TRT", "myop", &cache_resource)); core::ScopedUnref sc(cache_resource); // It should contain only one engine. diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index 8f6f08710d1..891b75be824 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/lib/io/record_writer.h" @@ -40,11 +41,9 @@ namespace tensorflow { namespace tensorrt { using ::nvinfer1::IRuntime; -class CreateTRTEngineCacheHandle : public OpKernel { +class CreateTRTResourceHandle : public OpKernel { public: - explicit CreateTRTEngineCacheHandle(OpKernelConstruction* ctx) - : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + explicit CreateTRTResourceHandle(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("resource_name", &resource_name_)); } @@ -57,12 +56,11 @@ class CreateTRTEngineCacheHandle : public OpKernel { OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle_, attr)); - VLOG(1) << "Creating TRT engine cache resource handle for container " - << container_ << " and op " << resource_name_ << " on device " - << ctx->device()->name(); + VLOG(1) << "Creating TRT engine cache resource handle for op " + << resource_name_ << " on device " << ctx->device()->name(); handle_.scalar()() = - MakeResourceHandle(ctx, container_, - resource_name_); + MakeResourceHandle( + ctx, std::string(kTfTrtContainerName), resource_name_); initialized_ = true; } } @@ -70,23 +68,22 @@ class CreateTRTEngineCacheHandle : public OpKernel { } private: - string container_; string resource_name_; Tensor handle_; mutex mutex_; bool initialized_ GUARDED_BY(mutex_) = false; - TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTEngineCacheHandle); + TF_DISALLOW_COPY_AND_ASSIGN(CreateTRTResourceHandle); }; -REGISTER_KERNEL_BUILDER(Name("CreateTRTEngineCacheHandle") +REGISTER_KERNEL_BUILDER(Name("CreateTRTResourceHandle") .Device(DEVICE_GPU) - .HostMemory("engine_cache_handle"), - CreateTRTEngineCacheHandle); + .HostMemory("resource_handle"), + CreateTRTResourceHandle); -class PopulateTRTEngineCache : public OpKernel { +class InitializeTRTResource : public OpKernel { public: - explicit PopulateTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) { + explicit InitializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) { OP_REQUIRES_OK( ctx, ctx->GetAttr("max_cached_engines_count", &max_cached_engines_)); } @@ -112,7 +109,7 @@ class PopulateTRTEngineCache : public OpKernel { resource->cache_.size(), " entries.")); // Get the file name. - const string& filename = ctx->input(1).scalar()(); + const string& filename = ctx->input(1).scalar()(); OP_REQUIRES(ctx, !filename.empty(), errors::InvalidArgument("filename cannot be empty.")); @@ -124,7 +121,7 @@ class PopulateTRTEngineCache : public OpKernel { uint64 offset = 0; int num_loaded_engine = 0; do { - string record; + tstring record; Status status = reader->ReadRecord(&offset, &record); if (errors::IsOutOfRange(status)) break; @@ -150,48 +147,51 @@ class PopulateTRTEngineCache : public OpKernel { raw_engine->createExecutionContext()))); ++num_loaded_engine; } while (1); - VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines to container " - << handle.container() << " for op " << handle.name() - << " on device " << ctx->device()->name() << " from file " - << filename; + VLOG(1) << "Loaded " << num_loaded_engine << " TRT engines for op " + << handle.name() << " on device " << ctx->device()->name() + << " from file " << filename; } private: // Maximum number of cached engines int max_cached_engines_; - TF_DISALLOW_COPY_AND_ASSIGN(PopulateTRTEngineCache); + TF_DISALLOW_COPY_AND_ASSIGN(InitializeTRTResource); }; -REGISTER_KERNEL_BUILDER(Name("PopulateTRTEngineCache") +REGISTER_KERNEL_BUILDER(Name("InitializeTRTResource") .Device(DEVICE_GPU) - .HostMemory("engine_cache_handle"), - PopulateTRTEngineCache); + .HostMemory("resource_handle"), + InitializeTRTResource); -class DumpTRTEngineCache : public OpKernel { +class SerializeTRTResource : public OpKernel { public: - explicit DumpTRTEngineCache(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_cache_after_dump", - &delete_cache_after_dump_)); + explicit SerializeTRTResource(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("delete_resource", &delete_resource_)); } void Compute(OpKernelContext* ctx) override { - const string& container = ctx->input(0).scalar()(); - const string& resource_name = ctx->input(1).scalar()(); - const string& filename = ctx->input(2).scalar()(); + const string& resource_name = ctx->input(0).scalar()(); + const string& filename = ctx->input(1).scalar()(); OP_REQUIRES(ctx, !filename.empty(), errors::InvalidArgument("filename cannot be empty.")); + // Lookup engine cache resource. TRTEngineCacheResource* resource = nullptr; - OP_REQUIRES_OK(ctx, ctx->resource_manager()->Lookup( - container, resource_name, &resource)); + OP_REQUIRES_OK( + ctx, ctx->resource_manager()->Lookup(std::string(kTfTrtContainerName), + resource_name, &resource)); core::ScopedUnref unref_me(resource); + // Terminate the calibration if any. + if (resource->calib_ctx_) resource->calib_ctx_->TerminateCalibration(); + // Serialize the engines and write them to file. std::unique_ptr file; OP_REQUIRES_OK(ctx, ctx->env()->NewWritableFile(filename, &file)); auto writer = absl::make_unique(file.get()); + int num_serialized_engines = 0; for (const auto& pair : resource->cache_) { // Ignore engines that failed to build. const std::unique_ptr& engine = pair.second; @@ -211,30 +211,29 @@ class DumpTRTEngineCache : public OpKernel { OP_REQUIRES_OK(ctx, writer->WriteRecord(engine_instance.SerializeAsString())); + ++num_serialized_engines; } - VLOG(1) << "Serialized " << resource->cache_.size() - << " TRT engines in container " << container << " for op " + VLOG(1) << "Serialized " << num_serialized_engines << " TRT engines for op " << resource_name << " on device " << ctx->device()->name() << " to file " << filename; - if (delete_cache_after_dump_) { - VLOG(1) << "Destroying TRT engine cache resource in container " - << container << " for op " << resource_name << " on device " - << ctx->device()->name(); + if (delete_resource_) { + VLOG(1) << "Destroying TRT engine cache resource for op " << resource_name + << " on device " << ctx->device()->name(); OP_REQUIRES_OK(ctx, ctx->resource_manager()->Delete( - container, resource_name)); + std::string(kTfTrtContainerName), resource_name)); } } private: - bool delete_cache_after_dump_ = false; + bool delete_resource_ = false; - TF_DISALLOW_COPY_AND_ASSIGN(DumpTRTEngineCache); + TF_DISALLOW_COPY_AND_ASSIGN(SerializeTRTResource); }; -REGISTER_KERNEL_BUILDER(Name("DumpTRTEngineCache").Device(DEVICE_GPU), - DumpTRTEngineCache); +REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU), + SerializeTRTResource); } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc index b3e541aab40..d27a67582d8 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc @@ -92,11 +92,10 @@ TEST_F(TRTEngineResourceOpsTest, Basic) { SetDevice(DEVICE_GPU, std::move(device)); // Create the resource handle. - const string container = "mycontainer"; + const string container(kTfTrtContainerName); const string resource_name = "myresource"; Reset(); - TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTEngineCacheHandle") - .Attr("container", container) + TF_ASSERT_OK(NodeDefBuilder("op", "CreateTRTResourceHandle") .Attr("resource_name", resource_name) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); @@ -108,7 +107,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) { EXPECT_TRUE( errors::IsNotFound(rm->Lookup(container, resource_name, &resource))); - // Create the resouce using an empty file with PopulateTRTEngineCache. + // Create the resouce using an empty file with InitializeTRTResource. Reset(); Env* env = Env::Default(); const string filename = io::JoinPath(testing::TmpDir(), "trt_engine_file"); @@ -116,7 +115,7 @@ TEST_F(TRTEngineResourceOpsTest, Basic) { std::unique_ptr file; TF_ASSERT_OK(env->NewWritableFile(filename, &file)); } - TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache") + TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource") .Input(FakeInput(DT_RESOURCE)) .Input(FakeInput(DT_STRING)) .Attr("max_cached_engines_count", 1) @@ -137,18 +136,16 @@ TEST_F(TRTEngineResourceOpsTest, Basic) { absl::make_unique(std::move(engine), std::move(context))); resource->Unref(); - // Serialize the engine using DumpTRTEngineCache op. + // Serialize the engine using SerializeTRTResource op. Reset(); - TF_ASSERT_OK(NodeDefBuilder("op", "DumpTRTEngineCache") - .Attr("delete_cache_after_dump", true) - .Input(FakeInput(DT_STRING)) + TF_ASSERT_OK(NodeDefBuilder("op", "SerializeTRTResource") + .Attr("delete_resource", true) .Input(FakeInput(DT_STRING)) .Input(FakeInput(DT_STRING)) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); - AddInputFromArray(TensorShape({}), {container}); - AddInputFromArray(TensorShape({}), {resource_name}); - AddInputFromArray(TensorShape({}), {filename}); + AddInputFromArray(TensorShape({}), {resource_name}); + AddInputFromArray(TensorShape({}), {filename}); TF_ASSERT_OK(RunOpKernel()); // Make sure the cache is deleted. @@ -178,14 +175,14 @@ TEST_F(TRTEngineResourceOpsTest, Basic) { // Recreate the cache resource. Reset(); - TF_ASSERT_OK(NodeDefBuilder("op", "PopulateTRTEngineCache") + TF_ASSERT_OK(NodeDefBuilder("op", "InitializeTRTResource") .Input(FakeInput(DT_RESOURCE)) .Input(FakeInput(DT_STRING)) .Attr("max_cached_engines_count", 1) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); AddInputFromArray(TensorShape({}), {handle}); - AddInputFromArray(TensorShape({}), {filename}); + AddInputFromArray(TensorShape({}), {filename}); TF_ASSERT_OK(RunOpKernel()); EXPECT_TRUE(rm->Lookup(container, resource_name, &resource).ok()); EXPECT_EQ(1, resource->cache_.size()); diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index b8f9058d8f6..7d8ff6dbe43 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -33,7 +33,7 @@ namespace tensorflow { // key to cache the instantiated functions for different executor subgraphs. REGISTER_OP("TRTEngineOp") .Attr("serialized_segment: string") - .Attr("segment_funcdef_name: string") + .Attr("segment_func: func = {}") .Attr("InT: list({int8,float16,float32,int32})") .Attr("OutT: list({int8,float16,float32,int32})") .Attr("max_cached_engines_count: int = 1") @@ -51,10 +51,11 @@ REGISTER_OP("TRTEngineOp") // inference function as a workaround. .SetShapeFn(shape_inference::UnknownShape) // Deprecated attributes. + .Attr("segment_funcdef_name: string = ''") .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("fixed_input_size: bool = true") - .Attr("input_shapes: list(shape)") - .Attr("output_shapes: list(shape)") + .Attr("input_shapes: list(shape) = []") + .Attr("output_shapes: list(shape) = []") .Attr("static_engine: bool = true"); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc index 67177efe228..01911de66ec 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc @@ -24,23 +24,21 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("CreateTRTEngineCacheHandle") - .Attr("container: string") +REGISTER_OP("CreateTRTResourceHandle") .Attr("resource_name: string") - .Output("engine_cache_handle: resource") + .Output("resource_handle: resource") .SetIsStateful() .SetShapeFn(shape_inference::ScalarShape); -REGISTER_OP("PopulateTRTEngineCache") +REGISTER_OP("InitializeTRTResource") .Attr("max_cached_engines_count: int = 1") - .Input("engine_cache_handle: resource") + .Input("resource_handle: resource") .Input("filename: string") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs); -REGISTER_OP("DumpTRTEngineCache") - .Attr("delete_cache_after_dump: bool = false") - .Input("container: string") +REGISTER_OP("SerializeTRTResource") + .Attr("delete_resource: bool = false") .Input("resource_name: string") .Input("filename: string") .SetIsStateful() diff --git a/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.cc b/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.cc deleted file mode 100644 index 5d6e11b536e..00000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT - -namespace tensorflow { -namespace tensorrt { - -const absl::string_view kCalibrationContainerName = "TF-TRT-Calibration"; - -TRTCalibrationResource::~TRTCalibrationResource() { - VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); -} - -string TRTCalibrationResource::DebugString() const { - std::stringstream oss; - using std::dec; - using std::endl; - using std::hex; - oss << " Calibrator = " << hex << calibrator_.get() << dec << endl - << " Builder = " << hex << builder_.get() << dec << endl - << " Engine = " << hex << engine_.get() << dec << endl - << " Logger = " << hex << &logger_ << dec << endl - << " Thread = " << hex << thr_.get() << dec << endl; - return oss.str(); -} - -void TRTCalibrationResource::SetCalibrationTable() { - calibration_table_ = calibrator_->getCalibrationTableAsString(); -} - -Status TRTCalibrationResource::SerializeToString(string* serialized) { - calibrator_->waitAndSetDone(); - thr_->join(); - *serialized = calibration_table_; - if (serialized->empty()) { - return errors::Unknown("Calibration table is empty."); - } - return Status::OK(); -} - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA diff --git a/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h b/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h deleted file mode 100644 index e7c29e9f1ed..00000000000 --- a/tensorflow/compiler/tf2tensorrt/utils/calibration_resource.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ -#define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" -#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" - -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT -#include "third_party/tensorrt/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { - -ABSL_CONST_INIT extern const absl::string_view kCalibrationContainerName; - -class TRTCalibrationResource : public ResourceBase { - public: - ~TRTCalibrationResource() override; - - string DebugString() const override; - - void SetCalibrationTable(); - - Status SerializeToString(string* serialized); - - // Lookup table for temporary staging areas of input tensors for calibration. - std::unordered_map> device_buffers_; - - // Temporary staging areas for calibration inputs. - std::vector device_tensors_; - - string calibration_table_; - std::unique_ptr calibrator_; - TrtUniquePtrType builder_; - TrtUniquePtrType engine_; - Logger logger_; - // TODO(sami): Use threadpool threads! - std::unique_ptr thr_; -}; - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA -#endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_RESOURCES_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index 008cabb9cb4..885f58cd70c 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorflow/stream_executor/platform/dso_loader.h" #include "third_party/tensorrt/NvInfer.h" #endif @@ -23,13 +24,16 @@ namespace tensorflow { namespace tensorrt { bool IsGoogleTensorRTEnabled() { - // TODO(laigd): consider also checking if tensorrt shared libraries are - // accessible. We can then direct users to this function to make sure they can - // safely write code that uses tensorrt conditionally. E.g. if it does not - // check for for tensorrt, and user mistakenly uses tensorrt, they will just - // crash and burn. #if GOOGLE_CUDA && GOOGLE_TENSORRT - return true; + auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries(); + if (!handle_or.ok()) { + LOG(WARNING) << "Cannot dlopen some TensorRT libraries. If you would like " + "to use Nvidia GPU with TensorRT, please make sure the " + "missing libraries mentioned above are installed properly."; + return false; + } else { + return true; + } #else return false; #endif diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc index 43dcd52b5a2..5ab6bf1a317 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -30,6 +30,28 @@ limitations under the License. namespace tensorflow { namespace tensorrt { +string CalibrationContext::TerminateCalibration() { + mutex_lock l(mu_); + if (terminated_) return calibration_table_; + + TRTInt8Calibrator* raw_calibrator = calibrator_.get(); + raw_calibrator->waitAndSetDone(); + terminated_ = true; + + // At this point the calibration thread `thr_` is woken up and can + // transfer the ownership of `calibrator_` and `engine_` at any time, so + // it's not safe to use `calibrator_` below, but we can still access it + // using raw pointer. + // TODO(laigd): make TRTEngineOp::AllocateCalibrationResources() a member + // function of this class instead. + + thr_->join(); + calibration_table_ = raw_calibrator->getCalibrationTableAsString(); + return calibration_table_; +} + +const absl::string_view kTfTrtContainerName = "TF-TRT"; + Logger& TRTEngineCacheResource::GetLogger() { static Logger* logger = new Logger(); return *logger; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 442e0bcfb53..8d603ac4d55 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -17,10 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LRU_CACHE_H_ #include +#include #include #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" +#include "tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/errors.h" @@ -137,6 +139,31 @@ struct EngineContext { GUARDED_BY(mu); }; +// Contains the context required to build the calibration data. +class CalibrationContext { + public: + string TerminateCalibration(); + + // Lookup table for temporary staging areas of input tensors for calibration. + std::unordered_map> device_buffers_; + + // Temporary staging areas for calibration inputs. + std::vector device_tensors_; + + std::unique_ptr calibrator_; + TrtUniquePtrType builder_; + TrtUniquePtrType engine_; + // TODO(sami): Use threadpool threads! + std::unique_ptr thr_; + + private: + mutex mu_; + bool terminated_ GUARDED_BY(mu_) = false; + std::string calibration_table_ GUARDED_BY(mu_); +}; + +ABSL_CONST_INIT extern const absl::string_view kTfTrtContainerName; + class TRTEngineCacheResource : public ResourceBase { public: // According to the TensorRT API, the logger is considered a singleton by the @@ -159,6 +186,10 @@ class TRTEngineCacheResource : public ResourceBase { LRUCache, std::unique_ptr, VectorTensorShapeHasher> cache_; + + // TODO(hinsu): Use different calibration context for the available shapes and + // attach it to each item of the cache. + std::unique_ptr calib_ctx_; }; #endif // GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 6a28a5acb14..f6bf672d6a0 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test") load( - "//tensorflow/core:platform/default/cuda_build_defs.bzl", + "//tensorflow/core/platform:default/cuda_build_defs.bzl", "if_cuda_is_configured", ) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library", "xla_py_proto_library") @@ -29,6 +29,7 @@ package_group( packages = [ "//learning/brain/tools/tf_replay/...", "//tensorflow/...", + "//tensorflow_models/...", ], ) @@ -202,13 +203,15 @@ cc_library( visibility = [":friends"], deps = [ ":common", + ":frontend_attributes_util", ":host_compute_metadata_proto", + ":rearrange_function_argument", ":sharding_util", ":side_effect_util", ":tf2xla_util", "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", - "//tensorflow/compiler/tf2xla:rearrange_function_argument", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -269,6 +272,21 @@ cc_library( ], ) +cc_library( + name = "frontend_attributes_util", + srcs = ["frontend_attributes_util.cc"], + hdrs = ["frontend_attributes_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], +) + cc_library( name = "sharding_util", srcs = ["sharding_util.cc"], @@ -577,6 +595,7 @@ cc_library( "functionalize_while.h", ], deps = [ + ":frontend_attributes_util", ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index ad2cc7b32f0..48513a43fb3 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -91,7 +91,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, FunctionLibraryRuntime* flib_runtime) { DCHECK(op_def != nullptr || op_kernel != nullptr); // TODO(b/124403063): Implement similar functionality for function call nodes. - if (node.op() == "While") { + if (node.op() == "While" || node.op() == "StatelessWhile") { // For While nodes, recurse into the body and cond graphs. const FunctionBody* fcond = nullptr; const FunctionBody* fbody = nullptr; diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.cc b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc new file mode 100644 index 00000000000..e0c70b81771 --- /dev/null +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.cc @@ -0,0 +1,41 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +const char kXlaFrontendAttributesAttrName[] = "_XlaFrontendAttributes"; + +xla::StatusOr> +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs) { + const AttrValue* attr = attrs.Find(kXlaFrontendAttributesAttrName); + if (attr == nullptr) { + return xla::StatusOr>( + absl::nullopt); + } + xla::FrontendAttributes attributes; + if (!attributes.ParseFromString(attr->s())) { + return errors::InvalidArgument( + "Experimental _XlaFrontendAttributes attribute was not a valid encoded " + "xla::FrontendAttributes proto."); + } + return absl::optional(attributes); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/frontend_attributes_util.h b/tensorflow/compiler/tf2xla/frontend_attributes_util.h new file mode 100644 index 00000000000..421f21e71d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/frontend_attributes_util.h @@ -0,0 +1,38 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ + +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/node_def_util.h" + +namespace tensorflow { + +// Frontend Attributes Id. +extern const char kXlaFrontendAttributesAttrName[]; +// Return the FrontendAttributes stored in the AttrSlice if there are some. +// +// Return an InvalidArgument error if some attributes are present but +// cannot be parsed. +xla::StatusOr> +GetFrontendAttributesFromAttrSlice(const AttrSlice& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FRONTEND_ATTRIBUTES_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index 54cebc61778..793a56e865d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -48,6 +48,43 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { return AddNodeDefToGraph(ret_def, graph); } +Status ExtractWhileLoopFrames( + const std::vector& cf_info, const Graph* graph, + std::unordered_map* frames) { + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + WhileLoopFrame& frame = (*frames)[cf.frame_name]; + WhileLoopFrame* parent = + &(*frames)[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } + + if (IsEnter(node)) { + WhileLoopArg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + return Status::OK(); +} + // Check that the graph has no cycle containing the given node. Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { std::vector ready; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 582b49d5116..f986376c8e3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -18,12 +18,56 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/graph/graph.h" -// Utility functions shared between functionalize cond and while. +// Utility functions shared between functionalize cond and while +// or used by other graph optimization passes. namespace tensorflow { +// Information about a loop argument. +struct WhileLoopArg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct WhileLoopFrame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + WhileLoopFrame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Extracts v1 while loops within a graph and creates a map of +// . +Status ExtractWhileLoopFrames( + const std::vector& cf_info, const Graph* graph, + std::unordered_map* frames); + // Check that the graph has no cycle containing the given node. Status CheckNodeNotInCycle(const Node* node, const int num_nodes); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index e4a21f90598..74790f9ee4d 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -42,42 +43,6 @@ namespace { using xla::StatusOr; -// Information about a loop argument. -struct Arg { - // Every loop argument has an Enter node. - Node* enter; - - // Is the loop argument a loop-invariant value? Taken from the `is_constant` - // attribute on the Enter node. - bool is_loop_invariant; - - // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant - // arguments must have all of the following nodes: - Node* merge = nullptr; - Node* switch_node = nullptr; - Node* next_iteration = nullptr; - Node* exit = nullptr; -}; - -// Information about a loop frame. -struct Frame { - string name; - - // Pointer to the parent frame. The root frame has a pointer to itself. - Frame* parent = nullptr; - int num_children = 0; - - // Arguments to this loop. - std::vector args; - - // The loop condition of the loop. There should be exactly one loop condition - // in every loop. - Node* loop_cond = nullptr; - - // Set of nodes that belong to the loop frame. - std::unordered_set nodes; -}; - // Copies a subgraph from `graph` to `output` by performing a reverse DFS // starting at nodes in vector `stack`. // `node_map` is a vector indexed by source node ID to dest nodes. @@ -93,7 +58,7 @@ struct Frame { // taking from the Switch node was not necessarily the first output, but _Arg // nodes only have one output. By adding the Switch node to `squash_src_outputs` // we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame* frame, +Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame, std::vector stack, const std::vector& squash_src_outputs, std::vector* node_map, Graph* output) { @@ -154,7 +119,7 @@ StatusOr BuildArgNode(Graph* graph, DataType type, int index) { } // Builds a graph for the loop condition. -Status BuildLoopCondition(const Graph& graph, Frame* frame, +Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame, std::unique_ptr* cond_output) { VLOG(2) << "Building loop condition for " << frame->name; *cond_output = absl::make_unique(graph.op_registry()); @@ -166,7 +131,7 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, // Build one _Arg node for each Enter node. for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; + const WhileLoopArg& arg = frame->args[i]; TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, arg.enter->input_type(0), i)); @@ -190,7 +155,7 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, } // Builds a graph for the loop body. -Status BuildLoopBody(const Graph& graph, Frame* frame, +Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, DataTypeVector* arg_types, std::unique_ptr* body_output) { VLOG(2) << "Building loop body for " << frame->name; @@ -206,7 +171,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, next_iterations.reserve(frame->args.size()); arg_types->reserve(frame->args.size()); for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; + const WhileLoopArg& arg = frame->args[i]; DataType dtype = arg.enter->input_type(0); arg_types->push_back(dtype); @@ -297,7 +262,7 @@ Status AddMissingFunctionDef(const FunctionDef& fdef, } Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, - Graph* graph, Frame* frame, + Graph* graph, WhileLoopFrame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << DumpGraphToFile("functionalize_before", *graph, library); @@ -307,8 +272,8 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // shared Enter node. We clone Enter nodes with multiple successors to // maintain the invariant of a unique Enter node per argument of the final // loop. - std::vector args; - for (const Arg& arg : frame->args) { + std::vector args; + for (const WhileLoopArg& arg : frame->args) { if (arg.is_loop_invariant) { args.push_back(arg); } else { @@ -319,7 +284,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, continue; } TF_RET_CHECK(!edges[i]->IsControlEdge()) << edges[i]->src()->name(); - Arg new_arg; + WhileLoopArg new_arg; new_arg.is_loop_invariant = false; if (i == 0) { new_arg.enter = arg.enter; @@ -342,7 +307,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, frame->args = std::move(args); std::sort(frame->args.begin(), frame->args.end(), - [](const Arg& a, const Arg& b) { + [](const WhileLoopArg& a, const WhileLoopArg& b) { return NodeCmpByNameResourcesLast()(a.enter, b.enter); }); @@ -368,7 +333,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // ^ ^ // | | // ... ... - for (Arg& arg : frame->args) { + for (WhileLoopArg& arg : frame->args) { if (!arg.is_loop_invariant) { // Follow the edge from the Enter to Merge. const Edge* enter_merge = nullptr; @@ -530,6 +495,12 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, builder.Attr("cond", cond_name); builder.Attr("body", body_name); string outside_compilation; + string frontend_attributes; + if (GetNodeAttr(frame->loop_cond->def(), kXlaFrontendAttributesAttrName, + &frontend_attributes) + .ok()) { + builder.Attr(kXlaFrontendAttributesAttrName, frontend_attributes); + } if (GetNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName, &outside_compilation) .ok()) { @@ -537,7 +508,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, } std::vector inputs; for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; + const WhileLoopArg& arg = frame->args[i]; const Edge* in_edge; TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -553,7 +524,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Copies edges to the Enter nodes and from the Exit nodes onto the While. for (int i = 0; i < frame->args.size(); ++i) { - const Arg& arg = frame->args[i]; + const WhileLoopArg& arg = frame->args[i]; const Edge* in_edge; TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -613,39 +584,11 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, } // Builds Frames, indexed by name. - std::unordered_map frames; - for (Node* node : graph->op_nodes()) { - const ControlFlowInfo& cf = cf_info[node->id()]; - - VLOG(2) << "node: " << node->name() << " (" << node->id() - << ") frame_name: " << cf.frame_name - << " frame: " << (cf.frame ? cf.frame->name() : "---") - << " parent_frame: " - << (cf.parent_frame ? cf.parent_frame->name() : "---"); - TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); - - Frame& frame = frames[cf.frame_name]; - Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; - if (frame.parent == nullptr) { - frame.parent = parent; - frame.name = cf.frame_name; - ++parent->num_children; - } - - if (IsEnter(node)) { - Arg arg; - arg.enter = node; - TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", - &arg.is_loop_invariant)); - frame.args.push_back(arg); - } else if (IsLoopCond(node)) { - frame.loop_cond = node; - } - frame.nodes.insert(node); - } + std::unordered_map frames; + TF_RETURN_IF_ERROR(ExtractWhileLoopFrames(cf_info, graph, &frames)); // Adds frames with no children (i.e., the innermost frames) to a worklist. - std::deque worklist; + std::deque worklist; for (auto& frame : frames) { if (frame.second.num_children == 0) { worklist.push_back(&frame.second); @@ -654,7 +597,7 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, // Eliminate loops from innermost to outermost. while (!worklist.empty()) { - Frame* frame = worklist.front(); + WhileLoopFrame* frame = worklist.front(); worklist.pop_front(); if (frame->parent == frame) { // Skip the root frame. diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 139d6709215..d60b4ca0b2b 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -55,8 +55,8 @@ tf_kernel_library( "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", + "matrix_diag_ops.cc", "matrix_inverse_op.cc", - "matrix_set_diag_op.cc", "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", "next_after_op.cc", @@ -132,6 +132,8 @@ tf_kernel_library( ":if_op", ":tensor_list_utils", ":while_op", + "//tensorflow/compiler/jit:xla_activity_listener", + "//tensorflow/compiler/jit:xla_activity_proto_cc", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:broadcast", @@ -202,6 +204,7 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 747ec133983..1f12c7980e7 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/pooling.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -153,52 +155,5 @@ class DiagPartOp : public XlaOpKernel { REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp); -class MatrixDiagOp : public XlaOpKernel { - public: - explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, ctx->num_inputs() >= 1, - errors::InvalidArgument("MatrixDiag op must have at an input")); - const TensorShape input_shape = ctx->InputShape(0); - - auto dims = input_shape.dim_sizes(); - OP_REQUIRES(ctx, !dims.empty(), - errors::InvalidArgument("Expected 1 <= dims, got shape ", - input_shape.DebugString())); - - - int last_dim = dims.size() - 1; - int64 last_dim_size = input_shape.dim_size(last_dim); - absl::Span other_dims(dims); - other_dims.remove_suffix(1); - - xla::XlaOp input = ctx->Input(0); - xla::XlaOp diag = CreateDiagonal(input, last_dim_size, other_dims); - ctx->SetOutput(0, diag); - } -}; - -REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); - -class MatrixDiagPartOp : public XlaOpKernel { - public: - explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - auto dims = input_shape.dim_sizes(); - - OP_REQUIRES(ctx, 2 <= dims.size(), - errors::InvalidArgument("Expected 2 <= dims, got shape ", - input_shape.DebugString())); - - xla::XlaOp input = ctx->Input(0); - ctx->SetOutput(0, xla::GetMatrixDiagonal(input)); - } -}; - -REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index b309541a864..8e53ca162f5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -14,7 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/image_resize_ops.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "tensorflow/compiler/jit/xla_activity.pb.h" +#include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -255,6 +258,15 @@ xla::XlaOp ResizeUsingDilationAndConvolution( ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size, align_corners); + + if (dims.kernel_size[0] * dims.kernel_size[1] > + kMax2DKernelSize * kMax2DKernelSize) { + BroadcastOptimizationRemark( + XlaOptimizationRemark::SLOW_IMAGE_RESIZE_DIMENSIONS, + absl::StrFormat("%dx%d", dims.kernel_size[0], dims.kernel_size[1])) + .IgnoreError(); + } + xla::XlaOp output; // Concatenation and padding below currently assumes num_spatial_dims is 2 to diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc new file mode 100644 index 00000000000..7eeb05a4920 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -0,0 +1,425 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +// Reads or infers lower_diag_index and upper_diag_index from kernel's input +// parameter "k". Also validates their values. +std::pair ProcessDiagIndex(XlaOpKernelContext* context) { + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + TensorShape diag_index_shape = context->InputShape("k"); + + // Wrapping OP_REQUIRES* macros with a function because they can "return;" + // early (without values) which contradicts ProcessDiagIndex's signature. + auto validate_diag_indices = [&]() { + if (diag_index_shape.dims() == 0) { + OP_REQUIRES_OK(context, + context->ConstantInputAsIntScalar("k", &lower_diag_index)); + upper_diag_index = lower_diag_index; + } else { + std::vector diag_index; + OP_REQUIRES_OK(context, + context->ConstantInputAsIntVector("k", &diag_index)); + OP_REQUIRES( + context, !diag_index.empty() && diag_index.size() <= 2, + errors::InvalidArgument( + "diag_index must have only one or two elements, received ", + diag_index.size(), " elements.")); + lower_diag_index = diag_index[0]; + upper_diag_index = + (diag_index.size() > 1) ? diag_index[1] : lower_diag_index; + } + OP_REQUIRES( + context, lower_diag_index <= upper_diag_index, + errors::InvalidArgument( + "lower_diag_index must not be larger than upper_diag_index: ", + lower_diag_index, " > ", upper_diag_index)); + }; + validate_diag_indices(); + return {lower_diag_index, upper_diag_index}; +} + +// Makes sure lower_diag_index and upper_diag_index are consistent with the +// input matrix size. +void ValidateDiagIndexWithOutputMatrixSize(XlaOpKernelContext* context, + const int64 lower_diag_index, + const int64 upper_diag_index, + const int64 num_rows, + const int64 num_cols) { + // `lower_diag_index == 0` condition is added to handle matrix shape = 0. + OP_REQUIRES(context, + (-num_rows < lower_diag_index && lower_diag_index < num_cols) || + lower_diag_index == 0, + errors::InvalidArgument( + "lower_diag_index is out of bound: ", lower_diag_index, + " It must be between ", -num_rows, " and ", num_cols)); + OP_REQUIRES(context, + (-num_rows < upper_diag_index && upper_diag_index < num_cols) || + upper_diag_index == 0, + errors::InvalidArgument( + "upper_diag_index is out of bound: ", upper_diag_index, + " It must be between ", -num_rows, " and ", num_cols)); + OP_REQUIRES(context, lower_diag_index <= upper_diag_index, + errors::InvalidArgument( + "lower_diag_index must not be larger than upper_diag_index: ", + lower_diag_index, " > ", upper_diag_index)); +} + +// Kernel to set matrix diagonals. +xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag, + const TensorShape& input_shape, const int64 diag_rank, + const int64 num_diags, const int64 lower_diag_index, + const int64 upper_diag_index, const int64 max_diag_len, + const int64 num_rows, const int64 num_cols) { + // Creates a padding config. + const int input_rank = input_shape.dims(); + xla::PaddingConfig padding_config; + padding_config = xla::MakeNoPaddingConfig(input_rank - 1); + + // Processes one diagonal at a time: + // 1) Extracts a single diagonal (diag_slice). + // 2) Broadcasts its contents to fill the whole matrix (diag_broadcast). + // 3) Masks diag_broadcast to get the right diagonal shape. + // + // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow. + // + // For example, + // diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4. + // [4, 5, 6], + // [7, 8, 9]] + // The expected output is [[4, 2, 0], + // [7, 5, 4], + // [0, 8, 6], + // [0, 0, 9]] + // The 1st diagonal is created by: + // 1) Extracting diag_slice = [1, 2, 0]. + // 2) Padding the vector to be as long as num_rows, + // diag_slice = [1, 2, 0, 0], + // then broadcasting diag_slice row-wise to a full matrix, + // diag_broadcast = [[1, 1, 1], + // [2, 2, 2], + // [0, 0, 0], + // [0, 0, 0]] + // The padding value can be anything because it will not appear in the + // results after masking. Here, we use zero. + // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal. + // mask = [[0, 1, 0], --> output = [[x, 2, x], + // [0, 0, 1], [x, x, 3], + // [0, 0, 0], [x, x, x], + // [0, 0, 0]] [x, x, x]], + // where x denotes the existing input contents. + std::vector broadcast_dimensions(input_rank - 1); + absl::c_iota(broadcast_dimensions, 0); + auto output = input; + for (int64 diag_index = lower_diag_index; diag_index <= upper_diag_index; + ++diag_index) { + // Extracts a single diagonal. + auto diag_slice = diag; + if (num_diags > 1) { + const int64 mapped_diag_index = upper_diag_index - diag_index; + diag_slice = xla::Collapse( + xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1, + diag_rank - 2), + {diag_rank - 2, diag_rank - 1}); + } + + // Pads if necessary. Always pad at the end because shorter diagonals in + // the input come padded at the end. + const int64 padding_length = + ((diag_index <= 0) ? num_cols : num_rows) - max_diag_len; + const xla::XlaOp zero = xla::ScalarLike(input, 0); + if (padding_length > 0) { + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(padding_length); + diag_slice = xla::Pad(diag_slice, zero, padding_config); + } + + // Broadcasts column-wise for subdiagonals; row-wise for superdiagonals. + broadcast_dimensions.back() = + (diag_index <= 0) ? input_rank - 1 : input_rank - 2; + xla::XlaOp diag_broadcast = xla::BroadcastInDim( + diag_slice, input_shape.dim_sizes(), broadcast_dimensions); + const auto mask = xla::GetDiagonalMask(output, diag_index); + output = xla::Select(mask, diag_broadcast, output); + } + return output; +} + +} // namespace + +class MatrixDiagOp : public XlaOpKernel { + public: + explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at least one input")); + const TensorShape diag_shape = context->InputShape(0); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), + errors::InvalidArgument("Expected >= 1 dims, got shape ", + diag_shape.DebugString())); + + const DataType dtype = context->expected_output_dtype(0); + const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype); + + // Initializes MatrixDiagV2-specific variables. + // Input arguments providing the values of num_rows and num_cols can be + // absent (-1) and will be inferred later. + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + int64 num_rows = -1; + int64 num_cols = -1; + xla::XlaOp padding_value = zero; + + // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has + // one input, so we have to check the number of inputs before reading + // additional parameters for MatrixDiagV2. + if (context->num_inputs() > 1) { + std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols)); + padding_value = context->Input(4); + } + + // More size validations. + const int64 diag_rank = diag_shape.dims(); + const int64 max_diag_len = diag_shape.dim_size(diag_rank - 1); + const int64 num_diags = upper_diag_index - lower_diag_index + 1; + OP_REQUIRES( + context, + num_diags == 1 || num_diags == diag_shape.dim_size(diag_rank - 2), + errors::InvalidArgument( + "The number of diagonals provided in the input does not " + "match the lower_diag_index and upper_diag_index range.")); + const int64 min_num_rows = max_diag_len - std::min(upper_diag_index, 0LL); + const int64 min_num_cols = max_diag_len + std::max(lower_diag_index, 0LL); + OP_REQUIRES(context, num_rows == -1 || num_rows >= min_num_rows, + errors::InvalidArgument("The number of rows is too small.")); + OP_REQUIRES(context, num_cols == -1 || num_cols >= min_num_cols, + errors::InvalidArgument("The number of columns is too small.")); + + // Infers num_rows and num_cols. If both are unknown, assume that the output + // is square. Otherwise, use smallest possible values. + if (num_rows == -1 && num_cols == -1) { + num_rows = std::max(min_num_rows, min_num_cols); + num_cols = num_rows; + } else if (num_rows == -1) { + num_rows = min_num_rows; + } else if (num_cols == -1) { + num_cols = min_num_cols; + } + + // At least one of num_rows and num_cols must match its minimum length. + // Otherwise, we'll have some incomplete diagonals. + OP_REQUIRES(context, num_rows == min_num_rows || num_cols == min_num_cols, + errors::InvalidArgument( + "The number of rows or columns is not consistent with " + "the specified d_lower, d_upper, and diagonal.")); + + // Actual processing. + // Initializes the output tensor with padding_value. + TensorShape output_shape = diag_shape; + output_shape.RemoveLastDims((num_diags == 1) ? 1 : 2); + output_shape.AddDim(num_rows); + output_shape.AddDim(num_cols); + xla::XlaOp output = xla::Broadcast(padding_value, output_shape.dim_sizes()); + xla::XlaOp diag = context->Input(0); + context->SetOutput( + 0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags, + lower_diag_index, upper_diag_index, max_diag_len, + num_rows, num_cols)); + } +}; + +REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); +REGISTER_XLA_OP(Name("MatrixDiagV2") + .CompileTimeConstantInput("k") + .CompileTimeConstantInput("num_rows") + .CompileTimeConstantInput("num_cols") + .CompileTimeConstantInput("padding_value"), + MatrixDiagOp); + +class MatrixDiagPartOp : public XlaOpKernel { + public: + explicit MatrixDiagPartOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const int input_rank = input_shape.dims(); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + const DataType dtype = context->expected_output_dtype(0); + const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype); + + // Initializes MatrixDiagPartV2-specific variables. + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + xla::XlaOp padding_value = zero; + + // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel. + // MatrixDiagPart only has one input, so we have to check the number of + // inputs before reading additional parameters in MatrixDiagV2. + if (context->num_inputs() > 1) { + std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); + padding_value = context->Input(2); + } + + // Checks if diag sizes are consistent with input. + const int64 num_rows = input_shape.dim_size(input_rank - 2); + const int64 num_cols = input_shape.dim_size(input_rank - 1); + ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index, + upper_diag_index, num_rows, num_cols); + + // Creates output shape. + TensorShape output_shape = input_shape; + output_shape.RemoveLastDims(2); + const int num_diags = upper_diag_index - lower_diag_index + 1; + if (num_diags > 1) output_shape.AddDim(num_diags); + const int32 max_diag_len = + std::min(num_rows + std::min(upper_diag_index, 0LL), + num_cols - std::max(lower_diag_index, 0LL)); + output_shape.AddDim(max_diag_len); + + // Computes output. + xla::XlaOp input = context->Input(0); + std::vector diag_list; + xla::PaddingConfig padding_config; + if (num_diags == 1) { + context->SetOutput(0, xla::GetMatrixDiagonal(input, upper_diag_index)); + return; + } + padding_config = xla::MakeNoPaddingConfig(input_rank - 1); + for (int diag_index = upper_diag_index; diag_index >= lower_diag_index; + --diag_index) { + auto single_diag = xla::GetMatrixDiagonal(input, diag_index); + const int64 diag_length = + (diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index); + const int64 padding_length = max_diag_len - diag_length; + if (padding_length > 0) { + padding_config.mutable_dimensions(input_rank - 2) + ->set_edge_padding_high(padding_length); + single_diag = xla::Pad(single_diag, padding_value, padding_config); + } + diag_list.emplace_back(single_diag); + } + auto concat = + xla::ConcatInDim(context->builder(), diag_list, input_rank - 2); + context->SetOutput(0, xla::Reshape(concat, output_shape.dim_sizes())); + } +}; + +REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); +REGISTER_XLA_OP(Name("MatrixDiagPartV2") + .CompileTimeConstantInput("k") + .CompileTimeConstantInput("padding_value"), + MatrixDiagPartOp); + +class MatrixSetDiagOp : public XlaOpKernel { + public: + explicit MatrixSetDiagOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape diag_shape = context->InputShape(1); + const int input_rank = input_shape.dims(); + const int diag_rank = diag_shape.dims(); + + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), + errors::InvalidArgument( + "diagonal must be at least 1-dim, received shape: ", + diag_shape.DebugString())); + + // MatrixSetDiag and MatrixSetDiagV2 both use this OpKernel. MatrixSetDiag + // only has two inputs, so we have to check the number of inputs before + // reading additional parameters in MatrixSetDiagV2. + int64 lower_diag_index = 0; + int64 upper_diag_index = 0; + if (context->num_inputs() > 2) { + std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context); + } + + // Checks if diag sizes are consistent with input. + const int64 num_rows = input_shape.dim_size(input_rank - 2); + const int64 num_cols = input_shape.dim_size(input_rank - 1); + ValidateDiagIndexWithOutputMatrixSize(context, lower_diag_index, + upper_diag_index, num_rows, num_cols); + const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1; + OP_REQUIRES( + context, + lower_diag_index == upper_diag_index || + (diag_shape.dim_size(input_rank - 2) == num_diags), + errors::InvalidArgument("The number of diagonals provided in `diag` " + "is not consistent with `lower_diag_index` and " + "`upper_diag_index`")); + + TensorShape expected_diag_shape = input_shape; + expected_diag_shape.RemoveLastDims(2); + if (num_diags > 1) expected_diag_shape.AddDim(num_diags); + const int32 max_diag_len = + std::min(num_rows + std::min(upper_diag_index, 0LL), + num_cols - std::max(lower_diag_index, 0LL)); + expected_diag_shape.AddDim(max_diag_len); + OP_REQUIRES( + context, expected_diag_shape == diag_shape, + errors::InvalidArgument( + "Either first dimensions of diagonal don't match input.shape[:-2], " + "or diagonal.shape[:-1] is not equal to the longests diagonal in " + "range [lower_diag_index:upper_diag_index].\nInput shape: ", + input_shape.DebugString(), + "\nDiagonal shape: ", diag_shape.DebugString(), + "\nExpected diagonal shape: ", expected_diag_shape.DebugString())); + + // Actual processing. + xla::XlaOp input = context->Input(0); + xla::XlaOp diag = context->Input(1); + context->SetOutput( + 0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags, + lower_diag_index, upper_diag_index, max_diag_len, + num_rows, num_cols)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); +}; + +REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); +REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"), + MatrixSetDiagOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc deleted file mode 100644 index ee9764c0c35..00000000000 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/primitive_util.h" - -namespace tensorflow { - -class MatrixSetDiagOp : public XlaOpKernel { - public: - explicit MatrixSetDiagOp(OpKernelConstruction* context) - : XlaOpKernel(context) {} - - void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape(0); - const TensorShape diag_shape = context->InputShape(1); - - const int rank = input_shape.dims(); - - // Preliminary validation of sizes. - OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), - errors::InvalidArgument( - "input must be at least 2-dim, received shape: ", - input_shape.DebugString())); - - // Check to make sure the last dimension of diag is equal to the smaller of - // the last two dimensions of input. - const int64 m = input_shape.dim_size(rank - 2); - const int64 n = input_shape.dim_size(rank - 1); - const int64 min_dim = std::min(m, n); - - TensorShape batch_shape = input_shape; - batch_shape.RemoveLastDims(2); - - TensorShape expected_diag_shape = batch_shape; - expected_diag_shape.AddDim(min_dim); - OP_REQUIRES(context, expected_diag_shape == diag_shape, - errors::InvalidArgument( - "must have diagonal.shape == input.shape[:-2] + " - "min(input.shape[-2:]), but received input shape: ", - input_shape.DebugString(), - " and diagonal shape: ", diag_shape.DebugString())); - - xla::XlaBuilder* builder = context->builder(); - xla::XlaOp input = context->Input(0); - xla::XlaOp diag = context->Input(1); - - auto zero = XlaHelpers::Zero(builder, context->input_type(0)); - - // Create an indicator tensor that is true only on the diagonal. - xla::XlaOp iota_m = xla::Iota(builder, xla::S32, m); - xla::XlaOp iota_n = xla::Iota(builder, xla::S32, n); - auto indicator = xla::Eq(iota_m, xla::Broadcast(iota_n, {m}), - /*broadcast_dimensions=*/{0}); - indicator = xla::Broadcast(indicator, batch_shape.dim_sizes()); - - // Broadcast diag up to the input shape. Use an implicit broadcast (Add/Or) - // because we need to broadcast on the right. - std::vector diag_broadcast_dims(rank - 1); - std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0); - if (min_dim != m) { - diag_broadcast_dims.back() = rank - 1; - } - if (context->input_xla_type(0) == xla::PRED) { - diag = xla::Or(diag, xla::Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); - - } else { - diag = xla::Add(diag, xla::Broadcast(zero, input_shape.dim_sizes()), - /*broadcast_dimensions=*/diag_broadcast_dims); - } - - auto output = xla::Select(indicator, diag, input); - context->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); -}; - -REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 063b97cd593..905f83fef9a 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -47,6 +47,11 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_)); + // TODO(b/140109958): Implement for axis != -1. + OP_REQUIRES(ctx, axis_ == -1, + errors::Unimplemented("QuantizeAndDequantizeOp with axis >= 0 " + "not yet implemented for XLA")); round_mode_ = ROUND_HALF_TO_EVEN; } @@ -156,6 +161,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { protected: int64 num_bits_ = -1; + int axis_; bool signed_input_; bool range_given_; bool narrow_range_; diff --git a/tensorflow/compiler/tf2xla/kernels/roll_op.cc b/tensorflow/compiler/tf2xla/kernels/roll_op.cc index a6cc5960c90..99f4a5f46d7 100644 --- a/tensorflow/compiler/tf2xla/kernels/roll_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/roll_op.cc @@ -47,11 +47,8 @@ class RollOp : public XlaOpKernel { xla::PrimitiveType shift_type = ctx->input_xla_type(1); int64 num_axes = axis_shape.dims() == 0 ? 1 : axis_shape.dim_size(0); for (int64 i = 0; i != num_axes; ++i) { - auto cur_axis_status = axis_shape.dims() == 0 - ? axis.GetIntegralAsS64({}) - : axis.GetIntegralAsS64({i}); - OP_REQUIRES_OK(ctx, cur_axis_status.status()); - int64 cur_axis = cur_axis_status.ValueOrDie(); + int64 cur_axis = axis_shape.dims() == 0 ? *axis.GetIntegralAsS64({}) + : *axis.GetIntegralAsS64({i}); xla::XlaOp offset = shift_shape.dims() == 0 diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 265e7e784a9..88af12dacee 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -40,9 +40,23 @@ class ShapeOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); - OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); - ctx->SetConstantOutput(0, shape_constant); + std::vector operands; + const int rank = input_shape.dims(); + if (rank != 0) { + for (int64 i = 0; i < rank; ++i) { + operands.push_back(xla::Broadcast( + xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i), + ctx->output_xla_type(0)), + {1})); + } + + ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), operands, 0)); + } else { + // Rank 0 won't have dynamic size dimension, use constant output. + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } } private: diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index ac3d2c22d65..4af3d4233dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -307,6 +308,59 @@ class TensorListGetItemOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorListGetItem"), TensorListGetItemOp); +class TensorListGatherOp : public XlaOpKernel { + public: + explicit TensorListGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // Check that the TensorList is initialized. + bool is_initialized; + OP_REQUIRES_OK(ctx, + (IsTensorListInitialized(ctx->Input(0), &is_initialized))); + OP_REQUIRES(ctx, is_initialized, + errors::InvalidArgument("TensorList is not initialized")); + + // Only non-nested TensorList is supported for now. + bool is_nested; + OP_REQUIRES_OK(ctx, IsNestedTensorList(ctx->Input(0), &is_nested)); + OP_REQUIRES(ctx, !is_nested, + errors::Unimplemented("Only non-nested TensorList is supported " + "for TensorListGather.")); + + DataType indices_type = ctx->input_type(1); + + const TensorShape indices_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, indices_shape.dims() == 1, + errors::InvalidArgument("indices must be rank 1")); + + xla::XlaOp list = ctx->Input(0); + xla::XlaOp indices = ctx->Input(1); + + xla::XlaOp buffer; + OP_REQUIRES_OK(ctx, GetTensorListBuffer(list, &buffer)); + xla::Shape buffer_xla_shape; + OP_REQUIRES_OK(ctx, GetTensorListBufferShape(list, &buffer_xla_shape)); + TensorShape buffer_shape; + OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(buffer_xla_shape, &buffer_shape)); + + xla::XlaOp result; + OP_REQUIRES_OK( + ctx, XlaGather(buffer, buffer_shape, indices, indices_shape, /*axis=*/0, + /*indices_are_nd=*/false, dtype_, indices_type, + ctx->builder(), &result)); + ctx->SetOutput(0, result); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListGatherOp); +}; + +REGISTER_XLA_OP(Name("TensorListGather"), TensorListGatherOp); + class TensorListStackOp : public XlaOpKernel { public: explicit TensorListStackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 247db8d5d17..191ce9dee2b 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -270,6 +270,53 @@ class ResourceApplyAdagrad : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), ResourceApplyAdagrad); +class ResourceApplyAdagradV2 : public XlaOpKernel { + public: + explicit ResourceApplyAdagradV2(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(2); + + TensorShape var_shape, accum_shape; + xla::XlaOp var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); + + OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), + errors::InvalidArgument( + "var and accum do not have the same shape", + var_shape.DebugString(), " ", accum_shape.DebugString())); + + TensorShape lr_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), + errors::InvalidArgument("lr is not a scalar: ", + lr_shape.DebugString())); + + TensorShape epsilon_shape = ctx->InputShape(3); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon_shape), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon_shape.DebugString())); + + TensorShape grad_shape = ctx->InputShape(4); + OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape), + errors::InvalidArgument( + "var and grad do not have the same shape", + var_shape.DebugString(), " ", grad_shape.DebugString())); + + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp epsilon = ctx->Input(3); + xla::XlaOp grad = ctx->Input(4); + + accum = accum + xla::Square(grad); + var = var - grad * lr / (xla::Sqrt(accum) + epsilon); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); + } +}; +REGISTER_XLA_OP(Name("ResourceApplyAdagradV2").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagradV2); + class ResourceApplyProximalAdagrad : public XlaOpKernel { public: explicit ResourceApplyProximalAdagrad(OpKernelConstruction* ctx) diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index d348d2b41dd..1991e332be8 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -69,6 +69,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::U8: literal = xla::LiteralUtil::CreateR0(value); break; + case xla::U16: + literal = xla::LiteralUtil::CreateR0(value); + break; case xla::U32: literal = xla::LiteralUtil::CreateR0(value); break; @@ -78,6 +81,9 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, case xla::S8: literal = xla::LiteralUtil::CreateR0(value); break; + case xla::S16: + literal = xla::LiteralUtil::CreateR0(value); + break; case xla::S32: literal = xla::LiteralUtil::CreateR0(value); break; @@ -98,9 +104,6 @@ xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; - case xla::S16: - case xla::U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::BF16: literal = xla::LiteralUtil::CreateR0(static_cast(value)); diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD index 3cc551e08aa..eaba5d3c420 100644 --- a/tensorflow/compiler/tf2xla/python/BUILD +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -1,5 +1,5 @@ load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_py_clif_cc", ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index b376fe94743..b6f8928f31e 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -527,7 +527,7 @@ Status RearrangeFunctionArguments( // Rewrite If/While nodes. for (Node* n : g->nodes()) { - if (n->type_string() == "While") { + if (n->IsWhileNode()) { bool node_rewritten; TF_RETURN_IF_ERROR(MaybeRewriteWhileNode(get_function_body_fn, g, n, fld, &node_rewritten)); diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 1243e31a047..2db431c0413 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -57,6 +57,7 @@ CreateResourceOpInfoMap() { add("ResourceApplyAdaMax" , kReadWrite, kVariable); add("ResourceApplyAdadelta" , kReadWrite, kVariable); add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradV2" , kReadWrite, kVariable), add("ResourceApplyAdagradDA" , kReadWrite, kVariable); add("ResourceApplyAdam" , kReadWrite, kVariable); add("ResourceApplyAddSign" , kReadWrite, kVariable); diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 8aae498be10..4d5bf0835e1 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -53,7 +53,7 @@ xla::StatusOr> ParseShardingFromDevice( const string& device_name, int num_cores_per_replica, absl::optional explicit_sharding) { if (device_name.empty()) { - return absl::optional(); + return explicit_sharding; } DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index eebeec87b60..86d900363b8 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -28,6 +28,9 @@ const char kXlaHasHostTransferAttrName[] = "_xla_has_host_transfer"; const char kXlaReplicaIdAttrName[] = "_xla_replica_id"; +const char kXlaIsPlaceholderForTailOcAttrName[] = + "_xla_is_placeholder_for_tail_oc"; + Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { return errors::InvalidArgument("Node ", node->DebugString(), @@ -50,7 +53,7 @@ Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { node->ClearAttr(attr_name); node->AddAttr(attr_name, branch_func); } - } else if (node->type_string() == "While") { + } else if (node->IsWhileNode()) { AttrValue device_ordinal_value; device_ordinal_value.set_i(device_ordinal); for (const string& attr_name : std::vector{"cond", "body"}) { diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index be26ba5769c..31326044738 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -41,6 +41,9 @@ extern const char kXlaHasHostTransferAttrName[]; // This attribute is the replica id for an outside compilation node node. extern const char kXlaReplicaIdAttrName[]; +// This node is a Placeholder node added for tail outside compilation. +extern const char kXlaIsPlaceholderForTailOcAttrName[]; + // Sets device ordinal attribute for nodes with attribute // `kXlaHasHostTransferAttrName`. Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3e4188f3c6d..3c2b256800c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -384,8 +384,8 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef( &second_copy_def, *g->op_registry(), /*node_offset=*/0)); - TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(GraphConstructorOptions(), - second_copy_def, g.get())); + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( + GraphConstructorOptions(), std::move(second_copy_def), g.get())); TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, feed_remapping)); // Functionalize control flow. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 3e8b9eb79d8..e82546def46 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -765,7 +765,7 @@ Status PropagateConstIntoFunctionalNodes( for (Node* n : g->op_nodes()) { if (n->IsIfNode()) { TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld)); - } else if (n->type_string() == "While") { + } else if (n->IsWhileNode()) { TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld)); } } @@ -796,7 +796,7 @@ Status RewriteTensorListWithConstElement(Graph* g, // Find the forward While op. std::vector fwd_while_edges; for (const Edge* e : n->out_edges()) { - if (!e->IsControlEdge() && e->dst()->type_string() == "While") { + if (!e->IsControlEdge() && e->dst()->IsWhileNode()) { fwd_while_edges.push_back(e); } } @@ -810,8 +810,7 @@ Status RewriteTensorListWithConstElement(Graph* g, int fwd_while_dst_input = fwd_while_edges[0]->dst_input(); std::vector bwd_while_edges; for (const Edge* e : fwd_while->out_edges()) { - if (e->src_output() == fwd_while_dst_input && - e->dst()->type_string() == "While") { + if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) { bwd_while_edges.push_back(e); } } diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index c14519c3ade..06423019f23 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -98,6 +99,20 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, absl::optional op_sharding = sharding_parse_result.ValueOrDie(); + auto frontend_attributes_result = + GetFrontendAttributesFromAttrSlice(AttrSlice(op_kernel->def())); + OP_REQUIRES_OK(context, frontend_attributes_result.status()); + absl::optional attributes = + frontend_attributes_result.ValueOrDie(); + + xla::FrontendAttributes merged_attributes = b->frontend_attributes(); + if (attributes.has_value()) { + merged_attributes.mutable_map()->insert(attributes.value().map().begin(), + attributes.value().map().end()); + } + xla::XlaScopedFrontendAttributesAssignment assign_frontend_attributes( + b, std::move(merged_attributes)); + // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device // 0. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 2ee8c7e5cfb..cfb118281e4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/variant.h" #include "tensorflow/compiler/jit/flags.h" +#include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -76,41 +77,38 @@ Status CheckSignature(const DataTypeVector& types, return Status::OK(); } -// Uses the _Arg and _Retval nodes in the graph to determine a core assignment -// for each argument and return value. -xla::StatusOr, std::map>> -ComputeArgAndRetvalCores(const Graph& graph) { - auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr { +// Uses the _Arg and _Retval nodes in the graph to determine an OpSharding for +// each argument and return value. +xla::StatusOr< + std::pair, std::map>> +ComputeArgAndRetvalShardings(const Graph& graph) { + auto get_sharding_for_node = + [](const Node* n) -> xla::StatusOr> { TF_ASSIGN_OR_RETURN( auto sharding, ParseShardingFromDevice(*n, std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == xla::OpSharding::MAXIMAL); - return sharding.value().tile_assignment_devices(0); - } else { - return -1; - } + return sharding; }; - std::map arg_cores; - std::map retval_cores; + std::map arg_shardings; + std::map retval_shardings; for (const Node* n : graph.nodes()) { if (n->IsArg()) { - TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); - if (core < 0) continue; + TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n)); + if (!sharding.has_value()) continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0) << "Negative _Arg index"; - arg_cores[index] = core; + arg_shardings[index] = std::move(*sharding); } else if (n->IsRetval()) { - TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); - if (core < 0) continue; + TF_ASSIGN_OR_RETURN(auto sharding, get_sharding_for_node(n)); + if (!sharding.has_value()) continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0) << "Negative _Retval index"; - retval_cores[index] = core; + retval_shardings[index] = std::move(*sharding); } } - return std::make_pair(std::move(arg_cores), std::move(retval_cores)); + return std::make_pair(std::move(arg_shardings), std::move(retval_shardings)); } Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, @@ -144,8 +142,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // - `args` is the list of input arguments // - `retvals` is the list of retvals produced by _Retval operators, in index // order. -// - `args_core` and `retval_cores` are mapping from arg/return indices to core -// assignments. +// - `arg_shardings` and `retval_shardings` are mapping from arg/return indices +// to sharding. // - If `return_updated_values_for_all_resources` is true, all resources will be // included in `resource_updates`, regardless of whether their value changed. // - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. @@ -158,7 +156,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& retvals, - const std::map& arg_cores, const std::map& retval_cores, + const std::map& arg_shardings, + const std::map& retval_shardings, const std::vector>& resources, std::unique_ptr token_output, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, @@ -212,19 +211,20 @@ Status BuildComputation( output.is_constant = false; TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); xla::XlaOp value = retval.handle(); - auto it = retval_cores.find(i); + auto it = retval_shardings.find(i); xla::XlaScopedShardingAssignment assign_sharding( - builder, it == retval_cores.end() + builder, it == retval_shardings.end() ? absl::optional() - : xla::sharding_builder::AssignDevice(it->second)); + : it->second); if (shape_representation_fn) { // If there is a shape representation function, reshape the output // tensor to the shape given by the representation shape function. TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( - output.shape, output.type)); + output.shape, output.type, + /*use_fast_memory=*/false)); value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); retval_index_and_layout.emplace_back(elems.size(), shape.layout()); - } else if (it != retval_cores.end()) { + } else if (it != retval_shardings.end()) { // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); } @@ -265,8 +265,7 @@ Status BuildComputation( for (const XlaResource* resource : arg_resources) { DCHECK_LT(resource->arg_num(), args.size()); const XlaCompiler::Argument& arg = args[resource->arg_num()]; - auto it = arg_cores.find(resource->arg_num()); - const int core = it == arg_cores.end() ? -1 : it->second; + auto it = arg_shardings.find(resource->arg_num()); bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. @@ -289,8 +288,8 @@ Status BuildComputation( // Request that the value be returned on a specific core. xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); + builder, it == arg_shardings.end() ? absl::optional() + : it->second); xla::XlaOp handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); @@ -303,7 +302,8 @@ Status BuildComputation( if (shape_representation_fn) { TF_ASSIGN_OR_RETURN( xla::Shape xla_shape, - shape_representation_fn(resource->shape(), resource->type())); + shape_representation_fn(resource->shape(), resource->type(), + /*use_fast_memory=*/false)); representation_shape = xla_shape; } if (resource->representation_shape().has_value()) { @@ -479,8 +479,8 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { options_.shape_representation_fn = - [](const TensorShape& shape, - DataType dtype) -> xla::StatusOr { + [](const TensorShape& shape, DataType dtype, + bool use_fast_memory) -> xla::StatusOr { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); return xla_shape; @@ -532,6 +532,11 @@ Status XlaCompiler::FindFunctionBody(const NameAttrList& function, std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); + + // Performs a first function inlining pass before shape inference, since + // otherwise shape inference can't see inside functions and a comprehensive + // shape_map, including function ops, is needed to constant-propagate Shape + // Ops below. auto flags = GetBuildXlaOpsPassFlags(); OptimizerOptions opts; opts.set_opt_level(OptimizerOptions::L0); @@ -570,6 +575,28 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, graph_optimizer_options); + // Run shape inference on the graph and optimize the graph again. + GraphShapeInfo shape_info; + InferShapes(graph.get(), /*arg_shapes=*/{}, + flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) + .IgnoreError(); + auto node_name_index = graph->BuildNodeNameIndex(); + std::unordered_map> shape_map; + for (const auto& node_shape_info : shape_info) { + const string& node_name = node_shape_info.first; + const std::vector& output_shapes = node_shape_info.second; + const auto& node_iter = node_name_index.find(node_name); + if (node_iter != node_name_index.end()) { + auto& partial_shapes = shape_map[node_name]; + for (const auto& inferred_shape : output_shapes) { + partial_shapes.push_back(inferred_shape.shape); + } + } + } + graph_optimizer_options.shape_map = &shape_map; + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), + /*device=*/nullptr, &graph, graph_optimizer_options); + return graph; } @@ -596,6 +623,33 @@ Status XlaCompiler::CompileFunction( CheckSignature(fbody->arg_types, args), "Signature check failure while compiling: ", fn_name_attrs.name()); + // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an + // Xla op requires a compile-time constant input, and that input is shape of + // an _Arg node. + for (int i = 0; i < args.size(); i++) { + // Skip resource variables and tensor lists. + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype)); + if (dtype == DT_RESOURCE || dtype == DT_VARIANT) { + continue; + } + + if (absl::holds_alternative(args[i].shape)) { + xla::Shape xla_shape = absl::get(args[i].shape); + TensorShape tensor_shape; + if (XLAShapeToTensorShape(xla_shape, &tensor_shape).ok()) { + fbody->arg_nodes[i]->ClearAttr("_output_shapes"); + fbody->arg_nodes[i]->AddAttr("_output_shapes", + std::vector{tensor_shape}); + } + } else { + TensorShape tensor_shape = absl::get(args[i].shape); + fbody->arg_nodes[i]->ClearAttr("_output_shapes"); + fbody->arg_nodes[i]->AddAttr("_output_shapes", + std::vector{tensor_shape}); + } + } + std::unique_ptr graph = GetGraph(fbody); // Clear the "_kernel" attribute if it is set to "host". This is used to @@ -604,7 +658,7 @@ Status XlaCompiler::CompileFunction( const char* const kKernelAttr = "_kernel"; for (Node* n : graph->nodes()) { string value; - if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { + if (TryGetNodeAttr(n->attrs(), kKernelAttr, &value) && value == "host") { n->ClearAttr(kKernelAttr); } } @@ -659,8 +713,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, TF_RETURN_IF_ERROR( XLAShapeToTensorShape(absl::get(arg.shape), &shape)); } - TF_ASSIGN_OR_RETURN(*xla_shape, - options_.shape_representation_fn(shape, arg.type)); + TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( + shape, arg.type, + /*use_fast_memory=*/false)); } else { if (absl::holds_alternative(arg.shape)) { *xla_shape = absl::get(arg.shape); @@ -684,7 +739,8 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, TF_RET_CHECK(absl::holds_alternative(arg.shape)); TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( - absl::get(arg.shape), arg.type)); + absl::get(arg.shape), arg.type, + /*use_fast_memory=*/false)); return Status::OK(); } @@ -742,7 +798,7 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - const std::map& arg_cores, + const std::map& arg_shardings, std::vector* arg_expressions, std::vector* input_to_args, std::vector* input_shapes, bool is_entry_computation) { @@ -833,10 +889,10 @@ Status XlaCompiler::BuildArguments( xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::TUPLE); for (int64 parameter : *input_to_args) { - auto it = arg_cores.find(parameter); - const int core = it == arg_cores.end() ? 0 : it->second; + auto it = arg_shardings.find(parameter); *tuple_sharding.add_tuple_shardings() = - xla::sharding_builder::AssignDevice(core); + it == arg_shardings.end() ? xla::sharding_builder::AssignDevice(0) + : it->second; } std::vector is_same_across_replicas; for (int i = 0; i < input_to_args->size(); ++i) { @@ -867,20 +923,18 @@ Status XlaCompiler::BuildArguments( } for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { - auto it = arg_cores.find(i); - const int core = it == arg_cores.end() ? -1 : it->second; + auto it = arg_shardings.find(i); xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); + builder, it == arg_shardings.end() ? absl::optional() + : it->second); arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { - auto it = arg_cores.find(i); - const int core = it == arg_cores.end() ? -1 : it->second; + auto it = arg_shardings.find(i); xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); + builder, it == arg_shardings.end() ? absl::optional() + : it->second); if (is_entry_computation) { // Add an entry to is_same_across_replicas for every leaf buffer. std::vector is_same_across_replicas( @@ -1155,16 +1209,16 @@ Status XlaCompiler::CompileGraph( real_args.push_back(token_arg); } - std::map arg_cores; - std::map retval_cores; - TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), - ComputeArgAndRetvalCores(*graph)); + std::map arg_shardings; + std::map retval_shardings; + TF_ASSIGN_OR_RETURN(std::tie(arg_shardings, retval_shardings), + ComputeArgAndRetvalShardings(*graph)); std::vector arg_expressions; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, - &arg_expressions, &result->input_mapping, &result->xla_input_shapes, - options.is_entry_computation)); + *graph, real_args, options.use_tuple_arg, &builder, context, + arg_shardings, &arg_expressions, &result->input_mapping, + &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); // Propagate any aliases given to us by the user. @@ -1233,7 +1287,7 @@ Status XlaCompiler::CompileGraph( ConvertConstantsToExpressions(&builder, absl::Span(retvals)); } TF_RETURN_IF_ERROR(BuildComputation( - real_args, retvals, arg_cores, retval_cores, context->resources(), + real_args, retvals, arg_shardings, retval_shardings, context->resources(), std::move(token_output), options.is_entry_computation ? options_.shape_representation_fn : ShapeRepresentationFn{}, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 1cc5d8d4728..98c487c9973 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -286,7 +286,8 @@ class XlaCompiler { std::shared_ptr computation; }; - typedef std::function(const TensorShape&, DataType)> + typedef std::function(const TensorShape&, DataType, + bool)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. @@ -446,7 +447,7 @@ class XlaCompiler { const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - const std::map& arg_cores, + const std::map& arg_shardings, std::vector* arg_expressions, std::vector* input_to_args, std::vector* input_shapes, diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 34b785754b9..4413625dc3c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -304,7 +304,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) { auto options = DefaultOptions(); options.shape_representation_fn = - [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + [](const TensorShape& shape, DataType dt, + bool use_fast_memory) -> xla::StatusOr { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); @@ -357,7 +358,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { auto options = DefaultOptions(); options.shape_representation_fn = - [](const TensorShape& shape, DataType dt) -> xla::StatusOr { + [](const TensorShape& shape, DataType dt, + bool use_fast_memory) -> xla::StatusOr { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape)); *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); @@ -1080,7 +1082,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) { auto options = DefaultOptions(); // Sets the representation function to return a non-default layout. options.shape_representation_fn = - [](const TensorShape& shape, DataType type) -> xla::StatusOr { + [](const TensorShape& shape, DataType type, + bool use_fast_memory) -> xla::StatusOr { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); @@ -1118,7 +1121,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) { auto options = DefaultOptions(); // Sets the representation function to return a non-default layout. options.shape_representation_fn = - [](const TensorShape& shape, DataType type) -> xla::StatusOr { + [](const TensorShape& shape, DataType type, + bool use_fast_memory) -> xla::StatusOr { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1}); @@ -1252,7 +1256,8 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); options.shape_representation_fn = - [](const TensorShape& shape, DataType type) -> xla::StatusOr { + [](const TensorShape& shape, DataType type, + bool use_fast_memory) -> xla::StatusOr { xla::PrimitiveType ptype; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); @@ -1322,7 +1327,8 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); options.shape_representation_fn = - [](const TensorShape& shape, DataType type) -> xla::StatusOr { + [](const TensorShape& shape, DataType type, + bool use_fast_memory) -> xla::StatusOr { xla::PrimitiveType ptype; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype)); return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()}); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 6996e39ba16..c95cd4e5475 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -415,7 +415,8 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, ctx->compiler()->options().shape_representation_fn( - variable->shape(), variable->type())); + variable->shape(), variable->type(), + /*use_fast_memory=*/false)); xla::Shape xla_shape; TF_RETURN_IF_ERROR( TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape)); @@ -550,9 +551,10 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); - TF_ASSIGN_OR_RETURN( - xla::Shape representation_shape, - ctx->compiler()->options().shape_representation_fn(shape, type)); + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, + ctx->compiler()->options().shape_representation_fn( + shape, type, + /*use_fast_memory=*/false)); xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape)); if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index b11e43a74d0..fa51753aa45 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,19 +47,20 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array kNumericTypes = { - {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { + {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, + DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, + DT_BFLOAT16}}; -constexpr std::array kCpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, - DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, + DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; -constexpr std::array kGpuAllTypes = { - {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, - DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, + DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index eeb598b165b..9066fb7e1e3 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -1,7 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "cc_header_only_library") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library_py", ) @@ -12,6 +12,7 @@ package( package_group( name = "friends", + includes = ["//tensorflow:internal"], packages = [ "//tensorflow/compiler/...", "//tensorflow/contrib/tpu/...", @@ -62,6 +63,7 @@ cc_library( hdrs = ["bit_cast.h"], visibility = [":friends"], deps = [ + ":types", "//tensorflow/core:lib", "//third_party/eigen3", "@com_google_absl//absl/base", diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index f9c93707f7a..029a2e0081f 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -3,4 +3,5 @@

XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear -algebra that optimizes TensorFlow computations. See the [documentation](./g3doc/overview.md). +algebra that optimizes TensorFlow computations. See the +[documentation](./g3doc/index.md). diff --git a/tensorflow/compiler/xla/bit_cast.h b/tensorflow/compiler/xla/bit_cast.h index c9edd7417eb..90e9a5c25dd 100644 --- a/tensorflow/compiler/xla/bit_cast.h +++ b/tensorflow/compiler/xla/bit_cast.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/base/casts.h" #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index acf59c47f3c..b46d04dc328 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -296,6 +296,8 @@ cc_library( srcs = ["slicing.cc"], hdrs = ["slicing.h"], deps = [ + ":arithmetic", + ":constants", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h index 03ebe4e0098..203b67082bd 100644 --- a/tensorflow/compiler/xla/client/lib/constants.h +++ b/tensorflow/compiler/xla/client/lib/constants.h @@ -62,12 +62,16 @@ XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { return ConstantR0(builder, static_cast(value)); case U8: return ConstantR0(builder, static_cast(value)); + case U16: + return ConstantR0(builder, static_cast(value)); case U32: return ConstantR0(builder, static_cast(value)); case U64: return ConstantR0(builder, static_cast(value)); case S8: return ConstantR0(builder, static_cast(value)); + case S16: + return ConstantR0(builder, static_cast(value)); case S32: return ConstantR0(builder, static_cast(value)); case S64: diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 3d15101ea66..ad525e69289 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/lib/math.h" + // This macro is required to make MSVC defines math constants in math.h #define _USE_MATH_DEFINES #include -#include "tensorflow/compiler/xla/client/lib/math.h" - #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -26,6 +26,21 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" namespace xla { +namespace { + +// Evaluate the polynomial given `x` and coefficients in decreasing order. +template +XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { + static_assert(std::is_floating_point::value, + "Template-argument 'FP' must be a floating-point type"); + XlaOp poly = ScalarLike(x, 0.0); + for (FP c : coefficients) { + poly = poly * x + ScalarLike(x, c); + } + return poly; +} + +} // namespace // Returns operation(operand), except if `operand` is one of the types in // upcast_types, in which case first converts it to F32, and then converts the @@ -134,88 +149,132 @@ XlaOp Square(XlaOp operand) { return operand * operand; } XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } -// Evaluate the polynomial given coefficients and `x`. -// N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients) { - XlaOp poly = ScalarLike(x, 0.0); - for (float c : coefficients) { - poly = poly * x + ScalarLike(x, c); - } - return poly; -} - // Computes an approximation of the error function complement (1 - erf(x)). // // Precondition: abs(x) >= 1. Otherwise, use ErfImpl. // -// This follows Cephes's f32 implementation of erfc, and so it may have errors -// for double precision. -// -// See also these alternate implementations of erf and erfc: -// -// https://stackoverflow.com/questions/35148198 -// https://stackoverflow.com/questions/35966695 -// -static XlaOp ErfcImpl(XlaOp x) { +// This follows Cephes's f32 implementation of erfc. +static XlaOp ErfcImpl32(XlaOp x) { // Coefficients for erfc(f32), from Cephes. - // - // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2 - static std::array kErfcPCoefficient{ + const double kMaxlog = 88.72283905206835; + // erfc(x) = exp(-x^2) P(1/x^2), 1 < x < 2 + static const std::array kErfcPCoefficient{ +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, }; - // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14 - static std::array kErfcRCoefficient{ + // erfc(x) = exp(-x^2) R(1/x^2), 2 <= x < kMaxlog + static const std::array kErfcRCoefficient{ -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, -2.820767439740514E-1, +5.641895067754075E-1, }; - XlaOp abs_x = Abs(x); XlaOp z = Exp(-x * x); XlaOp q = ScalarLike(x, 1) / abs_x; XlaOp y = q * q; XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)), - EvaluatePolynomial(y, kErfcPCoefficient), - EvaluatePolynomial(y, kErfcRCoefficient)); + EvaluatePolynomial(y, kErfcPCoefficient), + EvaluatePolynomial(y, kErfcRCoefficient)); y = z * q * p; - return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y); + XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y); + return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp); } // Compute a polynomial approximation of the error function. // // Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. // -// This follows Cephes's f32 implementation of erf, so it may have errors for -// double precision. -static XlaOp ErfImpl(XlaOp x) { +// This follows Cephes's f32 implementation of erf. +static XlaOp ErfImpl32(XlaOp x) { // Coefficients for by erf(f32), from Cephes. // // erf(x) = x P(x^2), 0 < x < 1 - static std::array kErfTCoefficient{ + static const std::array kErfTCoefficient{ +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, +1.128379165726710E+0, }; + return x * EvaluatePolynomial(x * x, kErfTCoefficient); +} - return x * EvaluatePolynomial(x * x, kErfTCoefficient); +static XlaOp ErfcImpl64(XlaOp x) { + // Coefficients for erfc(f64), from Cephes. + const double kMaxlog = 7.09782712893383996843E2; + // erfc(x) = exp(-x^2) P(|x|) / Q(|x|), 1 < x < 8 + static const std::array kErfcPCoefficient{ + 2.46196981473530512524E-10, 5.64189564831068821977E-1, + 7.46321056442269912687E0, 4.86371970985681366614E1, + 1.96520832956077098242E2, 5.26445194995477358631E2, + 9.34528527171957607540E2, 1.02755188689515710272E3, + 5.57535335369399327526E2}; + static const std::array kErfcQCoefficient{ + 1.00000000000000000000E0, 1.32281951154744992508E1, + 8.67072140885989742329E1, 3.54937778887819891062E2, + 9.75708501743205489753E2, 1.82390916687909736289E3, + 2.24633760818710981792E3, 1.65666309194161350182E3, + 5.57535340817727675546E2}; + + // erfc(x) = exp(-x^2) R(|x|) / S(|x|), 8 <= x < kMaxlog + static const std::array kErfcRCoefficient{ + 5.64189583547755073984E-1, 1.27536670759978104416E0, + 5.01905042251180477414E0, 6.16021097993053585195E0, + 7.40974269950448939160E0, 2.97886665372100240670E0}; + static const std::array kErfcSCoefficient{ + 1.00000000000000000000E0, 2.26052863220117276590E0, + 9.39603524938001434673E0, 1.20489539808096656605E1, + 1.70814450747565897222E1, 9.60896809063285878198E0, + 3.36907645100081516050E0}; + + XlaOp z = -x * x; + XlaOp abs_x = Abs(x); + XlaOp y = + Select(Lt(abs_x, ScalarLike(x, 8.0)), + Exp(z) * EvaluatePolynomial(abs_x, kErfcPCoefficient) / + EvaluatePolynomial(abs_x, kErfcQCoefficient), + Exp(z) * EvaluatePolynomial(abs_x, kErfcRCoefficient) / + EvaluatePolynomial(abs_x, kErfcSCoefficient)); + XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y); + return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp); +} + +// Compute a polynomial approximation of the error function. +// +// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. +static XlaOp ErfImpl64(XlaOp x) { + // Coefficients for by erf(f64), from Cephes. + // + // erf(x) = x T(x^2) / U(x^2), 0 < x < 1 + static std::array kErfTCoefficient{ + 9.60497373987051638749E0, 9.00260197203842689217E1, + 2.23200534594684319226E3, 7.00332514112805075473E3, + 5.55923013010394962768E4}; + static std::array kErfUCoefficient{ + 1.00000000000000000000E0, 3.35617141647503099647E1, + 5.21357949780152679795E2, 4.59432382970980127987E3, + 2.26290000613890934246E4, 4.92673942608635921086E4}; + XlaOp z = x * x; + return x * EvaluatePolynomial(z, kErfTCoefficient) / + EvaluatePolynomial(z, kErfUCoefficient); } XlaOp Erfc(XlaOp x) { auto& b = *x.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x)); - + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); // erfc(x) = // erfc_impl(x) if x > 1 // 1 - erf_impl(x) otherwise - // + if (shape.element_type() == F64) { + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl64(x), + ScalarLike(x, 1) - ErfImpl64(x)); + } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) { - return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x), - ScalarLike(x, 1) - ErfImpl(x)); + return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), + ScalarLike(x, 1) - ErfImpl32(x)); }); }); } @@ -224,15 +283,19 @@ XlaOp Erf(XlaOp x) { auto& b = *x.builder(); return b.ReportErrorOrReturn([&]() -> StatusOr { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x)); + TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x)); // erf(x) = // erf_impl(x) if x < 1 // 1 - erfc_impl(x) otherwise - // + if (shape.element_type() == F64) { + return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl64(x), + ScalarLike(x, 1) - ErfcImpl64(x)); + } // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. - return DoWithUpcastToF32(x, {BF16}, [](XlaOp x) { - return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x), - ScalarLike(x, 1) - ErfcImpl(x)); + return DoWithUpcastToF32(x, {BF16, F16}, [](XlaOp x) { + return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl32(x), + ScalarLike(x, 1) - ErfcImpl32(x)); }); }); } diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h index 89a58aa3970..57e50e56fa7 100644 --- a/tensorflow/compiler/xla/client/lib/math.h +++ b/tensorflow/compiler/xla/client/lib/math.h @@ -43,10 +43,6 @@ XlaOp Square(XlaOp operand); // Computes the reciprocal of 'operand'. XlaOp Reciprocal(XlaOp operand); -// Evaluates a polynomial given coefficients and 'x'. -// N.B. Coefficients should be supplied in decreasing order. -XlaOp EvaluatePolynomial(XlaOp x, absl::Span coefficients); - // Computes an approximation of the error function complement (1 - erf(x)). XlaOp Erfc(XlaOp x); diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index d4bc560b03f..f10342a8bf8 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/util.h" @@ -138,18 +140,54 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, }); } -XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim) { +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse) { XlaBuilder* builder = input.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); - ShapeUtil::AppendMajorDimension(1, &index_shape); - std::vector to_concat; TF_ASSIGN_OR_RETURN(Shape input_shape, builder->GetShape(input)); if (ShapeUtil::ElementHasBitWidth(index_shape, 64) && input_shape.dimensions(dim) < std::numeric_limits::max()) { index = ConvertElementType(index, U32); index_shape.set_element_type(U32); } + if (index_shape.rank() == 1) { + return TorchIndexSelect(input, index, 0); + } + if (!sparse) { + std::vector index_broacast_dims; + std::vector input_broacast_dims; + std::vector sizes; + for (int64 i = 0; i < index_shape.rank(); ++i) { + if (i < dim) { + input_broacast_dims.push_back(i); + index_broacast_dims.push_back(i); + } else if (i == dim) { + sizes.push_back(input_shape.dimensions(i)); + input_broacast_dims.push_back(i); + index_broacast_dims.push_back(i + 1); + } else { + input_broacast_dims.push_back(i + 1); + index_broacast_dims.push_back(i + 1); + } + sizes.push_back(index_shape.dimensions(i)); + } + auto mask = Eq( + BroadcastInDim(index, sizes, index_broacast_dims), + Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes), + dim)); + auto masked_input = Select( + mask, BroadcastInDim(input, sizes, input_broacast_dims), + Zeros(builder, + ShapeUtil::MakeShape(input_shape.element_type(), sizes))); + return Reduce(masked_input, Zero(builder, input_shape.element_type()), + CreateScalarIdentityWithZeroComputation( + input_shape.element_type(), builder), + {dim}); + } + + ShapeUtil::AppendMajorDimension(1, &index_shape); + std::vector to_concat; + to_concat.reserve(input_shape.rank()); for (int64 i = 0; i < input_shape.rank(); ++i) { if (i == dim) { diff --git a/tensorflow/compiler/xla/client/lib/slicing.h b/tensorflow/compiler/xla/client/lib/slicing.h index 89ec1fe510e..9a59a048b9f 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.h +++ b/tensorflow/compiler/xla/client/lib/slicing.h @@ -55,7 +55,7 @@ XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, // [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size // [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as // `index`. -XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim); +XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true); // Returns a new tensor which indexes the input tensor along dimension dim using // the entries in index. diff --git a/tensorflow/compiler/xla/client/lib/slicing_test.cc b/tensorflow/compiler/xla/client/lib/slicing_test.cc index 04d3f96b6a5..107cbae0a73 100644 --- a/tensorflow/compiler/xla/client/lib/slicing_test.cc +++ b/tensorflow/compiler/xla/client/lib/slicing_test.cc @@ -102,7 +102,7 @@ XLA_TEST_F(SlicingTest, SimpleSliceUpdate) { {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); } -XLA_TEST_F(SlicingTest, TorchGather) { +XLA_TEST_F(SlicingTest, TorchGatherSparse) { xla::XlaBuilder builder(TestName()); xla::XlaOp input, index; @@ -116,6 +116,20 @@ XLA_TEST_F(SlicingTest, TorchGather) { {input_data.get(), index_data.get()}); } +XLA_TEST_F(SlicingTest, TorchGatherDense) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp input, index; + auto input_data = + CreateR2Parameter({{1, 2}, {3, 4}}, 0, "input", &builder, &input); + auto index_data = + CreateR2Parameter({{0, 0}, {1, 0}}, 1, "index", &builder, &index); + TorchGather(input, index, 1, false); + + ComputeAndCompareR2(&builder, {{1, 1}, {4, 3}}, + {input_data.get(), index_data.get()}); +} + XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) { xla::XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1bd9d7b7228..153cb9f5212 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -176,12 +176,13 @@ StatusOr LocalExecutable::Run( ExecutableRunOptions run_options) { TF_ASSIGN_OR_RETURN(auto options_and_stream, RunHelper(arguments, run_options)); - - if (executable_->dumping_snapshot()) { - return ExecuteAndDump(&options_and_stream.first, arguments); - } - return executable_->ExecuteOnStreamWrapper( - &options_and_stream.first, run_options.execution_profile(), arguments); + ExecutableRunOptions options = options_and_stream.first.run_options(); + options.set_device_ordinal(-1); + auto result = RunAsync(arguments, options); + Status block_status = options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(block_status); + return result; } StatusOr LocalExecutable::RunAsync( @@ -189,50 +190,49 @@ StatusOr LocalExecutable::RunAsync( ExecutableRunOptions run_options) { TF_ASSIGN_OR_RETURN(auto options_and_stream, RunHelper(arguments, run_options)); - return executable_->ExecuteAsyncOnStream(&options_and_stream.first, - arguments); -} + se::Stream* stream = run_options.stream(); -StatusOr LocalExecutable::ExecuteAndDump( - const ServiceExecutableRunOptions* run_options, - const absl::Span arguments) { - executable_->hlo_snapshot()->set_execution_platform( - backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result, - executable_->ExecuteOnStream(run_options, arguments, - /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); - DumpHloSnapshotIfEnabled(executable_->module(), *executable_->hlo_snapshot()); - return std::move(result); -} - -Status LocalExecutable::RecordArguments( - const absl::Span arguments, - HloSnapshot* hlo_snapshot) { - hlo_snapshot->clear_arguments(); - for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*argument)); - *hlo_snapshot->add_arguments() = literal.ToProto(); + std::shared_ptr snapshot; + if (executable_->dumping_snapshot()) { + snapshot = std::make_shared(); + snapshot->set_execution_platform(backend_->platform()->Name()); + *snapshot->mutable_hlo() = *executable_->hlo_proto(); + for (const ShapedBuffer* arg : arguments) { + auto literal = std::make_shared(arg->on_host_shape()); + backend_->transfer_manager()->TransferLiteralFromDevice( + stream, *arg, literal.get(), [snapshot, literal](Status status) { + if (!status.ok()) { + LOG(ERROR) << "TransferLiteralFromDevice for HLO snapshot inputs " + "failed: " + << status; + return; + } + *snapshot->add_arguments() = literal->ToProto(); + }); + } } - return Status::OK(); -} -Status LocalExecutable::RecordResult(const ShapedBuffer* result, - HloSnapshot* hlo_snapshot) { - hlo_snapshot->clear_result(); - TF_ASSIGN_OR_RETURN(Literal literal, LiteralFromShapedBuffer(*result)); - *hlo_snapshot->mutable_result() = literal.ToProto(); - return Status::OK(); -} + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer outputs, + executable_->ExecuteAsyncOnStreamWrapper( + &options_and_stream.first, arguments)); -StatusOr LocalExecutable::LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer) { - TF_ASSIGN_OR_RETURN(auto stream, - backend_->BorrowStream(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice(stream.get(), - shaped_buffer); + // Transfer the outputs and save the snapshot to disk. + if (snapshot) { + auto literal = std::make_shared(outputs.on_host_shape()); + backend_->transfer_manager()->TransferLiteralFromDevice( + stream, outputs, literal.get(), [snapshot, literal](Status status) { + if (status.ok()) { + *snapshot->mutable_result() = literal->ToProto(); + } else { + LOG(ERROR) + << "TransferLiteralFromDevice for HLO snapshot outputs failed: " + << status; + } + DumpHloSnapshotIfEnabled(*snapshot, GetDebugOptionsFromFlags()); + }); + } + + return std::move(outputs); } se::Platform* LocalClient::platform() const { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 1e7c97d6f06..b697fb031fd 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -72,23 +72,6 @@ class LocalExecutable { const absl::Span arguments, const ExecutableRunOptions& run_options, const Backend& backend); - // Records the computation in a SessionModule proto with the arguments used to - // invoke it, and the result. Enabled by flag: --xla_dump_hlo_snapshots. - // - // The given ServiceExecutableRunOptions override any values from the - // XLA_FLAGS environment variable. - StatusOr ExecuteAndDump( - const ServiceExecutableRunOptions* run_options, - const absl::Span arguments); - - // Records the arguments used to invoke the computation in a SessionModule - // proto. - Status RecordArguments(const absl::Span arguments, - HloSnapshot* hlo_snapshot); - - // Records the result of the computation in a SessionModule proto. - Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); - // Returns a literal containing the contents of the given ShapedBuffer. StatusOr LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 318d5f3be35..dccdec22fb9 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -289,6 +289,15 @@ Status XlaBuilder::SetDynamicBinding(int64 dynamic_size_param_num, return Status::OK(); } +Status XlaBuilder::SetInstructionFrontendAttribute(const XlaOp op, + std::string attribute, + std::string value) { + TF_ASSIGN_OR_RETURN(auto instr_proto, LookUpMutableInstruction(op)); + auto* frontend_attributes = instr_proto->mutable_frontend_attributes(); + (*frontend_attributes->mutable_map())[attribute] = std::move(value); + return Status::OK(); +} + XlaComputation XlaBuilder::BuildAndNoteError() { DCHECK(parent_builder_ != nullptr); auto build_status = Build(); @@ -702,6 +711,12 @@ XlaOp XlaBuilder::BroadcastInDim( // not necessarily the same as the dimension sizes of the output shape. auto output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), out_dim_size); + if (operand_shape.rank() != broadcast_dimensions.size()) { + return InvalidArgument( + "Size of broadcast_dimensions has to match operand's rank; operand " + "rank: %lld, size of broadcast_dimensions %u.", + operand_shape.rank(), broadcast_dimensions.size()); + } for (int i = 0; i < broadcast_dimensions.size(); i++) { if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > out_dim_size.size()) { @@ -1028,6 +1043,11 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) { "Operand to GetTupleElement() is not a tuple; got %s", ShapeUtil::HumanString(tuple_shape)); } + if (index < 0 || index >= ShapeUtil::TupleElementCount(tuple_shape)) { + return InvalidArgument( + "GetTupleElement() index (%d) out of range for tuple shape %s", index, + ShapeUtil::HumanString(tuple_shape)); + } *instr.mutable_shape() = ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto(); @@ -1204,8 +1224,9 @@ XlaOp XlaBuilder::ConvGeneralDilated( rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } TF_ASSIGN_OR_RETURN(*instr.mutable_window(), - MakeWindow(window_dimensions, window_strides, padding, - lhs_dilation, rhs_dilation)); + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding, + lhs_dilation, rhs_dilation)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferConvolveShape( @@ -1226,60 +1247,6 @@ XlaOp XlaBuilder::ConvGeneralDilated( }); } -StatusOr XlaBuilder::MakeWindow( - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation) const { - const auto verify_size = [&](const size_t x, const char* x_name) { - if (x == 0 || x == window_dimensions.size()) { - return Status::OK(); - } else { - return InvalidArgument( - "%s", absl::StrCat( - "Window has different number of window dimensions than of ", - x_name, - "\nNumber of window dimensions: ", window_dimensions.size(), - "\nNumber of ", x_name, ": ", x, "\n")); - } - }; - TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); - TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); - TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); - TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); - - Window window; - for (size_t i = 0; i < window_dimensions.size(); i++) { - auto dim = window.add_dimensions(); - dim->set_size(window_dimensions[i]); - if (!window_strides.empty()) { - dim->set_stride(window_strides[i]); - } else { - dim->set_stride(1); - } - if (!padding.empty()) { - dim->set_padding_low(padding[i].first); - dim->set_padding_high(padding[i].second); - } else { - dim->set_padding_low(0); - dim->set_padding_high(0); - } - if (!lhs_dilation.empty()) { - dim->set_base_dilation(lhs_dilation[i]); - } else { - dim->set_base_dilation(1); - } - if (!rhs_dilation.empty()) { - dim->set_window_dilation(rhs_dilation[i]); - } else { - dim->set_window_dilation(1); - } - dim->set_window_reversal(false); - } - return window; -} - XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const absl::Span fft_length) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1739,9 +1706,11 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes) { + absl::Span slice_sizes, + bool indices_are_sorted) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_indices_are_sorted(indices_are_sorted); TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape, @@ -1764,9 +1733,11 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices, XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers) { + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; + instr.set_indices_are_sorted(indices_are_sorted); TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, @@ -1952,9 +1923,10 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), - MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/base_dilations, - /*rhs_dilation=*/window_dilations)); + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding, + /*lhs_dilation=*/base_dilations, + /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape( operand_shape, init_shape, instr.window(), to_apply_shape)); @@ -2199,8 +2171,9 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding( TF_ASSIGN_OR_RETURN(const ProgramShape& scatter_shape, scatter.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), - MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSelectAndScatterShape( operand_shape, select_shape, instr.window(), @@ -2662,6 +2635,7 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, if (sharding_) { *instr.mutable_sharding() = *sharding_; } + *instr.mutable_frontend_attributes() = frontend_attributes_; handle_to_index_[handle] = instructions_.size(); instructions_.push_back(std::move(instr)); @@ -2719,32 +2693,67 @@ void XlaBuilder::AddCalledComputation(const XlaComputation& computation, } } -StatusOr XlaBuilder::LookUpInstruction( - const XlaOp& op) const { - TF_RETURN_IF_ERROR(first_error_); +namespace { - if (op.builder_ == nullptr) { +template +StatusOr LookUpInstructionByHandleInternal( + const absl::flat_hash_map& handle_to_index, + const std::vector& instructions, int64 handle) { + auto it = handle_to_index.find(handle); + if (it == handle_to_index.end()) { + return InvalidArgument("No XlaOp with handle %d", handle); + } + return const_cast(&instructions.at(it->second)); +} + +template +StatusOr LookUpInstructionInternal( + const absl::flat_hash_map& handle_to_index, + const std::vector& instructions, + OpBuilderType op_builder, BuilderType builder, OpType op_handle) { + if (op_builder == nullptr) { return InvalidArgument( "invalid XlaOp with handle %d; the builder of this op is freed", - op.handle()); + op_handle); } - if (op.builder_ != this) { + if (op_builder != builder) { return InvalidArgument( "XlaOp with handle %d is built by builder '%s', but is trying to use " "it in builder '%s'", - op.handle(), op.builder_->name(), this->name()); + op_handle, op_builder->name(), builder->name()); } - return LookUpInstructionByHandle(op.handle()); + return LookUpInstructionByHandleInternal( + handle_to_index, instructions, op_handle); +} + +} // namespace + +StatusOr XlaBuilder::LookUpInstruction( + const XlaOp op) const { + TF_RETURN_IF_ERROR(first_error_); + return LookUpInstructionInternal( + handle_to_index_, instructions_, op.builder_, this, op.handle()); } StatusOr XlaBuilder::LookUpInstructionByHandle( int64 handle) const { - auto it = handle_to_index_.find(handle); - if (it == handle_to_index_.end()) { - return InvalidArgument("No XlaOp with handle %d", handle); - } - return &instructions_[it->second]; + return LookUpInstructionByHandleInternal( + handle_to_index_, instructions_, handle); +} + +StatusOr XlaBuilder::LookUpMutableInstruction( + const XlaOp op) { + TF_RETURN_IF_ERROR(first_error_); + return LookUpInstructionInternal( + handle_to_index_, instructions_, op.builder_, this, op.handle()); +} + +StatusOr XlaBuilder::LookUpMutableInstructionByHandle( + int64 handle) { + return LookUpInstructionByHandleInternal( + handle_to_index_, instructions_, handle); } // Enqueues a "retrieve parameter value" instruction for a parameter that was @@ -3361,16 +3370,18 @@ XlaOp ReducePrecision(const XlaOp operand, const int exponent_bits, XlaOp Gather(const XlaOp input, const XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes) { + absl::Span slice_sizes, bool indices_are_sorted) { return input.builder()->Gather(input, start_indices, dimension_numbers, - slice_sizes); + slice_sizes, indices_are_sorted); } XlaOp Scatter(const XlaOp input, const XlaOp scatter_indices, const XlaOp updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers) { + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted) { return input.builder()->Scatter(input, scatter_indices, updates, - update_computation, dimension_numbers); + update_computation, dimension_numbers, + indices_are_sorted); } void Send(const XlaOp operand, const ChannelHandle& handle) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 89e8be7de1e..5c28e8b5150 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -147,8 +147,8 @@ class XlaBuilder { // Sets OpMetadata that will be added to all instructions until cleared. // // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same + // result, OpMetadata is set on the computation builder. All subsequent + // instructions generated via this computation builder will have the same // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } @@ -158,6 +158,35 @@ class XlaBuilder { // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + // Sets the FrontendAttributes that will be added to all instructions until + // cleared. + // + // FrontendAttributes are often applied to a series of XLA HLO instructions. + // As a result they are set on the computation builder and all the + // instructions generated via the computation builder will have the same + // frontend attributes attached to them. + void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) { + frontend_attributes_ = frontend_attributes; + } + + // Swap the passed FrontendAttributes with the ones currently set. + // + // Return the old attributes. + FrontendAttributes SwapFrontendAttributes( + const FrontendAttributes& frontend_attributes) { + FrontendAttributes old_attributes = std::move(frontend_attributes_); + frontend_attributes_ = frontend_attributes; + return old_attributes; + } + + // Returns the FrontendAttributes that will be attached to all instructions. + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + + // Clears all the frontend attributes. + void ClearFrontendAttributes() { frontend_attributes_.Clear(); } + // Clears the sharding. Ops will be sharded according to the default placement // policy. void ClearSharding() { sharding_ = absl::nullopt; } @@ -314,6 +343,16 @@ class XlaBuilder { ShapeIndex param_index; }; + // Looks up the HloInstruction and sets the frontend attribute "attribute" to + // "value". + // + // If the attribute already existed then its value is updated. + // + // Note: the attribute is only added to the HloInstruction, not to the + // builder. + Status SetInstructionFrontendAttribute(XlaOp op, string attribute, + string value); + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); @@ -547,11 +586,13 @@ class XlaBuilder { XlaOp Gather(const XlaOp& input, const XlaOp& start_indices, const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes); + absl::Span slice_sizes, + bool indices_are_sorted = false); XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, const XlaOp& updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers); + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false); void Send(const XlaOp& operand, const ChannelHandle& handle); XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token, @@ -593,9 +634,11 @@ class XlaBuilder { void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); - StatusOr LookUpInstruction(const XlaOp& op) const; + StatusOr LookUpInstruction(XlaOp op) const; StatusOr LookUpInstructionByHandle( int64 handle) const; + StatusOr LookUpMutableInstruction(XlaOp op); + StatusOr LookUpMutableInstructionByHandle(int64 handle); // Internal helper method that does the building for an arbitrary unary op. XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand); @@ -649,14 +692,6 @@ class XlaBuilder { const Shape& lhs_shape, const Shape& rhs_shape, const ConvolutionDimensionNumbers& dimension_numbers) const; - // Helper function for creating a Window proto from user-supplied data. - // Returns error if the user-supplied data was invalid. - StatusOr MakeWindow(absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation) const; - int64 GetNextId() { return ++next_id_; } // Populates the module with the input/output alias information stored within @@ -713,6 +748,8 @@ class XlaBuilder { XlaBuilder* parent_builder_{nullptr}; + FrontendAttributes frontend_attributes_; + friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape, const string& name, const std::vector& replicated_at_leaf_buffers); @@ -968,10 +1005,12 @@ class XlaBuilder { const int mantissa_bits); friend XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes); + absl::Span slice_sizes, + bool indices_are_sorted); friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers); + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted); friend void Send(XlaOp operand, const ChannelHandle& handle); friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); @@ -1038,6 +1077,27 @@ class XlaScopedShardingAssignment { absl::optional prev_sharding_; }; +// RAII-style object: save the current builder's frontend attributes, and merge +// them with the new ones on construction. +// Restore the original attributes on destruction. +class XlaScopedFrontendAttributesAssignment { + public: + XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, + FrontendAttributes attributes) + : builder_(builder) { + saved_ = builder_->SwapFrontendAttributes(attributes); + } + + ~XlaScopedFrontendAttributesAssignment() { + builder_->SetFrontendAttributes(saved_); + } + + private: + xla::XlaBuilder* const builder_; + FrontendAttributes saved_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment); +}; // Free functions for building XlaOps. The intention is that these will // become the public API for building XlaOps rather than calling methods on // XlaBuilder directly. @@ -1802,12 +1862,14 @@ XlaOp ReducePrecision(XlaOp operand, const int exponent_bits, // Enqueues a Gather node onto the computation. XlaOp Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes); + absl::Span slice_sizes, + bool indices_are_sorted = false); // Enqueues a Scatter node onto the computation. XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers); + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false); // Enqueues a Send node onto the computation for device-to-device // communication. This operation sends the given operand to diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 12656a89943..701729b94f3 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -978,5 +978,151 @@ TEST_F(XlaBuilderTest, CheckInputOutputAlias) { EXPECT_EQ(*alias_p1, ShapeIndex({0})); } +void ExpectAttributesMatch(const FrontendAttributes& attr, + const FrontendAttributes& ref) { + EXPECT_EQ(ref.map_size(), attr.map_size()); + for (auto reference : ref.map()) { + auto other = attr.map().find(reference.first); + EXPECT_NE(other, attr.map().end()); + EXPECT_EQ(other->second, reference.second); + } +} + +void ExpectInstructionsAttributesMatch( + const HloModule& module, const std::vector& expected) { + ASSERT_EQ(module.computation_count(), 1); + auto expected_it = expected.begin(); + for (auto inst : module.entry_computation()->instructions()) { + ASSERT_NE(expected_it, expected.end()); + ExpectAttributesMatch(inst->frontend_attributes(), *expected_it); + expected_it++; + } + EXPECT_EQ(expected_it, expected.end()); +} + +TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) { + XlaBuilder b(TestName()); + FrontendAttributes attributes; + + ConstantR0(&b, 0); // No attribute set + + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_a": "a" } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + + std::vector expected{FrontendAttributes(), attributes, + FrontendAttributes()}; + ExpectInstructionsAttributesMatch(*module, expected); +} + +TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) { + XlaBuilder b(TestName()); + + ConstantR0(&b, 0); // No attribute set. + std::vector expected{FrontendAttributes()}; + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_a": "a" } + expected.push_back(attributes); + } + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_b"] = "b"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // One attribute: { "attr_b": "b" } + expected.push_back(attributes); + } + + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_b"] = "b"; + (*attributes.mutable_map())["attr_c"] = "c"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" } + expected.push_back(attributes); + } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + ExpectInstructionsAttributesMatch(*module, expected); +} + +TEST_F(XlaBuilderTest, AddFrontendAttribute) { + XlaBuilder b(TestName()); + + ConstantR0(&b, 0); + std::vector expected{FrontendAttributes()}; + + // One attribute: { "attr_a": "a" } + { + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + b.SetFrontendAttributes(attributes); + ConstantR0(&b, 0); + expected.push_back(attributes); + } + + // Two attributes: {"attra": "a", "attr_c": "c"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c")); + + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + (*attributes.mutable_map())["attr_c"] = "c"; + expected.push_back(attributes); + } + + // Override value of existing "attr_a" + // One attribute: { "attr_a", "a2"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2")); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a2"; + expected.push_back(attributes); + } + + // Check "attr_a" is back to its original value + // One attribute: { "attr_a", "a"} + { + auto op = ConstantR0(&b, 0); + (void)op; + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_a"] = "a"; + expected.push_back(attributes); + } + + b.ClearFrontendAttributes(); + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + // One attribute: { "attr_d", "d"} + { + auto op = ConstantR0(&b, 0); + EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d")); + FrontendAttributes attributes; + (*attributes.mutable_map())["attr_d"] = "d"; + expected.push_back(attributes); + } + + ConstantR0(&b, 0); // No attribute set + expected.push_back(FrontendAttributes()); + + TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); + ExpectInstructionsAttributesMatch(*module, expected); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 45f9cbe4ce8..13173e0dbc8 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -149,6 +149,12 @@ static void AllocateFlags() { return true; }; + // Custom "sub-parser" lambda for xla_gpu_ptx_file. + auto setter_for_xla_gpu_ptx_file = [](string value) { + flag_values->add_xla_gpu_ptx_file(value); + return true; + }; + // Custom "sub-parser" lambda for xla_backend_extra_options. auto setter_for_xla_backend_extra_options = [](string comma_separated_values) { @@ -244,6 +250,13 @@ static void AllocateFlags() { "When xla_cpu_enable_fast_math is true then this controls whether " "we forbid to use multiplication by the reciprocal instead of " "division. Ignored when xla_cpu_enable_fast_math is false."), + tensorflow::Flag( + "xla_cpu_fast_math_honor_functions", + bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions), + flag_values->xla_cpu_fast_math_honor_functions(), + "When xla_cpu_enable_fast_math is true then this controls whether " + "we forbid to approximate calculations for functions. Ignored when " + "xla_cpu_enable_fast_math is false."), tensorflow::Flag( "xla_gpu_enable_fast_min_max", bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max), @@ -342,6 +355,13 @@ static void AllocateFlags() { int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor), flag_values->xla_gpu_max_kernel_unroll_factor(), "Specify the maximum kernel unroll factor for the GPU backend."), + tensorflow::Flag("xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "", + "If non-empty, speficies a file containing ptx to use. " + "The filename prefix must have the same pattern as PTX " + "dumped by XLA. This allows to match one specific " + "module. General workflow. Get the generated module " + "ptx from XLA. Modify it. Then pass it back via this " + "option."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), @@ -508,6 +528,12 @@ static void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw), flag_values->xla_gpu_force_conv_nchw(), "For cuDNN convolutions, always NCHW layouts."), + tensorflow::Flag("xla_gpu_algorithm_blacklist_path", + string_setter_for( + &DebugOptions::set_xla_gpu_algorithm_blacklist_path), + flag_values->xla_gpu_algorithm_blacklist_path(), + "An AlgorithmBlacklist text proto file as a blacklist " + "of convolutions to avoid to use."), }); ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects); } diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index dafc3345555..7d225e1240c 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -11,16 +11,16 @@ upper_tabs: lower_tabs: # Subsite tabs other: - - name: Guide & Tutorials + - name: Overview contents: - - title: XLA overview - path: /xla/overview + - title: Overview + path: /xla + - title: XLA architecture + path: /xla/architecture - title: Broadcasting semantics path: /xla/broadcasting - title: Developing a new backend for XLA path: /xla/developing_new_backend - - title: Using JIT compilation - path: /xla/jit - title: Operation semantics path: /xla/operation_semantics - title: Shapes and layout @@ -32,6 +32,8 @@ upper_tabs: - title: Writing custom calls path: /xla/custom_call - heading: Tutorials + - title: XLA autoclustering + path: /xla/tutorials/autoclustering_xla - title: XLA compile API path: /xla/tutorials/xla_compile status: experimental diff --git a/tensorflow/compiler/xla/g3doc/_index.yaml b/tensorflow/compiler/xla/g3doc/_index.yaml deleted file mode 100644 index 858de427119..00000000000 --- a/tensorflow/compiler/xla/g3doc/_index.yaml +++ /dev/null @@ -1,35 +0,0 @@ -book_path: /xla/_book.yaml -project_path: /xla/_project.yaml -description: -landing_page: - custom_css_path: /site-assets/css/style.css - rows: - - heading: XLA is a compiler that optimizes TensorFlow computations. - items: - - classname: devsite-landing-row-50 - description: > - XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear - algebra that optimizes TensorFlow computations. The results are - improvements in speed, memory usage, and portability on server and mobile - platforms. The XLA framework is experimental and in active development. - For details, read the XLA guide. - - - classname: devsite-landing-row-cards - items: - - heading: XLA - TensorFlow, compiled - image_path: /resources/images/tf-logo-card-16x9.png - path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html - buttons: - - label: Read on Google Developers blog - path: https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html - - heading: XLA at the Dev Summit - youtube_id: kAOanJczHA0 - buttons: - - label: Watch the video - path: https://www.youtube.com/watch?v=kAOanJczHA0 - - heading: XLA on GitHub - image_path: /resources/images/github-card-16x9.png - path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla - buttons: - - label: View on GitHub - path: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla diff --git a/tensorflow/compiler/xla/g3doc/overview.md b/tensorflow/compiler/xla/g3doc/architecture.md similarity index 75% rename from tensorflow/compiler/xla/g3doc/overview.md rename to tensorflow/compiler/xla/g3doc/architecture.md index d3428b72761..f9be646c441 100644 --- a/tensorflow/compiler/xla/g3doc/overview.md +++ b/tensorflow/compiler/xla/g3doc/architecture.md @@ -1,25 +1,9 @@ -# XLA Overview +# XLA Architecture
-> Note: XLA is still under development. Some use cases will not -> see improvements in speed or decreased memory usage. - -XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear -algebra that optimizes TensorFlow computations. The results are improvements in -speed, memory usage, and portability on server and mobile platforms. Initially, -most users will not see large benefits from XLA, but are welcome to experiment -by using XLA via [just-in-time (JIT) compilation](./jit.md) or -[ahead-of-time (AOT) compilation](./tfcompile.md). Developers targeting new -hardware accelerators are especially encouraged to try out XLA. - -The XLA framework is experimental and in active development. In particular, -while it is unlikely that the semantics of existing operations will change, it -is expected that more operations will be added to cover important use cases. The -team welcomes feedback from the community about missing functionality and -community contributions via GitHub. ## Why did we build XLA? @@ -91,8 +75,3 @@ code from this LLVM IR. The GPU backend currently supports NVIDIA GPUs via the LLVM NVPTX backend; the CPU backend supports multiple CPU ISAs. - -## Supported Platforms - -XLA currently supports [JIT compilation](./jit.md) on x86-64 and NVIDIA GPUs; and -[AOT compilation](./tfcompile.md) for x86-64 and ARM. diff --git a/tensorflow/compiler/xla/g3doc/custom_call.md b/tensorflow/compiler/xla/g3doc/custom_call.md index acc2c9a92f5..7837f0aefaf 100644 --- a/tensorflow/compiler/xla/g3doc/custom_call.md +++ b/tensorflow/compiler/xla/g3doc/custom_call.md @@ -128,8 +128,8 @@ using xla::ShapeUtil; Shape p0_shape = ShapeUtil::MakeTuple({ ShapeUtil::MakeShape(F32, {32}), ShapeUtil::MakeTuple({ - ShapeUtil::MakeTuple(F32, {64}), - ShapeUtil::MakeTuple(F32, {128}), + ShapeUtil::MakeShape(F32, {64}), + ShapeUtil::MakeShape(F32, {128}), }), ShapeUtil::MakeShape(F32, {256}), }); @@ -197,133 +197,18 @@ subbuffers of `output_tuple` are accessible by dereferencing `out`. ### Tuples in GPU custom-calls In GPU code, we have a function `do_custom_call(..., void** buffers, ...)`. In -this case `buffers` is a host array of *nine* device pointers, one for each -nested buffer. To generate the flat list, we iterate over the parameters and -output, and then do preorder traversal of their shapes. Concretely: +this case `buffers` is a host array of *six* device pointers, one for each leaf +buffer in the input/output. To generate the flat list, we iterate over the +parameters and output, and for each we do a preorder traversal of its shape. +Concretely: ```c++ // Layout of `buffers` parameter to GPU custom call function for custom-call // above. -buffers[0] == param0 -buffers[1] == subbuf0 or null -buffers[2] == subtuple or null -buffers[3] == subbuf1 or null -buffers[4] == subbuf2 or null -buffers[5] == subbuf3 or null -buffers[6] == output_tuple -buffers[7] == output_subbuf0 -buffers[8] == output_subbuf1 +buffers[0] == subbuf0 +buffers[1] == subbuf1 +buffers[2] == subbuf2 +buffers[3] == subbuf3 +buffers[4] == output_subbuf0 +buffers[5] == output_subbuf1 ``` - -The `or null` part is significant. A sub-buffer of an input tuple will be -non-null in the `buffers` list if XLA is able to statically analyze the program -and figure out the address of the sub-buffer. This is usually the case, but may -not be in programs with control flow and/or `select` ops over tuples. - -A correct custom-call implementation that accepts a tuple as input must always -handle null input sub-buffers, by dereferencing the root tuple. - -The rule is reversed for output buffers. The output sub-buffers will always be -populated, but it's up to the custom call to populate the root tuple at the end. - -See the following code. Note that we leave out CUDA error handling for clarity, -but you'll be thankful if you do it, because otherwise it can be hard to tell -when a stream encounters an error. - -```c++ -void do_custom_call(CUstream stream, void** buffers, const char* opaque, - size_t opaque_len) { - bool needs_sync = false; - const float* subbuf0 = reinterpret_cast(buffers[1]); - if (subbuf0 == nullptr) { - needs_sync = true; - cudaMemcpyAsync(&subbuf0, buffers[0], sizeof(void*), - cudaMemcpyDeviceToHost, stream); - } - const void** subtuple = reinterpret_cast(buffers[2]); - if (subtuple == nullptr) { - needs_sync = true; - cudaMemcpyAsync(&subtuple, buffers[2], ...); - } - - // ... similarly for other params ... - - // Wait for copies enqueued above to complete. - if (needs_sync) { - cudaStreamSynchronize(stream); - } - needs_sync = false; - - // Now that we have `subtuple`, we can get subbuf1 and subbuf2. - float* subbuf1 = buffers[3]; - if (subbuf1 == nullptr) { - needs_sync = true; - cudaMemcpyAsync(&subbuf1, subtuple, ...); - } - float* subbuf2 = buffers[4]; - if (subbuf2 == nullptr) { - needs_sync = true; - cudaMemcpyAsync(&subbuf2, subtuple + 1, ...); - } - - // Wait for copies enqueued above to complete. - if (needs_sync) { - cudaStreamSynchronize(stream); - } - - // ... actually run the kernel ... - - // Fill the output tuple. - void* outputs[2] = {buffers[7], buffers[8]}; - cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice, - stream); - - // Necessary to force the cudaMemcpyAsync above to complete before `outputs` - // goes out of scope. A sync is only necessary in the tuple output case, and - // see below for a way to avoid this. - cudaStreamSynchronize(stream); -} -``` - -The `cudaStreamSynchronize` at the end of the function is unfortunate, as it's -not required in the non-tuple-output case, and it can be expensive. One way to -get around this would be to make `outputs` into a global variable and ensure -that the previous cudaMemcpyAsync completed before overwriting the global and -enqueueing another one. This is sketched below. - -``` -void do_custom_call(CUstream stream, void** buffers, const char* opaque, - size_t opaque_len) { - - // ... Beginning of function is the same as above ... - - // ... actually run the kernel ... - - static std::atomic first_time{true}; - static CUevent event; - static void* outputs[2]; - if (first_time.fetch_and(false)) { - // First time running this function. Initialize `event`. - cuEventCreate(&event, CU_EVENT_DISABLE_TIMING); - } else { - // Not first time running this function. Wait for previous event to - // complete before touching `outputs`. - cuEventSynchronize(event); - } - - // Fill the output tuple. - outputs[0] = buffers[7]; - outputs[1] = buffers[8]; - cudaMemcpyAsync(buffers[6], outputs, sizeof(outputs), cudaMemcpyHostToDevice, - stream); - - // Unblock `event` after the memcpy completes. - cuEventRecord(event, stream); -} -``` - -This simple implementation would limit parallelism if you want to run this op on -multiple GPUs concurrently (or on one GPU with multiple streams); in that case -you might need multiple events and globals. We have seen one implementation of -this algorithm which keeps a pool of globals and events and periodically polls -them (perhaps on each call to the op) to garbage collect. diff --git a/tensorflow/compiler/xla/g3doc/images/jit_cpu_xla_graph.png b/tensorflow/compiler/xla/g3doc/images/jit_cpu_xla_graph.png deleted file mode 100644 index 4e2dc091fee..00000000000 Binary files a/tensorflow/compiler/xla/g3doc/images/jit_cpu_xla_graph.png and /dev/null differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_gpu_xla_graph.png b/tensorflow/compiler/xla/g3doc/images/jit_gpu_xla_graph.png deleted file mode 100644 index 39d7c90c4fc..00000000000 Binary files a/tensorflow/compiler/xla/g3doc/images/jit_gpu_xla_graph.png and /dev/null differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu.png deleted file mode 100644 index a38f636983b..00000000000 Binary files a/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu.png and /dev/null differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu_xla.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu_xla.png deleted file mode 100644 index 285c3a96d5a..00000000000 Binary files a/tensorflow/compiler/xla/g3doc/images/jit_timeline_cpu_xla.png and /dev/null differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu.png deleted file mode 100644 index 488fc2c2f10..00000000000 Binary files a/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu.png and /dev/null differ diff --git a/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu_xla.png b/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu_xla.png deleted file mode 100644 index d0df38cf181..00000000000 Binary files a/tensorflow/compiler/xla/g3doc/images/jit_timeline_gpu_xla.png and /dev/null differ diff --git a/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png b/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png new file mode 100644 index 00000000000..70087f5747c Binary files /dev/null and b/tensorflow/compiler/xla/g3doc/images/tf_xla_performance.png differ diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md new file mode 100644 index 00000000000..c3b708d6907 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -0,0 +1,168 @@ +# XLA: Optimizing Compiler for TensorFlow + +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear +algebra that accelerates TensorFlow models with potentially no source code +changes. + +The results are improvements in speed and memory usage: most internal benchmarks +run ~1.15x faster after XLA is enabled. The dataset below is evaluated on a +single NVidia V100 GPU: + +
+ +
+ +## Introduction + +When a TensorFlow program is run, all of the operations are executed +individually by the TensorFlow executor. Each TensorFlow operation has a +precompiled GPU kernel implementation that the executor dispatches to. + +XLA provides an alternative mode of running TF models: it compiles the +TensorFlow graph into a sequence of computation kernels generated specifically +for the given model. Because these kernels are unique to the model, they can +exploit model-specific information for optimization. For example, let's look at +an optimization XLA does in the context of a simple TensorFlow computation: + +``` +def model_fn(x, y, z): + return tf.reduce_sum(x + y * z) +``` + +Run without XLA, the graph launches three kernels: one for the multiplication, +one for the addition and one for the reduction. However, XLA can optimize the +graph so that it computes the result in a single kernel launch. It does this by +"fusing" the addition, multiplication and reduction into a single GPU kernel. +Moreover, this fused operation does not write out the intermediate values +produced by `y*z` and `x+y*z` to memory; instead it "streams" the results of +these intermediate computations directly to their users while keeping them +entirely in GPU registers. Fusion is XLA's single most important optimization. +Memory bandwidth is typically the scarcest resource on hardware accelerators, so +removing memory operations is one of the best ways to improve performance. + +## Enable XLA for TensorFlow models + +### Auto-clustering + +A simplest way to start using XLA in TensorFlow models is to enable +_auto-clustering_, which automatically finds _clusters_ (connected subgraphs) +within the TensorFlow graph which can be compiled and executed using XLA. +Auto-clustering on GPU can be enabled by either modifying the `TF_XLA_FLAGS` +environment variable: + +``` +$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program +``` + +Or by setting a configuration value within the program: + +``` +import tensorflow as tf + +tf.config.optimizer_set_jit(True) + +# ... the rest of your program ... +``` + +Note: The JIT level is cached for a session, and can only be set in the very +beginning of the program. In order to change it midway through, the session +needs to be cleared: `tf.keras.backend.clear_session()` + +Auto-clustering is currently optimized for GPU workloads, but it can also be +enabled on CPU by additionally using the flag `--tf_xla_cpu_global_jit`: + +``` +$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program +``` + +For a detailed usage example, see the +[auto-clustering tutorial colab](./tutorials/autoclustering_xla.ipynb). + +### Use `xla.compile` + +The `xla.compile` API offers a more fine-grained control for choosing which +functions should be compiled with XLA. However, it requires restructuring source +code, as not all TensorFlow operations can be represented in XLA. That is, when +using `xla.compile` you pass it the functions which should be compiled using +XLA; a failure to compile results in an exception. + +See the [`xla.compile` tutorial colab](./tutorials/xla_compile.ipynb) for usage +examples. + +### AOT (Ahead-of-time) compilation for CPU with `tfcompile` + +You can also use a standalone [`tfcompile`](./tfcompile) tool, +which converts TensorFlow graph into executable code (for CPU only). + +## Inspect compiled programs + +XLA provides introspection facilities which let you inspect the generated +programs. To dump the generated programs, use the environment variable +`XLA_FLAGS`: + +``` +$ XLA_FLAGS="--dump_hlo_as_text --xla_dump_to=/tmp/generated" +TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program +``` + +After the dumping is performed, you can find the following files in +`/tmp/generated`: + +- `module_XXXX.*_optimizations.txt` Generated + [XLA programs](./operation_semantics.md), one per each compiled cluster. + Attaching those when submitting XLA bug reports is extremely helpful! + +- `module_XXXX.ir-*.ll` Generated files in + [LLVM](https://llvm.org/docs/LangRef.html) intermediate representation, with + [NVPTX](https://llvm.org/docs/NVPTXUsage.html) intrinsics. + +- `module_XXXX.ptx` Generated + [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + files. + +You can also dump the graph visualizing the embedding of XLA clusters inside of +the TensorFlow graph with: + +``` +$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug" +``` + +## Supported platforms + +Auto-clustering is supported on NVIDIA GPUs, and ahead-of-time compilation is +supported on x86-64 CPUs. Auto-clustering support on multi-GPU environments and +on a CPU is experimental. + +## Generating great bug reports + +A bug report is much easier to reproduce if it includes dumps for the generated +XLA programs and the used auto-clustering embedding. +To generate them for a TensorFlow program running with auto-clustering, launch: + +``` +$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug +--tf_xla_auto_jit=2" XLA_FLAGS="--dump_hlo_as_text --xla_dump_to=/tmp/generated" +my/tensorflow/program" +``` + +When filing bugs, attach the contents of the `/tmp/generated` directory +(referenced above). + +If possible, try to isolate +a bug to a single XLA program by using the +[`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/replay_computation.cc) +and iteratively running it on generated programs. + +## Further reading + +- [XLA Architecture](./architecture.md): Overview of the XLA architecture +- [XLA - TensorFlow, Compiled](https://developers.googleblog.com/2017/03/xla-tensorflow-compiled.html): + Read on Google Developers Blog +- Check out the + [XLA source](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/xla) + on Github! + + diff --git a/tensorflow/compiler/xla/g3doc/jit.md b/tensorflow/compiler/xla/g3doc/jit.md deleted file mode 100644 index d7ce5ee1ba6..00000000000 --- a/tensorflow/compiler/xla/g3doc/jit.md +++ /dev/null @@ -1,163 +0,0 @@ -# Using JIT Compilation - -> Note: TensorFlow must be compiled from source to include XLA. - -## Why use just-in-time (JIT) compilation? - -The TensorFlow/XLA JIT compiler compiles and runs parts of TensorFlow graphs via -XLA. The benefit of this over the standard TensorFlow implementation is that XLA -can fuse multiple operators (kernel fusion) into a small number of compiled -kernels. Fusing operators can reduce memory bandwidth requirements and improve -performance compared to executing operators one-at-a-time, as the TensorFlow -executor does. - -## Running TensorFlow graphs via XLA - -There are two ways to run TensorFlow computations via XLA, either by -JIT-compiling operators placed on a CPU or GPU device, or by placing operators -on the `XLA_CPU` or `XLA_GPU` TensorFlow devices. Placing operators directly on -a TensorFlow XLA device forces the operator to run on that device and is mainly -used for testing. - -> Note: The XLA CPU backend supports intra-op parallelism (i.e. it can shard a -> single operation across multiple cores) but it does not support inter-op -> parallelism (i.e. it cannot execute independent operations concurrently across -> multiple cores). The XLA GPU backend is competitive with the standard -> TensorFlow implementation, sometimes faster, sometimes slower. - -### Turning on JIT compilation - -JIT compilation can be turned on at the session level or manually for select -operations. Both of these approaches are zero-copy --- data does not need to be -copied when passing data between a compiled XLA kernel and a TensorFlow operator -placed on the same device. - -#### Session - -Turning on JIT compilation at the session level will result in all possible -operators being greedily compiled into XLA computations. Each XLA computation -will be compiled into one or more kernels for the underlying device. - -Subject to a few constraints, if there are two adjacent operators in the graph -that both have XLA implementations, then they will be compiled into a single XLA -computation. - -JIT compilation is turned on at the session level by setting the -`global_jit_level` config to `tf.OptimizerOptions.ON_1` and passing the config -during session initialization. - -```python -# Config to turn on JIT compilation -config = tf.ConfigProto() -config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 - -sess = tf.Session(config=config) -``` - -> Note: Turning on JIT at the session level will not result in operations being -> compiled for the CPU. JIT compilation for CPU operations must be done via -> the manual method documented below. - -#### Manual with experimental_jit_scope() - -JIT compilation can also be turned on manually for one or more operators. This -is done by tagging the operators to compile with the attribute -`_XlaCompile=true`. The simplest way to do this is via the -`tf.contrib.compiler.jit.experimental_jit_scope()` scope defined in -[`tensorflow/contrib/compiler/jit.py`](https://www.tensorflow.org/code/tensorflow/contrib/compiler/jit.py). -Example usage: - -```python - jit_scope = tf.contrib.compiler.jit.experimental_jit_scope - - x = tf.placeholder(np.float32) - with jit_scope(): - y = tf.add(x, x) # The "add" will be compiled with XLA. -``` - -The `_XlaCompile` attribute is currently supported on a best-effort basis. If an -operator cannot be compiled, TensorFlow will silently fall back to the normal -implementation. - -#### Manual with xla.compile() - -Unlike experimental_jit_scope() which silently falls back to normal Tensorflow -on uncompilable operator, xla.compile() returns an explicit error. This is -useful if you want more predictable behaviors from XLA compilation. - -Please see -[xla.compile() tutorial Colab](./tutorials/xla_compile.ipynb) -for how to use it. - -### Placing operators on XLA devices - -Another way to run computations via XLA is to place an operator on a specific -XLA device. This method is normally only used for testing. Valid targets are -`XLA_CPU` or `XLA_GPU`. - -```python -with tf.device("/job:localhost/replica:0/task:0/device:XLA_GPU:0"): - output = tf.add(input1, input2) -``` - -Unlike JIT compilation on the standard CPU and GPU devices, these devices make a -copy of data when it is transferred on and off the device. The extra copy makes -it expensive to mix XLA and TensorFlow operators in the same graph. - -## Tutorial - -This tutorial covers training a simple version of MNIST softmax with JIT turned -on. Currently JIT at the session level, which is what is used for the tutorial, -only supports GPU. - -Before starting the tutorial verify that the LD_LIBRARY environment variable or -ldconfig contains `$CUDA_ROOT/extras/CUPTI/lib64`, which contains libraries for -the CUDA Profiling Tools Interface -[(CUPTI)](http://docs.nvidia.com/cuda/cupti/index.html). TensorFlow uses CUPTI -to pull tracing information from the GPU. - -### Step #1: Prepare sample script - -Download or move -[mnist_softmax_xla.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax_xla.py) -into a folder outside of the TensorFlow source tree. - -### Step #2: Run without XLA - -Execute the python script to train the model without XLA. - -```shell -python mnist_softmax_xla.py --xla='' -``` - -Using the Chrome Trace Event Profiler (browse to chrome://tracing), -open the timeline file created when the script finishes: `timeline.ctf.json`. -The rendered timeline should look similar to the picture below with multiple -green boxes labeled `MatMul`, possibly across multiple CPUs. -
- -
- -### Step #3 Run with XLA - -Execute the python script to train the model with XLA and turn on a debugging -feature of XLA via an environmental variable that outputs the XLA graph. - -```shell -XLA_FLAGS="--xla_hlo_profile --xla_dump_to=/tmp/foo --xla_dump_hlo_as_text" -python mnist_softmax_xla.py -``` - -Open the timeline file created (`timeline.ctf.json`). The rendered timeline -should look similar to the picture below with one long bar labeled `XlaLaunch`. -
- -
- -To understand what is happening in `XlaLaunch`, look at the console output. Each -XLA cluster that's launched will have a corresponding profile (from -`--xla_hlo_profile`) showing how long each HLO took to run. - -`/tmp/foo` will contain the HLO before and after optimizations for each HLO -module that's run. You can read this as-is, or you can visualize it using -`tensorflow/compiler/xla/tools:interactive_graphviz`. diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index d6c99580c39..1f2790e98bb 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1379,13 +1379,16 @@ For a more intuitive description, see the "Informal Description" section below. : : : map indices in : : : : `start_indices` to legal : : : : indices into operand. : +| `indices_are_sorted` | `bool` | Whether the indices are | +: : : guaranteed to be sorted by : +: : : the caller. : For convenience, we label dimensions in the output array not in `offset_dims` as `batch_dims`. The output is an array of rank `batch_dims.size` + `offset_dims.size`. -The `operand.rank` must equal the sume of `offset_dims.size` and +The `operand.rank` must equal the sum of `offset_dims.size` and `collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to `operand.rank`. @@ -1443,6 +1446,10 @@ and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., `offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`, `2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. +If `indices_are_sorted` is set to true then XLA can assume that `start_indices` +are sorted (in ascending `start_index_map` order) by the user. If they are not +then the semantics is implementation defined. + ### Informal Description and Examples Informally, every index `Out` in the output array corresponds to an element `E` @@ -1980,8 +1987,12 @@ window_strides, padding)` | `window_dilations` | `ArraySlice` | array of integers for window | : : : dilation values : | `padding` | `Padding` | padding type for window | -: : : (Padding\:\:kSame or : -: : : Padding\:\:kValid) : +: : : (Padding\:\:kSame, which pads so : +: : : as to have the same output shape : +: : : as input if the stride is 1, or : +: : : Padding\:\:kValid, which uses no : +: : : no padding and "stops" the : +: : : window once it no longer fits) : Below code and figure shows an example of using `ReduceWindow`. Input is a matrix of size [4x6] and both window_dimensions and window_stride_dimensions are @@ -2027,6 +2038,17 @@ padding. +For a non-trivial padding example, consider computing reduce-window minimum +(initial value is `MAX_FLOAT`) with dimension `3` and stride `2` over the input +array `[10000, 1000, 100, 10, 1]`. Padding `kValid` computes minimums over two +_valid_ windows: `[10000, 1000, 100]` and `[100, 10, 1]`, resulting in the +output `[100, 1]`. Padding `kSame` first pads the array so that the shape after +the reduce-window would be the _same_ as input for stride one by adding initial +elements on both sides, getting `[MAX_VALUE, 10000, 1000, 100, 10, 1, +MAX_VALUE]`. Running reduce-window over the padded array operates on three +windows `[MAX_VALUE, 10000, 1000]`, `[1000, 100, 10]`, `[10, 1, MAX_VALUE]`, and +yields `[1000, 10, 1]`. + The evaluation order of the reduction function is arbitrary and may be non-deterministic. Therefore, the reduction function should not be overly sensitive to reassociation. See the discussion about associativity in the @@ -2213,6 +2235,7 @@ Arguments | Type | Semantics `update_window_dims` | `ArraySlice` | The set of dimensions in `updates` shape that are _window dimensions_. `inserted_window_dims` | `ArraySlice` | The set of _window dimensions_ that must be inserted into `updates` shape. `scatter_dims_to_operand_dims` | `ArraySlice` | A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping `i` to `scatter_dims_to_operand_dims[i]` . It has to be one-to-one and total. +`indices_are_sorted` | `bool` | Whether the indices are guaranteed to be sorted by the caller. If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider `scatter_indices` to have a trailing `1` dimension. @@ -2299,6 +2322,10 @@ always be the current value from the `output` array and the second parameter will always be the value from the `updates` array. This is important specifically for cases when the `update_computation` is _not commutative_. +If `indices_are_sorted` is set to true then XLA can assume that `start_indices` +are sorted (in ascending `start_index_map` order) by the user. If they are not +then the semantics is implementation defined. + Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. the scatter op updates the elements in the input that are extracted by the corresponding gather op. @@ -2517,6 +2544,11 @@ arguments to the slice operation. : : : respective `start_indices` value for : : : : the dimension and less than or equal : : : : to the size of the dimension. : +| `strides` | `ArraySlice` | List of N integers that decides the | +: : : input stride of the slice. The slice : +: : : picks every `strides[d]` element in : +: : : dimension `d`. : + 1-dimensional example: diff --git a/tensorflow/compiler/xla/g3doc/tfcompile.md b/tensorflow/compiler/xla/g3doc/tfcompile.md index 5ee09fd302b..c80e2745341 100644 --- a/tensorflow/compiler/xla/g3doc/tfcompile.md +++ b/tensorflow/compiler/xla/g3doc/tfcompile.md @@ -16,9 +16,7 @@ kernels that are actually used in the computation. The compiler is built on top of the XLA framework. The code bridging TensorFlow to the XLA framework resides under -[tensorflow/compiler](https://www.tensorflow.org/code/tensorflow/compiler/), -which also includes support for [just-in-time (JIT) compilation](jit.md) of -TensorFlow graphs. +[tensorflow/compiler](https://www.tensorflow.org/code/tensorflow/compiler/). ## What does tfcompile do? diff --git a/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb new file mode 100644 index 00000000000..78f1bca1478 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/tutorials/autoclustering_xla.ipynb @@ -0,0 +1,222 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "CIFT with XLA.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "metadata": { + "colab_type": "text", + "id": "b7noD9NjFRL-" + }, + "cell_type": "markdown", + "source": [ + "\n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mz65veHXsmnS" + }, + "source": [ + "# Classifying CIFAR-10 with XLA\n", + "\n", + "In this colab we train a TensorFlow model to classify the [CIFAR-10](https://en.wikipedia.org/wiki/CIFAR-10) dataset, and we compile it using XLA.\n", + "\n", + "We start by loading and normalizing the dataset using the Keras API:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7vm2QsMisCxI" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb\n", + "assert(tf.test.is_gpu_available())\n", + "\n", + "tf.keras.backend.clear_session()\n", + "tf.config.optimizer.set_jit(False) # Start with XLA disabled.\n", + "\n", + "def load_data():\n", + " (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n", + " x_train = x_train.astype('float32') / 256\n", + " x_test = x_test.astype('float32') / 256\n", + "\n", + " # Convert class vectors to binary class matrices.\n", + " y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)\n", + " y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)\n", + " return ((x_train, y_train), (x_test, y_test))\n", + "\n", + "(x_train, y_train), (x_test, y_test) = load_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "MgNM2tbgtScx" + }, + "source": [ + "We define the model, adapted from the Keras [CIFAR-10 example](https://keras.io/examples/cifar10_cnn/):" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3ZRQSwoRsKM_" + }, + "outputs": [], + "source": [ + "def generate_model():\n", + " return tf.keras.models.Sequential([\n", + " tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Conv2D(32, (3, 3)),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " tf.keras.layers.Dropout(0.25),\n", + "\n", + " tf.keras.layers.Conv2D(64, (3, 3), padding='same'),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Conv2D(64, (3, 3)),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " tf.keras.layers.Dropout(0.25),\n", + "\n", + " tf.keras.layers.Flatten(),\n", + " tf.keras.layers.Dense(512),\n", + " tf.keras.layers.Activation('relu'),\n", + " tf.keras.layers.Dropout(0.5),\n", + " tf.keras.layers.Dense(10),\n", + " tf.keras.layers.Activation('softmax')\n", + " ])\n", + "\n", + "model = generate_model()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-M4GtGDZtb8a" + }, + "source": [ + "We train the model using the\n", + "[RMSprop](https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer)\n", + "optimizer:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "UKCmrhF0tiMa" + }, + "outputs": [], + "source": [ + "def compile_model(model):\n", + " opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)\n", + " model.compile(loss='categorical_crossentropy',\n", + " optimizer=opt,\n", + " metrics=['accuracy'])\n", + " return model\n", + "\n", + "model = compile_model(model)\n", + "\n", + "def train_model(model, x_train, y_train, x_test, y_test, epochs=25):\n", + " model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)\n", + "\n", + "def warmup(model, x_train, y_train, x_test, y_test):\n", + " # Warm up the JIT, we do not wish to measure the compilation time.\n", + " initial_weights = model.get_weights()\n", + " train_model(model, x_train, y_train, x_test, y_test, epochs=1)\n", + " model.set_weights(initial_weights)\n", + "\n", + "warmup(model, x_train, y_train, x_test, y_test)\n", + "%time train_model(model, x_train, y_train, x_test, y_test)\n", + "\n", + "scores = model.evaluate(x_test, y_test, verbose=1)\n", + "print('Test loss:', scores[0])\n", + "print('Test accuracy:', scores[1])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SLpfQ0StRgsu" + }, + "source": [ + "Now let's train the model again, using the XLA compiler.\n", + "To enable the compiler in the middle of the application, we need to reset the Keras session." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jxU-Tzy4SX7p" + }, + "outputs": [], + "source": [ + "tf.keras.backend.clear_session() # We need to clear the session to enable JIT in the middle of the program.\n", + "tf.config.optimizer.set_jit(True) # Enable XLA.\n", + "model = compile_model(generate_model())\n", + "(x_train, y_train), (x_test, y_test) = load_data()\n", + "\n", + "warmup(model, x_train, y_train, x_test, y_test)\n", + "%time train_model(model, x_train, y_train, x_test, y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "iWHz6P1se92F" + }, + "source": [ + "On a machine with a Titan V GPU and an Intel Xeon E5-2690 CPU the speed up is ~1.17x." + ] + } + ], + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb index 2a83092805b..38abda8974f 100644 --- a/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb +++ b/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb @@ -370,4 +370,4 @@ "outputs": [] } ] -} \ No newline at end of file +} diff --git a/tensorflow/compiler/xla/layout.h b/tensorflow/compiler/xla/layout.h index f216bd63d77..4f309cd9f70 100644 --- a/tensorflow/compiler/xla/layout.h +++ b/tensorflow/compiler/xla/layout.h @@ -28,7 +28,7 @@ limitations under the License. namespace xla { // Describes a tile used in tiling-based layout. Refer to -// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for +// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for // details. class Tile { public: @@ -136,6 +136,7 @@ class Layout { Equal& MinorToMajorOnly() { ignore_tiles_ = true; ignore_element_size_ = true; + ignore_memory_space_ = true; return *this; } diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 63d9a1e9067..03b47ba7089 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -891,7 +891,7 @@ string LiteralBase::GetSparseElementAsString( } } -StatusOr LiteralBase::GetIntegralAsS64( +absl::optional LiteralBase::GetIntegralAsS64( absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -908,12 +908,11 @@ StatusOr LiteralBase::GetIntegralAsS64( case U64: return Get(multi_index); default: - return FailedPrecondition("Array element type is not integral: %s", - PrimitiveType_Name(shape().element_type())); + return absl::nullopt; } } -StatusOr LiteralBase::GetAsDouble( +absl::optional LiteralBase::GetAsDouble( absl::Span multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -926,8 +925,27 @@ StatusOr LiteralBase::GetAsDouble( case BF16: return static_cast(Get(multi_index)); default: - return FailedPrecondition("Array element type is not floating: %s", - PrimitiveType_Name(shape().element_type())); + return absl::nullopt; + } +} + +absl::optional LiteralBase::GetAsComplex128( + absl::Span multi_index) const { + switch (shape().element_type()) { + case BF16: + return {{static_cast(Get(multi_index)), 0}}; + case F16: + return {{static_cast(Get(multi_index)), 0}}; + case F32: + return {{Get(multi_index), 0}}; + case F64: + return {{Get(multi_index), 0}}; + case C64: + return {Get(multi_index)}; + case C128: + return {Get(multi_index)}; + default: + return absl::nullopt; } } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index ffd5a883240..af15cab4a94 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -130,13 +130,47 @@ class LiteralBase { // value into text. string GetSparseElementAsString(int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Return whether the value at the specified index is equal to the provided + // generic `value` (T must be an arithmetic type). + // + // Precondition: must be an array. + template + typename std::enable_if<(std::is_arithmetic::value || + std::is_same::value || + std::is_same::value), + bool>::type + IsEqualAt(absl::Span multi_index, T value) const { + if (auto as_s64 = GetIntegralAsS64(multi_index)) { + return *as_s64 == value; + } + complex128 as_complex128 = *GetAsComplex128(multi_index); + return as_complex128.imag() == 0 && as_complex128.real() == value; + } + + bool IsEqualAt(absl::Span multi_index, complex128 value) const { + if (auto as_s64 = GetIntegralAsS64(multi_index)) { + return *as_s64 == value.real() && value.imag() == 0; + } + auto as_complex128 = GetAsComplex128(multi_index); + return *as_complex128 == value; + } + // As Get(), but determines the correct type and converts the value into // int64. This literal must be an array. - StatusOr GetIntegralAsS64(absl::Span multi_index) const; + absl::optional GetIntegralAsS64( + absl::Span multi_index) const; // As Get(), but determines the correct type, and converts the value into // double. This literal must be an array. - StatusOr GetAsDouble(absl::Span multi_index) const; + absl::optional GetAsDouble(absl::Span multi_index) const; + + // As Get(), but determines the correct type, and converts the value into + // complex128. All floating point types can be converted into complex128. + // + // This literal must be an array. + absl::optional GetAsComplex128( + absl::Span multi_index) const; // Returns the multi-index of the element in a sparse literal at the given // sparse element number. The sparse element number is the position with in diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 8d46d30b4cf..885d18db673 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -2021,5 +2021,46 @@ TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } +TEST_F(LiteralUtilTest, GetAsComplex128) { + complex128 value = {1, 0}; + Literal c1 = LiteralUtil::CreateR0(value); + EXPECT_EQ(*c1.GetAsComplex128({}), value); + Literal c2 = LiteralUtil::CreateR0(1); + EXPECT_EQ(*c2.GetAsComplex128({}), value); + complex64 float_value = {1, 0}; + Literal c4 = LiteralUtil::CreateR0(float_value); + EXPECT_EQ(*c4.GetAsComplex128({}), value); + complex128 other_value = {1, 2}; + Literal c5 = LiteralUtil::CreateR0(other_value); + EXPECT_EQ(*c5.GetAsComplex128({}), other_value); + Literal c6 = LiteralUtil::CreateR0(1); + EXPECT_FALSE(c6.GetAsComplex128({}).has_value()); +} + +TEST_F(LiteralUtilTest, IsEqualAt) { + double val_double = 10.0; + int val_integral = 10; + Literal c1 = LiteralUtil::CreateR0(10); + EXPECT_TRUE(c1.IsEqualAt({}, val_double)); + EXPECT_TRUE(c1.IsEqualAt({}, val_integral)); + Literal c2 = LiteralUtil::CreateR0(10); + EXPECT_TRUE(c2.IsEqualAt({}, val_double)); + EXPECT_TRUE(c2.IsEqualAt({}, val_integral)); + complex128 val_complex = {10, 0}; + EXPECT_TRUE(c2.IsEqualAt({}, val_complex)); + EXPECT_TRUE(c1.IsEqualAt({}, val_complex)); + Literal c3 = LiteralUtil::CreateR0(val_complex); + EXPECT_TRUE(c3.IsEqualAt({}, val_double)); + EXPECT_TRUE(c3.IsEqualAt({}, val_integral)); + EXPECT_TRUE(c3.IsEqualAt({}, val_complex)); + double val_inf = 1. / 0; + EXPECT_FALSE(c3.IsEqualAt({}, val_inf)); + complex128 val_true_complex = {10, 3}; + complex64 val_smaller_complex = {10, 3}; + Literal c4 = LiteralUtil::CreateR0(val_true_complex); + EXPECT_TRUE(c4.IsEqualAt({}, val_true_complex)); + EXPECT_TRUE(c4.IsEqualAt({}, val_smaller_complex)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 95186b94511..70dc386eb14 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -147,12 +147,16 @@ Literal ConvertType(LiteralSlice literal) { switch (primitive_type) { case U8: return LiteralUtil::CreateR0(1); + case U16: + return LiteralUtil::CreateR0(1); case U32: return LiteralUtil::CreateR0(1); case U64: return LiteralUtil::CreateR0(1); case S8: return LiteralUtil::CreateR0(1); + case S16: + return LiteralUtil::CreateR0(1); case S32: return LiteralUtil::CreateR0(1); case S64: @@ -171,9 +175,6 @@ Literal ConvertType(LiteralSlice literal) { return LiteralUtil::CreateR0(1); case PRED: return LiteralUtil::CreateR0(true); - case S16: - case U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE_TYPE: @@ -187,12 +188,16 @@ Literal ConvertType(LiteralSlice literal) { switch (primitive_type) { case U8: return LiteralUtil::CreateR0(std::numeric_limits::min()); + case U16: + return LiteralUtil::CreateR0(std::numeric_limits::min()); case U32: return LiteralUtil::CreateR0(std::numeric_limits::min()); case U64: return LiteralUtil::CreateR0(std::numeric_limits::min()); case S8: return LiteralUtil::CreateR0(std::numeric_limits::min()); + case S16: + return LiteralUtil::CreateR0(std::numeric_limits::min()); case S32: return LiteralUtil::CreateR0(std::numeric_limits::min()); case S64: @@ -209,9 +214,6 @@ Literal ConvertType(LiteralSlice literal) { LOG(FATAL) << "C128 element type has no minimum value"; case PRED: return LiteralUtil::CreateR0(false); - case S16: - case U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: return LiteralUtil::CreateR0( static_cast(-std::numeric_limits::infinity())); @@ -231,12 +233,16 @@ Literal ConvertType(LiteralSlice literal) { switch (primitive_type) { case U8: return LiteralUtil::CreateR0(std::numeric_limits::max()); + case U16: + return LiteralUtil::CreateR0(std::numeric_limits::max()); case U32: return LiteralUtil::CreateR0(std::numeric_limits::max()); case U64: return LiteralUtil::CreateR0(std::numeric_limits::max()); case S8: return LiteralUtil::CreateR0(std::numeric_limits::max()); + case S16: + return LiteralUtil::CreateR0(std::numeric_limits::max()); case S32: return LiteralUtil::CreateR0(std::numeric_limits::max()); case S64: @@ -249,9 +255,6 @@ Literal ConvertType(LiteralSlice literal) { std::numeric_limits::infinity()); case PRED: return LiteralUtil::CreateR0(true); - case S16: - case U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: return LiteralUtil::CreateR0( static_cast(std::numeric_limits::infinity())); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index c50c0baf007..2f12db73330 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -226,8 +226,7 @@ class LiteralUtil { // in invocation between the above signature and this one. static Literal MakeTupleOwned(std::vector elements); - // This overload lets you pass a braced list of Literals to - // MakeTupleOwned: + // This overload lets you pass a list of Literals to MakeTupleOwned: // // LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1(...), ...). // diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 295d3530032..034c14e8930 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -45,7 +45,7 @@ const int kBFloat16MantissaBits = 7; template PrimitiveType NativeToPrimitiveType() { // Make the expression depend on the template parameter NativeT so - // that this compile-time error only apperas if this function is + // that this compile-time error only appears if this function is // instantiated with some concrete type that is not specialized // below. static_assert(!std::is_same::value, diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index e476015f94f..b7c30531923 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -39,12 +39,17 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, } Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name) { + const string& directory, const string& file_name, + string* full_path) { tensorflow::Env* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); string safe_file_name = SanitizeFileName(file_name) + ".pb"; - const string path = tensorflow::io::JoinPath(directory, safe_file_name); - return tensorflow::WriteBinaryProto(env, path, message); + string full_path_impl; + if (!full_path) { + full_path = &full_path_impl; + } + *full_path = tensorflow::io::JoinPath(directory, safe_file_name); + return tensorflow::WriteBinaryProto(env, *full_path, message); } } // namespace protobuf_util diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index e20a7e95a63..7db020982b9 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -37,8 +37,12 @@ extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, // 'directory/file_name.pb'. The 'directory' is recursively created if it // doesn't already exist, and the 'file_name' is sanitized by replacing // illegal characters with underscore '_'. +// +// If 'full_name' is not null then it is set to the name of the file the +// protobuf was written to. Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name); + const string& directory, const string& file_name, + string* full_path = nullptr); // Registers a function that may either expand a dirpath or forward the original // dirpath along as-is. diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index a6a1bd1830e..4377dabaa9d 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -1,7 +1,7 @@ -load("//tensorflow/core:platform/default/build_config.bzl", "pyx_library") -load("//tensorflow/compiler/xla:xla.bzl", "xla_python_default_plugins") -load("//tensorflow:tensorflow.bzl", "tf_pybind_extension") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:default/build_config.bzl", "pyx_library") +load("//tensorflow/compiler/xla:xla.bzl", "xla_py_test_deps", "xla_python_default_plugins") +load("//tensorflow:tensorflow.bzl", "pybind_extension") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_cc_test") package( default_visibility = ["//tensorflow:internal"], @@ -29,15 +29,14 @@ py_test( name = "xla_client_test", srcs = ["xla_client_test.py"], main = "xla_client_test.py", - python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], + tags = ["no_oss"], # TODO(phawkins): This test passes, but requires --config=monolithic. deps = [ ":custom_call_for_test", ":xla_client", - "//tensorflow/compiler/xla:xla_data_proto_py", - "//tensorflow/python:platform_test", - ], + ":xla_extension", + "@absl_py//absl/testing:absltest", + ] + xla_py_test_deps(), ) cc_library( @@ -69,7 +68,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@pybind11", @@ -171,9 +169,9 @@ tf_cc_test( ) cc_library( - name = "device", - srcs = ["device.cc"], - hdrs = ["device.h"], + name = "device_state", + srcs = ["device_state.cc"], + hdrs = ["device_state.h"], deps = [ ":event_pool", ":semaphore", @@ -189,24 +187,11 @@ cc_library( cc_library( name = "local_client", - srcs = [ - "local_client.cc", - "python_ref_manager.cc", - "python_ref_manager.h", - ], - hdrs = [ - "local_client.h", - ], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - "-Wno-c++98-c++11-compat", - ], - features = ["-use_header_modules"], + srcs = ["local_client.cc"], + hdrs = ["local_client.h"], deps = [ - ":device", + ":device_state", ":shared_device_buffer", - ":types", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -222,22 +207,39 @@ cc_library( "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:allocator", "//tensorflow/core:bfc_allocator", "//tensorflow/core:gpu_mem_allocator", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:tf_allocator_adapter", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "python_ref_manager", + srcs = ["python_ref_manager.cc"], + hdrs = ["python_ref_manager.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + "-Wno-c++98-c++11-compat", + ], + features = ["-use_header_modules"], + deps = [ + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@pybind11", ], ) -tf_pybind_extension( +pybind_extension( name = "xla_extension", srcs = [ "xla.cc", @@ -252,6 +254,7 @@ tf_pybind_extension( deps = [ ":local_client", ":shared_device_buffer", + ":python_ref_manager", ":types", ":xrt", "@com_google_absl//absl/base", diff --git a/tensorflow/compiler/xla/python/custom_call_for_test.pyx b/tensorflow/compiler/xla/python/custom_call_for_test.pyx index 530dffd1755..4f7c4c3e5a8 100644 --- a/tensorflow/compiler/xla/python/custom_call_for_test.pyx +++ b/tensorflow/compiler/xla/python/custom_call_for_test.pyx @@ -15,7 +15,7 @@ cdef void test_subtract_f32(void* out_ptr, void** data_ptr) nogil: cpu_custom_call_targets = {} cdef register_custom_call_target(fn_name, void* fn): - cdef const char* name = "xla._CPU_CUSTOM_CALL_TARGET" + cdef const char* name = "xla._CUSTOM_CALL_TARGET" cpu_custom_call_targets[fn_name] = PyCapsule_New(fn, name, NULL) register_custom_call_target(b"test_subtract_f32", (test_subtract_f32)) diff --git a/tensorflow/compiler/xla/python/device.cc b/tensorflow/compiler/xla/python/device_state.cc similarity index 80% rename from tensorflow/compiler/xla/python/device.cc rename to tensorflow/compiler/xla/python/device_state.cc index 73df698a274..6363a5a488f 100644 --- a/tensorflow/compiler/xla/python/device.cc +++ b/tensorflow/compiler/xla/python/device_state.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/python/device.h" +#include "tensorflow/compiler/xla/python/device_state.h" #include #include @@ -24,8 +24,9 @@ limitations under the License. namespace xla { -Device::Device(se::StreamExecutor* executor, bool synchronous_deallocation, - bool asynchronous, bool allow_event_reuse) +DeviceState::DeviceState(se::StreamExecutor* executor, + bool synchronous_deallocation, bool asynchronous, + bool allow_event_reuse) : synchronous_deallocation_(synchronous_deallocation), event_pool_(allow_event_reuse), compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1) { @@ -49,14 +50,14 @@ Device::Device(se::StreamExecutor* executor, bool synchronous_deallocation, "py_xla_callback"); } -Device::~Device() { +DeviceState::~DeviceState() { Status status = SynchronizeAllActivity(); if (!status.ok()) { LOG(ERROR) << "Error when closing device: " << status; } } -Status Device::SynchronizeAllActivity() { +Status DeviceState::SynchronizeAllActivity() { Status status; // TODO(phawkins): in theory the call to SynchronizeAllActivity below should // suffice. However on the Host platform SynchronizeAllActivity is a dummy @@ -64,6 +65,7 @@ Status Device::SynchronizeAllActivity() { // stopped, also block on the compute stream. If SynchronizeAllActivity is // fixed, we could remove the BlockHostUntilDone call. status.Update(compute_stream_->BlockHostUntilDone()); + status.Update(callback_stream_->BlockHostUntilDone()); bool ok = compute_stream_->parent()->SynchronizeAllActivity(); if (!ok) { status.Update(Unknown("SynchronizeAllActivity failed.")); @@ -71,10 +73,10 @@ Status Device::SynchronizeAllActivity() { return status; } -Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream, - se::Stream* dst_stream, - se::DeviceMemoryBase src_buffer, - se::DeviceMemoryBase dst_buffer) { +Status DeviceState::ThenMemcpyDeviceToDevice(se::Stream* src_stream, + se::Stream* dst_stream, + se::DeviceMemoryBase src_buffer, + se::DeviceMemoryBase dst_buffer) { // The default implementation simply calls ThenMemcpyD2D, and assumes that // the buffer addresses identify the devices. This does not work // on all platforms; this method is virtual so it can be overridden. @@ -82,14 +84,14 @@ Status Device::ThenMemcpyDeviceToDevice(se::Stream* src_stream, return Status::OK(); } -void Device::ThenExecuteOnCallbackThread(se::Stream* stream, - std::function callback) const { +void DeviceState::ThenExecuteOnCallbackThread( + se::Stream* stream, std::function callback) const { stream->ThenDoHostCallback([this, callback]() mutable { callback_thread_->Schedule(std::move(callback)); }); } -se::Stream* Device::GetDeviceToDeviceStream() { +se::Stream* DeviceState::GetDeviceToDeviceStream() { absl::MutexLock lock(&mu_); int i = next_device_to_device_stream_; next_device_to_device_stream_ = diff --git a/tensorflow/compiler/xla/python/device.h b/tensorflow/compiler/xla/python/device_state.h similarity index 91% rename from tensorflow/compiler/xla/python/device.h rename to tensorflow/compiler/xla/python/device_state.h index f40c5df7c61..f108c517169 100644 --- a/tensorflow/compiler/xla/python/device.h +++ b/tensorflow/compiler/xla/python/device_state.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_ -#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ #include #include @@ -29,8 +29,9 @@ limitations under the License. namespace xla { // Class that encapsulates state relating to a device (e.g., a GPU) on which we -// can perform computation and transfers. -class Device { +// can perform computation and transfers. DeviceState objects only exist for +// devices local to this host. +class DeviceState { public: // If synchronous_deallocation is true, the host must not free buffers until // compute/transfers that use those buffers have completed. For example, this @@ -39,9 +40,9 @@ class Device { // // If asynchronous is false, the host will synchronize to the device after // each execution or transfer. This is intended for debugging only. - Device(se::StreamExecutor* executor, bool synchronous_deallocation, - bool asynchronous, bool allow_event_reuse); - virtual ~Device(); + DeviceState(se::StreamExecutor* executor, bool synchronous_deallocation, + bool asynchronous, bool allow_event_reuse); + virtual ~DeviceState(); bool synchronous_deallocation() const { return synchronous_deallocation_; } @@ -131,4 +132,4 @@ class Device { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_H_ +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_DEVICE_STATE_H_ diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 982bf9eb21f..1d9bd1f0695 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -85,14 +85,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" -#include "include/pybind11/pybind11.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h" -#include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -106,13 +105,20 @@ limitations under the License. namespace xla { -namespace py = pybind11; +std::string CpuDevice::DebugString() const { + return absl::StrCat("CPU_", id()); +} + +std::string GpuDevice::DebugString() const { + return absl::StrCat("GPU_", id()); +} static StatusOr> CreateBFCAllocator( - se::Platform* platform, LocalClient* client, double memory_fraction, - bool preallocate) { + se::Platform* platform, + absl::Span> device_states, + LocalClient* client, double memory_fraction, bool preallocate) { CHECK_GT(client->backend().device_count(), 0); - std::vector> allocators; + std::vector allocators; for (se::StreamExecutor* executor : client->backend().stream_executors()) { int device_ordinal = executor->device_ordinal(); auto sub_allocator = absl::make_unique( @@ -141,12 +147,23 @@ static StatusOr> CreateBFCAllocator( sub_allocator.release(), allocator_memory, /*allow_growth=*/!preallocate, absl::StrCat("GPU_", device_ordinal, "_bfc")); - allocators.emplace_back(std::move(gpu_bfc_allocator)); + allocators.emplace_back(std::move(gpu_bfc_allocator), + device_states.at(device_ordinal)->compute_stream()); } return absl::make_unique(platform, std::move(allocators)); } +static std::shared_ptr MakeDevice(const std::string& platform_name, + int id, int local_device_ordinal) { + if (platform_name == "cpu") { + return std::make_shared(id, local_device_ordinal); + } else { + CHECK_EQ(platform_name, "gpu"); + return std::make_shared(id, local_device_ordinal); + } +} + StatusOr> PyLocalClient::Get( const std::string& platform_name, const std::string& xla_platform_name, bool asynchronous, const AllocatorConfig& allocator_config) { @@ -162,14 +179,26 @@ StatusOr> PyLocalClient::Get( ClientLibrary::GetOrCreateLocalClient(options)); bool gpu_platform = platform_name == "gpu"; + std::vector> device_states; + std::vector> devices; + bool synchronous_deallocation = platform_name == "cpu"; + for (int i = 0; i < client->device_count(); ++i) { + se::StreamExecutor* executor = + client->backend().stream_executor(i).ValueOrDie(); + device_states.push_back(absl::make_unique( + executor, synchronous_deallocation, asynchronous, + /*allow_event_reuse=*/gpu_platform)); + devices.push_back(MakeDevice(platform_name, i, i)); + } + std::unique_ptr allocator; std::unique_ptr host_memory_allocator; if (gpu_platform) { if (allocator_config.kind != AllocatorConfig::Kind::kPlatform) { - TF_ASSIGN_OR_RETURN( - allocator, - CreateBFCAllocator(platform, client, allocator_config.memory_fraction, - allocator_config.preallocate)); + TF_ASSIGN_OR_RETURN(allocator, + CreateBFCAllocator(platform, device_states, client, + allocator_config.memory_fraction, + allocator_config.preallocate)); } tensorflow::SubAllocator* sub_allocator = new tensorflow::GpuHostAllocator( @@ -186,29 +215,23 @@ StatusOr> PyLocalClient::Get( return Unimplemented("BFCAllocator only available for GPU."); } - std::vector> devices; - devices.reserve(client->device_count()); - bool synchronous_deallocation = platform_name == "cpu"; - for (int i = 0; i < client->device_count(); ++i) { - se::StreamExecutor* executor = - client->backend().stream_executor(i).ValueOrDie(); - devices.push_back(absl::make_unique( - executor, synchronous_deallocation, asynchronous, - /*allow_event_reuse=*/gpu_platform)); - } return std::make_shared( - platform_name, client, std::move(devices), std::move(allocator), + platform_name, client, std::move(devices), /*host_id=*/0, + std::move(device_states), std::move(allocator), std::move(host_memory_allocator)); } PyLocalClient::PyLocalClient( std::string platform_name, LocalClient* client, - std::vector> devices, + std::vector> devices, int host_id, + std::vector> device_states, std::unique_ptr allocator, std::unique_ptr host_memory_allocator) : platform_name_(std::move(platform_name)), client_(client), devices_(std::move(devices)), + host_id_(host_id), + device_states_(std::move(device_states)), owned_allocator_(std::move(allocator)), host_memory_allocator_(std::move(host_memory_allocator)), h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", @@ -218,63 +241,48 @@ PyLocalClient::PyLocalClient( } else { allocator_ = client_->backend().memory_allocator(); } + + for (const std::shared_ptr& device : devices_) { + CHECK(id_to_device_.insert({device->id(), device}).second) + << "Duplicate device id: " << device->id(); + } } Status PyLocalClient::TransferToInfeed(const LiteralSlice& literal, int device_ordinal) { - py_ref_manager().CollectGarbage(); - py::gil_scoped_release gil_release; return client_->TransferToInfeedLocal(literal, device_ordinal); } -StatusOr PyLocalClient::TransferFromOutfeed( - const Shape& shape, int device_ordinal) { - py_ref_manager().CollectGarbage(); - Literal literal; - { - py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN( - literal, client_->TransferFromOutfeedLocal(shape, device_ordinal)); - } - return LiteralToPython(std::make_shared(std::move(literal))); +StatusOr PyLocalClient::TransferFromOutfeed(const Shape& shape, + int device_ordinal) { + return client_->TransferFromOutfeedLocal(shape, device_ordinal); +} + +StatusOr PyLocalClient::GetDefaultDeviceAssignment( + int num_replicas) const { + return client_->backend().computation_placer()->AssignDevices( + num_replicas, /*computation_count=*/1); } /* static */ -StatusOr> PyLocalBuffer::FromPython( - const py::object& argument, std::shared_ptr client, - int device_ordinal) { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromPython"); - struct H2DTransfer { - PythonBufferTree tree; - std::shared_ptr py_buffer_ref; - }; - auto transfer = std::make_shared(); - TF_ASSIGN_OR_RETURN(transfer->tree, GetPythonBufferTree(argument)); - - client->py_ref_manager().CollectGarbage(); - - // Take a reference to the buffer to ensure that the inputs in host memory - // remain live until the transfer is complete. - transfer->py_buffer_ref = client->py_ref_manager().ManageReferences( - absl::MakeSpan(transfer->tree.arrays)); - transfer->tree.arrays.clear(); - - // We are done manipulating Python objects; release the GIL. - py::gil_scoped_release gil_release; - VLOG(1) << "PyLocalBuffer::FromPython: shape: " - << transfer->tree.shape.ToString() +StatusOr> PyLocalBuffer::FromLiterals( + std::vector leaves_literals, const Shape& tuple_shape, + std::shared_ptr leaves_reference, + std::shared_ptr client, int device_ordinal) { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals"); + VLOG(1) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString() << " device ordinal: " << device_ordinal; - Device* device = &client->device(device_ordinal); + DeviceState* device = &client->device_state(device_ordinal); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); se::DeviceMemoryAllocator* allocator = client->allocator(); TF_ASSIGN_OR_RETURN( - transfer->tree.shape, - transfer_manager->ChooseCompactLayoutForShape(transfer->tree.shape)); + Shape compact_shape, + transfer_manager->ChooseCompactLayoutForShape(tuple_shape)); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer scoped_buffer, transfer_manager->AllocateScopedShapedBuffer( - transfer->tree.shape, allocator, device_ordinal)); + compact_shape, allocator, device_ordinal)); // Make the host to device stream wait for the newly allocated buffer to be // available on the compute stream. We schedule this wait synchronously; while @@ -293,21 +301,25 @@ StatusOr> PyLocalBuffer::FromPython( SharedDeviceBuffer::FromScopedShapedBuffer(std::move(scoped_buffer), definition_event); + // TODO(makro): Use move capture once C++ 14 features are available. + auto leaves = std::make_shared>( + std::move(leaves_literals)); auto transfer_h2d = [client, transfer_manager, device, device_ordinal, - device_buffer, transfer]() { + device_buffer, compact_shape, leaves, + leaves_reference]() { // This function uses TF_CHECK_OK and ValueOrDie() since we have no way to // report failures from a callback. However, the operations here are // unlikely to fail and not recoverable even if we were to fail: DMAs to // memory that has already been allocated, and a possible Event allocation. - ShapedBuffer buffer = device_buffer->AsShapedBuffer(transfer->tree.shape); + ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape); TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync( device->host_to_device_stream(), buffer)); std::vector> staging_buffers; - staging_buffers.reserve(transfer->tree.leaves.size()); - auto it = transfer->tree.leaves.begin(); + staging_buffers.reserve(leaves->size()); + auto it = leaves->begin(); for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes(transfer->tree.shape)) { - CHECK(it != transfer->tree.leaves.end()); + ShapeUtil::GetLeafShapes(compact_shape)) { + CHECK(it != leaves->end()); ShapedBuffer leaf( indexed_shape.shape, transfer_manager->HostShapeToDeviceShape(indexed_shape.shape), @@ -352,19 +364,19 @@ StatusOr> PyLocalBuffer::FromPython( device->ThenRelease(device->host_to_device_stream(), device_buffer); } - device->ThenRelease(device->host_to_device_stream(), - std::make_pair(std::move(transfer->py_buffer_ref), - std::move(staging_buffers))); + device->ThenRelease( + device->host_to_device_stream(), + std::make_pair(leaves_reference, std::move(staging_buffers))); }; client->h2d_transfer_pool()->Schedule(transfer_h2d); return absl::make_unique( - transfer->tree.shape, std::move(device_buffer), std::move(client)); + compact_shape, std::move(device_buffer), std::move(client)); } /* static */ StatusOr> PyLocalBuffer::MakeTuple( const std::vector buffers, std::shared_ptr client, int device_ordinal) { - std::vector host_shapes; + std::vector host_shapes; std::vector> device_buffers; host_shapes.reserve(buffers.size()); device_buffers.reserve(buffers.size()); @@ -382,7 +394,7 @@ StatusOr> PyLocalBuffer::FromPython( se::DeviceMemoryAllocator* allocator = client->allocator(); TransferManager* transfer_manager = client->client()->backend().transfer_manager(); - Device& device = client->device(device_ordinal); + DeviceState& device = client->device_state(device_ordinal); auto definition_event = std::make_shared(); TF_ASSIGN_OR_RETURN( @@ -445,7 +457,8 @@ Status PyLocalBuffer::CopyToHostAsync() { } host_value = host_value_ = std::make_shared(); } - se::Stream* stream = client_->device(device_ordinal_).device_to_host_stream(); + se::Stream* stream = + client_->device_state(device_ordinal_).device_to_host_stream(); WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); host_value->value = std::make_shared(on_host_shape_); TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, AsShapedBuffer()); @@ -458,29 +471,22 @@ Status PyLocalBuffer::CopyToHostAsync() { return Status::OK(); } -StatusOr PyLocalBuffer::ToPython() { - tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToPython"); +StatusOr> PyLocalBuffer::ToLiteral() { + tensorflow::profiler::TraceMe traceme("PyLocalBuffer::ToLiteral"); std::shared_ptr device_buffer = DeviceBuffer(); if (!device_buffer) { - return InvalidArgument("ToPython() called on invalid buffer."); + return InvalidArgument("ToLiteral() called on invalid buffer."); } - client_->py_ref_manager().CollectGarbage(); - std::shared_ptr literal; + TF_RETURN_IF_ERROR(CopyToHostAsync()); + std::shared_ptr host_value; { - py::gil_scoped_release gil_release; - TF_RETURN_IF_ERROR(CopyToHostAsync()); - std::shared_ptr host_value; - { - absl::MutexLock lock(&mu_); - host_value = host_value_; - } - host_value->ready.WaitForNotification(); - TF_RETURN_IF_ERROR(host_value->status); - literal = host_value->value; + absl::MutexLock lock(&mu_); + host_value = host_value_; } - - return LiteralToPython(std::move(literal)); + host_value->ready.WaitForNotification(); + TF_RETURN_IF_ERROR(host_value->status); + return host_value->value; } std::shared_ptr PyLocalBuffer::DeviceBuffer() const { @@ -524,15 +530,13 @@ PyLocalBuffer::DestructureTuple() { StatusOr> PyLocalBuffer::CopyToDevice( int dst_device_ordinal) { tensorflow::profiler::TraceMe traceme("PyLocalBuffer::CopyToDevice"); - client_->py_ref_manager().CollectGarbage(); - py::gil_scoped_release gil_release; std::shared_ptr src_device_buffer = DeviceBuffer(); if (dst_device_ordinal == device_ordinal_) { return absl::make_unique(on_host_shape_, src_device_buffer, client_); } - Device& src_device = client_->device(device_ordinal_); - const Device& dst_device = client_->device(dst_device_ordinal); + DeviceState& src_device = client_->device_state(device_ordinal_); + const DeviceState& dst_device = client_->device_state(dst_device_ordinal); se::Stream* src_device_to_device_stream = src_device.GetDeviceToDeviceStream(); @@ -554,7 +558,7 @@ StatusOr> PyLocalBuffer::CopyToDevice( // Copy the leaf buffers. for (const auto& leaf : src_buffer.buffers().leaves()) { - const xla::ShapeIndex& index = leaf.first; + const ShapeIndex& index = leaf.first; const se::DeviceMemoryBase& input_buffer = leaf.second; const se::DeviceMemoryBase& output_buffer = dst_buffer.buffer(index); TF_RET_CHECK(input_buffer.size() == output_buffer.size()) @@ -603,43 +607,58 @@ Status PyLocalBuffer::BlockHostUntilReady() { return InvalidArgument("BlockHostUntilReady() called on invalid buffer."); } - client_->py_ref_manager().CollectGarbage(); - py::gil_scoped_release gil_release; - // This code waits at least until the buffer is ready, but it may wait longer // if there are other device to host transfers scheduled. If this proves to // be an issue, we could either use a separate stream for this purpose, or // poll for the buffer definition events. - se::Stream* stream = - client_->device(device_buffer->device_ordinal()).device_to_host_stream(); + se::Stream* stream = client_->device_state(device_buffer->device_ordinal()) + .device_to_host_stream(); WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); return stream->BlockHostUntilDone(); } +static absl::optional LookupDeviceOrdinal(const PyLocalClient& client, + int device_id) { + auto it = client.id_to_device().find(device_id); + CHECK(it != client.id_to_device().end()) + << "Unknown device id: " << device_id; + int device_ordinal = it->second->local_device_ordinal(); + if (device_ordinal == -1) { + return absl::optional(); + } + return device_ordinal; +} + PyLocalExecutable::PyLocalExecutable( std::shared_ptr executable, DeviceAssignment device_assignment, std::shared_ptr client) : client_(std::move(client)), executable_(std::move(executable)), - device_assignment_(std::move(device_assignment)) {} - -std::vector PyLocalExecutable::DeviceOrdinals() const { + device_assignment_(std::move(device_assignment)) { int num_replicas = device_assignment_.replica_count(); - std::vector device_ordinals; - device_ordinals.reserve(num_replicas); - for (int i = 0; i < num_replicas; ++i) { - device_ordinals.push_back(device_assignment_(i, 0)); + for (int replica = 0; replica < num_replicas; ++replica) { + int device_id = device_assignment_(replica, 0); + absl::optional device_ordinal = + LookupDeviceOrdinal(*client_, device_id); + if (!device_ordinal) { + VLOG(3) << "Non-local device: " << device_id; + continue; + } + local_replicas_.push_back(replica); + device_ordinals_.push_back(*device_ordinal); } - return device_ordinals; + CHECK_GE(local_replicas_.size(), 1); } StatusOr> PyLocalExecutable::ExecuteHelper( absl::Span argument_handles, int replica, const RunId& run_id) { - const int device_ordinal = device_assignment_(replica, 0); + const int device_id = device_assignment_(replica, 0); + absl::optional device_ordinal = LookupDeviceOrdinal(*client_, device_id); + CHECK(device_ordinal); tensorflow::profiler::TraceMe traceme("LocalExecutable::Execute"); VLOG(3) << "Replica " << replica - << " mapped to device ordinal for execution: " << device_ordinal; + << " mapped to device ordinal for execution: " << *device_ordinal; absl::flat_hash_set events; std::vector> device_buffers; @@ -657,11 +676,11 @@ StatusOr> PyLocalExecutable::ExecuteHelper( "%d to replica %d", i, replica); } - if (device_buffer->device_ordinal() != device_ordinal) { + if (device_buffer->device_ordinal() != *device_ordinal) { return InvalidArgument( "Buffer passed to Execute() as argument %d to replica %d is on " "device %d, but replica is assigned to device %d.", - i, replica, device_buffer->device_ordinal(), device_ordinal); + i, replica, device_buffer->device_ordinal(), *device_ordinal); } TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, handle->AsShapedBuffer()); argument_buffers.push_back(std::move(shaped_buffer)); @@ -672,7 +691,7 @@ StatusOr> PyLocalExecutable::ExecuteHelper( << " buffer: " << argument_buffers.back().ToString(); } - Device* device = &client_->device(device_ordinal); + DeviceState* device = &client_->device_state(*device_ordinal); // The choice of where we wait is arbitrary; the reason for the wait is pacing // to avoid problems such as memory fragmentation and running ahead too far, // not for correctness. Placing it before the executable launch allows the @@ -740,45 +759,49 @@ StatusOr>> PyLocalExecutable::ExecutePerReplica( absl::Span> argument_handles) { tensorflow::profiler::TraceMe traceme("LocalExecutable::ExecutePerReplica"); - const int num_devices = client_->device_count(); + int num_local_replicas = local_replicas_.size(); + const int num_local_devices = client_->local_device_count(); - if (argument_handles.size() != num_replicas()) { + if (argument_handles.size() != num_local_replicas) { return InvalidArgument( - "Attempted to execute with %d replicas when replica count is %d", - argument_handles.size(), num_devices); + "Attempted to execute with %d local replicas when local replica count " + "is %d (total replica count: %d)", + argument_handles.size(), num_local_replicas, num_replicas()); } - if (argument_handles.size() > num_devices) { + if (argument_handles.size() > num_local_devices) { return InvalidArgument( "Attempted to execute with %d replicas when device count is %d", - argument_handles.size(), num_devices); + argument_handles.size(), num_local_devices); } - VLOG(1) << "Executing replicated computation; num_replicas=" - << num_replicas(); - std::vector>> results(num_replicas()); - if (num_replicas() == 1) { + VLOG(1) << "Executing replicated computation; num_replicas=" << num_replicas() + << " num_local_replicas=" << num_local_replicas; + std::vector>> results( + num_local_replicas); + if (num_local_replicas == 1) { // Fast-path if there is only one replica — run the computation on the // current thread. - results[0] = ExecuteHelper(argument_handles[0], /*replica=*/0, RunId()); + results[0] = + ExecuteHelper(argument_handles[0], local_replicas_[0], RunId()); } else { RunId run_id; absl::Mutex mu; - int running GUARDED_BY(mu) = num_replicas(); - int failed GUARDED_BY(mu) = 0; - Status first_failure_status GUARDED_BY(mu); + int running = num_local_replicas; + int failed = 0; + Status first_failure_status; - for (int replica = 0; replica < num_replicas(); ++replica) { - const int device_ordinal = device_assignment_(replica, 0); - const Device& device = client_->device(device_ordinal); - device.execute_thread()->Schedule([&, replica] { - results[replica] = - ExecuteHelper(argument_handles[replica], replica, run_id); + for (int i = 0; i < num_local_replicas; ++i) { + const int replica = local_replicas_[i]; + const int device_ordinal = device_ordinals_[i]; + const DeviceState& device = client_->device_state(device_ordinal); + device.execute_thread()->Schedule([&, replica, i] { + results[i] = ExecuteHelper(argument_handles[i], replica, run_id); absl::MutexLock lock(&mu); --running; - if (!results[replica].ok()) { + if (!results[i].ok()) { if (failed == 0) { - first_failure_status = results[replica].status(); + first_failure_status = results[i].status(); } ++failed; } @@ -813,18 +836,19 @@ PyLocalExecutable::ExecutePerReplica( } VLOG(1) << "Replicated execution complete."; - std::vector> wrapped_results(num_replicas()); - for (int replica = 0; replica < num_replicas(); ++replica) { - auto& statusor = results[replica]; + std::vector> wrapped_results( + num_local_replicas); + for (int i = 0; i < num_local_replicas; ++i) { + auto& statusor = results[i]; if (!statusor.ok()) { return AppendStatus( statusor.status(), absl::StrFormat( "while running replica %d of a replicated computation (other " "replicas may have failed as well).", - replica)); + local_replicas_[i])); } - wrapped_results[replica] = std::move(statusor.ValueOrDie()); + wrapped_results[i] = std::move(statusor.ValueOrDie()); } return wrapped_results; } @@ -858,10 +882,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation, device_assignment->computation_count()); } } else { - TF_ASSIGN_OR_RETURN( - device_assignment, - client->client()->backend().computation_placer()->AssignDevices( - options.num_replicas(), /*computation_count=*/1)); + TF_ASSIGN_OR_RETURN(device_assignment, client->GetDefaultDeviceAssignment( + options.num_replicas())); } if (!argument_layouts) { diff --git a/tensorflow/compiler/xla/python/local_client.h b/tensorflow/compiler/xla/python/local_client.h index 8ad4c44d53f..37b3c56b7d2 100644 --- a/tensorflow/compiler/xla/python/local_client.h +++ b/tensorflow/compiler/xla/python/local_client.h @@ -23,12 +23,10 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "absl/types/span.h" -#include "include/pybind11/pybind11.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" -#include "tensorflow/compiler/xla/python/device.h" -#include "tensorflow/compiler/xla/python/python_ref_manager.h" +#include "tensorflow/compiler/xla/python/device_state.h" #include "tensorflow/compiler/xla/python/shared_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -40,6 +38,50 @@ limitations under the License. namespace xla { +class Device { + public: + explicit Device(int id, int local_device_ordinal, int host_id = 0) + : id_(id), + local_device_ordinal_(local_device_ordinal), + host_id_(host_id) {} + virtual ~Device() {} + + // The ID of this device. IDs are unique among devices of this type + // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all + // hosts' devices. This is the ID that should be used in a DeviceAssignment. + int id() const { return id_; } + + // If this is a device local to this host, the local index of this device as + // according to the underlying backend. Unlike id(), this will always be in + // the range [0, num_local_devices), and can be used with the xla::LocalClient + // and xla::Backend APIs. + // + // -1 if this device is not local to this host. + int local_device_ordinal() const { return local_device_ordinal_; } + + // The ID of this device's host. This is always 0 on single-host platforms. + int host_id() const { return host_id_; } + + virtual std::string DebugString() const = 0; + + private: + const int id_; + const int local_device_ordinal_; + const int host_id_; +}; + +class CpuDevice : public Device { + public: + using Device::Device; + std::string DebugString() const override; +}; + +class GpuDevice : public Device { + public: + using Device::Device; + std::string DebugString() const override; +}; + struct AllocatorConfig { enum class Kind { kDefault, // Client picks the best option for the platform. @@ -72,19 +114,31 @@ class PyLocalClient { // `allocator` may null, in which case the platform default allocator is used. explicit PyLocalClient( std::string platform_name, LocalClient* client, - std::vector> devices, + std::vector> devices, int host_id, + std::vector> device_states, std::unique_ptr allocator, std::unique_ptr host_memory_allocator); virtual ~PyLocalClient() = default; Status TransferToInfeed(const LiteralSlice& literal, int device_ordinal); - StatusOr TransferFromOutfeed(const Shape& shape, - int device_ordinal); + StatusOr TransferFromOutfeed(const Shape& shape, int device_ordinal); - int device_count() const { return client_->device_count(); } - Device& device(int device_ordinal) const { - return *devices_.at(device_ordinal); + virtual StatusOr GetDefaultDeviceAssignment( + int num_replicas) const; + + int device_count() const { return devices_.size(); } + const std::vector>& devices() { return devices_; } + const std::map>& id_to_device() const { + return id_to_device_; } + int host_id() const { return host_id_; } + const std::string& platform_name() const { return platform_name_; } + + int local_device_count() const { return device_states_.size(); } + DeviceState& device_state(int device_ordinal) const { + return *device_states_.at(device_ordinal); + } + LocalClient* client() const { return client_; } se::DeviceMemoryAllocator* allocator() const { return allocator_; } tensorflow::Allocator* host_memory_allocator() const { @@ -95,19 +149,18 @@ class PyLocalClient { return &h2d_transfer_pool_; } - PythonRefManager& py_ref_manager() { return py_ref_manager_; } - protected: std::string platform_name_; LocalClient* client_; - // py_ref_manager_ must come after devices_ in the class destruction order - // (i.e., appear first in the class.) - // Destruction of devices waits for them to quiesce; callbacks on device - // streams may refer to py_ref_manager_ and we must wait for them to complete. - PythonRefManager py_ref_manager_; + // Includes all devices, including non-local devices on multi-host platforms. + std::vector> devices_; + // Maps Device::id() to the corresponding Device. + std::map> id_to_device_; + int host_id_; - std::vector> devices_; + // Device states local to this host. Indexed by local device ordinal. + std::vector> device_states_; se::DeviceMemoryAllocator* allocator_; std::unique_ptr owned_allocator_; @@ -128,9 +181,10 @@ class PyLocalClient { // Thread-safe. class PyLocalBuffer { public: - static StatusOr> FromPython( - const pybind11::object& argument, std::shared_ptr client, - int device_ordinal); + static StatusOr> FromLiterals( + std::vector leaves_literals, const Shape& tuple_shape, + std::shared_ptr leaves_reference, + std::shared_ptr client, int device_ordinal); static StatusOr> MakeTuple( const std::vector buffers, @@ -148,16 +202,17 @@ class PyLocalBuffer { const Shape& on_host_shape() const { return on_host_shape_; } int device_ordinal() const { return device_ordinal_; } + const std::string& platform_name() const { return client_->platform_name(); } // Returns the buffer's value as a tuple DAG of Python arrays. If the value // has previously been prefetched to the host, then returns the prefetched // version, otherwise copies the buffer to the host. Blocks until the // value is ready. - StatusOr ToPython(); + StatusOr> ToLiteral(); // Initiates a copy of the buffer to the host. Does not block waiting for // the transfer to complete. The value can be retrieved by a later call to - // ToPython(). + // ToLiteral(). Status CopyToHostAsync(); // Returns the associated device buffer. Returns a nullptr if the buffer is @@ -190,14 +245,14 @@ class PyLocalBuffer { std::shared_ptr device_buffer_ GUARDED_BY(mu_); // The cached value of the buffer on the host, produced either from a call to - // CopyToHost or from a call to ToPython. Once a value has been fetched to + // CopyToHost or from a call to ToLiteral. Once a value has been fetched to // the host, it persists Delete() is called or the PyLocalBuffer is destroyed. struct HostValue { absl::Notification ready; // status and value are valid for reading only after `ready` has been // notified. Status status; - std::shared_ptr value; + std::shared_ptr value; }; std::shared_ptr host_value_ GUARDED_BY(mu_); }; @@ -222,8 +277,12 @@ class PyLocalExecutable { return executable_->build_options().num_replicas(); } + int64 SizeOfGeneratedCodeInBytes() const { + return executable_->executable()->SizeOfGeneratedCodeInBytes(); + } + // Returns the device ordinals to which each replica is assigned. - std::vector DeviceOrdinals() const; + const std::vector& DeviceOrdinals() const { return device_ordinals_; } const DeviceAssignment& device_assignment() const { return device_assignment_; @@ -248,6 +307,13 @@ class PyLocalExecutable { std::shared_ptr const client_; std::shared_ptr executable_; const DeviceAssignment device_assignment_; + // The replica indices of device_assignment_ to be run by this client. On + // single-host platforms, this is all replicas (i.e. local_replicas_[i] = i), + // but this may not be the case on multi-host platforms. + std::vector local_replicas_; + // device_ordinals_[i] is the device ordinal to which local_replicas_[i] is + // assigned. + std::vector device_ordinals_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/python_ref_manager.cc b/tensorflow/compiler/xla/python/python_ref_manager.cc index 1e9cc58d090..0a980f1a749 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.cc +++ b/tensorflow/compiler/xla/python/python_ref_manager.cc @@ -49,4 +49,9 @@ void PythonRefManager::CollectGarbage() { python_garbage_.clear(); } +PythonRefManager* GlobalPyRefManager() { + static PythonRefManager* static_ref_manager = new PythonRefManager(); + return static_ref_manager; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/python/python_ref_manager.h b/tensorflow/compiler/xla/python/python_ref_manager.h index 8be19336a89..054150faf25 100644 --- a/tensorflow/compiler/xla/python/python_ref_manager.h +++ b/tensorflow/compiler/xla/python/python_ref_manager.h @@ -74,6 +74,11 @@ class PythonRefManager { std::deque python_garbage_ GUARDED_BY(mu_); }; +// A global PythonRefManager. Unless `CollectGarbage()` is called before +// shutdown, this container will hold on to Python objects and thus cause a +// leak. This behavior is similar to `tensorflow::ClearDecRefCache()`. +PythonRefManager* GlobalPyRefManager(); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTHON_REF_MANAGER_H_ diff --git a/tensorflow/compiler/xla/python/types.h b/tensorflow/compiler/xla/python/types.h index bc0ee2b19b4..1873249b07c 100644 --- a/tensorflow/compiler/xla/python/types.h +++ b/tensorflow/compiler/xla/python/types.h @@ -104,7 +104,7 @@ struct type_caster> { using value_conv = make_caster; PYBIND11_TYPE_CASTER(absl::Span, - _("Span[") + value_conv::name() + _("]")); + _("Span[") + value_conv::name + _("]")); // absl::Span doesn't hold ownership. We therefore need a temporary array. // Pybind appears to keep type_casters alive until the callee has run. @@ -151,7 +151,7 @@ struct type_caster> { using value_conv = make_caster; PYBIND11_TYPE_CASTER(xla::StatusOr, - _("StatusOr[") + value_conv::name() + _("]")); + _("StatusOr[") + value_conv::name + _("]")); static handle cast(xla::StatusOr src, return_value_policy policy, handle parent) { diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 172e24f801e..078fee8f652 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/python/local_client.h" +#include "tensorflow/compiler/xla/python/python_ref_manager.h" #include "tensorflow/compiler/xla/python/types.h" #include "tensorflow/compiler/xla/python/xrt.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -109,18 +110,23 @@ StatusOr GetComputationHloDotGraph( } // Registers a 'fn_capsule' as a CPU custom call target. -// 'fn_capsule' is a void* pointer encapsulated in a PyCapsule object, with name -// "xla._CPU_CUSTOM_CALL_TARGET". -Status RegisterCpuCustomCallTarget(const std::string& fn_name, - py::capsule capsule) { - static const char* const kName = "xla._CPU_CUSTOM_CALL_TARGET"; - if (absl::string_view(capsule.name()) != kName) { +// 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object, +// with name "xla._CUSTOM_CALL_TARGET". +// 'platform' is an XLA platform name, e.g., "Host" or "CUDA". +Status PyRegisterCustomCallTarget(const std::string& fn_name, + py::capsule capsule, + const std::string& platform) { + static const char* const kName = "xla._CUSTOM_CALL_TARGET"; + // TODO(phawkins): remove old name after fixing users. + static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET"; + if (absl::string_view(capsule.name()) != kName && + absl::string_view(capsule.name()) != kOldCpuName) { return InvalidArgument( - "Argument to RegisterCpuCustomCallTargetRegistry was not a " - "xla._CPU_CUSTOM_CALL_TARGET capsule."); + "Argument to RegisterCustomCallTargetRegistry was not a " + "xla._CUSTOM_CALL_TARGET capsule."); } CustomCallTargetRegistry::Global()->Register( - fn_name, static_cast(capsule), "Host"); + fn_name, static_cast(capsule), platform); return Status::OK(); } @@ -292,10 +298,34 @@ PYBIND11_MODULE(xla_extension, m) { .def("computation_count", &DeviceAssignment::computation_count) .def("__repr__", &DeviceAssignment::ToString); + py::class_>( + m, "Device", + "A descriptor of an available device.\n\nSubclasses are used to " + "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " + "have additional properties specific to that device type.") + .def_property_readonly( + "id", &Device::id, + "Integer ID of this device.\n\nUnique across all available devices " + "of this type, including remote devices on multi-host platforms.") + .def_property_readonly("host_id", &Device::host_id, + "Integer ID of this device's host.\n\n" + "This is always 0 except on multi-host platforms.") + .def("__str__", &Device::DebugString); + + py::class_>(m, "CpuDevice") + .def("__repr__", [](const CpuDevice& device) { + return absl::StrFormat("CpuDevice(id=%i)", device.id()); + }); + + py::class_>(m, "GpuDevice") + .def("__repr__", [](const GpuDevice& device) { + return absl::StrFormat("GpuDevice(id=%i)", device.id()); + }); + // Local XLA client methods. - // CPU custom-call targets. - m.def("RegisterCpuCustomCallTarget", &RegisterCpuCustomCallTarget); + // Custom-call targets. + m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget); py::class_ alloc_config(m, "AllocatorConfig"); alloc_config.def(py::init<>()) @@ -311,21 +341,84 @@ PYBIND11_MODULE(xla_extension, m) { .def_static("Get", &PyLocalClient::Get, py::arg("platform"), py::arg("xla_platform_id"), py::arg("asynchronous"), py::arg("allocator_config") = AllocatorConfig()) - .def("DeviceCount", &PyLocalClient::device_count) - .def("TransferToInfeed", &PyLocalClient::TransferToInfeed) - .def("TransferFromOutfeed", &PyLocalClient::TransferFromOutfeed); + .def("device_count", &PyLocalClient::device_count) + .def("local_device_count", &PyLocalClient::local_device_count) + .def("devices", &PyLocalClient::devices) + .def("host_id", &PyLocalClient::host_id) + .def("TransferToInfeed", + [](PyLocalClient* client, const LiteralSlice& literal, + int device_ordinal) { + GlobalPyRefManager()->CollectGarbage(); + py::gil_scoped_release gil_release; + return client->TransferToInfeed(literal, device_ordinal); + }) + .def("TransferFromOutfeed", + [](PyLocalClient* client, const Shape& shape, + int device_ordinal) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal_shared; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed( + shape, device_ordinal)); + literal_shared = std::make_shared(std::move(literal)); + } + return LiteralToPython(std::move(literal_shared)); + }); py::class_(m, "PyLocalBuffer") - .def_static("from_python", &PyLocalBuffer::FromPython) + .def_static( + "from_python", + [](const pybind11::object& argument, + std::shared_ptr client, + int device_ordinal) -> StatusOr> { + GlobalPyRefManager()->CollectGarbage(); + TF_ASSIGN_OR_RETURN(PythonBufferTree tree, + GetPythonBufferTree(argument)); + std::shared_ptr py_buffer_ref = + GlobalPyRefManager()->ManageReferences( + absl::MakeSpan(tree.arrays)); + tree.arrays.clear(); + + std::vector leaves; + leaves.insert(leaves.end(), + std::make_move_iterator(tree.leaves.begin()), + std::make_move_iterator(tree.leaves.end())); + + py::gil_scoped_release gil_release; + return PyLocalBuffer::FromLiterals( + std::move(leaves), tree.shape, std::move(py_buffer_ref), + std::move(client), device_ordinal); + }) .def_static("make_tuple", &PyLocalBuffer::MakeTuple) - .def("copy_to_device", &PyLocalBuffer::CopyToDevice) + .def("copy_to_device", + [](PyLocalBuffer* buffer, int dst_device_ordinal) { + GlobalPyRefManager()->CollectGarbage(); + py::gil_scoped_release gil_release; + return buffer->CopyToDevice(dst_device_ordinal); + }) .def("delete", &PyLocalBuffer::Delete) .def("destructure", &PyLocalBuffer::DestructureTuple) - .def("block_host_until_ready", &PyLocalBuffer::BlockHostUntilReady) + .def("block_host_until_ready", + [](PyLocalBuffer* buffer) { + GlobalPyRefManager()->CollectGarbage(); + py::gil_scoped_release gil_release; + return buffer->BlockHostUntilReady(); + }) .def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync) - .def("to_py", &PyLocalBuffer::ToPython) + .def("to_py", + [](PyLocalBuffer* buffer) -> StatusOr { + GlobalPyRefManager()->CollectGarbage(); + std::shared_ptr literal; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(literal, buffer->ToLiteral()); + } + return LiteralToPython(std::move(literal)); + }) .def("shape", &PyLocalBuffer::on_host_shape) .def("device", &PyLocalBuffer::device_ordinal) + .def("platform", &PyLocalBuffer::platform_name) .def("is_deleted", [](const PyLocalBuffer& buffer) { return buffer.DeviceBuffer() == nullptr; @@ -347,6 +440,8 @@ PYBIND11_MODULE(xla_extension, m) { .def_static("Compile", &PyLocalExecutable::Compile, py::call_guard()) .def("DeviceOrdinals", &PyLocalExecutable::DeviceOrdinals) + .def("SizeOfGeneratedCodeInBytes", + &PyLocalExecutable::SizeOfGeneratedCodeInBytes) .def("Delete", &PyLocalExecutable::Delete) .def("Execute", &PyLocalExecutable::Execute, py::call_guard(), py::arg("arguments")) @@ -365,7 +460,13 @@ PYBIND11_MODULE(xla_extension, m) { &DebugOptions::set_xla_cpu_fast_math_honor_nans) .def_property("xla_cpu_fast_math_honor_division", &DebugOptions::xla_cpu_fast_math_honor_division, - &DebugOptions::set_xla_cpu_fast_math_honor_division); + &DebugOptions::set_xla_cpu_fast_math_honor_division) + .def_property("xla_cpu_fast_math_honor_functions", + &DebugOptions::xla_cpu_fast_math_honor_functions, + &DebugOptions::set_xla_cpu_fast_math_honor_functions) + .def_property("xla_gpu_enable_fast_min_max", + &DebugOptions::xla_gpu_enable_fast_min_max, + &DebugOptions::set_xla_gpu_enable_fast_min_max); py::class_(m, "ExecutableBuildOptions") .def(py::init<>()) @@ -473,7 +574,8 @@ PYBIND11_MODULE(xla_extension, m) { .value("IRFFT", FftType::IRFFT); ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), - py::arg("dimension_numbers"), py::arg("slice_sizes")); + py::arg("dimension_numbers"), py::arg("slice_sizes"), + py::arg("indices_are_sorted")); ops.def("GetTupleElement", &GetTupleElement); ops.def("Infeed", &Infeed, py::arg("builder"), py::arg("shape"), py::arg("config") = ""); @@ -533,20 +635,26 @@ PYBIND11_MODULE(xla_extension, m) { py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); ops.def( "Sort", - [](XlaBuilder* builder, absl::Span operands, - int64 dimension) -> XlaOp { + [](XlaBuilder* builder, absl::Span operands, int64 dimension, + absl::optional comparator) -> XlaOp { return builder->ReportErrorOrReturn([&]() -> StatusOr { std::vector operand_types; for (const auto& operand : operands) { TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); operand_types.push_back(operand_shape.element_type()); } - return Sort(operands, - CreateScalarLtComputation(operand_types, builder), - dimension); + + if (comparator) { + return Sort(operands, **comparator, dimension); + } else { + return Sort(operands, + CreateScalarLtComputation(operand_types, builder), + dimension); + } }); }, - py::arg("builder"), py::arg("operands"), py::arg("dimension") = -1); + py::arg("builder"), py::arg("operands"), py::arg("dimension") = -1, + py::arg("comparator") = absl::nullopt); ops.def("Transpose", &Transpose); ops.def("TriangularSolve", &TriangularSolve); ops.def("Tuple", &Tuple); @@ -640,6 +748,6 @@ PYBIND11_MODULE(xla_extension, m) { py::class_(m, "ChannelHandle"); tensorflow::AddXrtSubmodule(&m); -} +} // NOLINT(readability/fn_size) } // namespace xla diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 7e5692fef30..63a9ea37692 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -59,6 +59,18 @@ class Backend(object): def device_count(self): """Returns the number of devices known to the backend.""" + @abc.abstractmethod + def local_device_count(self): + """Returns the number of devices local to this host.""" + + @abc.abstractmethod + def devices(self): + """Returns a list of `device_count()` Device subclasses.""" + + @abc.abstractmethod + def host_id(self): + """Returns the integer ID of this host.""" + @abc.abstractmethod def buffer_from_pyval(self, pyval, device=0): """Allocates a fresh buffer and populates it with `pyval`.""" @@ -93,7 +105,16 @@ class LocalBackend(Backend): self.client = client def device_count(self): - return self.client.DeviceCount() + return self.client.device_count() + + def local_device_count(self): + return self.client.local_device_count() + + def devices(self): + return self.client.devices() + + def host_id(self): + return self.client.host_id() def buffer_from_pyval(self, pyval, device=0): return _xla.PyLocalBuffer.from_python(pyval, self.client, device) @@ -109,15 +130,25 @@ class LocalBackend(Backend): options.debug_options.xla_cpu_fast_math_honor_infs = True options.debug_options.xla_cpu_fast_math_honor_nans = True options.debug_options.xla_cpu_fast_math_honor_division = True + options.debug_options.xla_cpu_fast_math_honor_functions = True + options.debug_options.xla_gpu_enable_fast_min_max = False return _xla.LocalExecutable.Compile(c_computation, compile_options.argument_layouts, options, self.client, compile_options.device_assignment) +xla_platform_names = { + 'cpu': 'Host', + 'gpu': 'CUDA', +} + + def _cpu_backend_factory(): client = _xla.LocalClient.Get( - platform='cpu', xla_platform_id='Host', asynchronous=True) + platform='cpu', + xla_platform_id=xla_platform_names['cpu'], + asynchronous=True) return LocalBackend(platform='cpu', client=client) @@ -142,7 +173,9 @@ def _gpu_backend_factory(): config.preallocate = preallocate not in ('0', 'false', 'False') client = _xla.LocalClient.Get( - platform='gpu', xla_platform_id='CUDA', asynchronous=True, + platform='gpu', + xla_platform_id=xla_platform_names['gpu'], + asynchronous=True, allocator_config=config) return LocalBackend(platform='gpu', client=client) @@ -449,6 +482,9 @@ def computation_count(): """ +Device = _xla.Device + + class CompileOptions(object): """Python object for XLA compile options. @@ -544,6 +580,9 @@ class Computation(object): # def Execute(self, arguments : [Buffer]) -> Buffer: # """Execute on one replica with Buffer arguments and return value.""" # +# def SizeOfGeneratedCodeInBytes(self) -> int: +# """Return generated binary size, or -1 if not known.""" +# # def ExecutePerReplica(self, arguments: [[Buffer]]) -> [Buffer]: # """Execute on many replicas with Buffer arguments and return value. # @@ -1431,12 +1470,31 @@ class ComputationBuilder(object): batch_group_count, precision_config=precision_config) - def Sort(self, operand, dimension=-1): - """Enqueues a sort operation onto the computation.""" - return ops.Sort(self._builder, [operand], dimension) + def Sort(self, operands, dimension=-1, comparator=None): + """Enqueues a sort operation onto the computation. + + Args: + operands: either an XlaOp or a sequence of XlaOps to sort. All operands + must be arrays with the same dimensions. + dimension: the array dimension over which to sort. + comparator: a comparator XlaComputation. See the XLA operation semantics + for details. + + Returns: + Either an XlaOp or a tuple of XlaOps (if `operands` was an XlaOp or + a tuple of XlaOps, respectively.) + """ + operands = ( + list(operands) + if isinstance(operands, collections.Sequence) else [operands]) + return ops.Sort(self._builder, operands, dimension, + comparator.computation if comparator else None) def SortKeyVal(self, keys, values, dimension=-1): - """Enqueues a key-value sort operation onto the computation.""" + """Enqueues a key-value sort operation onto the computation. + + Deprecated. Use `Sort` instead. + """ return ops.Sort(self._builder, [keys, values], dimension) def QR(self, a, full_matrices=True): @@ -1470,11 +1528,27 @@ class ComputationBuilder(object): """Enqueues a singular value decomposition.""" return self.Tuple(*ops.SVD(a)) - def Scatter(self, a, scatter_indices, updates, update_computation, - dimension_numbers): + def Gather(self, + a, + start_indices, + dimension_numbers, + slice_sizes, + indices_are_sorted=False): + """Enqueues a Gather operation onto the computation.""" + return ops.Gather(a, start_indices, dimension_numbers, slice_sizes, + indices_are_sorted) + + def Scatter(self, + a, + scatter_indices, + updates, + update_computation, + dimension_numbers, + indices_are_sorted=False): """Enqueues a Scatter operation onto the computation.""" return ops.Scatter(a, scatter_indices, updates, - update_computation.computation, dimension_numbers) + update_computation.computation, dimension_numbers, + indices_are_sorted) def Fft(self, operand, fft_type, fft_lengths): """Enqueues a FFT operation onto the computation.""" @@ -1558,7 +1632,6 @@ _OTHER_OPS = [ 'CollectivePermute', 'ConvertElementType', 'Dot', - 'Gather', 'GetTupleElement', 'ReducePrecision', 'Rev', @@ -1592,14 +1665,18 @@ def _forward_methods_to_local_builder(): _forward_methods_to_local_builder() -def register_cpu_custom_call_target(name, fn): - """Registers a CPU custom call target. +def register_custom_call_target(name, fn, platform='cpu'): + """Registers a custom call target. Args: name: bytes containing the name of the function. fn: a PyCapsule object containing the function pointer. + platform: the target platform. """ - _xla.RegisterCpuCustomCallTarget(name, fn) + _xla.RegisterCustomCallTarget(name, fn, xla_platform_names[platform]) + +# Deprecated. Use register_custom_call_target instead. +register_cpu_custom_call_target = register_custom_call_target class PaddingConfigDimension(object): diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 16c1d4237a6..257e02ceec3 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -22,14 +22,14 @@ import functools import itertools import threading +from absl.testing import absltest import numpy as np from tensorflow.compiler.xla.python import custom_call_for_test from tensorflow.compiler.xla.python import xla_client -import unittest -class ComputationTest(unittest.TestCase): +class ComputationTest(absltest.TestCase): """Base class for running an XLA Computation through the local client.""" def _NewComputation(self, name=None): @@ -89,7 +89,7 @@ def NumpyArrayBool(*args, **kwargs): return np.array(*args, dtype=np.bool, **kwargs) -class ComputationPrinting(unittest.TestCase): +class ComputationPrinting(absltest.TestCase): def ExampleComputation(self): builder = xla_client.ComputationBuilder("acomputation") @@ -311,7 +311,7 @@ class ComputationsWithConstantsTest(ComputationTest): def testCustomCall(self): c = self._NewComputation() for name, fn in custom_call_for_test.cpu_custom_call_targets.items(): - xla_client.register_cpu_custom_call_target(name, fn) + xla_client.register_custom_call_target(name, fn, platform="cpu") c.CustomCall( b"test_subtract_f32", operands=(c.ConstantF32Scalar(1.25), c.ConstantF32Scalar(0.5)), @@ -448,14 +448,14 @@ class BufferTest(ComputationTest): local_buffer = xla_client.Buffer.from_pyval(t) pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) - self.assertEqual(len(pieces), 0) + self.assertEmpty(pieces) def testDestructureTupleOneArrayElement(self): t = (np.array([1, 2, 3, 4], dtype=np.int32),) local_buffer = xla_client.Buffer.from_pyval(t) pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) - self.assertEqual(len(pieces), 1) + self.assertLen(pieces, 1) array = pieces[0] got = array.to_py() want = NumpyArrayS32([1, 2, 3, 4]) @@ -472,7 +472,7 @@ class BufferTest(ComputationTest): for _ in range(2): pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) - self.assertEqual(len(pieces), 2) + self.assertLen(pieces, 2) array0, array1 = pieces got = array0.to_py() want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) @@ -486,14 +486,14 @@ class BufferTest(ComputationTest): local_buffer = xla_client.Buffer.from_pyval(t) pieces = local_buffer.destructure() self.assertFalse(local_buffer.is_deleted()) - self.assertEqual(len(pieces), 2) + self.assertLen(pieces, 2) tuple0, array1 = pieces got = array1.to_py() want = NumpyArrayS32([5]) np.testing.assert_equal(want, got) got = tuple0.to_py() self.assertEqual(type(got), tuple) - self.assertEqual(len(got), 2) + self.assertLen(got, 2) np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) @@ -506,7 +506,7 @@ class BufferTest(ComputationTest): b1 = xla_client.Buffer.from_pyval(t[1]) btup = xla_client.Buffer.make_tuple([b0, b1], device=0) pieces = btup.destructure() - self.assertEqual(len(pieces), 2) + self.assertLen(pieces, 2) array0, array1 = pieces np.testing.assert_equal( np.array([1, 2, 3, 4], dtype=np.float32), array0.to_py()) @@ -699,7 +699,7 @@ class SingleOpTest(ComputationTest): rhs = NumpyArrayF32(rng.randn(10, 4, 5)) dimension_numbers = (([2], [1]), ([0], [0])) c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) def testDotGeneralWithDotDimensionNumbersProto(self): c = self._NewComputation() @@ -714,7 +714,7 @@ class SingleOpTest(ComputationTest): dimension_numbers.rhs_batch_dimensions.append(0) c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) def testDotGeneralWithPrecisionConfig(self): c = self._NewComputation() @@ -730,7 +730,7 @@ class SingleOpTest(ComputationTest): c.Constant(rhs), dimension_numbers, precision_config=config) - self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs), rtol=1e-6) def testConvF32Same(self): c = self._NewComputation() @@ -1222,7 +1222,7 @@ class SingleOpTest(ComputationTest): result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape and uniqueness self.assertEqual(result.shape, shape) - self.assertEqual(len(np.unique(result)), np.prod(shape)) + self.assertLen(np.unique(result), np.prod(shape)) def testRngUniformF32(self): lo, hi = 2., 4. @@ -1235,7 +1235,7 @@ class SingleOpTest(ComputationTest): result = xla_client.execute_with_python_values(c.Build().Compile()) # since the result is random, we just check shape, uniqueness, and range self.assertEqual(result.shape, shape) - self.assertEqual(len(np.unique(result)), np.prod(shape)) + self.assertLen(np.unique(result), np.prod(shape)) self.assertTrue(np.all(lo <= result)) self.assertTrue(np.all(result < hi)) @@ -1272,12 +1272,32 @@ class SingleOpTest(ComputationTest): keys = np.array([[2, 4, 1, 3], [3, 1, 4, 2]], dtype=np.float32) values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) c = self._NewComputation() - c.SortKeyVal(c.Constant(keys), c.Constant(values), dimension=0) + c.Sort((c.Constant(keys), c.Constant(values)), dimension=0) result = xla_client.execute_with_python_values(c.Build().Compile()) self.assertIsInstance(result, tuple) np.testing.assert_allclose(result[0], [[2, 1, 1, 2], [3, 4, 4, 3]]) np.testing.assert_equal(result[1], [[0, 5, 2, 7], [4, 1, 6, 3]]) + def testSortCustomComparator(self): + b = self._NewComputation("comparator") + p0 = b.ParameterFromNumpy(NumpyArrayF32(0)) + q0 = b.ParameterFromNumpy(NumpyArrayF32(0)) + p1 = b.ParameterFromNumpy(NumpyArrayS32(0)) + q1 = b.ParameterFromNumpy(NumpyArrayS32(0)) + b.Or(b.Lt(p0, q0), b.And(b.Eq(p0, q0), b.Gt(p1, q1))) + comparator = b.Build() + + keys = np.array([[2, 3, 1, 3], [3, 1, 2, 2]], dtype=np.float32) + values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) + c = self._NewComputation() + c.Sort((c.Constant(keys), c.Constant(values)), + dimension=1, + comparator=comparator) + result = xla_client.execute_with_python_values(c.Build().Compile()) + self.assertIsInstance(result, tuple) + np.testing.assert_allclose(result[0], [[1, 2, 3, 3], [1, 2, 2, 3]]) + np.testing.assert_equal(result[1], [[2, 0, 3, 1], [5, 7, 6, 4]]) + def testQR(self): a = np.array( [[4, 6, 8, 10], [6, 45, 54, 63], [8, 54, 146, 166], [10, 63, 166, 310]], @@ -1923,4 +1943,4 @@ class ComputationRootTest(ComputationTest): if __name__ == "__main__": - unittest.main() + absltest.main() diff --git a/tensorflow/compiler/xla/python/xrt.py b/tensorflow/compiler/xla/python/xrt.py index 40dea45e442..7ab2afa19d4 100644 --- a/tensorflow/compiler/xla/python/xrt.py +++ b/tensorflow/compiler/xla/python/xrt.py @@ -61,6 +61,15 @@ class XrtBackend(xla_client.Backend): def device_count(self): return self.context.DeviceCount() + def local_device_count(self): + raise NotImplementedError() + + def devices(self): + raise NotImplementedError() + + def host_id(self): + raise NotImplementedError() + def buffer_from_pyval(self, pyval, device=0): return _xla.xrt.XrtBuffer.from_literal(self.context, device, pyval) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD old mode 100644 new mode 100755 index c4af8863c05..c14048a18d6 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4,10 +4,18 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library_py", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core/platform:default/cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) package( default_visibility = [":friends"], @@ -290,6 +298,64 @@ cc_library( ], ) +cc_library( + name = "hlo_live_range", + srcs = [ + "hlo_live_range.cc", + ], + hdrs = [ + "hlo_live_range.h", + ], + deps = [ + ":hlo", + ":hlo_alias_analysis", + ":hlo_buffer", + ":hlo_dataflow_analysis", + ":hlo_ordering", + ":logical_buffer", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +tf_cc_test( + name = "hlo_live_range_test", + srcs = ["hlo_live_range_test.cc"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_alias_analysis", + ":hlo_live_range", + ":hlo_memory_scheduler", + ":hlo_ordering", + ":hlo_parser", + ":hlo_value", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", + ], +) + tf_cc_test( name = "hlo_evaluator_test", srcs = ["hlo_evaluator_test.cc"], @@ -565,8 +631,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla/service/gpu:backend_configs", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -860,10 +928,25 @@ cc_library( name = "gpu_plugin", deps = [ ":service", + "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", - "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core:stream_executor_no_cuda", + ] + if_cuda_is_configured([ + "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + ]) + if_rocm_is_configured([ + "//tensorflow/compiler/xla/service/gpu:amdgpu_compiler", + "//tensorflow/core/platform/default/build_config:stream_executor_rocm", + ]), +) + +cc_library( + name = "mlir_gpu_plugin", + deps = [ + ":service", + "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", + "//tensorflow/compiler/xla/service/mlir_gpu:mlir_compiler", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -950,6 +1033,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "//tensorflow/stream_executor:device_description", "//tensorflow/stream_executor:device_memory_allocator", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings:str_format", @@ -1111,8 +1195,10 @@ cc_library( ":hlo_alias_analysis", ":hlo_buffer", ":hlo_dataflow_analysis", + ":hlo_live_range", ":hlo_proto", ":logical_buffer", + ":memory_space_assignment", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1213,6 +1299,7 @@ cc_library( ":hlo_alias_analysis", ":hlo_buffer", ":hlo_dataflow_analysis", + ":hlo_live_range", ":hlo_ordering", ":hlo_proto", ":tuple_points_to_analysis", @@ -1424,6 +1511,7 @@ cc_library( hdrs = ["fusion_queue.h"], deps = [ ":hlo", + "@com_google_absl//absl/strings", ], ) @@ -1679,6 +1767,7 @@ cc_library( ":hlo", ":hlo_casting_utils", ":hlo_creation_utils", + ":hlo_evaluator", ":hlo_pass", ":hlo_query", ":pattern_matcher", @@ -1692,6 +1781,39 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "tree_reduction_rewriter", + srcs = ["tree_reduction_rewriter.cc"], + hdrs = ["tree_reduction_rewriter.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_creation_utils", + ":hlo_evaluator", + ":hlo_pass", + ":shape_inference", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/core:lib", + "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1891,6 +2013,41 @@ tf_cc_test( ], ) +cc_library( + name = "depthwise_convolution_converter", + srcs = ["depthwise_convolution_converter.cc"], + hdrs = ["depthwise_convolution_converter.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "depthwise_convolution_converter_test", + size = "small", + srcs = ["depthwise_convolution_converter_test.cc"], + deps = [ + ":depthwise_convolution_converter", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + cc_library( name = "while_loop_analysis", srcs = ["while_loop_analysis.cc"], @@ -2096,13 +2253,14 @@ cc_library( hdrs = ["dynamic_dimension_inference.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:window_util", - "//tensorflow/core:lib", + "//tensorflow/core/platform:macros", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], @@ -2782,6 +2940,30 @@ tf_cc_test( ], ) +cc_library( + name = "memory_space_assignment", + srcs = ["memory_space_assignment.cc"], + hdrs = ["memory_space_assignment.h"], + deps = [ + ":heap_simulator", + ":hlo_pass", + ], +) + +tf_cc_test( + name = "memory_space_assignment_test", + srcs = ["memory_space_assignment_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":memory_space_assignment", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_dce", srcs = ["hlo_dce.cc"], @@ -4221,3 +4403,18 @@ cc_library( "//tensorflow/compiler/xla/client/lib:prng", ], ) + +cc_library( + name = "slow_operation_alarm", + srcs = ["slow_operation_alarm.cc"], + hdrs = ["slow_operation_alarm.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index eef570e2540..077b76c4c64 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -60,6 +61,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/lib/statusor.h" namespace xla { @@ -170,6 +172,10 @@ bool IsUnstridedSlice(const HloInstruction* hlo) { // more general case a worklist based approach would be needed. class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { public: + explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier) + : options_(options), simplifier_(simplifier) {} + Status HandleAdd(HloInstruction* add) override; Status HandleAnd(HloInstruction* logical_and) override; @@ -204,10 +210,18 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleDot(HloInstruction* dot) override; + Status HandleGather(HloInstruction* gather) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleLog(HloInstruction* log) override; + Status HandleMaximum(HloInstruction* maximum) override; + + Status HandleMinimum(HloInstruction* minimum) override; + + Status HandleClamp(HloInstruction* clamp) override; + Status HandleMultiply(HloInstruction* multiply) override; Status HandleNegate(HloInstruction* negate) override; @@ -224,7 +238,7 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleReshape(HloInstruction* reshape) override; - Status HandleReduce(HloInstruction* reduce) override; + Status HandleReduce(HloInstruction* hlo) override; Status HandleReduceWindow(HloInstruction* reduce_window) override; @@ -246,16 +260,11 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { Status HandleMap(HloInstruction* map) override; // Runs the visitor on a computation. - static bool Run(HloComputation* computation, - const AlgebraicSimplifierOptions& options, - AlgebraicSimplifier* simplifier); + bool Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier); private: - explicit AlgebraicSimplifierVisitor(HloComputation* computation, - const AlgebraicSimplifierOptions& options, - AlgebraicSimplifier* simplifier) - : computation_(computation), options_(options), simplifier_(simplifier) {} - // Removes degenerate dimension from dot. StatusOr RemoveDegenerateDimensionFromDot(HloInstruction* dot); @@ -385,6 +394,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Useful when we want to use the same visitor over multiple computations. + void ResetState(HloComputation* computation); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -403,12 +415,18 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { } // namespace +void AlgebraicSimplifierVisitor::ResetState(HloComputation* computation) { + changed_ = false; + ResetVisitStates(); + computation_ = computation; +} + bool AlgebraicSimplifierVisitor::Run(HloComputation* computation, const AlgebraicSimplifierOptions& options, AlgebraicSimplifier* simplifier) { - AlgebraicSimplifierVisitor visitor(computation, options, simplifier); - TF_CHECK_OK(computation->Accept(&visitor)); - return visitor.changed_ || visitor.changed(); + ResetState(computation); + TF_CHECK_OK(computation->Accept(this)); + return changed_ || changed(); } bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, @@ -431,8 +449,8 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()), ShapeUtil::ByteSizeOf(operand->shape())); - auto bitcast = computation_->AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kBitcast, operand)); + auto bitcast = computation_->AddInstruction( + HloInstruction::CreateBitcast(instruction->shape(), operand)); TF_CHECK_OK(ReplaceInstruction(instruction, bitcast)); } @@ -573,8 +591,7 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { HloInstruction* op; if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { return ReplaceWithNewInstruction( - bitcast, - HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op)); + bitcast, HloInstruction::CreateBitcast(bitcast->shape(), op)); } // All bitcasts can be eliminated (assuming layout constraints are // satisified). @@ -1875,6 +1892,175 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { + const Shape& operand_shape = gather->operand(0)->shape(); + // If the operand of a gather is very small, it is easier to fuse a + // sequence of selects. + if (operand_shape.rank() == 1 && + operand_shape.dimensions(0) <= options_.very_small_gather_size() && + gather->gather_dimension_numbers().index_vector_dim() == + gather->operand(1)->shape().rank() && + gather->gather_dimension_numbers().collapsed_slice_dims_size() == 1) { + const Shape& index_shape = gather->operand(1)->shape(); + const int64 operand_elements = operand_shape.dimensions(0); + auto get_value = [&](int64 i) { + auto slice = computation_->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(operand_shape.element_type(), {1}), + gather->mutable_operand(0), {i}, {i + 1}, {1})); + auto scalar = computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand_shape.element_type(), {}), slice)); + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(gather->shape(), scalar, {})); + }; + auto result = get_value(0); + auto one = computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::One(index_shape.element_type()))); + auto index = one; + auto pred_shape = ShapeUtil::ChangeElementType(gather->shape(), PRED); + auto iter_shape = ShapeUtil::ChangeElementType(gather->shape(), + index_shape.element_type()); + for (int64 i = 1; i < operand_elements; ++i) { + auto broadcasted_index = computation_->AddInstruction( + HloInstruction::CreateBroadcast(iter_shape, index, {})); + auto index_mask = + computation_->AddInstruction(HloInstruction::CreateCompare( + pred_shape, gather->mutable_operand(1), broadcasted_index, + ComparisonDirection::kGe)); + result = computation_->AddInstruction( + HloInstruction::CreateTernary(gather->shape(), HloOpcode::kSelect, + index_mask, get_value(i), result)); + index = computation_->AddInstruction(HloInstruction::CreateBinary( + index->shape(), HloOpcode::kAdd, index, one)); + } + return ReplaceInstruction(gather, result); + } + return Status::OK(); +} + +namespace { +StatusOr> MinMaxToClamp( + HloInstruction* clamp_lower_bound_bcast, HloInstruction* to_clamp, + HloInstruction* clamp_upper_bound_bcast) { + HloInstruction* clamp_lower_bound; + CHECK(Match(clamp_lower_bound_bcast, + m::Broadcast(m::ConstantEffectiveScalar(&clamp_lower_bound)))) + << clamp_lower_bound_bcast->ToString(); + + HloInstruction* clamp_upper_bound; + CHECK(Match(clamp_upper_bound_bcast, + m::Broadcast(m::ConstantEffectiveScalar(&clamp_upper_bound)))) + << clamp_upper_bound_bcast->ToString(); + + const Literal& lower_bound = + Cast(clamp_lower_bound)->literal(); + const Literal& upper_bound = + Cast(clamp_upper_bound)->literal(); + + std::unique_ptr lower_bound_instr = + HloInstruction::CreateConstant(lower_bound.Clone()); + std::unique_ptr upper_bound_instr = + HloInstruction::CreateConstant(upper_bound.Clone()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(lower_bound_instr->shape(), PRED), + lower_bound_instr.get(), upper_bound_instr.get(), + ComparisonDirection::kLt); + + HloEvaluator evaluator; + TF_ASSIGN_OR_RETURN(auto result, + evaluator.Evaluate(cloned_instruction.get())); + if (result.IsAll(true)) { + return HloInstruction::CreateTernary(to_clamp->shape(), HloOpcode::kClamp, + clamp_lower_bound_bcast, to_clamp, + clamp_upper_bound_bcast); + } + return std::unique_ptr(); +} +} // namespace + +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { + HloInstruction *lhs, *rhs; + CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs)))); + + HloInstruction* clamp_upper_bound_bcast; + HloInstruction* clamp_lower_bound_bcast; + HloInstruction* to_clamp; + if (Match(maximum, m::MaximumAnyOrder( + m::Broadcast(&clamp_lower_bound_bcast, + m::ConstantEffectiveScalar()), + m::MinimumAnyOrder( + m::Op(&to_clamp), + m::Broadcast(&clamp_upper_bound_bcast, + m::ConstantEffectiveScalar()))))) { + TF_ASSIGN_OR_RETURN(auto clamp, + MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, + clamp_upper_bound_bcast)); + if (clamp) { + return ReplaceWithNewInstruction(maximum, std::move(clamp)); + } + } + + HloInstruction* clamp_lower_bound; + HloInstruction* clamp_upper_bound; + HloInstruction* max_operand; + HloInstruction* clamp; + if (Match(maximum, + m::MaximumAnyOrder( + m::Op(&max_operand), + m::Clamp(&clamp, m::Op(&clamp_lower_bound), m::Op(&to_clamp), + m::Op(&clamp_upper_bound))))) { + if (max_operand == clamp_lower_bound && + ReplaceInstructionIfSameShape(maximum, clamp)) { + return Status::OK(); + } + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { + HloInstruction *lhs, *rhs; + CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs)))); + + HloInstruction* clamp_upper_bound_bcast; + HloInstruction* clamp_lower_bound_bcast; + HloInstruction* to_clamp; + if (Match(minimum, m::MinimumAnyOrder( + m::Broadcast(&clamp_upper_bound_bcast, + m::ConstantEffectiveScalar()), + m::MaximumAnyOrder( + m::Op(&to_clamp), + m::Broadcast(&clamp_lower_bound_bcast, + m::ConstantEffectiveScalar()))))) { + TF_ASSIGN_OR_RETURN(auto clamp, + MinMaxToClamp(clamp_lower_bound_bcast, to_clamp, + clamp_upper_bound_bcast)); + if (clamp) { + return ReplaceWithNewInstruction(minimum, std::move(clamp)); + } + } + + return Status::OK(); +} + +Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) { + HloInstruction* clamp_lower_bound; + HloInstruction* clamp_upper_bound; + HloInstruction* to_clamp; + CHECK(Match(clamp, m::Clamp(m::Op(&clamp_lower_bound), m::Op(&to_clamp), + m::Op(&clamp_upper_bound)))); + + // clamp(a, clamp(a, x, b), b) -> clamp(a, x, b) + if (Match(to_clamp, m::Clamp(m::Op().Is(clamp_lower_bound), m::Op(), + m::Op().Is(clamp_upper_bound))) && + ReplaceInstructionIfSameShape(clamp, to_clamp)) { + return Status::OK(); + } + + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { HloInstruction *lhs, *rhs; CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); @@ -2385,9 +2571,11 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { TF_ASSIGN_OR_RETURN( HloInstruction * slice, MakeSliceHlo(nonzero_pad, start_indices, end_indices, strides)); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + pad->shape(), slice->mutable_shape())); // Verify that the slice shape matches the pad shape. - TF_RET_CHECK(ShapeUtil::Compatible(slice->shape(), pad->shape())); + TF_RET_CHECK(ShapeUtil::Equal(slice->shape(), pad->shape())); return ReplaceInstruction(pad, slice); } @@ -2699,9 +2887,9 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { // this. But that's OK for our purposes here.) int64 iota_upper_bound = iota->shape().dimensions( Cast(iota)->iota_dimension()); - StatusOr divisor_val = divisor->literal().GetIntegralAsS64( + absl::optional divisor_val = divisor->literal().GetIntegralAsS64( std::vector(0, divisor->shape().dimensions_size())); - if (divisor_val.ok() && divisor_val.ValueOrDie() >= iota_upper_bound) { + if (divisor_val && *divisor_val >= iota_upper_bound) { return ReplaceInstruction(remainder, iota); } } @@ -2727,12 +2915,12 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { // smaller. int64 iota_upper_bound = iota->shape().dimensions( Cast(iota)->iota_dimension()); - StatusOr divisor_val = divisor->literal().GetIntegralAsS64( + absl::optional divisor_val = divisor->literal().GetIntegralAsS64( std::vector(0, divisor->shape().dimensions_size())); - if (divisor_val.ok()) { + if (divisor_val) { // Check whether divisor_val + iota_upper_bound - 1 overflows. absl::optional max_val = - OverflowSafeAdd(divisor_val.ValueOrDie(), iota_upper_bound); + OverflowSafeAdd(*divisor_val, iota_upper_bound); if (max_val.has_value() && FitsInIntegralType(*max_val, iota->shape().element_type())) { return ReplaceWithNewInstruction( @@ -3026,7 +3214,11 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { return Status::OK(); } - TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + // Do not try to reorder slices and reshapes after layout assignment as it may + // be invalid. + if (!options_.is_layout_sensitive()) { + TF_ASSIGN_OR_RETURN(replaced, TryToReorderSliceAndReshape(slice)); + } if (replaced) { return Status::OK(); } @@ -3807,7 +3999,7 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( std::vector dims(operand->shape().dimensions_size()); std::iota(dims.begin(), dims.end(), 0); return computation_->AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand)); + HloInstruction::CreateBitcast(shape, operand)); }; // Replace it with a dot, with bitcasts around it to get the right shape. @@ -3946,8 +4138,9 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; + AlgebraicSimplifierVisitor visitor(options_, this); for (auto* comp : module->MakeNonfusionComputations()) { - if (AlgebraicSimplifierVisitor::Run(comp, options_, this)) { + if (visitor.Run(comp, options_, this)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 37ea35ade0d..74d8b1d4582 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -92,6 +92,13 @@ class AlgebraicSimplifierOptions { return enable_window_reduce_to_reduce_replacement_; } + // Sets the size of a gather operand that can be unrolled into many selects. + void set_very_small_gather_size(int64 size) { + very_small_gather_size_ = size; + } + + int64 very_small_gather_size() const { return very_small_gather_size_; } + private: ReshapeIsBitcastCallback reshape_is_bitcast_callback_; bool is_layout_sensitive_{false}; @@ -99,6 +106,7 @@ class AlgebraicSimplifierOptions { bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; bool enable_window_reduce_to_reduce_replacement_{true}; + int64 very_small_gather_size_{4}; }; // A pass which performs algebraic simplifications. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 4c5e5ef9e7e..230a5a1c058 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5543,5 +5543,103 @@ TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) { GmockMatch(m::Remainder(m::Parameter(), m::Parameter()))); } +TEST_F(AlgebraicSimplifierTest, SlicePadLayout) { + const char* kModuleStr = R"( + HloModule m + test { + %param.0 = f32[128,9,9,1024]{0,3,2,1} parameter(0) + %param.1 = f32[] parameter(1) + %slice = f32[128,9,9,1024]{0,3,2,1} slice(%param.0), + slice={[0:128], [0:9], [0:9], [0:1024]} + ROOT %pad = f32[128,8,9,1024]{0,3,2,1} pad(%slice, %param.1), + padding=0_0x-1_0x0_0x0_0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + const Shape root_shape = m->entry_computation()->root_instruction()->shape(); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + ASSERT_TRUE(AlgebraicSimplifier(options).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Slice().WithShapeEqualTo(&root_shape))); +} + +TEST_F(AlgebraicSimplifierTest, MinOfMaxToClamp) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(3.0) + c1 = f32[] constant(4.0) + b0 = f32[4] broadcast(c0), dimensions={} + b1 = f32[4] broadcast(c1), dimensions={} + m0 = f32[4] maximum(b0, p0) + ROOT m1 = f32[4] minimum(m0, b1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0), + m::Broadcast(m::ConstantScalar(4.0))))); +} + +TEST_F(AlgebraicSimplifierTest, MaxOfMinToClamp) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(3.0) + c1 = f32[] constant(4.0) + b0 = f32[4] broadcast(c0), dimensions={} + b1 = f32[4] broadcast(c1), dimensions={} + m0 = f32[4] minimum(p0, b1) + ROOT m1 = f32[4] maximum(b0, m0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Clamp(m::Broadcast(m::ConstantScalar(3.0)), m::Parameter(0), + m::Broadcast(m::ConstantScalar(4.0))))); +} + +TEST_F(AlgebraicSimplifierTest, ClampOfClamp) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + c0 = f32[] clamp(p0, p1, p2) + ROOT c1 = f32[] clamp(p0, c0, p2) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2)))); +} + +TEST_F(AlgebraicSimplifierTest, MaxOfClamp) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + c0 = f32[] clamp(p0, p1, p2) + ROOT m0 = f32[] maximum(p0, c0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Clamp(m::Parameter(0), m::Parameter(1), m::Parameter(2)))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 147f3ae7b6d..9c19308bff3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -29,9 +29,9 @@ namespace xla { class BatchNormExpander : public HloModulePass { public: // When use_fusion is set, a multi-output fusion node is created. - BatchNormExpander(bool rewrite_training_op = false, - bool rewrite_inference_op = false, - bool rewrite_grad_op = false) + explicit BatchNormExpander(bool rewrite_training_op = false, + bool rewrite_inference_op = false, + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op) {} diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 4d465640f2d..6331f02aa81 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -308,6 +308,28 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return true; } +namespace { + +// Returns whether we should avoid changing the precision of inst regardless of +// the producers and users. +bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == HloInstruction::FusionKind::kCustom) { + return ShouldKeepPrecisionUnchanged( + inst->fused_instructions_computation()->root_instruction()); + } + // Do not change precision for side-effecting instructions, control flow, and + // bitcast-convert, because this pass might break the interfaces or + // assumptions for them. + return inst->opcode() == HloOpcode::kCustomCall || // + inst->opcode() == HloOpcode::kCall || // + inst->opcode() == HloOpcode::kConditional || // + inst->opcode() == HloOpcode::kBitcastConvert || // + inst->HasSideEffectNoRecurse(); +} + +} // namespace + void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters) { // We handle any fusion computation or while body/condition after the @@ -354,15 +376,7 @@ void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, return; } - // Do not change precision for instructions related to entry and exit of a - // computation, side-effecting instructions, control flow, and - // bitcast-convert, because this pass might break the interfaces or - // assumptions for them. - if (hlo->opcode() == HloOpcode::kCustomCall || // - hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kConditional || // - hlo->opcode() == HloOpcode::kBitcastConvert || // - hlo->HasSideEffectNoRecurse() || // + if (ShouldKeepPrecisionUnchanged(hlo) || (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { return; } @@ -797,6 +811,39 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { // Apply the changes in changes_to_bf16_. for (auto& change : changes_to_bf16_) { + auto inst = change.first; + // It is possible that we marked inst to change precision even if it is an + // unsupported change, when inst is the root of a fusion computation and it + // has to match the fusion node's output precision. We do a convert instead + // of in-place change for such cases. + if (ShouldKeepPrecisionUnchanged(inst)) { + auto users = inst->users(); + bool is_root = inst == inst->parent()->root_instruction(); + TF_ASSIGN_OR_RETURN( + HloInstruction * copy, + inst->parent()->DeepCopyInstructionWithCustomCopier( + inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, + HloComputation* comp) { + if (!ContainsKey(change.second, + ShapeUtil::GetMutableSubshape( + inst->mutable_shape(), leaf_index))) { + return leaf; + } + auto converted_shape = + ShapeUtil::ChangeElementType(leaf->shape(), BF16); + UpdateLayout(&converted_shape); + return comp->AddInstruction( + HloInstruction::CreateConvert(converted_shape, leaf)); + })); + for (auto user : users) { + TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy)); + } + if (is_root) { + inst->parent()->set_root_instruction(copy, + /*accept_different_shape=*/true); + } + continue; + } for (const auto& entry : change.second) { auto subshape = entry.first; CHECK_EQ(subshape->element_type(), F32); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 86eb8cb240c..d716e62d467 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -422,6 +422,35 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { EXPECT_TRUE(OutputsBF16(b_f1)); } +// Tests that a fusion with a bitcast-convert as its root is changed via adding +// extra convert, instead of changing the type in-place. +TEST_F(BFloat16PropagationTest, FusionWithBitcastConvertRoot) { + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + Shape u32_shape = ShapeUtil::MakeShape(U32, {4, 4}); + Shape f32_shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, u32_shape, "param")); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = builder_f.AddInstruction( + HloInstruction::CreateParameter(0, u32_shape, "a")); + HloInstruction* bc_f = builder_f.AddInstruction( + HloInstruction::CreateBitcastConvert(f32_shape, a_f)); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + f32_shape, HloInstruction::FusionKind::kLoop, {param}, comp_f)); + auto dot = builder.AddInstruction(CreateDot(f32_shape, fusion, fusion)); + + auto computation = module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_EQ(bc_f->shape(), f32_shape); + EXPECT_TRUE(OutputsBF16(bc_f)); +} + // Tests that changes to BF16 that cannot be propagated outside a fusion are // discarded. TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 3ae7235d887..d72a91f45df 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_live_range.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -233,8 +234,8 @@ BufferAllocation::Slice BufferAllocation::GetSlice( void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset, int64 size) { - VLOG(4) << "Adding the following buffer to allocation #" << index() << ": " - << buffer; + VLOG(4) << "Adding the following buffer to allocation #" << index() << " [" + << offset << ", " << size << "]: " << buffer; CHECK(!assigned_buffers_.contains(&buffer)) << "LogicalBuffer " << buffer << " already assigned to allocation " << index_; @@ -250,6 +251,13 @@ void BufferAllocation::AddAssignment(const HloValue& buffer, int64 offset, offset_size.offset = offset; offset_size.size = size; assigned_buffers_.emplace(&buffer, offset_size); + // For debugging purposes, store the assigned memory space in the + // instruction's layout. + HloInstruction* defining_instruction = buffer.defining_instruction(); + if (defining_instruction->shape().has_layout()) { + defining_instruction->mutable_shape()->mutable_layout()->set_memory_space( + buffer.color().value()); + } } BufferAllocationProto BufferAllocation::ToProto() const { @@ -758,14 +766,69 @@ StatusOr> BufferAssigner::Run( LogicalBuffer::AlignmentFunction color_alignment, bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer, const absl::flat_hash_set& reuse_checker, - HloDataflowAnalysis::CanShareBuffer can_share_buffer) { + HloDataflowAnalysis::CanShareBuffer can_share_buffer, + std::unique_ptr preset_assignments) { BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer), - reuse_checker); + reuse_checker, std::move(preset_assignments)); return assigner.CreateAssignment( module, std::move(hlo_ordering), std::move(buffer_size), std::move(color_alignment), std::move(can_share_buffer)); } +bool BufferAssigner::LiveRangeInterferes(const HloValue* buffer1, + const HloValue* buffer2, + BufferAssignment* assignment) { + CHECK((assignment->hlo_live_range().total_order_scheduled())); + const HloLiveRange& hlo_live_range = assignment->hlo_live_range(); + + const auto& buffer_live_ranges = hlo_live_range.buffer_live_ranges(); + + CHECK(buffer_live_ranges.contains(buffer1)) + << "Buffer doesn't have a proper live range:" << buffer1; + + CHECK(buffer_live_ranges.contains(buffer2)) + << "Buffer doesn't have a proper live range:" << buffer2; + + // Check if a user value can share the same buffer as its operand. + auto can_share_as_operand = [&assignment](const HloValue* user_value, + const HloValue* operand_value) { + return user_value->instruction()->IsUserOf(operand_value->instruction()) && + assignment->dataflow_analysis().CanShareOperandBufferWithUser( + operand_value->instruction(), operand_value->index(), + user_value->instruction(), user_value->index()) && + user_value->instruction()->opcode() != HloOpcode::kCopy; + }; + + auto live_range_1 = buffer_live_ranges.at(buffer1); + auto live_range_2 = buffer_live_ranges.at(buffer2); + + if (!(live_range_1.start > live_range_2.end || + live_range_2.start > live_range_1.end)) { + if (live_range_1.end == live_range_2.start) { + auto operand_value = buffer1; + auto user_value = buffer2; + if (!can_share_as_operand(user_value, operand_value)) { + return true; + } + } else if (live_range_2.end == live_range_1.start) { + auto operand_value = buffer2; + auto user_value = buffer1; + if (!can_share_as_operand(user_value, operand_value)) { + return true; + } + } else { + VLOG(4) << "Can't assign: assignee " << *buffer1 << " may interfere with " + << *buffer2; + VLOG(4) << "assigned_buffer.start: " << live_range_1.start; + VLOG(4) << "assigned_buffer.end: " << live_range_1.end; + VLOG(4) << "live_range_2.start" << live_range_2.start; + VLOG(4) << "live_range_2.end" << live_range_2.end; + return true; + } + } + return false; +} + bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, const HloBuffer& hlo_buffer, BufferAssignment* assignment) { @@ -777,7 +840,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, << " to allocation: " << *allocation; if (hlo_buffer.color() != allocation->color()) { - VLOG(4) << "Can't assign: buffer has color" << hlo_buffer.color() + VLOG(4) << "Can't assign: buffer has color " << hlo_buffer.color() << " and allocation has color " << allocation->color() << "."; return false; } @@ -833,10 +896,17 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, const HloValue& assigned_buffer = *CHECK_NOTNULL(dynamic_cast(buffer_offset_size.first)); for (const HloValue* new_value : hlo_buffer.values()) { - if (assignment->hlo_ordering().MayInterfere( - assigned_buffer, *new_value, assignment->dataflow_analysis())) { + if (assignment->hlo_live_range().total_order_scheduled()) { + if (LiveRangeInterferes(new_value, &assigned_buffer, assignment)) { + return false; + } + } else if (assignment->hlo_ordering().MayInterfere( + assigned_buffer, *new_value, + assignment->dataflow_analysis())) { + // Fallback to partial order based interference detection (slower) when + // we don't have a total order scheduled module. VLOG(4) << "Can't assign: assignee " << assigned_buffer - << " may interfere with " << new_value; + << " may interfere with " << new_value->ToShortString(); return false; } @@ -847,7 +917,8 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, assigned_buffer_position.instruction) && new_value->instruction()->opcode() == HloOpcode::kCopy) { VLOG(4) << "Can't assign: assignee " << assigned_buffer - << " is used at copy instruction " << new_value; + << " is used at copy instruction " + << new_value->ToShortString(); return false; } } @@ -1094,8 +1165,20 @@ Status BufferAssigner::AssignBuffersForComputations( } std::vector sorted_buffers; + // First assign the preset allocations. + absl::flat_hash_set preset_assigned_buffers; + + TF_RETURN_IF_ERROR(AssignPresetBuffers(&preset_assigned_buffers, assignment)); + const HloAliasAnalysis& alias_analysis = assignment->alias_analysis(); + for (const HloBuffer& buffer : alias_analysis.buffers()) { + // Skip if the buffer is already assigned since it had a preset allocation. + if (preset_assigned_buffers.find(&buffer) != + preset_assigned_buffers.end()) { + VLOG(3) << "Skip allocation for buffer: " << buffer; + continue; + } TF_RET_CHECK(!buffer.values().empty()); const HloComputation* comp = buffer.values()[0]->instruction()->parent(); if (absl::c_linear_search(computations, comp)) { @@ -1124,9 +1207,12 @@ Status BufferAssigner::AssignBuffersForComputations( } } + HloSchedule schedule(&assignment->module()); + for (const HloComputation* computation : computations) { - const bool has_sequential_order = - assignment->hlo_ordering().SequentialOrder(*computation) != nullptr; + const HloInstructionSequence* instruction_sequence = + assignment->hlo_ordering().SequentialOrder(*computation); + const bool has_sequential_order = instruction_sequence != nullptr; if (has_sequential_order && buffers_to_assign_sequentially != nullptr) { // Every sequential computation must get an entry in the // buffers_to_assign_sequentially map, even if we end up with an empty @@ -1134,6 +1220,8 @@ Status BufferAssigner::AssignBuffersForComputations( // run whole-module heap simulation. buffers_to_assign_sequentially->emplace(computation, flat_hash_set()); + + schedule.set_sequence(computation, *instruction_sequence); } } @@ -1188,6 +1276,54 @@ BufferAssigner::SplitBuffersByColor( return color_map; } +Status BufferAssigner::AssignPresetBuffers( + absl::flat_hash_set* assigned_buffers, + BufferAssignment* assignment) { + if (!preset_assignments_) { + return Status::OK(); + } + + // Create an allocation for each preset color. + absl::flat_hash_map + preset_allocations; + for (auto& color_and_size : preset_assignments_->sizes()) { + LogicalBuffer::Color color(color_and_size.first); + auto inserted = preset_allocations.emplace( + color, assignment->NewEmptyAllocation(color_and_size.second, color)); + BufferAllocation* inserted_allocation = inserted.first->second; + VLOG(3) << "Created preset buffer allocation " + << inserted_allocation->index() + << ", color: " << inserted_allocation->color() + << ", size: " << inserted_allocation->size(); + } + + const HloAliasAnalysis& alias_analysis = assignment->alias_analysis(); + + for (auto& position_and_chunk : preset_assignments_->chunks()) { + const HloPosition& position = position_and_chunk.first; + const HloBuffer& buffer = + alias_analysis.GetUniqueBufferAt(position.instruction, position.index); + VLOG(3) << "Preset allocation for buffer: " << buffer; + const HeapSimulator::Chunk& chunk = position_and_chunk.second; + auto preset_allocations_iter = preset_allocations.find(buffer.color()); + CHECK(preset_allocations_iter != preset_allocations.end()) + << "No preset buffer allocation for color " << buffer.color() + << " found."; + preset_allocations_iter->second->AddAssignment(buffer.GetUniqueValue(), + chunk.offset, chunk.size); + // Ensure that there is at most one preset allocation for each buffer. + CHECK_EQ(assigned_buffers->count(&buffer), 0); + assigned_buffers->emplace(&buffer); + } + + // Upon consumption of the preset assignments, delete it so that if this + // method is called again, it does not assign the same buffers multiple times. + preset_assignments_ = {}; + + return Status::OK(); +} + Status BufferAssigner::AssignBuffersWithSequentialOrdering( const flat_hash_map>& buffers_to_assign_sequentially, @@ -1393,6 +1529,21 @@ StatusOr> BufferAssigner::CreateAssignment( TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer)); + // Set up a schedule for each computation. + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + const HloInstructionSequence* instruction_sequence = + hlo_ordering->SequentialOrder(*computation); + const bool has_sequential_order = instruction_sequence != nullptr; + if (has_sequential_order) { + schedule.set_sequence(computation, *instruction_sequence); + } + } + + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(schedule, *alias_analysis, + module->entry_computation(), true)); + VLOG(1) << "Assigning buffers to module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); XLA_VLOG_LINES(3, alias_analysis->ToString()); @@ -1404,7 +1555,8 @@ StatusOr> BufferAssigner::CreateAssignment( // private. std::unique_ptr assignment(new BufferAssignment( module, std::move(hlo_ordering), std::move(buffer_size), - std::move(color_alignment), std::move(alias_analysis))); + std::move(color_alignment), std::move(alias_analysis), + std::move(hlo_live_range))); TF_RETURN_IF_ERROR( colorer_(&assignment->alias_analysis(), assignment->hlo_ordering())); @@ -1432,7 +1584,7 @@ StatusOr> BufferAssigner::CreateAssignment( // module, which reduces memory usage. const bool run_whole_module_heap_simulation = buffers_to_assign_sequentially.size() == global_computations.size(); - VLOG(2) << "Running whole module heap simulation" + VLOG(2) << "Running whole module heap simulation: " << run_whole_module_heap_simulation; TF_RETURN_IF_ERROR(AssignBuffersWithSequentialOrdering( buffers_to_assign_sequentially, run_whole_module_heap_simulation, diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index f60ad22fa51..9caf4bee0ad 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -31,8 +31,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_live_range.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -445,9 +447,11 @@ class BufferAssignment { HloAliasAnalysis& alias_analysis() const { return *alias_analysis_; } - // Returns the BufferLiveness object used to construct this assignment. const HloOrdering& hlo_ordering() const { return *hlo_ordering_; } + // Returns the HloLiveRange object used to construct this assignment. + const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } + string ToString() const; BufferAssignmentProto ToProto() const; @@ -480,12 +484,14 @@ class BufferAssignment { std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, LogicalBuffer::AlignmentFunction color_alignment, - std::unique_ptr alias_analysis) + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range) : module_(module), hlo_ordering_(std::move(hlo_ordering)), buffer_size_(std::move(buffer_size)), color_alignment_(std::move(color_alignment)), - alias_analysis_(std::move(alias_analysis)) {} + alias_analysis_(std::move(alias_analysis)), + hlo_live_range_(std::move(hlo_live_range)) {} // Creates and returns a new BufferAllocation, with no assigned // LogicalBuffers. Ownership is maintained internally. @@ -545,6 +551,8 @@ class BufferAssignment { std::unique_ptr alias_analysis_; + std::unique_ptr hlo_live_range_; + Stats stats_; TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment); @@ -558,7 +566,13 @@ class BufferAssigner { static Colorer DefaultColorer() { return [](HloAliasAnalysis* alias_analysis, const HloOrdering&) { for (HloValue* value : alias_analysis->dataflow_analysis().values()) { - value->set_color(BufferValue::Color(0)); + HloInstruction* defining_instruction = value->defining_instruction(); + if (defining_instruction->shape().has_layout()) { + value->set_color(BufferValue::Color( + defining_instruction->shape().layout().memory_space())); + } else { + value->set_color(BufferValue::Color(0)); + } } return Status::OK(); }; @@ -569,7 +583,9 @@ class BufferAssigner { // Build and return a BufferAssignment for the given module. The given // HloOrdering is used to determine buffer liveness. buffer_size and // color_alignment are functions which returns the size and alignment of a - // LogicalBuffer. + // LogicalBuffer. If preset_assignments is provided, those pre-set assignment + // offsets will be used. The caller guarantees that those assignments are + // valid and they do not overwrite each other. static StatusOr> Run( const HloModule* module, std::unique_ptr hlo_ordering, BufferValue::SizeFunction buffer_size, @@ -577,14 +593,17 @@ class BufferAssigner { bool allocate_buffers_for_constants = false, Colorer colorer = DefaultColorer(), const absl::flat_hash_set& must_not_live_out = {}, - HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr); + HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr, + std::unique_ptr preset_assignments = {}); private: BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer, - const absl::flat_hash_set& must_not_live_out) + const absl::flat_hash_set& must_not_live_out, + std::unique_ptr preset_assignments) : allocate_buffers_for_constants_(allocate_buffers_for_constants), colorer_(colorer), - must_not_live_out_(must_not_live_out) {} + must_not_live_out_(must_not_live_out), + preset_assignments_(std::move(preset_assignments)) {} virtual ~BufferAssigner() = default; // Create a buffer assignment. @@ -606,6 +625,16 @@ class BufferAssigner { buffers_to_assign_sequentially, BufferAssignment* assignment); + // Returns true if buffer's live range interferences with buffer2's. + bool LiveRangeInterferes(const HloValue* buffer1, const HloValue* buffer2, + BufferAssignment* assignment); + + // Assigns pre-set assignments, if provided. These assignments will be added + // to assigned_buffers and skip buffer allocation. + Status AssignPresetBuffers( + absl::flat_hash_set* assigned_buffers, + BufferAssignment* assignment); + // Promotes operations (DUS, scatter) to be done in place: If an operation can // be done in place, merge its buffer with its operand buffer. Status MergeInplaceOpBuffers(BufferAssignment* assignment); @@ -657,6 +686,9 @@ class BufferAssigner { // A set of hlo opcodes that can't live out of a computation. absl::flat_hash_set must_not_live_out_; + // Description of any buffer offsets that are already set by an earlier pass. + std::unique_ptr preset_assignments_; + TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner); }; diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 3bb98d5d1be..1c985485d43 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -143,6 +143,20 @@ class BufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr RunBufferAssignmentWithPresetAssignments( + HloModule* module, std::unique_ptr preset_assignments, + int64 alignment = 1) { + return BufferAssigner::Run( + module, absl::make_unique(module), + backend().compiler()->BufferSizeBytesFunction(), + [alignment](LogicalBuffer::Color) { return alignment; }, + /*allocate_buffers_for_constants=*/true, + BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, + /*can_share_buffer=*/nullptr, std::move(preset_assignments)) + .ConsumeValueOrDie(); + } + // Builds an x+1.0 computation to use in a Map. std::unique_ptr BuildMapComputationPlus1(const string& name) { auto builder = HloComputation::Builder(name); @@ -599,6 +613,13 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // The sub node has a valid output buffer assigned. GetAssignedOutputAllocation(*buffers, sub); + + // Check if the HLO instructions have the correct colors in the layout. + EXPECT_EQ(param0->shape().layout().memory_space(), 2); + EXPECT_EQ(param1->shape().layout().memory_space(), 3); + EXPECT_EQ(mul->shape().layout().memory_space(), 4); + EXPECT_EQ(add->shape().layout().memory_space(), 5); + EXPECT_EQ(sub->shape().layout().memory_space(), 6); } TEST_F(BufferAssignmentTest, BasicPartiallyColored) { @@ -666,6 +687,86 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // The sub node has a valid output buffer assigned. GetAssignedOutputAllocation(*buffers, sub); + + // Check if the HLO instructions have the correct colors in the layout. + EXPECT_EQ(mul->shape().layout().memory_space(), 1); + EXPECT_EQ(add->shape().layout().memory_space(), 1); + EXPECT_EQ(sub->shape().layout().memory_space(), 0); + EXPECT_EQ(param0->shape().layout().memory_space(), 0); + EXPECT_EQ(param1->shape().layout().memory_space(), 0); +} + +TEST_F(BufferAssignmentTest, PresetAssignments) { + // paramscalar ------- (mul) -- (add) -- (sub) + // / / / + // param0[100] -------/ / / + // / / + // param1[100] --------------/--------/ + // Similar to BasicPartiallyColored, but the color is set in the layout. + // The output of the mul and the add have the color 1 and have preset + // assignments, and the other buffers have the color 0, which allows the mul + // and add to share buffers. + auto builder = HloComputation::Builder(TestName()); + auto paramscalar = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec100_, paramscalar, {})); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec100_, "p1")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec100_, "p2")); + Shape f32vec100_color1 = + ShapeUtil::MakeShapeWithLayout(F32, {100}, {0}, {}, 0, 1); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_color1, HloOpcode::kMultiply, broadcast, param0)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_color1, HloOpcode::kAdd, mul, param1)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec100_, HloOpcode::kSubtract, add, param1)); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(builder.Build()); + + auto preset_assignments = absl::make_unique(); + preset_assignments->add_chunk({mul, {}}, {/*offset=*/100, /*size=*/400}); + preset_assignments->add_chunk({add, {}}, {/*offset=*/550, /*size=*/400}); + preset_assignments->add_size(/*memory_space=*/1, /*size=*/950); + + auto buffers = RunBufferAssignmentWithPresetAssignments( + module.get(), std::move(preset_assignments)); + + // Distinct input buffers were assigned for parameters. + BufferAllocation paramscalar_buffer = + GetAssignedInputAllocation(*buffers, paramscalar); + BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0); + BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1); + EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index()); + EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index()); + EXPECT_EQ(paramscalar_buffer.color(), LogicalBuffer::Color(0)); + EXPECT_NE(param0_buffer.index(), param1_buffer.index()); + EXPECT_EQ(param0_buffer.color(), LogicalBuffer::Color(0)); + + // The mul and add use the same preset buffer. Ensure it has the correct color + // and offsets. + const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); + const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add); + EXPECT_EQ(mul_buffer, add_buffer); + EXPECT_NE(mul_buffer.index(), param0_buffer.index()); + EXPECT_EQ(mul_buffer.color(), LogicalBuffer::Color(1)); + + EXPECT_EQ(mul_buffer.assigned_buffers().size(), 2); + for (const auto& value_and_offsetsize : mul_buffer.assigned_buffers()) { + if (value_and_offsetsize.first->instruction() == mul) { + EXPECT_EQ(value_and_offsetsize.second.offset, 100); + EXPECT_EQ(value_and_offsetsize.second.size, 400); + } else { + EXPECT_EQ(value_and_offsetsize.first->instruction(), add); + EXPECT_EQ(value_and_offsetsize.second.offset, 550); + EXPECT_EQ(value_and_offsetsize.second.size, 400); + } + } + + // The sub node has a valid output buffer assigned. + GetAssignedOutputAllocation(*buffers, sub); } TEST_F(BufferAssignmentTest, MultipleUsersForNode) { @@ -1482,7 +1583,7 @@ TEST_F(BufferAssignmentTest, BitcastAsOutput) { auto param = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {42}), "param")); auto bitcast = builder.AddInstruction( - HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param)); + HloInstruction::CreateBitcast(param->shape(), param)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cholesky_expander.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc index 27b1dcca2bd..74fc15a3eed 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -46,11 +46,12 @@ namespace { // n = a.shape[-2] // l = np.zeros_like(a) // for j in xrange(n): -// row = l[..., j, :j] -// row_t = np.swapaxes(row, -1, -2) -// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t)) -// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / -// l[..., j, j] +// mask = np.zeros_like(a) +// mask[i, k] == 1 when i >= k and k == j +// l_square = np.dot(l, l_t) +// temp = a - l_square +// l[..., j, j] = temp(j, j) +// l = temp / l[..., j, j) * mask + l // return l // Returns a (result, error) pair. std::pair CholeskyUnblocked( @@ -65,6 +66,11 @@ std::pair CholeskyUnblocked( /*pos=*/0, /*len=*/n_dims - 2); + auto matrix_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims); + XlaOp l = ZerosLike(a); // Construct the for loop body to iterate over rows. @@ -73,63 +79,33 @@ std::pair CholeskyUnblocked( XlaBuilder* body_builder) -> StatusOr> { std::vector row_shape_dims(major_dims.begin(), major_dims.end()); std::vector col_shape_dims(major_dims.begin(), major_dims.end()); - row_shape_dims.push_back(1); - row_shape_dims.push_back(n); - auto mask_zeros_row = - Zeros(body_builder, - ShapeUtil::MakeShape(a_shape.element_type(), row_shape_dims)); - - col_shape_dims.push_back(n); - col_shape_dims.push_back(1); - auto mask_zeros_col = - Zeros(body_builder, - ShapeUtil::MakeShape(a_shape.element_type(), col_shape_dims)); - - auto mask_range_row = - Iota(body_builder, ShapeUtil::MakeShape(S32, row_shape_dims), - /*iota_dimension=*/n_dims - 1); - auto mask_range_col = - Iota(body_builder, ShapeUtil::MakeShape(S32, col_shape_dims), - /*iota_dimension=*/n_dims - 2); auto body_a = loop_vars[0]; auto body_l = loop_vars[1]; auto seen_error = loop_vars[2]; + auto iota_row = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), + n_dims - 1); + auto iota_col = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), + n_dims - 2); + + auto mask_pred = Ge(iota_col, iota_row); + mask_pred = And(mask_pred, Eq(iota_row, i)); + auto mask_zeros = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); + // L * L.T, This matrix has of a lot of multiplying with zero + // (namely, L[:, j:] = 0) and redudant computation, but it is faster + // than slice. + auto l_square = BatchDot(body_l, false, body_l, true, precision); + + // A - L*L.T + l_square = body_a - l_square; + auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); + l_ii = Sqrt(l_ii); + // L = (A - L*L.T) / l_ii * mask + L + body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; - // row = l[..., i, :i] - // select the whole i-th row, then mask out all columns past i-1 - auto zero = ConstantR0(body_builder, 0); - auto l_i = DynamicSliceInMinorDims(body_l, {i, zero}, {1, n}); - auto row = Select(Ge(mask_range_row, i), mask_zeros_row, l_i); - // a[..., i, i] - auto a_ii = DynamicSliceInMinorDims(body_a, {i, i}, {1, 1}); - // np.dot(row, np.swapaxes(row, -1, -2)) - auto diag_dot = BatchDot(row, false, row, true, precision); - // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, - // np.swapaxes(row, -1, -2))) - auto l_ii = a_ii - diag_dot; seen_error = Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); - l_ii = Sqrt(l_ii); - - // a[..., i+1:, i] - // select the whole i-th column, then mask out all rows above i+1 - auto a_0i = DynamicSliceInMinorDims(body_a, {i}, {1}); - auto a_ip1i = Select(Le(mask_range_col, i), mask_zeros_col, a_0i); - - // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / - // l[..., i, i] - // The columns in [i, n] are zeroed out in `row`, so we just have to - // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], - // r.T) - auto dot = BatchDot(body_l, false, row, true, precision); - // np.dot(l[..., i+1:, :i], r.T) - auto dot_ip1 = Select(Le(mask_range_col, i), mask_zeros_col, dot); - - body_l = - DynamicUpdateSliceInMinorDims(body_l, (a_ip1i - dot_ip1) / l_ii, {i}); - // Assign the diagonal after the rest of the column because otherwise the - // column assign will wrap around and overwrite the diagonal assign. - body_l = DynamicUpdateSliceInMinorDims(body_l, l_ii, {i, i}); return std::vector{body_a, body_l, seen_error}; }; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 631a7dd7e6a..eee2e26ec9f 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -151,13 +151,6 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) = 0; - // Optimizes a HLO module group, a set of module which runs concurrently on - // multiple devices potentially communicating data between the modules. - virtual Status RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) = 0; - // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are // applied to module. Generally a module should be passed through RunHloPasses @@ -172,14 +165,6 @@ class Compiler { std::unique_ptr module, se::StreamExecutor* executor, se::DeviceMemoryAllocator* device_allocator) = 0; - // Compiles a set of HLO modules that can run in parallel, potentially - // communicating data between the modules. - virtual StatusOr>> - RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) = 0; - // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index 92d1ca4ba5d..863fd030d35 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -66,4 +67,23 @@ ProgramShape ComputationLayout::ComputeProgramShape() const { return program_shape; } +bool ComputationLayout::operator==(const ComputationLayout& other) const { + return result_layout() == other.result_layout() && + parameter_layouts() == other.parameter_layouts(); +} + +bool ComputationLayout::operator!=(const ComputationLayout& other) const { + return result_layout() != other.result_layout() || + parameter_layouts() != other.parameter_layouts(); +} + +uint64 ComputationLayout::Hash() const { + uint64 hash_value = ShapeUtil::Hash(result_layout_.shape()); + for (const auto& parameter_layout : parameter_layouts_) { + hash_value = tensorflow::Hash64Combine( + hash_value, ShapeUtil::Hash(parameter_layout.shape())); + } + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index a2fb656677f..5aab1a5fd42 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -87,6 +87,10 @@ class ComputationLayout { // within this object. ProgramShape ComputeProgramShape() const; + bool operator==(const ComputationLayout& other) const; + bool operator!=(const ComputationLayout& other) const; + uint64 Hash() const; + private: std::vector parameter_layouts_; ShapeLayout result_layout_; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index f1936035fed..985603b08e4 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -253,6 +253,31 @@ StatusOr TryRemoveUnusedConditionalOperands( } return true; } + +// Replaces the roots of all branches with an empty tuple if the conditional op +// has no users. Returns if anything is changed. +bool ReplaceRootWithEmptyTupleIfNoUsers(HloInstruction* conditional_op) { + const Shape empty_tuple = ShapeUtil::MakeTupleShape({}); + if (conditional_op->user_count() == 0 && + conditional_op != conditional_op->parent()->root_instruction() && + !ShapeUtil::Compatible(empty_tuple, conditional_op->shape())) { + for (int64 branch_id = 0; branch_id < conditional_op->branch_count(); + ++branch_id) { + auto branch_computation = + conditional_op->GetModule()->AddEmbeddedComputation( + conditional_op->branch_computation(branch_id)->Clone()); + conditional_op->set_branch_computation(branch_id, branch_computation); + auto new_empty_root = + branch_computation->AddInstruction(HloInstruction::CreateTuple({})); + branch_computation->set_root_instruction(new_empty_root, + /*accept_different_shape=*/true); + } + *conditional_op->mutable_shape() = empty_tuple; + return true; + } + return false; +} + } // namespace StatusOr ConditionalSimplifier::Run(HloModule* module) { @@ -274,6 +299,7 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { std::map> changed_computations; for (HloInstruction* conditional_op : conditional_ops) { + changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op); TF_ASSIGN_OR_RETURN(bool result, TryRemoveConditional(conditional_op)); if (!result) { TF_ASSIGN_OR_RETURN(result, TryRemoveUnusedConditionalOperands( diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index 58659156a75..d409e22463e 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -285,6 +285,49 @@ TEST_F(ConditionalSimplifierTest, EXPECT_TRUE(ConditionalSimplifier().Run(module.get()).ValueOrDie()); } +TEST_F(ConditionalSimplifierTest, RemoveDeadRoots) { + absl::string_view hlo_string = + R"( +HloModule RemoveDeadRoots +on_false { + t = (f32[20,40], f32[40,40]) parameter(0) + lhs = f32[20,40] get-tuple-element(t), index=0 + rhs = f32[40,40] get-tuple-element(t), index=1 + dot = f32[20,40] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={0} + after-all = token[] after-all() + outfeed = token[] outfeed(dot, after-all) + ROOT result = (f32[20,40]) tuple(dot) +} + +on_true { + t = (f32[20,40], f32[40,40]) parameter(0) + lhs = f32[20,40] get-tuple-element(t), index=0 + add = f32[20,40] add(lhs, lhs) + ROOT result = (f32[20,40]) tuple(add) +} + +ENTRY main { + c0_0 = f32[20,40] parameter(0) + c0_1 = f32[40,40] parameter(1) + p = pred[] parameter(2) + t = (f32[20,40], f32[40,40]) tuple(c0_0, c0_1) + conditional = (f32[20, 40]) conditional(p,t,t), false_computation=on_false, true_computation=on_true + ROOT result = () tuple() +} +)"; + auto status = ParseAndReturnUnverifiedModule(hlo_string); + TF_ASSERT_OK(status.status()); + HloVerifier v(false, false); + TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); + EXPECT_TRUE( + ConditionalSimplifier().Run(status.ValueOrDie().get()).ValueOrDie()); + TF_ASSERT_OK(v.Run(status.ValueOrDie().get()).status()); + HloInstruction* conditional = + FindInstruction(status.ValueOrDie().get(), "conditional"); + // The conditional root should be replaced with an empty tuple. + EXPECT_EQ(ShapeUtil::TupleElementCount(conditional->shape()), 0); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc index ff75f0f2469..20ebafcf780 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc @@ -355,7 +355,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { } // We want to repeat 'filter' in the 'input_feature_dim' dimension // 'group_count' times. - if (filter_expansion_) { + if (!is_cost_viable_(convolution) || filter_expansion_) { Shape reshaped_filter_shape = ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape()); auto reshaped_filter = diff --git a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc index d2eea14896e..85c54d31582 100644 --- a/tensorflow/compiler/xla/service/convolution_group_converter_test.cc +++ b/tensorflow/compiler/xla/service/convolution_group_converter_test.cc @@ -49,7 +49,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,2], filter: f32[1,1,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + auto cost_model = [](HloInstruction* conv) { return true; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); @@ -80,7 +81,8 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,4], filter: f32[1,2,2]) -> f32[1,2 auto computation = module->entry_computation(); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); - ConvolutionGroupConverter converter(nullptr, /*convert_batch_groups_only=*/ + auto cost_model = [](HloInstruction* conv) { return true; }; + ConvolutionGroupConverter converter(cost_model, /*convert_batch_groups_only=*/ false); ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 6fa3161e578..f0ac579a387 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -235,8 +235,8 @@ TEST_F(CopyInsertionTest, BitcastParameter) { auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); - HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); + HloInstruction* bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -258,8 +258,9 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* constant = builder.AddInstruction(HloInstruction::CreateConstant( LiteralUtil::CreateR1({1.0, 42.0}))); - HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); + HloInstruction* bitcast = + builder.AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::MakeShape(F32, {2, 2}), constant)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -279,8 +280,8 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x")); - HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x)); + HloInstruction* bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(ShapeUtil::MakeShape(F32, {2, 2}), x)); builder.AddInstruction(HloInstruction::CreateTuple({bitcast})); auto module = CreateNewVerifiedModule(); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 37baf0e36df..8a5bbc4248d 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -35,6 +35,7 @@ cc_library( srcs = ["cpu_transfer_manager.cc"], hdrs = ["cpu_transfer_manager.h"], deps = [ + ":cpu_runtime", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -45,7 +46,6 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", @@ -95,8 +95,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:tree_reduction_rewriter", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:conditional_to_select", + "//tensorflow/compiler/xla/service:slow_operation_alarm", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:slice_sinker", "//tensorflow/compiler/xla:cpu_function_runtime", @@ -1012,3 +1014,19 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], ) + +tf_cc_test( + name = "vectorized_reduce_with_no_vector_registers_test", + size = "small", + srcs = ["vectorized_reduce_with_no_vector_registers_test.cc"], + deps = [ + ":cpu_compiler", + ":cpu_transfer_manager", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@llvm//:core", + "@llvm//:support", + "@llvm//:target", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 9f8f74344af..e7371c79b39 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -99,8 +99,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/rng_expander.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/slice_sinker.h" +#include "tensorflow/compiler/xla/service/slow_operation_alarm.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -300,6 +302,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); + pass.AddPass(); pass.AddPass(); pass.AddPass( /*rewrite_training_op=*/true, @@ -606,6 +609,7 @@ StatusOr> CpuCompiler::RunBackend( VLOG(1) << "Compiling: " << module->name(); XLA_SCOPED_LOGGING_TIMER( absl::StrFormat("Compiling [%s] for CPU using JIT", module->name())); + auto slow_compile_alarm = SlowCompilationAlarm(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 476579883f3..9b79e8ca8d7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -194,13 +194,13 @@ Status CpuExecutable::ExecuteComputeFunction( uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - { - tensorflow::mutex_lock lock(mutex_); + if (run_options->execution_profile()) { const double nanoseconds = (end_micros - start_micros) * 1000.0; - execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + run_options->execution_profile()->set_compute_time_ns( + std::max(nanoseconds, 1.0)); // If hlo profiling was disabled then the cycle count is left empty. if (hlo_execution_profile) { - execution_profile_.set_compute_cycle_count( + run_options->execution_profile()->set_compute_cycle_count( hlo_execution_profile->total_cycles_executed( *module().entry_computation())); } @@ -268,29 +268,7 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( return std::move(result_buffer); } -StatusOr CpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments, - HloExecutionProfile* hlo_execution_profile) { - TF_ASSIGN_OR_RETURN( - auto result, - ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); - TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); - return std::move(result); -} - StatusOr CpuExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) { - if (hlo_profiling_enabled()) { - return Unimplemented( - "Asynchronous execution on stream with hlo profiling is not yet " - "supported on CPU."); - } - return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); -} - -StatusOr CpuExecutable::ExecuteAsyncOnStreamImpl( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 169acdeffd4..37af630a2d9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -55,15 +55,11 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr ExecuteOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) override; - // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -86,16 +82,6 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: - // This is for sharing the code between ExecuteOnStream and - // ExecuteAsyncOnStream. - // - // Notice that it's tricky to use correctly, as the profile object (when it - // exists) must out-live the task. - StatusOr ExecuteAsyncOnStreamImpl( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments, - HloExecutionProfile* hlo_execution_profile); - // Creates an array suitable for passing as the "buffer_table" argument to the // JIT compiled function pointer. // diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 6620a9620b5..a6f960a5cb6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -40,10 +40,11 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } -bool IsNonComplexMatrixVectorDot(const HloInstruction* hlo) { +bool IsNonComplexNonBatchedMatrixVectorDot(const HloInstruction* hlo) { const Shape& hlo_shape = hlo->shape(); return !ShapeUtil::ElementIsComplex(hlo_shape) && - hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1; + hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() <= 1 && + hlo->dot_dimension_numbers().lhs_batch_dimensions_size() == 0; } bool HasExactlyOneUse(const HloInstruction& hlo_instr) { @@ -54,7 +55,7 @@ bool HasExactlyOneUse(const HloInstruction& hlo_instr) { bool CanBeOutputFused(const HloInstruction* producer, const HloInstruction* consumer) { return consumer->opcode() == HloOpcode::kAdd && - IsNonComplexMatrixVectorDot(producer) && + IsNonComplexNonBatchedMatrixVectorDot(producer) && HasExactlyOneUse(*producer) == 1; } @@ -74,10 +75,13 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, constexpr int kFusionThresholdBytes = 16 * 1024; if (CanBeOutputFused(producer, consumer)) { + VLOG(2) << "Fusion OK: Can create output fusion."; return true; } if (CanBeOutputFusedIntoSomeOperand(producer)) { + VLOG(2) + << "Bailing because producer can be output-fused into some operand."; return false; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 1509da6f7ec..f0d7461e5e7 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1027,10 +1027,13 @@ StatusOr IrEmitter::EmitElementalConvolution( PrimitiveType lhs_element_type = lhs->shape().element_type(); llvm::Type* lhs_llvm_type = llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); + // Upcast the accumulator to F32 from F16 for increased precision. + llvm::Type* accumulator_type = + lhs_element_type == F16 ? b_.getFloatTy() : lhs_llvm_type; llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - lhs_llvm_type, "convolution_sum_address", &b_, + accumulator_type, "convolution_sum_address", &b_, MinimumAlignmentForPrimitiveType(lhs_element_type)); - llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); + llvm::Value* constant_zero = llvm::Constant::getNullValue(accumulator_type); Store(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); @@ -1139,11 +1142,11 @@ StatusOr IrEmitter::EmitElementalConvolution( TF_ASSIGN_OR_RETURN(llvm::Value* const kernel_value, kernel_generator(kernel_index)); llvm::Value* product = FMul(input_value, kernel_value); - llvm::Value* sum = FAdd(Load(sum_address), product); + llvm::Value* sum = FAdd(Load(sum_address), FPCast(product, accumulator_type)); Store(sum, sum_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return Load(sum_address); + return FPCast(Load(sum_address), lhs_llvm_type); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1736,6 +1739,16 @@ StatusOr IrEmitter::EmitVectorizedReduce( return false; } + int vector_register_size_in_elements = + target_machine_features_.vector_register_byte_size( + *compute_function_->function()) / + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()); + if (vector_register_size_in_elements == 0) { + // Either we don't know the vector register width for the target or the + // vector register is smaller than the size of the primitive type. + return false; + } + int vectorization_factor_in_bytes = target_machine_features_.vectorization_factor_in_bytes(); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index d3e2e2bea95..19b0bb3f4dc 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -128,6 +128,30 @@ tf_cc_test( ], ) +tf_cc_test( + name = "tree_reduction_rewriter_test", + srcs = ["tree_reduction_rewriter_test.cc"], + deps = [ + ":cpu_codegen_test", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/tests:codegen_test_base", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "cpu_infeed_test", srcs = ["cpu_infeed_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/tests/tree_reduction_rewriter_test.cc b/tensorflow/compiler/xla/service/cpu/tests/tree_reduction_rewriter_test.cc new file mode 100644 index 00000000000..bcb7da0e6cf --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/tree_reduction_rewriter_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace cpu { + +namespace { + +class TreeReductionRewriterTest : public CpuCodegenTest {}; + +TEST_F(TreeReductionRewriterTest, SimpleRewrite) { + const char* hlo_text = R"( +HloModule SimpleReduction + +add { + acc = f32[] parameter(1) + op = f32[] parameter(0) + ROOT out = f32[] add(acc, op) +} + +ENTRY main { + input = f32[1000] parameter(0) + zero = f32[] constant(0) + ROOT out = f32[] reduce(input, zero), dimensions={0}, to_apply=add +} + )"; + + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %main (input: f32[1000]) -> f32[] { +; CHECK-NEXT: %input = f32[1000]{0} parameter(0) +; CHECK-NEXT: %zero = f32[] constant(0) +; CHECK-NEXT: %reduce-window = f32[32]{0} reduce-window(%input, %zero) +; CHECK-NEXT: %reduce-window.1 = f32[1]{0} reduce-window(%reduce-window, %zero), window={size=32 stride=32}, to_apply=%add +; CHECK-NEXT: ROOT %bitcast = f32[] bitcast(%reduce-window.1) + )"); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc new file mode 100644 index 00000000000..2918c886f08 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/IR/Function.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { +class CodegenReduceOnArchWithNoVectorRegisters : public HloTestBase {}; + +StatusOr GetTargetVectorRegisterByteSize(std::string triple) { + // Unfortunately we need a lot of boilerplate to get to an + // llvm::TargetMachine. + + std::string error; + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (target == nullptr) { + return InternalError("TargetRegistry::lookupTarget failed: %s", error); + } + + llvm::LLVMContext context; + std::unique_ptr function = + absl::WrapUnique(llvm::Function::Create( + llvm::FunctionType::get(llvm::Type::getVoidTy(context), {}), + llvm::GlobalValue::ExternalLinkage, "test")); + + std::unique_ptr target_machine = + absl::WrapUnique(target->createTargetMachine( + /*TT=*/triple, /*CPU=*/"", /*Features=*/"", llvm::TargetOptions{}, + /*RM=*/llvm::None)); + cpu::LLVMTargetMachineFeatures target_machine_features(target_machine.get()); + return target_machine_features.vector_register_byte_size(*function); +} + +TEST_F(CodegenReduceOnArchWithNoVectorRegisters, Test) { + absl::string_view text = R"( +HloModule Reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY main { + input = f32[1000,1000] parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[1000] reduce(input, constant), dimensions={0}, to_apply=add +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(text)); + cpu::CpuCompiler cpu_compiler; + auto module_group = absl::make_unique("group"); + module_group->push_back(std::move(hlo_module)); + + // Check that the GetTargetVectorRegisterByteSize is itself working. + TF_ASSERT_OK_AND_ASSIGN(unsigned vector_register_byte_size_for_x86_64, + GetTargetVectorRegisterByteSize("x86_64-pc-linux")); + ASSERT_EQ(vector_register_byte_size_for_x86_64, 16); + + std::string triple = "i686-none-android"; + + TF_ASSERT_OK_AND_ASSIGN(unsigned vector_register_byte_size, + GetTargetVectorRegisterByteSize(triple)); + + // This test is supposed to check whether the XLA CPU vectorized reduction + // codegen works correctly for architectures that do not have vector + // registers. So first ASSERT that `triple` is actually a target with no + // vector registers, as otherwise the test isn't actually testing anything + // interesting. + + ASSERT_EQ(vector_register_byte_size, 0); + + cpu::CpuAotCompilationOptions aot_compilation_options( + /*triple=*/triple, /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"main", + cpu::CpuAotCompilationOptions::RelocationModel::BigPic); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector> aot_compilation_result, + cpu_compiler.CompileAheadOfTime(std::move(module_group), + aot_compilation_options)); + EXPECT_EQ(aot_compilation_result.size(), 1); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc b/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc new file mode 100755 index 00000000000..37a1d1346a7 --- /dev/null +++ b/tensorflow/compiler/xla/service/depthwise_convolution_converter.cc @@ -0,0 +1,215 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +class ConvolutionVisitor : public DfsHloVisitorWithDefault { + public: + // Default visitor action is to do nothing and return OK. + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* convolution) override; + + Status HandleBackwardFilterBatchGroupConvolution(HloInstruction* convolution); + + // Runs the visitor on a computation. + static bool Run(HloComputation* computation, + std::function is_cost_viable); + + // Returns whether any convolution ops were rewritten. + const bool changed() const { return changed_; } + + ~ConvolutionVisitor() override = default; + + private: + explicit ConvolutionVisitor( + HloComputation* computation, + std::function is_cost_viable) + : computation_(computation), is_cost_viable_(is_cost_viable) {} + + // Current HloComputation instance the ConvolutionVisitor is traversing. + HloComputation* computation_; + + // Whether rewrite has occurred. + bool changed_ = false; + + std::function is_cost_viable_; +}; + +bool ConvolutionVisitor::Run( + HloComputation* computation, + std::function is_cost_viable) { + ConvolutionVisitor visitor(computation, is_cost_viable); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; +} + +namespace { +Shape SwapInputOutputFeatureDims(const Shape& shape, int64 input_feature_dim, + int64 output_feature_dim) { + int64 num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); + Shape transformed_shape = shape; + transformed_shape.set_dimensions(input_feature_dim, + shape.dimensions(output_feature_dim)); + transformed_shape.set_dimensions(output_feature_dim, + shape.dimensions(input_feature_dim)); + return transformed_shape; +} +} // namespace + +// This function handles batch_group_counts which are relevant only for +// depthwise backprop filter convolutions. +Status ConvolutionVisitor::HandleBackwardFilterBatchGroupConvolution( + HloInstruction* convolution) { + auto dim_numbers = convolution->convolution_dimension_numbers(); + auto lhs = convolution->mutable_operand(0); + auto rhs = convolution->mutable_operand(1); + int64 batch_group_count = convolution->batch_group_count(); + + if (batch_group_count == 1) { + return Status::OK(); + } + + VLOG(2) << "Dealing with batch_group_count " << batch_group_count + << " for convolution " << convolution->ToString() << "\n"; + + int64 output_batch_dimension = dim_numbers.output_batch_dimension(); + int64 output_feature_dimension = dim_numbers.output_feature_dimension(); + + // When mapping depthwise conv backward filter to batch grouped convolution, + // tf2xla bridge needs to swap the output batch and feature dimension. Since + // we want to use grouped convolution APIs, this swap needs to be reverted. + dim_numbers.set_output_batch_dimension(output_feature_dimension); + dim_numbers.set_output_feature_dimension(output_batch_dimension); + + if (!is_cost_viable_(convolution)) { + Shape transformed_filter_grad_shape = SwapInputOutputFeatureDims( + convolution->shape(), dim_numbers.output_batch_dimension(), + dim_numbers.output_feature_dimension()); + + int64 num_groups = convolution->batch_group_count(); + int64 input_batch_dimension = dim_numbers.input_batch_dimension(); + int64 input_batch = lhs->shape().dimensions(input_batch_dimension); + int64 input_feature_dimension = dim_numbers.input_feature_dimension(); + int64 input_feature = lhs->shape().dimensions(input_feature_dimension); + + CHECK_EQ(input_batch, num_groups) + << "Feature group count should be equal to number of input features " + "for depthwise convolution"; + + auto add = [&](std::unique_ptr inst) { + return computation_->AddInstruction(std::move(inst)); + }; + // Reshape batch_dim C -> [G, C/G] - Batch and feature dims have been + // swapped in tf2xla bridge + std::vector reshape_dims = lhs->shape().dimensions(); + reshape_dims[input_batch_dimension] = + reshape_dims[input_batch_dimension] / num_groups; + reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, + num_groups); + lhs = add(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), lhs)); + + // Transpose G to the axis before N, For eg: [G, C/G, H, W, N ] -> [C/G, H, + // W, G, N] + std::vector transpose_dims(lhs->shape().dimensions_size()); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + transpose_dims.erase(transpose_dims.begin() + input_batch_dimension); + transpose_dims.insert(transpose_dims.begin() + input_feature_dimension, + input_batch_dimension); + std::vector transpose_reshape_dims = + ComposePermutations(lhs->shape().dimensions(), transpose_dims); + lhs = add(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(lhs->shape().element_type(), + transpose_reshape_dims), + lhs, transpose_dims)); + + // Merge [G,N] -> [N*G] + Shape new_shape = lhs->shape(); + new_shape.DeleteDimension(input_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_feature * num_groups); + lhs = add(HloInstruction::CreateReshape(new_shape, lhs)); + + std::vector new_operands = {lhs, rhs}; + auto new_conv = convolution->CloneWithNewOperands( + transformed_filter_grad_shape, new_operands); + new_conv->set_feature_group_count(num_groups); + new_conv->set_batch_group_count(1); + new_conv->set_convolution_dimension_numbers(dim_numbers); + auto new_convolution = computation_->AddInstruction(std::move(new_conv)); + + // Another reshape is required since the filter grad shape as a result of + // the 'new convolution` will be [kh, kw, C_i/G = 1, C_o = C_i = G ] but the + // expected shape is [kh, kw, C_i = G, DM=1] assuming the Depth-Multiplier + // (DM) is 1 and number of input features = G as required by the depthwise + // conv semantics + auto reshape = + HloInstruction::CreateReshape(convolution->shape(), new_convolution); + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + convolution, std::move(reshape))); + changed_ = true; + } + + return Status::OK(); +} + +Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { + return HandleBackwardFilterBatchGroupConvolution(convolution); +} + +} // namespace + +StatusOr DepthwiseConvolutionConverter::Run(HloModule* module) { + XLA_VLOG_LINES(2, "DepthwiseConvolutionConverter::Run(), before:\n" + + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (ConvolutionVisitor::Run(comp, is_cost_viable_)) { + changed = true; + } + } + XLA_VLOG_LINES( + 2, "DepthwiseConvolutionConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter.h b/tensorflow/compiler/xla/service/depthwise_convolution_converter.h new file mode 100755 index 00000000000..a71b2b0d45d --- /dev/null +++ b/tensorflow/compiler/xla/service/depthwise_convolution_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEPTHWISE_CONVOLUTION_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DEPTHWISE_CONVOLUTION_CONVERTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +class DepthwiseConvolutionConverter : public HloModulePass { + public: + explicit DepthwiseConvolutionConverter( + std::function is_cost_viable) + : is_cost_viable_(is_cost_viable) {} + + absl::string_view name() const override { + return "depthwise-convolution-converter"; + } + + // Run convolution rewriting on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + // Lambda containing cost model that decides whether to expand + // batch_group_count. + std::function is_cost_viable_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEPTHWISE_CONVOLUTION_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/depthwise_convolution_converter_test.cc b/tensorflow/compiler/xla/service/depthwise_convolution_converter_test.cc new file mode 100755 index 00000000000..cbf748bd5c9 --- /dev/null +++ b/tensorflow/compiler/xla/service/depthwise_convolution_converter_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace { + +using DepthwiseConvolutionConverterTest = HloTestBase; + +TEST_F(DepthwiseConvolutionConverterTest, + ConvertBatchGroupCountToFeatureGroupCount) { + string hlo_string = R"(HloModule Convolve1D1Window_0_module + +ENTRY %Convolve1D1Window_0.v3 (input: f32[16,19,19,512]{3,2,1,0}, filter: f32[16,19,19,512]{3,2,1,0}) -> f32[3,3,512,1]{3,2,1,0} { + %input = f32[16,19,19,512]{3,2,1,0} parameter(0) + %filter = f32[16,19,19,512]{3,2,1,0} parameter(1) + ROOT %convolution = f32[3,3,512,1]{3,2,1,0} convolution(f32[16,19,19,512]{3,2,1,0} %input, f32[16,19,19,512]{3,2,1,0} %filter), window={size=19x19 pad=1_1x1_1}, dim_labels=f01b_i01o->01fb, batch_group_count=512 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + auto batch_group_count = root->batch_group_count(); + EXPECT_EQ(root->opcode(), HloOpcode::kConvolution); + auto conv_dim_num = root->convolution_dimension_numbers(); + int64 out_batch_dim = conv_dim_num.output_batch_dimension(); + int64 out_feature_dim = conv_dim_num.output_feature_dimension(); + auto cost_model = [](HloInstruction*) { return false; }; + DepthwiseConvolutionConverter converter(cost_model); + ASSERT_TRUE(converter.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + // Verify that the convolution is replaced by a reshape. + EXPECT_EQ(root->opcode(), HloOpcode::kReshape) + << HloOpcodeString(root->opcode()) << " vs Reshape"; + + // Verify that the operand to the reshape is the new convolution + // with feature_group_count = batch_group_count + auto new_conv = root->operand(0); + EXPECT_EQ(new_conv->opcode(), HloOpcode::kConvolution) + << HloOpcodeString(new_conv->opcode()) << " vs Convolution"; + EXPECT_EQ(new_conv->feature_group_count(), batch_group_count); + // Verify that the output_batch_dim and output_feature_dim + // have been swapped back (tf2xla swaps these dimensions to make use + // of batch_group convolution for computing filter grad for depthwise + // convolutions) + EXPECT_EQ(new_conv->convolution_dimension_numbers().output_batch_dimension(), + out_feature_dim); + EXPECT_EQ( + new_conv->convolution_dimension_numbers().output_feature_dimension(), + out_batch_dim); + + // Verify that the operand to conv is a reshape + auto reshape_1 = new_conv->operand(0); + EXPECT_EQ(reshape_1->opcode(), HloOpcode::kReshape) + << HloOpcodeString(reshape_1->opcode()) << " vs Reshape"; + + // Verify that the operand to reshape_1 is transpose + auto transpose = reshape_1->operand(0); + EXPECT_EQ(transpose->opcode(), HloOpcode::kTranspose) + << HloOpcodeString(transpose->opcode()) << " vs Transpose"; + + // Verify that the operand to transpose is reshape + auto reshape_2 = transpose->operand(0); + EXPECT_EQ(reshape_2->opcode(), HloOpcode::kReshape) + << HloOpcodeString(reshape_2->opcode()) << " vs Reshape"; +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 1341535aad4..94a99c77a5a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -300,7 +300,18 @@ class DfsHloVisitorBase { // Useful when we want to visit the same computation more than once with the // same visitor. - void ResetVisitStates() { visit_state_.clear(); } + void ResetVisitStates() { + // Clear the map, but don't resize the capacity across uses -- Calculating + // and reserving space could be expensive, and we always use the same + // module->instruction_count() as the capacity. + visit_state_.erase(visit_state_.begin(), visit_state_.end()); + } + + // Useful when we want to free up the memory used by the visit state without + // destroying the actual visitor subclass. + void DestroyVisitState() { + visit_state_ = absl::flat_hash_map{}; + } void SetVisitState(int id, VisitState state) { visit_state_[id] = state; } diff --git a/tensorflow/compiler/xla/service/dump.cc b/tensorflow/compiler/xla/service/dump.cc index 6a4837211e8..331c935bdc9 100644 --- a/tensorflow/compiler/xla/service/dump.cc +++ b/tensorflow/compiler/xla/service/dump.cc @@ -136,10 +136,6 @@ struct CanonicalDebugOptions { bool dump_snapshots; }; -string FilenameFor(const HloModule& module, string_view suffix) { - return StrFormat("module_%04d.%s", module.unique_id(), suffix); -} - void DumpToFileInDirImpl(string_view filename, string_view contents, const CanonicalDebugOptions& opts) { if (opts.dumping_to_stdout()) { @@ -263,6 +259,10 @@ static auto& module_id_to_step_number GUARDED_BY(mu) = } // namespace +string FilenameFor(const HloModule& module, string_view suffix) { + return StrFormat("module_%04d.%s", module.unique_id(), suffix); +} + void DumpToFileInDir(const HloModule& module, string_view suffix, string_view contents) { DumpToFileInDirImpl(FilenameFor(module, suffix), contents, diff --git a/tensorflow/compiler/xla/service/dump.h b/tensorflow/compiler/xla/service/dump.h index 6edc9b28dde..d245ad582c4 100644 --- a/tensorflow/compiler/xla/service/dump.h +++ b/tensorflow/compiler/xla/service/dump.h @@ -33,6 +33,9 @@ class BufferAssignment; class HloExecutionProfile; class HloSnapshot; +// Create the filename we will use to dump in DumpToFileInDir. +string FilenameFor(const HloModule& module, absl::string_view suffix); + // Writes the given string to a file in the xla_dump_to directory specified by // module's DebugOptions. // diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 3925eeb7f62..1f7d41c7b94 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -17,8 +17,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -53,6 +55,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleReshape(HloInstruction* hlo) override; + Status HandleSort(HloInstruction* hlo) override; + Status HandlePad(HloInstruction* hlo) override; Status HandleBroadcast(HloInstruction* hlo) override; @@ -161,6 +165,29 @@ Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { }); } +Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) { + return ForEachOperandDynamicDimension( + hlo, [&](HloInstruction* operand, ShapeIndex index, + int64 dynamic_dimension, int64 operand_index, + HloInstruction* dynamic_size, DimensionConstraint constraint) { + HloSortInstruction* sort = Cast(hlo); + int64 sort_dimension = sort->sort_dimension(); + if (sort_dimension == dynamic_dimension) { + return Unimplemented( + "Dynamic dimension on sorting dimension is not supported"); + } + if (sort->values_count() == 0) { + parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size, + constraint); + } else { + parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension, + dynamic_size, constraint); + } + + return Status::OK(); + }); +} + Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index 5821e89612b..7a13307ffbf 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -912,6 +912,78 @@ TEST_F(DynamicDimensionInferenceTest, DynamicSliceTest) { EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 0), size_param); } +TEST_F(DynamicDimensionInferenceTest, SortTest) { + auto builder = HloComputation::Builder(TestName()); + + auto data_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto compare_builder = HloComputation::Builder("condition"); + compare_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "param1")); + compare_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "param2")); + compare_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* compare = + module_->AddEmbeddedComputation(compare_builder.Build()); + + auto* sort = builder.AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeShape(F32, {5, 7}), 1, {data_param}, compare, + /*is_stable=*/false)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(sort, {}, 0), size_param); +} + +TEST_F(DynamicDimensionInferenceTest, MultiValueSortTest) { + auto builder = HloComputation::Builder(TestName()); + + auto shape = ShapeUtil::MakeShape(F32, {5, 7}); + + auto data_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "data_param")); + auto size_param = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "size_param")); + + auto compare_builder = HloComputation::Builder("condition"); + compare_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "param1")); + compare_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "param2")); + compare_builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {}), "param3")); + compare_builder.AddInstruction(HloInstruction::CreateParameter( + 3, ShapeUtil::MakeShape(F32, {}), "param4")); + compare_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloComputation* compare = + module_->AddEmbeddedComputation(compare_builder.Build()); + + auto* sort = builder.AddInstruction( + HloInstruction::CreateSort(ShapeUtil::MakeTupleShape({shape, shape}), 1, + {data_param, data_param}, compare, + /*is_stable=*/false)); + + module_->AddEntryComputation(builder.Build()); + // Set up dynamic parameter binding. + TF_CHECK_OK(module_->dynamic_parameter_binding().Bind( + DynamicParameterBinding::DynamicParameter{1, {}}, + DynamicParameterBinding::DynamicDimension{0, {}, 0})); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(sort, {0}, 0), size_param); + EXPECT_EQ(inference_->GetDynamicSize(sort, {1}, 0), size_param); +} + TEST_F(DynamicDimensionInferenceTest, DynamicSliceSingleElementTest) { // Slicing out a single element from a dynamic dimension terminates the // dynamic dimension. diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 4eed3b8a560..5fea5d823de 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -90,6 +90,7 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, case HloOpcode::kAllReduce: case HloOpcode::kBroadcast: case HloOpcode::kTranspose: + case HloOpcode::kSort: case HloOpcode::kSlice: return nullptr; default: diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 48559bf5fc3..63d7f3b1c0d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -515,15 +515,14 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( : input_type; switch (op->opcode()) { case HloOpcode::kLog: { - // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + // log(a+bi) = log(abs(a+bi)) + i*atan2(b,a) auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); - llvm::Type* llvm_ty = a->getType(); - auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); - TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); - TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); + TF_ASSIGN_OR_RETURN(llvm::Value * angle, EmitAtan2(component_type, b, a)); + TF_ASSIGN_OR_RETURN(llvm::Value * abs, + EmitComplexAbs(component_type, operand_value)); + TF_ASSIGN_OR_RETURN(llvm::Value * log_abs, EmitLog(component_type, abs)); + return EmitComposeComplex(op, log_abs, angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -639,32 +638,128 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(e^(2a)-e^(-2a) + + i*[cos(b)sin(b)(e^(2a)+2+e^(-2a))-cos(b)sin(b)(e^(2a)-2+e^(2a)))] + / (cos(b)^2*(e^(2a)+2+e^(-2a)) + sin(b)^2*(e^(2a)-2+e^(2a)) + =(e^(2a)-e^(-2a) + + i*cos(b)sin(b)*[e^(2a)+2+e^(-2a)-e^(2a)+2-e^(-2a)]) / + ([cos(b)^2 + sin(b)^2][e^(2a)+e^(-2a)])+2*[cos(b)^2 - sin(b)^2]) + =(e^(2a)-e^(-2a) + i*cos(b)sin(b)*4) / + (e^(2a)+e^(-2a)+2*[cos(b)^2 - sin(b)^2]) + =(e^(2a)-e^(-2a) + i*[sin(2b)/2]*4) / + (e^(2a)+e^(-2a)+2*[cos(2b)]) + =(e^(2a)-e^(-2a) + i*2*sin(2b)) / (e^(2a) + e^(-2a) + 2*cos(2b)) */ - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); - TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); - TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = - FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = FMul(cos_b, cos_b); - auto sin_b_sq = FMul(sin_b, sin_b); - auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - FMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = FMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a); - auto exp_a_plus_exp_neg_a_sq = - FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a); - auto exp_a_minus_exp_neg_a_sq = - FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = FMul( - cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); - auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, FDiv(real_num, denom), - FDiv(imag_num, denom)); + llvm::Value* a = EmitExtractReal(operand_value); + llvm::Value* b = EmitExtractImag(operand_value); + + llvm::Type* type = a->getType(); + + llvm::Value* neg_one = llvm::ConstantFP::get(type, -1.F); + llvm::Value* two_a = FAdd(a, a); + llvm::Value* neg_2a = FMul(neg_one, two_a); + + // When we are calculating the real numerator, e^(2a)-e^(-2a), for small + // values of `a`, we will get a ULP of 2^-23 using the exp function. Using + // expm1 to calculate e^(2a)-e^(-2a) = [e^(2a)-1] - [e^(-2a)-1] allows our + // ULP to be arbitrarily small. For larger values of `a`, calculating the + // numerator as Exp(2a)-Exp(-2a) vs Expm1(2a)-Expm1(-2a) return virtually + // identical results. + TF_ASSIGN_OR_RETURN(llvm::Value * exp_2a_m1, + EmitExpm1(component_type, two_a)); + TF_ASSIGN_OR_RETURN(llvm::Value * exp_neg_2a_m1, + EmitExpm1(component_type, neg_2a)); + llvm::Value* real_numerator = FSub(exp_2a_m1, exp_neg_2a_m1); + + // We can use the identity cos(2b)+1 = cos(b)^2-sin(b)^2+cos(b)^2+sin(b)^2 + // = 2cos(b)^2. This gives us the ability to be more precise when the + // denominator is close to zero. + TF_ASSIGN_OR_RETURN(llvm::Value * cos_b, EmitCos(component_type, b)); + llvm::Value* four = llvm::ConstantFP::get(type, 4.F); + llvm::Value* cos_b_sq = FMul(cos_b, cos_b); + llvm::Value* two_cos_2b_p2 = FMul(cos_b_sq, four); + + // Similarly we can compute sin(2b) with the formula sin(2b) = + // 2*sin(b)*cos(b). + TF_ASSIGN_OR_RETURN(llvm::Value * sin_b, EmitSin(component_type, b)); + llvm::Value* imag_numerator = FMul(four, FMul(cos_b, sin_b)); + + // Expm1(x) is about x for small values of x, but exp_sum_m2 is about x^2 + // for small value of x. As a result, due to floating point precission + // issues, x^2 is a better approximation than Expm1(x) + Expm1(x) for + // small values of x. + llvm::Value* a_sqr = FMul(a, a); + llvm::Value* use_approx_cutoff = llvm::ConstantFP::get(type, 1e-8); + llvm::Value* use_approx = FCmpOLT(a_sqr, use_approx_cutoff); + + llvm::Value* exp_sum_m2 = + Select(use_approx, a_sqr, FAdd(exp_2a_m1, exp_neg_2a_m1)); + llvm::Value* denom = FAdd(exp_sum_m2, two_cos_2b_p2); + + // As `a` grows toward +inf and -inf, the real numerator will grow towards + // +inf and -inf respectively, while the denominator will always grow + // towards +inf. The result is real_numerator/denom = NaN, when it should + // equal +1 and -1 respectively. Therefore, if our denominator is +inf, + // we just hardcode the limits for the real numbers. + llvm::Value* inf = llvm::ConstantFP::getInfinity(type); + llvm::Value* is_inf = FCmpOEQ(exp_sum_m2, inf); + llvm::Value* real_limit = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {neg_one, a}, {type}, b_); + + llvm::Value* real = + Select(is_inf, real_limit, FDiv(real_numerator, denom)); + llvm::Value* imag = FDiv(imag_numerator, denom); + + // The complex tanh functions have a few corner cases: + // 1. (+0, +0) => (+0, +0) - Handled normally + // 2. (x, +Inf) => (NaN, NaN) - See below + // 3. (x, NaN) => (NaN, NaN) - See below + // 4. (+inf, y) => (1, +0) - Handled normally + // 5. (+Inf, +Inf) => (1, +/-0) - See below + // 6. (+Inf, NaN) => (1, +/-0) - See below + // 7. (NaN, +0) => (NaN, +0) - See below + // 8. (NaN, y) => (NaN, NaN) - Handled normally + // 9. (NaN, NaN) => (NaN, NaN) - Handled normally + // + // For the cases that aren't handled normally: + // 2/3) Part of the calculation we do is that if exp(a) + exp(-a) = +inf, + // then we return (+/-1, +/-0). However, this is only true if we + // assume that a is infinity or b is finite. In the event that both a + // is finite and b is either +/-Inf or NaN, then our normal + // calculation would end up returing (+/-1, NaN), as opposed to (NaN, + // NaN). + // 5/6) We always calculate the imagninary value as sin(2b)/denominator. + // When the denominator is infinity, this assures us that the zero is + // the correct sign. However if our imaginary input results in + // sin(2b) = NaN, we calculate our imaginary result as NaN. + // 7) In the event that a is NaN, the denominator will be NaN. + // Therefore, the normal calculation gives (NaN, NaN) while we need + // (NaN, +0). + if (!(b_->getFastMathFlags().noNaNs() && + b_->getFastMathFlags().noInfs())) { + llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, + {a}, {type}, b_); + llvm::Value* zero = llvm::ConstantFP::get(type, 0.F); + llvm::Value* nan = llvm::ConstantFP::getNaN(type); + + llvm::Value* a_is_inf = FCmpOEQ(abs_a, inf); + llvm::Value* b_is_zero = FCmpOEQ(b, zero); + + // imag_numerator = 2sin(2b), so sin(2b) is NaN if and only if + // imag_numerator is NaN. + llvm::Value* sin_2b_is_nan = + b_->CreateFCmpUNO(imag_numerator, imag_numerator); + + llvm::Value* real_is_nan = + b_->CreateAnd(sin_2b_is_nan, b_->CreateNot(a_is_inf)); + llvm::Value* imag_is_zero = + b_->CreateOr(b_is_zero, b_->CreateAnd(a_is_inf, sin_2b_is_nan)); + + real = Select(real_is_nan, nan, real); + imag = Select(imag_is_zero, zero, imag); + } + + return EmitComposeComplex(op, real, imag); } case HloOpcode::kAbs: { return EmitComplexAbs(component_type, operand_value); @@ -681,18 +776,10 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( FDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kSqrt: { - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - auto c = llvm::ConstantFP::get(a->getType(), 0.5); - auto d = llvm::ConstantFP::get(b->getType(), 0.0); - return EmitComplexPower(op, a, b, c, d); + return EmitComplexSqrt(op, component_type, operand_value); } case HloOpcode::kRsqrt: { - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - auto c = llvm::ConstantFP::get(a->getType(), -0.5); - auto d = llvm::ConstantFP::get(b->getType(), 0.0); - return EmitComplexPower(op, a, b, c, d); + return EmitComplexRsqrt(op, component_type, operand_value); } case HloOpcode::kNegate: return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), @@ -783,25 +870,209 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( // Using sqrt(a^2 + b^2) can cause overflow errors. Therefore we can use // sqrt(a^2 + b^2) = sqrt(a^2 * (1 + b^2/a^2)) // = |a| * sqrt(1 + (b/a)^2) -// With the assumption that |a| >= |b| +// With the assumption that |a| >= |b|. +// +// This method returns the min, max, and sqrt term for this calculation. This is +// done to prevent potential overflow errors that can occur from multiplying the +// max with the sqrt term. (i.e. when calculating the sqrt of the absolute +// value, we can take the sqrt of the max and the sqrt term before multiplying +// them together.) If return_sqrt is false, it returns 1 + (b/a)^2 instead of +// sqrt(1 + (b/a)^2). +StatusOr> +ElementalIrEmitter::EmitComplexAbsHelper(PrimitiveType prim_type, + llvm::Value* operand_value, + bool return_sqrt) { + llvm::Value* real = EmitExtractReal(operand_value); + llvm::Value* imag = EmitExtractImag(operand_value); + llvm::Value* abs_real = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::fabs, {real}, {real->getType()}, b_); + llvm::Value* abs_imag = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::fabs, {imag}, {imag->getType()}, b_); + llvm::Value* max = EmitFloatMax(abs_real, abs_imag); + llvm::Value* min = EmitFloatMin(abs_real, abs_imag); + + llvm::Value* div = FDiv(min, max); + llvm::Value* div_sq = FMul(div, div); + llvm::Value* one = llvm::ConstantFP::get(max->getType(), 1); + llvm::Value* one_p_div_sq = FAdd(one, div_sq); + TF_ASSIGN_OR_RETURN(llvm::Value * sqrt, EmitSqrt(prim_type, one_p_div_sq)); + return std::make_tuple(min, max, return_sqrt ? sqrt : one_p_div_sq); +} + StatusOr ElementalIrEmitter::EmitComplexAbs( PrimitiveType prim_type, llvm::Value* operand_value) { - auto real = EmitExtractReal(operand_value); - auto imag = EmitExtractImag(operand_value); - auto abs_real = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {real}, - {real->getType()}, b_); - auto abs_imag = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {imag}, - {imag->getType()}, b_); - auto max = EmitFloatMax(abs_real, abs_imag); - auto min = EmitFloatMin(abs_real, abs_imag); + llvm::Value* min; + llvm::Value* max; + llvm::Value* sqrt; + TF_ASSIGN_OR_RETURN( + std::tie(min, max, sqrt), + EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true)); + llvm::Value* result = FMul(max, sqrt); + // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN. + // In such cases, we return `min` instead of `result`. + return Select(FCmpUNO(result, result), min, result); +} - auto div = FDiv(min, max); - auto div_sq = FMul(div, div); - auto one = llvm::ConstantFP::get(max->getType(), 1); - TF_ASSIGN_OR_RETURN(auto sqrt, EmitSqrt(prim_type, FAdd(one, div_sq))); +// Calculates ComplexAbs in the same way, except using: +// sqrt(|a| * sqrt(1 + (b/a)^2)) = sqrt(|a|) * pow(1 + (b/a)^2, .25) +StatusOr ElementalIrEmitter::EmitSqrtComplexAbs( + PrimitiveType prim_type, llvm::Value* operand_value) { + llvm::Value* min; + llvm::Value* max; + llvm::Value* one_p_div_sq; + TF_ASSIGN_OR_RETURN( + std::tie(min, max, one_p_div_sq), + EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/false)); + TF_ASSIGN_OR_RETURN(llvm::Value * sqrt_max, EmitSqrt(prim_type, max)); + TF_ASSIGN_OR_RETURN(llvm::Value * pow, + EmitPow(prim_type, one_p_div_sq, + llvm::ConstantFP::get(max->getType(), .25))); + llvm::Value* result = FMul(sqrt_max, pow); + // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN. + // In such cases, we return `min` instead of `result`. + return Select(FCmpUNO(result, result), min, result); +} - auto zero = llvm::ConstantFP::get(max->getType(), 0); - return Select(FCmpOEQ(max, zero), zero, FMul(max, sqrt)); +// Calculates ComplexAbs in the same way, except using: +// rsqrt(|a| * sqrt(1 + (b/a)^2)) = rsqrt(|a|) * rsqrt(sqrt(1 + (b/a)^2)) +StatusOr ElementalIrEmitter::EmitRsqrtComplexAbs( + PrimitiveType prim_type, llvm::Value* operand_value) { + llvm::Value* min; + llvm::Value* max; + llvm::Value* sqrt; + TF_ASSIGN_OR_RETURN( + std::tie(min, max, sqrt), + EmitComplexAbsHelper(prim_type, operand_value, /*return_sqrt=*/true)); + TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_max, EmitRsqrt(prim_type, max)); + TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_sqrt, EmitRsqrt(prim_type, sqrt)); + llvm::Value* result = FMul(rsqrt_max, rsqrt_sqrt); + TF_ASSIGN_OR_RETURN(llvm::Value * rsqrt_min, EmitRsqrt(prim_type, min)); + // When (min, max) are (0, 0), (inf, inf), or (NaN, ...), `result` is NaN. + // In such cases, we return rsqrt(min) instead of `result`. + return Select(FCmpUNO(result, result), rsqrt_min, result); +} + +// Using our EmitComplexPower formula, but setting c=0.5 and d=0, we get: +// e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)] +// = e^[ln(r)*0.5] * [cos(t*0.5) + i*sin(t*0.5)] +// = r^0.5 * [cos(t/2) + i*sin(t/2)] +// = sqrt(r) * [cos(t/2) + i*sin(t/2)] +// where r = |a+bi| and t = atan2(b,a) +// TODO(bixia): See doc for implementation without atan2. +StatusOr ElementalIrEmitter::EmitComplexSqrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value) { + llvm::Type* type = static_cast(operand_value->getType()) + ->getElementType(0); + + TF_ASSIGN_OR_RETURN(llvm::Value * r, + EmitSqrtComplexAbs(prim_type, operand_value)); + + llvm::Value* a = EmitExtractReal(operand_value); + llvm::Value* b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a)); + + llvm::Value* c = llvm::ConstantFP::get(type, 0.5); + llvm::Value* angle = FMul(t, c); + TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle)); + TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle)); + + llvm::Value* real_part; + llvm::Value* imag_part; + + llvm::Value* zero = llvm::ConstantFP::get(type, 0); + + if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) { + llvm::Value* inf = llvm::ConstantFP::getInfinity(type); + llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true); + llvm::Value* nan = llvm::ConstantFP::getNaN(type); + llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, + {b}, {b->getType()}, b_); + + real_part = Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, inf)), inf, + Select(And(FCmpOEQ(a, neg_inf), FCmpONE(abs_b, inf)), + zero, FMul(r, cos))); + + llvm::Value* b_signed_inf = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {inf, b}, {b->getType()}, b_); + imag_part = + Select(Or(FCmpOEQ(abs_b, inf), FCmpOEQ(a, neg_inf)), b_signed_inf, + Select(FCmpUNO(r, r), nan, + Select(FCmpOEQ(sin, zero), sin, FMul(r, sin)))); + } else { + real_part = FMul(r, cos); + imag_part = Select(FCmpOEQ(sin, zero), sin, FMul(r, sin)); + } + + return Select(FCmpOEQ(r, zero), EmitComposeComplex(op, zero, zero), + EmitComposeComplex(op, real_part, imag_part)); +} + +// Similar to Sqrt, we can use our EmitComplexPower formula, but set +// c=-0.5 and d=0. We get: +// e^[ln(r)*c - t*d] * [cos(ln(r)*d + t*c) + i*sin(ln(r)*d + t*c)] +// = e^[ln(r)*-0.5] * [cos(t*-0.5) + i*sin(t*-0.5)] +// = r^(-0.5) * [cos(-t/2) + i*sin(-t/2)] +// = rsqrt(r) * [cos(-t/2) + i*sin(-t/2)] +// where r = |a+bi| and t = atan2(b,a). +StatusOr ElementalIrEmitter::EmitComplexRsqrt( + const HloInstruction* op, PrimitiveType prim_type, + llvm::Value* operand_value) { + llvm::Type* type = static_cast(operand_value->getType()) + ->getElementType(0); + + TF_ASSIGN_OR_RETURN(llvm::Value * r, + EmitRsqrtComplexAbs(prim_type, operand_value)); + + llvm::Value* a = EmitExtractReal(operand_value); + llvm::Value* b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN(llvm::Value * t, EmitAtan2(prim_type, b, a)); + + llvm::Value* c = llvm::ConstantFP::get(type, -0.5); + llvm::Value* angle = FMul(t, c); + TF_ASSIGN_OR_RETURN(llvm::Value * cos, EmitCos(prim_type, angle)); + TF_ASSIGN_OR_RETURN(llvm::Value * sin, EmitSin(prim_type, angle)); + + llvm::Value* real_part = FMul(r, cos); + llvm::Value* imag_part = FMul(r, sin); + + if (!(b_->getFastMathFlags().noNaNs() && b_->getFastMathFlags().noInfs())) { + llvm::Value* zero = llvm::ConstantFP::get(type, 0); + llvm::Value* neg_one = llvm::ConstantFP::get(type, -1); + llvm::Value* inf = llvm::ConstantFP::getInfinity(type); + llvm::Value* nan = llvm::ConstantFP::getNaN(type); + // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true); + llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_); + llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::copysign, {zero, b}, {b->getType()}, b_); + llvm::Value* neg_b_signed_zero = FMul(b_signed_zero, neg_one); + + llvm::Value* abs_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, + {a}, {a->getType()}, b_); + llvm::Value* abs_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, + {b}, {b->getType()}, b_); + + llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero)); + real_part = Select( + is_zero_zero, inf, + Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)), + a_signed_zero, FMul(r, cos))); + imag_part = Select( + is_zero_zero, nan, + Select(Or(And(FCmpOEQ(abs_b, inf), FCmpUNO(a, a)), FCmpOEQ(abs_a, inf)), + neg_b_signed_zero, FMul(r, sin))); + } else { + llvm::Value* zero = llvm::ConstantFP::get(type, 0); + llvm::Value* inf = llvm::ConstantFP::getInfinity(type); + llvm::Value* nan = llvm::ConstantFP::getNaN(type); + + llvm::Value* is_zero_zero = And(FCmpOEQ(b, zero), FCmpOEQ(a, zero)); + real_part = Select(is_zero_zero, inf, FMul(r, cos)); + imag_part = Select(is_zero_zero, nan, FMul(r, sin)); + } + + return EmitComposeComplex(op, real_part, imag_part); } // (a+bi)^(c+di) = @@ -1051,7 +1322,7 @@ StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, return Select(x_is_small, for_small_x, for_large_x); } -StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType prim_type, +StatusOr ElementalIrEmitter::EmitSqrt(PrimitiveType, llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {value}, {value->getType()}, b_); @@ -1097,7 +1368,10 @@ StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, auto x_squared = FMul(x, x); auto x_squared_over_two = FMul(x_squared, half); auto for_small_x = FAdd(x, x_squared_over_two); - const auto kExponentIsSmallThreshold = 1e-5; + // At this point, the relative errors due to floating point precision loss of + // calculating exp(x) - 1 and the polynomial exp(x)-1 = x + x^2/2 are about + // equal, with a value of approximetely 2^-16. + const auto kExponentIsSmallThreshold = 0.009; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); auto x_is_small = @@ -1433,7 +1707,8 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( source_index_phis[operand_id] = PHI(source_index.GetType(), operand_usage_count[operand_id]); std::vector operand_multi_index = source_index.multidim(); - operand_multi_index[concat_dim] = source_index_phis[operand_id]; + operand_multi_index[concat_dim] = + NSWSub(operand_multi_index[concat_dim], source_index_phis[operand_id]); // Create the terminator of the block before calling operand generators, // because they require non-degenerate basic blocks. @@ -1447,25 +1722,24 @@ StatusOr ElementalIrEmitter::EmitElementalConcatenate( b_->SetInsertPoint(init_block, saved_insert_point); } - std::vector source_multi_index = source_index.multidim(); + int64 concat_dim_size = 0; for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); auto false_block = llvm_ir::CreateBasicBlock( exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_); - auto concat_dim_size = source_index.GetConstantWithIndexType( - operand->shape().dimensions(concat_dim)); int64 operand_id = to_unique_operand_id[operand]; - source_index_phis[operand_id]->addIncoming(source_multi_index[concat_dim], - b_->GetInsertBlock()); - CondBr(ICmpULT(source_multi_index[concat_dim], concat_dim_size), + source_index_phis[operand_id]->addIncoming( + source_index.GetConstantWithIndexType(concat_dim_size), + b_->GetInsertBlock()); + concat_dim_size += operand->shape().dimensions(concat_dim); + CondBr(ICmpULT(source_index[concat_dim], + source_index.GetConstantWithIndexType(concat_dim_size)), emit_operand_blocks[operand_id], false_block); // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_multi_index[concat_dim] = - Sub(source_multi_index[concat_dim], concat_dim_size); } Unreachable(); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 3ba669c5365..99833a5525f 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -143,9 +143,26 @@ class ElementalIrEmitter : public IrBuilderMixin { virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x); + virtual StatusOr> + EmitComplexAbsHelper(PrimitiveType prim_type, llvm::Value* operand_value, + bool return_sqrt); + virtual StatusOr EmitComplexAbs(PrimitiveType prim_type, llvm::Value* operand_value); + virtual StatusOr EmitSqrtComplexAbs(PrimitiveType prim_type, + llvm::Value* operand_value); + virtual StatusOr EmitRsqrtComplexAbs( + PrimitiveType prim_type, llvm::Value* operand_value); + + virtual StatusOr EmitComplexSqrt(const HloInstruction* op, + PrimitiveType prim_type, + llvm::Value* operand_value); + + virtual StatusOr EmitComplexRsqrt(const HloInstruction* op, + PrimitiveType prim_type, + llvm::Value* operand_value); + virtual llvm::Value* EmitExtractReal(llvm::Value* value); virtual llvm::Value* EmitExtractImag(llvm::Value* value); diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 7b60c983b30..c45ecc7c2c4 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -26,9 +26,42 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/stream_executor/device_description.h" namespace xla { +StatusOr Executable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile) { + StatusOr result = + ExecuteAsyncOnStream(run_options, arguments, hlo_execution_profile); + Status blocking_status = run_options->stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(blocking_status); + return result; +} + +StatusOr Executable::ExecuteOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile) { + StatusOr result = ExecuteAsyncOnStream( + run_options, std::move(arguments), hlo_execution_profile); + Status blocking_status = run_options->stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(blocking_status); + return result; +} + +StatusOr Executable::ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* /*run_options*/, + std::vector> /*arguments*/, + HloExecutionProfile* /*hlo_execution_profile*/) { + return Unimplemented( + "MaybeOwningDeviceMemory version of overload is not implemented "); +} + StatusOr> Executable::ExecuteOnStreams( absl::Span run_options, absl::Span> arguments) { @@ -49,8 +82,9 @@ StatusOr> Executable::ExecuteOnStreams( // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. - TF_ASSIGN_OR_RETURN(auto rv, - ExecuteAsyncOnStream(&run_options[i], arguments[i])); + TF_ASSIGN_OR_RETURN( + auto rv, ExecuteAsyncOnStream(&run_options[i], arguments[i], + /*hlo_execution_profile=*/nullptr)); return_values.push_back(std::move(rv)); } for (const auto& options : run_options) { @@ -61,27 +95,39 @@ StatusOr> Executable::ExecuteOnStreams( } StatusOr Executable::ExecuteOnStreamWrapper( - const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, + const ServiceExecutableRunOptions* run_options, + absl::Span arguments) { + StatusOr result = + ExecuteAsyncOnStreamWrapper(run_options, arguments); + Status block_status = run_options->stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(block_status); + return result; +} + +StatusOr Executable::ExecuteAsyncOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, absl::Span arguments) { se::Stream* stream = run_options->stream(); - std::unique_ptr timer; + std::shared_ptr timer; + ExecutionProfile* profile = run_options->run_options().execution_profile(); if (profile != nullptr) { - timer.reset(new se::Timer(stream->parent())); + timer = std::make_shared(stream->parent()); stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); } VLOG(1) << "enqueueing executable on stream..."; // If the profiling flag isn't enabled, we pass nullptr as the profile to // indicate profiling is not requested. - std::unique_ptr profile_ptr = + std::shared_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? absl::make_unique(&hlo_profile_printer_data(), - &hlo_profile_index_map()) + ? std::make_shared(&hlo_profile_printer_data(), + &hlo_profile_index_map()) : nullptr; StatusOr return_value = - ExecuteOnStream(run_options, arguments, profile_ptr.get()); + ExecuteAsyncOnStream(run_options, arguments, profile_ptr.get()); if (!return_value.status().ok()) { if (profile != nullptr) { // Ensure the ThenStartTimer call has completed before we destroy timer. @@ -96,30 +142,19 @@ StatusOr Executable::ExecuteOnStreamWrapper( } if (profile != nullptr) { - VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; + VLOG(1) << "enqueueing 'stop timer' and profiling callback..."; stream->ThenStopTimer(timer.get()); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - VLOG(1) << "done with block-host-until-done"; + // We block instead of using an async callback because reading the timer + // value may call back into the driver on GPU, which is not allowed. + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + + const int64 executable_size_in_bytes = SizeOfGeneratedCodeInBytes(); // Merge in run-time profile information from execution_profile. - // - // TODO(b/71713097): This is buggy -- even though the mutex takes care of - // C++ level races, some other concurrent ExecuteOnStreamWrapper call could - // have rewritten the execution_profile before we get to it. - profile->MergeFrom(execution_profile()); // Overall execution time (in nanoseconds) from the executor timer. - if (stream->ok()) { - // Don't read timer->Nanoseconds() if the stream isn't OK -- that's - // illegal. - profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); - } + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); - // TODO(b/28123297): On GPU we end up including transfer time in - // the compute time this way. Instead, we should get the correct - // value by measuring it. Setting the field here at least lets - // benchmarks provide *some* value for GPU computations. - // // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually // the compute time without the transfer time, so this way we get the // correct compute time. We should instead have the correct value for @@ -128,21 +163,23 @@ StatusOr Executable::ExecuteOnStreamWrapper( profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); } - const int64 executable_size_in_bytes = SizeInBytes(); if (executable_size_in_bytes != 0) { profile->set_executable_size_in_bytes(executable_size_in_bytes); } } if (profile_ptr != nullptr) { - XLA_LOG_LINES( - tensorflow::INFO, - profile_ptr->ToString(stream->parent()->GetDeviceDescription())); + const se::DeviceDescription* device_description = + &stream->parent()->GetDeviceDescription(); + stream->ThenDoHostCallback([profile_ptr, device_description]() { + XLA_LOG_LINES(tensorflow::INFO, + profile_ptr->ToString(*device_description)); + }); } return return_value; } -int64 Executable::SizeInBytes() { return -1; } +int64 Executable::SizeOfGeneratedCodeInBytes() { return -1; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 492ea72228d..223832271ec 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -123,16 +123,10 @@ class Executable { // enabled. // // Returns a shaped buffer containing the result of the computation. - virtual StatusOr ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, - HloExecutionProfile* hlo_execution_profile) = 0; - - // Same as ExecuteOnStream(), but this call is non-blocking and returns as - // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) = 0; + HloExecutionProfile* hlo_execution_profile); // Starts the given program executing on the given stream/executor. // @@ -143,20 +137,31 @@ class Executable { // // If an input is donated to XLA but is not reused as output, it is returned // as an leftover buffer for the caller to release. - virtual StatusOr ExecuteOnStream( + // + // This call should be non-blocking and may return as soon as all of the + // operations are enqueued for launch on the stream. Note that some + // implementations may in fact block or may block in some circumstances (e.g., + // when profiling); i.e., asynchronous is a "may" not a "must". + // + // If the hlo_execution_profile is provided as non-nullptr, profiling will be + // enabled. Note that profiling is tricky to use correctly, as the profiling + // objects (when they exist) must out-live the task. + virtual StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments, + HloExecutionProfile* hlo_execution_profile) = 0; + + // Same as ExecuteAsyncOnStream(), but blocks waiting for the computation to + // complete. + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, std::vector> arguments, - HloExecutionProfile* hlo_execution_profile) { - return Unimplemented( - "MaybeOwningDeviceMemory version of overload is not implemented "); - } + HloExecutionProfile* hlo_execution_profile); virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - std::vector> arguments) { - return Unimplemented( - "MaybeOwningDeviceMemory version of overload is not implemented "); - } + std::vector> arguments, + HloExecutionProfile* hlo_execution_profile); // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on @@ -171,6 +176,7 @@ class Executable { // called explicitly for other (async, for example) variants after the stream // has completed. virtual Status PopulateExecutionProfile( + ExecutionProfile* execution_profile, HloExecutionProfile* hlo_execution_profile, se::Stream* stream) { return Status::OK(); } @@ -179,15 +185,12 @@ class Executable { // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. StatusOr ExecuteOnStreamWrapper( - const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, + const ServiceExecutableRunOptions* run_options, absl::Span arguments); - // Returns the ExecutionProfile from executing on the device. This includes - // the number of cycles taken for the computation or the compilation time. - ExecutionProfile execution_profile() const { - tensorflow::mutex_lock lock(mutex_); - return execution_profile_; - } + StatusOr ExecuteAsyncOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, + absl::Span arguments); const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); @@ -219,30 +222,27 @@ class Executable { return hlo_module_->config().entry_computation_layout().result_shape(); } - // Returns the size of the executable in bytes. Returns -1 by default if the - // method is not overridden to support this kind of query. - virtual int64 SizeInBytes(); + // Returns the size of the executable in bytes. Returns -1 if this query is + // not supported by the executable. + // + // Does not include the size of used libraries (e.g. cuDNN, Eigen, etc.). + virtual int64 SizeOfGeneratedCodeInBytes(); // Dumping helpers. - void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { - hlo_snapshot_ = std::move(hlo_snapshot); + void set_hlo_proto(std::unique_ptr hlo_proto) { + hlo_proto_ = std::move(hlo_proto); } - bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } - HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } + bool dumping_snapshot() const { return hlo_proto_ != nullptr; } + HloProto const* hlo_proto() const { return hlo_proto_.get(); } protected: - mutable tensorflow::mutex mutex_; - - // Execution profile data on the device. - ExecutionProfile execution_profile_ GUARDED_BY(mutex_); - // HloModule this was compiled from. BufferAssignment keeps pointers to // HloInstructions owned by the HloModule so we need to keep the HloModule // around. const std::shared_ptr hlo_module_; - // HloSnapshot this was compiled from. Null if not dumping executions. - std::unique_ptr hlo_snapshot_; + // The serialized HLO proto. Non-null only if dumping snapshots is enabled. + std::unique_ptr hlo_proto_; // Execution count, used to generate a unique filename for each dumped // execution. diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h index 4ddb96c5539..3eec47ee205 100644 --- a/tensorflow/compiler/xla/service/fusion_queue.h +++ b/tensorflow/compiler/xla/service/fusion_queue.h @@ -15,8 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ -#include +#include +#include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -25,15 +27,11 @@ namespace xla { using FusionConfig = std::vector>; // Converts fusion config to string format. -static string FusionConfigToString(const FusionConfig& config) { - string s = ""; - for (auto& edge_list : config) { - for (auto edge : edge_list) { - if (edge) { - s += "1"; - } else { - s += "0"; - } +static std::string FusionConfigToString(const FusionConfig& config) { + std::string s; + for (const auto& edge_list : config) { + for (bool edge : edge_list) { + absl::StrAppend(&s, edge ? "1" : "0"); } } return s; diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 2eae159861c..d65083d701a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -53,7 +53,7 @@ Status GenericTransferManager::WriteSingleTupleIndexTable( TF_RETURN_IF_ERROR(TransferBufferToDevice( stream, GetByteSizeRequirement(shape), element_pointers->data(), region)); // Ensure the buffer is transferred before we destroy element_pointers. - stream->ThenDoHostCallback([element_pointers]() { + stream->ThenRunAfterNextBlockHostUntilDone([element_pointers]() { /* holds reference to element_pointers in closure */ }); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD old mode 100644 new mode 100755 index a5fc6e80cec..053c3051aea --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -3,12 +3,24 @@ load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "if_static", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) -load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_copts", "tf_cuda_library") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", +) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load( + "//tensorflow/core/platform:default/cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) package( default_visibility = [":friends"], @@ -186,6 +198,26 @@ cc_library( ], ) +cc_library( + name = "thunk_emitter", + srcs = ["thunk_emitter.cc"], + hdrs = ["thunk_emitter.h"], + deps = [ + ":backend_configs", + ":buffer_allocations", + ":gpu_constants", + ":gpu_executable", + ":ir_emission_utils", + ":nccl_all_reduce_thunk", + ":thunk", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:custom_call_target_registry", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + ], +) + cc_library( name = "ir_emitter", srcs = [ @@ -213,6 +245,7 @@ cc_library( ":partition_assignment", ":target_util", ":thunk", + ":thunk_emitter", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -222,7 +255,6 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:custom_call_target_registry", "//tensorflow/compiler/xla/service:elemental_ir_emitter", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", @@ -260,6 +292,7 @@ cc_library( hdrs = ["parallel_loop_emitter.h"], deps = [ ":partition_assignment", + ":target_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", @@ -410,7 +443,6 @@ tf_cc_test( cc_library( name = "gpu_executable", srcs = [ - "cholesky_thunk.cc", "collective_permute_thunk.cc", "conditional_thunk.cc", "convolution_thunk.cc", @@ -431,9 +463,10 @@ cc_library( "triangular_solve_thunk.cc", "tuple_thunk.cc", "while_thunk.cc", - ], + ] + if_cuda_is_configured([ + "cholesky_thunk.cc", + ]), hdrs = [ - "cholesky_thunk.h", "collective_permute_thunk.h", "conditional_thunk.h", "convolution_thunk.h", @@ -454,12 +487,13 @@ cc_library( "triangular_solve_thunk.h", "tuple_thunk.h", "while_thunk.h", - ], + ] + if_cuda_is_configured([ + "cholesky_thunk.h", + ]), deps = [ ":backend_configs", ":buffer_allocations", ":cudnn_conv_runner", - ":cusolver_context", ":gpu_debug_info_manager", ":gpu_types", ":hlo_execution_profiler", @@ -495,17 +529,12 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core/platform/default/build_config:cublas_plugin", - "//tensorflow/core/platform/default/build_config:cudnn_plugin", - "//tensorflow/core/platform/default/build_config:cufft_plugin", - "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor", "//tensorflow/stream_executor:blas", "//tensorflow/stream_executor:device_memory", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:kernel", - "//tensorflow/stream_executor/cuda:cuda_stream", "//tensorflow/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -516,8 +545,18 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", + ] + if_cuda_is_configured([ + ":cusolver_context", + "//tensorflow/stream_executor/cuda:cuda_stream", + "//tensorflow/core/platform/default/build_config:cublas_plugin", + "//tensorflow/core/platform/default/build_config:cudnn_plugin", + "//tensorflow/core/platform/default/build_config:cufft_plugin", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "@local_config_cuda//cuda:cuda_headers", - ], + ]) + if_rocm_is_configured([ + "//tensorflow/core/platform/default/build_config:stream_executor_rocm", + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( @@ -596,6 +635,7 @@ cc_library( ":cudnn_conv_runner", ":gpu_autotuning_proto", ":gpu_executable", + ":hlo_algorithm_blacklist", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:literal_util", @@ -620,18 +660,6 @@ cc_library( ], ) -cc_library( - name = "scratch_allocator", - srcs = ["scratch_allocator.cc"], - hdrs = ["scratch_allocator.h"], - deps = [ - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/stream_executor:device_memory_allocator", - ], -) - cc_library( name = "cudnn_conv_runner", srcs = ["cudnn_conv_runner.cc"], @@ -703,10 +731,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:blas", - ] + if_static( - ["@local_config_cuda//cuda:cusolver"], - ["//tensorflow/stream_executor/cuda:cusolver_stub"], - ), + "//tensorflow/stream_executor/cuda:cusolver_lib", + ], ) cc_library( @@ -939,6 +965,38 @@ tf_cc_test( ], ) +cc_library( + name = "cublas_gemm_pad_for_tensor_cores", + srcs = ["cublas_gemm_pad_for_tensor_cores.cc"], + hdrs = ["cublas_gemm_pad_for_tensor_cores.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core/platform:types", + ], +) + +tf_cc_test( + name = "cublas_gemm_pad_for_tensor_cores_test", + srcs = ["cublas_gemm_pad_for_tensor_cores_test.cc"], + deps = [ + ":cublas_gemm_pad_for_tensor_cores", + ":ir_emission_utils", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + ], +) + cc_library( name = "target_constants", hdrs = ["target_constants.h"], @@ -972,20 +1030,19 @@ cc_library( ) cc_library( - name = "nvptx_compiler_impl", - srcs = ["nvptx_compiler.cc"], - hdrs = ["nvptx_compiler.h"], + name = "gpu_compiler", + srcs = [ + "gpu_compiler.cc", + ], + hdrs = [ + "gpu_compiler.h", + ], deps = [ ":cudnn_batchnorm_rewriter", ":cudnn_conv_algorithm_picker", - ":cudnn_conv_pad_for_tensor_cores", ":cudnn_conv_padding_legalization", ":cudnn_conv_rewriter", - ":cudnn_fused_conv_rewriter", - ":cusolver_rewriter", ":fusion_merger", - ":gemm_algorithm_picker", - ":gemm_rewriter", ":gpu_constants", ":gpu_copy_insertion", ":gpu_executable", @@ -1013,7 +1070,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:convolution_group_converter", + "//tensorflow/compiler/xla/service:depthwise_convolution_converter", "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:dump", "//tensorflow/compiler/xla/service:dynamic_index_splitter", @@ -1038,6 +1095,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:slice_sinker", + "//tensorflow/compiler/xla/service:slow_operation_alarm", "//tensorflow/compiler/xla/service:sort_simplifier", "//tensorflow/compiler/xla/service:stable_sort_expander", "//tensorflow/compiler/xla/service:transpose_folding", @@ -1048,15 +1106,12 @@ cc_library( "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/stream_executor:stream_executor_headers", - "//tensorflow/stream_executor/cuda:cuda_diagnostics", - "//tensorflow/stream_executor/cuda:ptxas_utils", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1068,11 +1123,108 @@ cc_library( cc_library( name = "nvptx_compiler", - srcs = ["nvptx_compiler_registration.cc"], - deps = [":nvptx_compiler_impl"], + srcs = [ + "nvptx_compiler_registration.cc", + ], + deps = [ + ":nvptx_compiler_impl", + ], alwayslink = True, # Contains compiler registration ) +cc_library( + name = "nvptx_compiler_impl", + srcs = [ + "nvptx_compiler.cc", + ], + hdrs = [ + "nvptx_compiler.h", + ], + deps = [ + ":cudnn_conv_algorithm_picker", + ":cudnn_conv_pad_for_tensor_cores", + ":cudnn_conv_padding_legalization", + ":cudnn_conv_rewriter", + ":cudnn_fused_conv_rewriter", + ":cusolver_rewriter", + ":gemm_algorithm_picker", + ":gemm_rewriter", + ":gpu_compiler", + ":gpu_layout_assignment", + ":stream_executor_util", + ":target_constants", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:dump", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_constant_folding", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:llvm_compiler", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:cuda_libdevice_path", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:stream_executor_headers", + "//tensorflow/stream_executor/cuda:cuda_diagnostics", + "//tensorflow/stream_executor/cuda:ptxas_utils", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "amdgpu_compiler", + srcs = [ + "amdgpu_compiler_registration.cc", + ], + deps = [ + ":amdgpu_compiler_impl", + ], + alwayslink = True, # Contains compiler registration +) + +cc_library( + name = "amdgpu_compiler_impl", + srcs = [ + "amdgpu_compiler.cc", + ], + hdrs = [ + "amdgpu_compiler.h", + ], + deps = [ + ":cudnn_conv_padding_legalization", + ":cudnn_conv_rewriter", + ":gpu_compiler", + ":gpu_layout_assignment", + ":target_constants", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_constant_folding", + "//tensorflow/compiler/xla/service:hlo_cse", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:llvm_compiler", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/platform:rocm_rocdl_path", + ], +) + cc_library( name = "cudnn_batchnorm_rewriter", srcs = ["cudnn_batchnorm_rewriter.cc"], @@ -1411,3 +1563,30 @@ xla_proto_library( "//tensorflow/core:autotuning_proto_cc", ], ) + +cc_library( + name = "hlo_algorithm_blacklist", + srcs = ["hlo_algorithm_blacklist.cc"], + hdrs = ["hlo_algorithm_blacklist.h"], + deps = [ + ":gpu_autotuning_proto", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/core:autotuning_proto_cc", + "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "hlo_algorithm_blacklist_test", + srcs = ["hlo_algorithm_blacklist_test.cc"], + data = ["data/hlo_algorithm_blacklist.pbtxt"], + deps = [ + ":hlo_algorithm_blacklist", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor:dnn", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc new file mode 100644 index 00000000000..949707a22e6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h" + +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" +// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged. +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/core/platform/rocm_rocdl_path.h" + +namespace xla { +namespace gpu { + +namespace { + +// Returns the directory containing ROCm-Device-Libs files. This function is +// called in AMDGPUCompiler's constructor, so can't return an error. But +// AMDGPUCompiler::Compile will return an error when the wanted rocdl file +// doesn't exist in the folder this function returns. +string GetROCDLDir(const HloModuleConfig& config) { + std::vector potential_rocdl_dirs; + const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); + if (!datadir.empty()) { + potential_rocdl_dirs.push_back(datadir); + } + potential_rocdl_dirs.push_back(tensorflow::RocdlRoot()); + + // Tries all potential ROCDL directories in the order they are inserted. + // Returns the first directory that exists in the file system. + for (const string& potential_rocdl_dir : potential_rocdl_dirs) { + if (tensorflow::Env::Default()->IsDirectory(potential_rocdl_dir).ok()) { + VLOG(2) << "Found ROCm-Device-Libs dir " << potential_rocdl_dir; + return potential_rocdl_dir; + } + VLOG(2) << "Unable to find potential ROCm-Device-Libs dir " + << potential_rocdl_dir; + } + + // Last resort: maybe in the current folder. + return "."; +} + +} // namespace + +Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Convert convolutions into CustomCalls to MIOpen, then canonicalize them + // (PadInsertion). + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pipeline.AddPass(); + pipeline.AddPass(); + + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + +Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + pipeline.AddPass>(options); + + // TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged. + + // Clean up new_tuple described above. + pipeline.AddPass(); + + pipeline.AddPass(/*is_layout_sensitive=*/true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + +AMDGPUCompiler::AMDGPUCompiler() + : GpuCompiler(stream_executor::rocm::kROCmPlatformId, amdgpu::kTargetTriple, + amdgpu::kDataLayout) {} + +GpuVersion AMDGPUCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) { + int isa_version = 0; + if (!stream_exec->GetDeviceDescription().rocm_amdgpu_isa_version( + &isa_version)) { + LOG(WARNING) + << "Couldn't get AMDGPU ISA version for device; assuming gfx803."; + isa_version = 803; + } + + return isa_version; +} + +StatusOr>> +AMDGPUCompiler::CompileTargetBinary(const HloModule* module, + llvm::Module* llvm_module, + GpuVersion gpu_version, + se::StreamExecutor* stream_exec) { + if (rocdl_dir_.empty()) { + // Compute rocdl_dir_ just once and cache it in this member. + rocdl_dir_ = GetROCDLDir(module->config()); + } + + std::vector hsaco; + { + XLA_SCOPED_LOGGING_TIMER( + "AMDGPUCompiler::CompileTargetBinary - CompileToHsaco"); + TF_ASSIGN_OR_RETURN(hsaco, + amdgpu::CompileToHsaco(llvm_module, gpu_version, + module->config(), rocdl_dir_)); + } + + llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false); + + if (user_post_optimization_hook_) { + user_post_optimization_hook_(*llvm_module); + } + + return std::pair>("", std::move(hsaco)); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h new file mode 100644 index 00000000000..d1a74a7822e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { +namespace gpu { + +// AMDGPUCompiler generates efficient GPU executables for AMDGPU target. +class AMDGPUCompiler : public GpuCompiler { + public: + AMDGPUCompiler(); + ~AMDGPUCompiler() override {} + + Status OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + Status OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; + + StatusOr>> CompileTargetBinary( + const HloModule* hlo_module, llvm::Module* llvm_module, + GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; + + private: + // The parent directory of ROCm-Device-Libs IR libraries. + string rocdl_dir_; + + TF_DISALLOW_COPY_AND_ASSIGN(AMDGPUCompiler); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_AMDGPU_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc new file mode 100644 index 00000000000..3d6d19fe980 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler_registration.cc @@ -0,0 +1,24 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h" + +static bool InitModule() { + xla::Compiler::RegisterCompilerFactory( + stream_executor::rocm::kROCmPlatformId, + []() { return absl::make_unique(); }); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 30108315e4d..e9b371e33d8 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -34,7 +34,7 @@ namespace gpu { static constexpr double kTolerance = 0.1f; -// Comparison kernel code: compare two buffers of fp16/fp32/fp64 of length +// Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length // buffer_length where the relative error does not exceed the passed // rel_error_threshold. Write the number of mismatches into out parameter // mismatch_count. @@ -46,12 +46,20 @@ static constexpr double kTolerance = 0.1f; // // #include // extern "C" { // avoid name mangling -// __device__ float canonicalize(float input) { +// __device__ float __xla_buffer_comparator_canonicalize(float input) { // // All fp16 infinities are treated as 65505 or -65505, in order to avoid // // differences due to overflows. // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f)); // } -// + +// __device__ float __xla_buffer_comparator_extract_int8(int pack) { +// // Extract the lower 8 bits from pack and convert it to float +// const unsigned int bit_mask = 0xff; +// unsigned int bits = pack & bit_mask; +// char* int8_ptr = (char*)&bits; +// return __int2float_rn(*int8_ptr); +// } + // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -60,15 +68,15 @@ static constexpr double kTolerance = 0.1f; // if (idx >= buffer_length) return; // float elem_a = __half2float(buffer_a[idx]); // float elem_b = __half2float(buffer_b[idx]); -// elem_a = canonicalize(elem_a); -// elem_b = canonicalize(elem_b); +// elem_a = __xla_buffer_comparator_canonicalize(elem_a); +// elem_b = __xla_buffer_comparator_canonicalize(elem_b); // if (isnan(elem_a) && isnan(elem_b)) return; // float rel_error = abs(elem_a - elem_b) // / (max(abs(elem_a), abs(elem_b)) + 1); // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -85,7 +93,7 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } -// + // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b, // float rel_error_threshold, // unsigned long long buffer_length, @@ -102,234 +110,440 @@ static constexpr double kTolerance = 0.1f; // if (rel_error > rel_error_threshold || isnan(rel_error)) // atomicAdd(mismatch_count, 1); // } + +// __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b, +// float rel_error_threshold, +// unsigned long long buffer_length, +// int* mismatch_count) { +// int idx = threadIdx.x + blockIdx.x * blockDim.x; +// if (idx >= buffer_length) return; +// int pack_a = buffer_a[idx]; +// int pack_b = buffer_b[idx]; +// for(int i = 0; i < 4; ++i) { +// float elem_a = __xla_buffer_comparator_extract_int8(pack_a); +// float elem_b = __xla_buffer_comparator_extract_int8(pack_b); +// float rel_error = abs(elem_a - elem_b) +// / (max(abs(elem_a), abs(elem_b)) + 1); +// if (rel_error > rel_error_threshold || isnan(rel_error)) +// atomicAdd(mismatch_count, 1); +// pack_a >>= 8; +// pack_b >>= 8; +// } +// } // } // end extern declaration. static const char* buffer_compare_ptx = R"( -.version 4.2 +.version 6.4 .target sm_30 .address_size 64 + // .globl __xla_fp16_comparison + .visible .entry __xla_fp16_comparison( - .param .u64 __xla_fp16_comparison_param_0, - .param .u64 __xla_fp16_comparison_param_1, - .param .f32 __xla_fp16_comparison_param_2, - .param .u64 __xla_fp16_comparison_param_3, - .param .u64 __xla_fp16_comparison_param_4 + .param .u64 __xla_fp16_comparison_param_0, + .param .u64 __xla_fp16_comparison_param_1, + .param .f32 __xla_fp16_comparison_param_2, + .param .u64 __xla_fp16_comparison_param_3, + .param .u64 __xla_fp16_comparison_param_4 ) { - .reg .pred %p<10>; - .reg .b16 %rs<3>; - .reg .f32 %f<20>; - .reg .b32 %r<6>; - .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp16_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB7_4; - ld.param.u64 %rd5, [__xla_fp16_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp16_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 1; - add.s64 %rd10, %rd3, %rd9; - ld.global.u16 %rs1, [%rd10]; - // begin inline asm - { cvt.f32.f16 %f6, %rs1;} + .reg .pred %p<9>; + .reg .b16 %rs<3>; + .reg .f32 %f<28>; + .reg .b32 %r<6>; + .reg .b64 %rd<12>; - // end inline asm - add.s64 %rd11, %rd2, %rd9; - ld.global.u16 %rs2, [%rd11]; - // begin inline asm - { cvt.f32.f16 %f7, %rs2;} - // end inline asm - abs.f32 %f8, %f6; - setp.gtu.f32 %p2, %f8, 0f7F800000; - min.f32 %f9, %f6, 0f477FE100; - max.f32 %f10, %f9, 0fC77FE100; - selp.f32 %f1, %f6, %f10, %p2; - abs.f32 %f11, %f7; - setp.gtu.f32 %p3, %f11, 0f7F800000; - min.f32 %f12, %f7, 0f477FE100; - max.f32 %f13, %f12, 0fC77FE100; - selp.f32 %f2, %f7, %f13, %p3; - abs.f32 %f3, %f1; - setp.gtu.f32 %p4, %f3, 0f7F800000; - abs.f32 %f4, %f2; - setp.gtu.f32 %p5, %f4, 0f7F800000; - and.pred %p6, %p4, %p5; - @%p6 bra LBB7_4; - ld.param.f32 %f5, [__xla_fp16_comparison_param_2]; - sub.f32 %f14, %f1, %f2; - abs.f32 %f15, %f14; - max.f32 %f16, %f3, %f4; - add.f32 %f17, %f16, 0f3F800000; - div.rn.f32 %f18, %f15, %f17; - setp.leu.f32 %p7, %f18, %f5; - abs.f32 %f19, %f18; - setp.le.f32 %p8, %f19, 0f7F800000; - and.pred %p9, %p7, %p8; - @%p9 bra LBB7_4; - ld.param.u64 %rd6, [__xla_fp16_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r5, [%rd1], 1; -LBB7_4: - ret; + ld.param.u64 %rd1, [__xla_fp16_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp16_comparison_param_1]; + ld.param.f32 %f10, [__xla_fp16_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp16_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp16_comparison_param_4]; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r1, %r2, %r3, %r4; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB0_9; + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 2; + add.s64 %rd8, %rd6, %rd7; + ld.global.u16 %rs1, [%rd8]; + // inline asm + { cvt.f32.f16 %f26, %rs1;} + + // inline asm + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.u16 %rs2, [%rd10]; + // inline asm + { cvt.f32.f16 %f27, %rs2;} + + // inline asm + abs.f32 %f13, %f26; + setp.gtu.f32 %p2, %f13, 0f7F800000; + @%p2 bra BB0_3; + + mov.f32 %f14, 0f477FE100; + min.f32 %f15, %f26, %f14; + mov.f32 %f16, 0fC77FE100; + max.f32 %f26, %f16, %f15; + +BB0_3: + abs.f32 %f17, %f27; + setp.gtu.f32 %p3, %f17, 0f7F800000; + @%p3 bra BB0_5; + + mov.f32 %f18, 0f477FE100; + min.f32 %f19, %f27, %f18; + mov.f32 %f20, 0fC77FE100; + max.f32 %f27, %f20, %f19; + +BB0_5: + abs.f32 %f7, %f26; + setp.gtu.f32 %p4, %f7, 0f7F800000; + abs.f32 %f8, %f27; + setp.gtu.f32 %p5, %f8, 0f7F800000; + and.pred %p6, %p4, %p5; + @%p6 bra BB0_9; + + sub.f32 %f21, %f26, %f27; + abs.f32 %f22, %f21; + max.f32 %f23, %f7, %f8; + add.f32 %f24, %f23, 0f3F800000; + div.rn.f32 %f9, %f22, %f24; + setp.gt.f32 %p7, %f9, %f10; + @%p7 bra BB0_8; + + abs.f32 %f25, %f9; + setp.le.f32 %p8, %f25, 0f7F800000; + @%p8 bra BB0_9; + +BB0_8: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r5, [%rd11], 1; + +BB0_9: + ret; } - // .globl __xla_fp32_comparison + + // .globl __xla_fp32_comparison .visible .entry __xla_fp32_comparison( - .param .u64 __xla_fp32_comparison_param_0, - .param .u64 __xla_fp32_comparison_param_1, - .param .f32 __xla_fp32_comparison_param_2, - .param .u64 __xla_fp32_comparison_param_3, - .param .u64 __xla_fp32_comparison_param_4 + .param .u64 __xla_fp32_comparison_param_0, + .param .u64 __xla_fp32_comparison_param_1, + .param .f32 __xla_fp32_comparison_param_2, + .param .u64 __xla_fp32_comparison_param_3, + .param .u64 __xla_fp32_comparison_param_4 ) { - .reg .pred %p<12>; - .reg .f32 %f<12>; - .reg .b32 %r<9>; - .reg .b64 %rd<12>; + .reg .pred %p<10>; + .reg .b16 %rs<3>; + .reg .f32 %f<13>; + .reg .b32 %r<10>; + .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp32_comparison_param_3]; - mov.u32 %r1, %tid.x; - mov.u32 %r2, %ctaid.x; - mov.u32 %r3, %ntid.x; - mad.lo.s32 %r4, %r3, %r2, %r1; - cvt.s64.s32 %rd4, %r4; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB8_6; - ld.param.u64 %rd5, [__xla_fp32_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp32_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 2; - add.s64 %rd10, %rd3, %rd9; - ld.global.f32 %f1, [%rd10]; - add.s64 %rd11, %rd2, %rd9; - ld.global.f32 %f2, [%rd11]; - abs.f32 %f3, %f1; - setp.gtu.f32 %p2, %f3, 0f7F800000; - abs.f32 %f4, %f2; - setp.gtu.f32 %p3, %f4, 0f7F800000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB8_6; - setp.neu.f32 %p5, %f3, 0f7F800000; - setp.neu.f32 %p6, %f4, 0f7F800000; - or.pred %p7, %p5, %p6; - @%p7 bra LBB8_4; - mov.b32 %r5, %f1; - mov.b32 %r6, %f2; - xor.b32 %r7, %r6, %r5; - setp.gt.s32 %p8, %r7, -1; - @%p8 bra LBB8_6; -LBB8_4: - ld.param.f32 %f5, [__xla_fp32_comparison_param_2]; - sub.f32 %f6, %f1, %f2; - abs.f32 %f7, %f6; - max.f32 %f8, %f3, %f4; - add.f32 %f9, %f8, 0f3F800000; - div.rn.f32 %f10, %f7, %f9; - setp.leu.f32 %p9, %f10, %f5; - abs.f32 %f11, %f10; - setp.le.f32 %p10, %f11, 0f7F800000; - and.pred %p11, %p9, %p10; - @%p11 bra LBB8_6; - ld.param.u64 %rd6, [__xla_fp32_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r8, [%rd1], 1; -LBB8_6: - ret; + ld.param.u64 %rd1, [__xla_fp32_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp32_comparison_param_1]; + ld.param.f32 %f6, [__xla_fp32_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp32_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp32_comparison_param_4]; + mov.u32 %r2, %ntid.x; + mov.u32 %r3, %ctaid.x; + mov.u32 %r4, %tid.x; + mad.lo.s32 %r1, %r2, %r3, %r4; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB1_8; + + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 4; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.f32 %f1, [%rd10]; + ld.global.f32 %f2, [%rd8]; + abs.f32 %f3, %f2; + setp.le.f32 %p2, %f3, 0f7F800000; + @%p2 bra BB1_3; + + abs.f32 %f7, %f1; + setp.gtu.f32 %p3, %f7, 0f7F800000; + @%p3 bra BB1_8; + +BB1_3: + setp.neu.f32 %p4, %f3, 0f7F800000; + abs.f32 %f4, %f1; + setp.neu.f32 %p5, %f4, 0f7F800000; + or.pred %p6, %p4, %p5; + @%p6 bra BB1_5; + + mov.b32 %r5, %f2; + shr.u32 %r6, %r5, 31; + cvt.u16.u32 %rs1, %r6; + mov.b32 %r7, %f1; + shr.u32 %r8, %r7, 31; + cvt.u16.u32 %rs2, %r8; + setp.eq.s16 %p7, %rs1, %rs2; + @%p7 bra BB1_8; + +BB1_5: + sub.f32 %f8, %f2, %f1; + abs.f32 %f9, %f8; + max.f32 %f10, %f3, %f4; + add.f32 %f11, %f10, 0f3F800000; + div.rn.f32 %f5, %f9, %f11; + setp.gt.f32 %p8, %f5, %f6; + @%p8 bra BB1_7; + + abs.f32 %f12, %f5; + setp.le.f32 %p9, %f12, 0f7F800000; + @%p9 bra BB1_8; + +BB1_7: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r9, [%rd11], 1; + +BB1_8: + ret; } - // .globl __xla_fp64_comparison + + // .globl __xla_fp64_comparison .visible .entry __xla_fp64_comparison( - .param .u64 __xla_fp64_comparison_param_0, - .param .u64 __xla_fp64_comparison_param_1, - .param .f32 __xla_fp64_comparison_param_2, - .param .u64 __xla_fp64_comparison_param_3, - .param .u64 __xla_fp64_comparison_param_4 + .param .u64 __xla_fp64_comparison_param_0, + .param .u64 __xla_fp64_comparison_param_1, + .param .f32 __xla_fp64_comparison_param_2, + .param .u64 __xla_fp64_comparison_param_3, + .param .u64 __xla_fp64_comparison_param_4 ) { - .reg .pred %p<16>; - .reg .f32 %f<2>; - .reg .b32 %r<13>; - .reg .f64 %fd<12>; - .reg .b64 %rd<12>; + .reg .pred %p<11>; + .reg .b16 %rs<3>; + .reg .f32 %f<2>; + .reg .b32 %r<14>; + .reg .f64 %fd<13>; + .reg .b64 %rd<12>; - ld.param.u64 %rd8, [__xla_fp64_comparison_param_3]; - mov.u32 %r2, %tid.x; - mov.u32 %r3, %ctaid.x; - mov.u32 %r4, %ntid.x; - mad.lo.s32 %r5, %r4, %r3, %r2; - cvt.s64.s32 %rd4, %r5; - setp.ge.u64 %p1, %rd4, %rd8; - @%p1 bra LBB9_6; - ld.param.u64 %rd5, [__xla_fp64_comparison_param_0]; - ld.param.u64 %rd7, [__xla_fp64_comparison_param_1]; - cvta.to.global.u64 %rd2, %rd7; - cvta.to.global.u64 %rd3, %rd5; - shl.b64 %rd9, %rd4, 3; - add.s64 %rd10, %rd3, %rd9; - ld.global.f64 %fd1, [%rd10]; - add.s64 %rd11, %rd2, %rd9; - ld.global.f64 %fd2, [%rd11]; - abs.f64 %fd3, %fd1; - setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000; - abs.f64 %fd4, %fd2; - setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000; - and.pred %p4, %p2, %p3; - @%p4 bra LBB9_6; - { - .reg .b32 %temp; - mov.b64 {%r6, %temp}, %fd1; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r1}, %fd1; - } - and.b32 %r7, %r1, 2147483647; - setp.ne.s32 %p5, %r7, 2146435072; - setp.ne.s32 %p6, %r6, 0; - or.pred %p7, %p6, %p5; - @%p7 bra LBB9_4; - { - .reg .b32 %temp; - mov.b64 {%r8, %temp}, %fd2; - } - { - .reg .b32 %temp; - mov.b64 {%temp, %r9}, %fd2; - } - and.b32 %r10, %r9, 2147483647; - setp.eq.s32 %p8, %r10, 2146435072; - setp.eq.s32 %p9, %r8, 0; - and.pred %p10, %p8, %p9; - xor.b32 %r11, %r9, %r1; - setp.gt.s32 %p11, %r11, -1; - and.pred %p12, %p11, %p10; - @%p12 bra LBB9_6; -LBB9_4: - ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; - sub.f64 %fd5, %fd1, %fd2; - abs.f64 %fd6, %fd5; - max.f64 %fd7, %fd3, %fd4; - add.f64 %fd8, %fd7, 0d3FF0000000000000; - div.rn.f64 %fd9, %fd6, %fd8; - cvt.f64.f32 %fd10, %f1; - setp.leu.f64 %p13, %fd9, %fd10; - abs.f64 %fd11, %fd9; - setp.le.f64 %p14, %fd11, 0d7FF0000000000000; - and.pred %p15, %p13, %p14; - @%p15 bra LBB9_6; - ld.param.u64 %rd6, [__xla_fp64_comparison_param_4]; - cvta.to.global.u64 %rd1, %rd6; - atom.global.add.u32 %r12, [%rd1], 1; -LBB9_6: - ret; + + ld.param.u64 %rd1, [__xla_fp64_comparison_param_0]; + ld.param.u64 %rd2, [__xla_fp64_comparison_param_1]; + ld.param.f32 %f1, [__xla_fp64_comparison_param_2]; + ld.param.u64 %rd4, [__xla_fp64_comparison_param_3]; + ld.param.u64 %rd3, [__xla_fp64_comparison_param_4]; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %tid.x; + mad.lo.s32 %r1, %r4, %r5, %r6; + cvt.s64.s32 %rd5, %r1; + setp.ge.u64 %p1, %rd5, %rd4; + @%p1 bra BB2_11; + + cvta.to.global.u64 %rd6, %rd1; + mul.wide.s32 %rd7, %r1, 8; + add.s64 %rd8, %rd6, %rd7; + cvta.to.global.u64 %rd9, %rd2; + add.s64 %rd10, %rd9, %rd7; + ld.global.f64 %fd1, [%rd10]; + ld.global.f64 %fd2, [%rd8]; + abs.f64 %fd3, %fd2; + setp.le.f64 %p2, %fd3, 0d7FF0000000000000; + @%p2 bra BB2_3; + + abs.f64 %fd5, %fd1; + setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000; + @%p3 bra BB2_11; + +BB2_3: + { + .reg .b32 %temp; + mov.b64 {%temp, %r2}, %fd2; + } + and.b32 %r7, %r2, 2147483647; + setp.ne.s32 %p4, %r7, 2146435072; + @%p4 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%r8, %temp}, %fd2; + } + setp.ne.s32 %p5, %r8, 0; + @%p5 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%temp, %r3}, %fd1; + } + and.b32 %r9, %r3, 2147483647; + setp.ne.s32 %p6, %r9, 2146435072; + @%p6 bra BB2_8; + + { + .reg .b32 %temp; + mov.b64 {%r10, %temp}, %fd1; + } + setp.ne.s32 %p7, %r10, 0; + @%p7 bra BB2_8; + + shr.u32 %r11, %r2, 31; + cvt.u16.u32 %rs1, %r11; + shr.u32 %r12, %r3, 31; + cvt.u16.u32 %rs2, %r12; + setp.eq.s16 %p8, %rs1, %rs2; + @%p8 bra BB2_11; + +BB2_8: + sub.f64 %fd6, %fd2, %fd1; + abs.f64 %fd7, %fd6; + abs.f64 %fd8, %fd1; + max.f64 %fd9, %fd3, %fd8; + add.f64 %fd10, %fd9, 0d3FF0000000000000; + div.rn.f64 %fd4, %fd7, %fd10; + cvt.f64.f32 %fd11, %f1; + setp.gt.f64 %p9, %fd4, %fd11; + @%p9 bra BB2_10; + + abs.f64 %fd12, %fd4; + setp.le.f64 %p10, %fd12, 0d7FF0000000000000; + @%p10 bra BB2_11; + +BB2_10: + cvta.to.global.u64 %rd11, %rd3; + atom.global.add.u32 %r13, [%rd11], 1; + +BB2_11: + ret; +} + + // .globl __xla_int8_comparison +.visible .entry __xla_int8_comparison( + .param .u64 __xla_int8_comparison_param_0, + .param .u64 __xla_int8_comparison_param_1, + .param .f32 __xla_int8_comparison_param_2, + .param .u64 __xla_int8_comparison_param_3, + .param .u64 __xla_int8_comparison_param_4 +) +{ + .reg .pred %p<10>; + .reg .f32 %f<42>; + .reg .b32 %r<23>; + .reg .b64 %rd<12>; + + + ld.param.u64 %rd2, [__xla_int8_comparison_param_0]; + ld.param.u64 %rd3, [__xla_int8_comparison_param_1]; + ld.param.f32 %f5, [__xla_int8_comparison_param_2]; + ld.param.u64 %rd4, [__xla_int8_comparison_param_3]; + ld.param.u64 %rd5, [__xla_int8_comparison_param_4]; + cvta.to.global.u64 %rd1, %rd5; + mov.u32 %r4, %ntid.x; + mov.u32 %r5, %ctaid.x; + mov.u32 %r6, %tid.x; + mad.lo.s32 %r1, %r4, %r5, %r6; + cvt.s64.s32 %rd6, %r1; + setp.ge.u64 %p1, %rd6, %rd4; + @%p1 bra BB3_13; + + cvta.to.global.u64 %rd7, %rd2; + mul.wide.s32 %rd8, %r1, 4; + add.s64 %rd9, %rd7, %rd8; + cvta.to.global.u64 %rd10, %rd3; + add.s64 %rd11, %rd10, %rd8; + ld.global.u32 %r2, [%rd9]; + cvt.s32.s8 %r7, %r2; + cvt.rn.f32.s32 %f6, %r7; + ld.global.u32 %r3, [%rd11]; + cvt.s32.s8 %r8, %r3; + cvt.rn.f32.s32 %f7, %r8; + sub.f32 %f8, %f6, %f7; + abs.f32 %f9, %f8; + abs.f32 %f10, %f6; + abs.f32 %f11, %f7; + max.f32 %f12, %f10, %f11; + add.f32 %f13, %f12, 0f3F800000; + div.rn.f32 %f1, %f9, %f13; + setp.gt.f32 %p2, %f1, %f5; + @%p2 bra BB3_3; + + abs.f32 %f14, %f1; + setp.le.f32 %p3, %f14, 0f7F800000; + @%p3 bra BB3_4; + +BB3_3: + atom.global.add.u32 %r9, [%rd1], 1; + +BB3_4: + shr.u32 %r10, %r3, 8; + shr.u32 %r11, %r2, 8; + cvt.s32.s8 %r12, %r11; + cvt.rn.f32.s32 %f15, %r12; + cvt.s32.s8 %r13, %r10; + cvt.rn.f32.s32 %f16, %r13; + sub.f32 %f17, %f15, %f16; + abs.f32 %f18, %f17; + abs.f32 %f19, %f15; + abs.f32 %f20, %f16; + max.f32 %f21, %f19, %f20; + add.f32 %f22, %f21, 0f3F800000; + div.rn.f32 %f2, %f18, %f22; + setp.gt.f32 %p4, %f2, %f5; + @%p4 bra BB3_6; + + abs.f32 %f23, %f2; + setp.le.f32 %p5, %f23, 0f7F800000; + @%p5 bra BB3_7; + +BB3_6: + atom.global.add.u32 %r14, [%rd1], 1; + +BB3_7: + shr.u32 %r15, %r3, 16; + shr.u32 %r16, %r2, 16; + cvt.s32.s8 %r17, %r16; + cvt.rn.f32.s32 %f24, %r17; + cvt.s32.s8 %r18, %r15; + cvt.rn.f32.s32 %f25, %r18; + sub.f32 %f26, %f24, %f25; + abs.f32 %f27, %f26; + abs.f32 %f28, %f24; + abs.f32 %f29, %f25; + max.f32 %f30, %f28, %f29; + add.f32 %f31, %f30, 0f3F800000; + div.rn.f32 %f3, %f27, %f31; + setp.gt.f32 %p6, %f3, %f5; + @%p6 bra BB3_9; + + abs.f32 %f32, %f3; + setp.le.f32 %p7, %f32, 0f7F800000; + @%p7 bra BB3_10; + +BB3_9: + atom.global.add.u32 %r19, [%rd1], 1; + +BB3_10: + shr.s32 %r20, %r2, 24; + cvt.rn.f32.s32 %f33, %r20; + shr.s32 %r21, %r3, 24; + cvt.rn.f32.s32 %f34, %r21; + sub.f32 %f35, %f33, %f34; + abs.f32 %f36, %f35; + abs.f32 %f37, %f33; + abs.f32 %f38, %f34; + max.f32 %f39, %f37, %f38; + add.f32 %f40, %f39, 0f3F800000; + div.rn.f32 %f4, %f36, %f40; + setp.gt.f32 %p8, %f4, %f5; + @%p8 bra BB3_12; + + abs.f32 %f41, %f4; + setp.le.f32 %p9, %f41, 0f7F800000; + @%p9 bra BB3_13; + +BB3_12: + atom.global.add.u32 %r22, [%rd1], 1; + +BB3_13: + ret; } )"; @@ -405,11 +619,13 @@ StatusOr HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs, const auto canonicalize = [](ComparisonType a) -> ComparisonType { if (std::is_same::value && a) { - constexpr ComparisonType kMaxFp16Value = 65505.; + constexpr ComparisonType kMaxFp16Value = + std::is_same::value ? 65505. : 0; if (std::isnan(a)) { return a; } - return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value)); + return std::max(static_cast(-kMaxFp16Value), + static_cast(std::min(a, kMaxFp16Value))); } return a; }; @@ -472,6 +688,9 @@ StatusOr BufferComparator::CompareEqual(se::Stream* stream, case xla::F64: return CompareEqualParameterized( stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison"); + case xla::S8: + return CompareEqualParameterized( + stream, lhs, rhs, shape_, config_, "__xla_int8_comparison"); default: return Unimplemented("Unimplemented element type"); } diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc index 139e4204304..0f547111096 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -178,6 +178,13 @@ TEST_F(BufferComparatorTest, TestNumbers) { EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); + + EXPECT_TRUE(CompareEqualFloatBuffers({200}, {201})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({90}, {100})); + EXPECT_TRUE(CompareEqualFloatBuffers({100}, {90})); + EXPECT_FALSE(CompareEqualFloatBuffers({-128}, {127})); } TEST_F(BufferComparatorTest, TestMultiple) { @@ -231,6 +238,23 @@ TEST_F(BufferComparatorTest, TestMultiple) { rhs[i] = 0; } } + + { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {21, 31, 41, 51, 61})); + std::vector lhs(200); + std::vector rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } + } } } // namespace diff --git a/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.cc new file mode 100644 index 00000000000..f2885e243e2 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +static StatusOr PadForTensorCores(HloDotInstruction* dot) { + auto* lhs = dot->mutable_operand(0); + auto* rhs = dot->mutable_operand(1); + + Shape lshape = lhs->shape(); + Shape rshape = rhs->shape(); + Shape result_shape = dot->shape(); + + if (lshape.element_type() != PrimitiveType::F16 || + rshape.element_type() != PrimitiveType::F16) { + return false; + } + + auto pad_dim = [](Shape& s, int64 dim) { + s.set_dimensions(dim, RoundUpToNearest(s.dimensions(dim), 8)); + }; + + auto pad_matrix_dims = [&pad_dim](Shape s) { + pad_dim(s, 0); + pad_dim(s, 1); + return s; + }; + + Shape new_lshape = pad_matrix_dims(lshape); + Shape new_rshape = pad_matrix_dims(rshape); + Shape new_result_shape = pad_matrix_dims(result_shape); + + if (new_lshape == lshape && new_rshape == rshape) { + return false; + } + + VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape; + VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " " + << new_result_shape; + + auto create_padding_config = [](Shape& shape, Shape& new_shape) { + PaddingConfig padding_config; + for (int i = 0; i < shape.rank(); ++i) { + auto dimension = padding_config.add_dimensions(); + dimension->set_edge_padding_high(new_shape.dimensions()[i] - + shape.dimensions()[i]); + dimension->set_edge_padding_low(0); + dimension->set_interior_padding(0); + } + return padding_config; + }; + + auto l_padding_config = create_padding_config(lshape, new_lshape); + auto r_padding_config = create_padding_config(rshape, new_rshape); + + HloComputation* parent = dot->parent(); + + HloInstruction* zero_float = parent->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0((half)0.0))); + zero_float->set_metadata(dot->metadata()); + + HloInstruction* lpad = parent->AddInstruction( + HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config)); + lpad->set_metadata(dot->metadata()); + + HloInstruction* rpad = parent->AddInstruction( + HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config)); + rpad->set_metadata(dot->metadata()); + + HloInstruction* new_dot = parent->AddInstruction( + dot->CloneWithNewOperands(new_result_shape, {lpad, rpad})); + + HloInstruction* slice = parent->AddInstruction(HloInstruction::CreateSlice( + result_shape, new_dot, {0, 0}, result_shape.dimensions(), {1, 1})); + slice->set_metadata(dot->metadata()); + + bool is_root = dot->user_count() == 0; + + TF_CHECK_OK(parent->ReplaceInstruction(dot, slice)); + + if (is_root) { + parent->set_root_instruction(slice); + } + + return true; +} + +static std::vector GetRelevantDots(HloComputation* comp) { + std::vector convs; + for (HloInstruction* instr : comp->instructions()) { + if (IsMatrixMultiplication(*instr)) { + convs.push_back(Cast(instr)); + } + } + return convs; +} + +StatusOr CublasGemmPadForTensorCores::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloDotInstruction* dot : GetRelevantDots(comp)) { + TF_ASSIGN_OR_RETURN(bool result, PadForTensorCores(dot)); + changed |= result; + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h new file mode 100644 index 00000000000..339e7e3dce6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_GEMM_PAD_FOR_TENSOR_CORES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_GEMM_PAD_FOR_TENSOR_CORES_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Adds padding to dot operations to make them run faster on GPUs with +// tensor cores (https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/). +// +// f16 dots are padded to have input/output shapes with dimensions that +// are multiples of 8, so that we can use tensor cores. +// +// Don't run this pass on GPUs without tensor cores -- it will make them slower! +class CublasGemmPadForTensorCores : public HloModulePass { + public: + absl::string_view name() const override { + return "cublas-gemm-pad-for-speed"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUBLAS_GEMM_PAD_FOR_TENSOR_CORES_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores_test.cc new file mode 100644 index 00000000000..df1ba164bef --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores_test.cc @@ -0,0 +1,223 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h" + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { +namespace { + +class CublasGemmPadForTensorCoresTest : public HloTestBase {}; + +TEST_F(CublasGemmPadForTensorCoresTest, OneDotRootComputation) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %param1 = f16[2048,1024] parameter(0) + %param2 = f16[1024,33708] parameter(1) + ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1, + f16[1024,33708]{0,1} %param2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Shape("f16[2048, 33708]"), + op::Slice(AllOf( + op::Shape("f16[2048, 33712]"), + op::Dot(AllOf(op::Shape("f16[2048, 1024]"), + op::Pad(AllOf(op::Shape("f16[2048, 1024]"), + op::Parameter()), + AllOf(op::Shape("f16[]"), op::Constant()))), + AllOf(op::Shape("f16[1024, 33712]"), + op::Pad(AllOf(op::Shape("f16[1024, 33708]"), + op::Parameter()), + AllOf(op::Shape("f16[]"), op::Constant()))), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0))))); +} + +TEST_F(CublasGemmPadForTensorCoresTest, TwoDotsComputation) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %param1 = f16[2048,1024] parameter(0) + %param2 = f16[1024,33708] parameter(1) + %param3 = f16[33708, 1] parameter(2) + %dot1 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1, + f16[1024,33708]{0,1} %param2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT %dot2 = f16[2048, 1]{1,0} dot(f16[2048,33708]{1,0} %dot1, + f16[33708, 1]{0,1} %param3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + auto* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf( + op::Shape("f16[2048, 1]"), + op::Slice(AllOf( + op::Shape("f16[2048, 8]"), + op::Dot( + AllOf( + op::Shape("f16[2048, 33712]"), + AllOf( + op::Shape("f16[2048, 33712]"), + AllOf( + op::Shape("f16[2048, 33712]"), + op::Pad( + AllOf(op::Shape("f16[2048, 33708]"), + op::Slice(AllOf( + op::Shape("f16[2048, 33712]"), + op::Dot( + AllOf(op::Shape( + "f16[2048, 1024]"), + op::Pad()), + AllOf(op::Shape( + "f16[1024, 33712]"), + op::Pad()), + 1, 0)))), + AllOf(op::Shape("f16[]"), op::Constant()))))), + AllOf(op::Shape("f16[33712, 8]"), + AllOf(op::Shape("f16[33712, 8]"), + op::Pad( + AllOf(op::Shape("f16[33708, 1]"), + op::Parameter()), + AllOf(op::Shape("f16[]"), op::Constant())))), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))))); + + auto* dot2 = root->operand(0)->operand(0)->operand(0)->operand(0); + EXPECT_THAT( + dot2, + AllOf(op::Dot( + AllOf(op::Shape("f16[2048, 1024]"), + op::Pad(AllOf(op::Shape("f16[2048, 1024]"), op::Parameter()), + AllOf(op::Shape("f16[]"), op::Constant()))), + AllOf(op::Shape("f16[1024, 33712]"), + op::Pad(AllOf(op::Shape("f16[1024, 33708]"), op::Parameter()), + AllOf(op::Shape("f16[]"), op::Constant()))), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(CublasGemmPadForTensorCoresTest, NoDotComputation) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %maximum = f32[] maximum(f32[] %x, f32[] %y) + })") + .ValueOrDie(); + + EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); +} + +TEST_F(CublasGemmPadForTensorCoresTest, F32DotComputation) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %param1 = f32[2048,1024] parameter(0) + %param2 = f32[1024,33708] parameter(1) + ROOT %dot.2309 = f32[2048,33708]{1,0} dot(f32[2048,1024]{1,0} %param1, + f32[1024,33708]{0,1} %param2), + lhs_contracting_dims={1}, rhs_contracting_dims={0}})") + .ValueOrDie(); + + EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); +} + +TEST_F(CublasGemmPadForTensorCoresTest, F64DotComputation) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %param1 = f64[2048,1024] parameter(0) + %param2 = f64[1024,33708] parameter(1) + ROOT %dot.2309 = f64[2048,33708]{1,0} dot(f64[2048,1024]{1,0} %param1, + f64[1024,33708]{0,1} %param2), + lhs_contracting_dims={1}, rhs_contracting_dims={0}})") + .ValueOrDie(); + + EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); +} + +TEST_F(CublasGemmPadForTensorCoresTest, MultiplesOf8DotComputation) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %param1 = f16[2048,1024] parameter(0) + %param2 = f16[1024,33712] parameter(1) + ROOT %dot.2309 = f16[2048,33712]{1,0} dot(f16[2048,1024]{1,0} %param1, + f16[1024,33712]{0,1} %param2), + lhs_contracting_dims={1}, rhs_contracting_dims={0}})") + .ValueOrDie(); + + EXPECT_FALSE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); +} + +TEST_F(CublasGemmPadForTensorCoresTest, CheckSavingMetadata) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule TestModule + + ENTRY TestComputation { + %param1 = f16[2048,1024] parameter(0) + %param2 = f16[1024,33708] parameter(1) + ROOT %dot.2309 = f16[2048,33708]{1,0} dot(f16[2048,1024]{1,0} %param1, + f16[1024,33708]{0,1} %param2), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + metadata={op_type="MatMul" op_name="transformer_v2/Transformer/decode/embedding_shared_weights_1/presoftmax_linear/MatMul"} + })") + .ValueOrDie(); + + SCOPED_TRACE(module->ToString()); + + EXPECT_TRUE(CublasGemmPadForTensorCores().Run(module.get()).ValueOrDie()); + auto metadata = module->entry_computation()->root_instruction()->metadata(); + EXPECT_EQ("MatMul", metadata.op_type()); + EXPECT_EQ( + "transformer_v2/Transformer/decode/embedding_shared_weights_1/" + "presoftmax_linear/MatMul", + metadata.op_name()); +} + +} // anonymous namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index ce17e0253c9..7a7ab6ba05f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -142,10 +143,8 @@ StatusOr CheckRedzones(const se::cuda::RedzoneAllocator& allocator, XLA_SCOPED_LOGGING_TIMER_LEVEL("CudnnConvAlgorithmPicker checking redzones", 2); using RedzoneCheckStatus = se::cuda::RedzoneAllocator::RedzoneCheckStatus; - TF_ASSIGN_OR_RETURN(RedzoneCheckStatus redzone_check, - allocator.CheckRedzones(stream)); - + allocator.CheckRedzones()); if (redzone_check.ok()) { return true; } @@ -235,7 +234,6 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( return result_or; } - StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( const HloCustomCallInstruction* instr) { XLA_SCOPED_LOGGING_TIMER( @@ -250,11 +248,6 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( return InternalError("Failed to synchronize GPU for autotuning."); } - // Create a stream for us to do our work on. - se::Stream stream{stream_exec_}; - stream.Init(); - const auto device_ordinal = stream_exec_->device_ordinal(); - // allocator either points to this->allocator_ or, if that's null, to a // se::StreamExecutorMemoryAllocator for stream_exec_. se::DeviceMemoryAllocator* allocator; @@ -266,11 +259,21 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( allocator = &*se_allocator; } + absl::optional stream_opt; + se::Stream* stream = [&] { + if (allocator->GetStream()) { + return allocator->GetStream(); + } + stream_opt.emplace(stream_exec_); + stream_opt->Init(); + return &stream_opt.value(); + }(); + int64 rng_state = 0; - const auto initialize_buffer = [&stream, &result_shape, + const auto initialize_buffer = [stream, &result_shape, &rng_state](DeviceMemoryBase buffer) { - InitializeFloatBuffer(&stream, result_shape.element_type(), &rng_state, + InitializeFloatBuffer(stream, result_shape.element_type(), &rng_state, buffer); }; @@ -278,18 +281,18 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( // Allocate space for the input, filter, and output of the convolution. se::cuda::RedzoneAllocator input_output_allocator( - device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config)); + stream, allocator, PtxOptsFromConfig(hlo_module_config)); std::vector operand_buffers; for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + ShapeUtil::ByteSizeOf(operand->shape()))); initialize_buffer(buffer); operand_buffers.push_back(buffer); } TF_ASSIGN_OR_RETURN(auto result_buffer, input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(result_shape))); + ShapeUtil::ByteSizeOf(result_shape))); initialize_buffer(result_buffer); TF_ASSIGN_OR_RETURN(auto backend_config, @@ -311,14 +314,33 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( const bool crash_on_checking_failure = debug_options.xla_gpu_crash_on_verification_failures(); + const auto canonical_hlo = + std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_)); + + string blas_version; + if (auto* blas = stream_exec_->AsBlas()) { + (void)blas->GetVersion(&blas_version); + } + + absl::Span blacklisted_algos = + GetBlacklistedConvAlgorithms(GetComputeCapability(stream_exec_), + GetCudnnVersion(stream_exec_), blas_version, + canonical_hlo); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { XLA_SCOPED_LOGGING_TIMER_LEVEL( absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ", AlgorithmToString(alg)), 2); + if (absl::c_linear_search(blacklisted_algos, alg)) { + LOG(INFO) << "Omitted potentially buggy algorithm " + << AlgorithmToString(alg) << " for conv " << instr->ToString(); + continue; + } + se::cuda::RedzoneAllocator scratch_allocator( - device_ordinal, allocator, PtxOptsFromConfig(hlo_module_config)); + stream, allocator, PtxOptsFromConfig(hlo_module_config)); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); @@ -329,7 +351,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( options.algo_override = alg; Status launch_status = RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, &stream, options); + &scratch_allocator, stream, options); if (!launch_status.ok()) { continue; @@ -352,22 +374,39 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( // Check for writes to redzones. TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear, - CheckRedzones(input_output_allocator, &stream, + CheckRedzones(input_output_allocator, stream, "input/output", instr, &result)); TF_ASSIGN_OR_RETURN( bool scratch_allocator_redzone_clear, - CheckRedzones(scratch_allocator, &stream, "scratch", instr, &result)); + CheckRedzones(scratch_allocator, stream, "scratch", instr, &result)); if (!input_output_allocator_redzone_clear || !scratch_allocator_redzone_clear) { + AlgorithmBlacklist proto; + auto entry = proto.add_entries(); + entry->set_hlo(canonical_hlo); + *entry->mutable_cc() = GetComputeCapability(stream_exec_); + *entry->mutable_cudnn_version() = GetCudnnVersion(stream_exec_); + entry->set_blas_version(blas_version); + auto algo = entry->add_algos(); + algo->set_id(alg.algo_id()); + algo->set_tensor_ops(alg.tensor_ops_enabled()); + + LOG(ERROR) + << "To blacklist this algorithm for this convolution, " + "copy-paste the following " + "proto to the blacklist file pointed by XLA_FLAGS " + "--xla_gpu_algorithm_blacklist_path=" + << GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path() + << " : " << proto.ShortDebugString(); continue; } if (comparator.has_value()) { XLA_SCOPED_LOGGING_TIMER_LEVEL("BufferComparator::CompareEqual", 2); StatusOr compare_result = comparator->CompareEqual( - &stream, reference_result_buffer, result_buffer); + stream, reference_result_buffer, result_buffer); if (!compare_result.ok()) { LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm) << " against " << AlgorithmToString(alg) << " for " @@ -385,7 +424,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( << instr->ToString() << " for " << AlgorithmToString(first_algorithm) << " vs " << AlgorithmToString(alg); - PrintPlatformInfo(&stream); + PrintPlatformInfo(stream); VLOG(1) << "Full module on failure: \n" << instr->GetModule()->ToString(); auto* fail = result.mutable_failure(); @@ -402,9 +441,9 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( comparator.emplace(result_shape, hlo_module_config); TF_ASSIGN_OR_RETURN( reference_result_buffer, - input_output_allocator.AllocateBytes(&stream, result_buffer.size())); - stream.ThenMemcpy(&reference_result_buffer, result_buffer, - result_buffer.size()); + input_output_allocator.AllocateBytes(result_buffer.size())); + stream->ThenMemcpy(&reference_result_buffer, result_buffer, + result_buffer.size()); first_algorithm = alg; } } @@ -431,6 +470,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec_); log.set_device_pci_bus_id( stream_exec_->GetDeviceDescription().pci_bus_id()); + log.set_blas_version(blas_version); VLOG(1) << "Autotuning result: " << log.ShortDebugString(); // If we crash on checking failure, we are in a testing/benchmark mode, thus // omitting logging through the logger. diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc old mode 100644 new mode 100755 index e81850db69e..fc44a9947b4 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.cc @@ -89,13 +89,11 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { // Try to match a backward filter pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple MatchBackwardFilter( - HloInstruction* conv) { +std::tuple +MatchBackwardFilter(HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); - if (conv->feature_group_count() > 1) { - return no_match_result; - } + std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); + // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -155,6 +153,15 @@ std::tuple MatchBackwardFilter( "to fold it to a backward filter convolution."; return no_match_result; } + auto rhs_in = + conv->mutable_operand(1)->shape().dimensions(kernel_input_feature_dim); + if (conv->feature_group_count() > 1 && rhs_in == 1 && + input_batch_dim == output_batch_dim) { + VLOG(1) << conv->ToString() + << " is a depthwise forward convolution. No need to fold to " + "backward filter."; + return no_match_result; + } // Step 3: fuse the matched HLOs into a backward convolution instruction. // @@ -248,7 +255,62 @@ std::tuple MatchBackwardFilter( backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, backward_conv_window, backward_conv_dnums); + HloInstruction* lhs = conv->mutable_operand(0); + if (conv->feature_group_count() == 1) { + return std::make_tuple(true, backward_conv_window, backward_conv_dnums, + lhs); + } + + int64 input_batch_dimension = backward_conv_dnums.input_batch_dimension(); + int64 input_feature_dimension = backward_conv_dnums.input_feature_dimension(); + + int64 input_batch = lhs->shape().dimensions(input_batch_dimension); + int64 input_feature = lhs->shape().dimensions(input_feature_dimension); + + // Reshape batch_dim G*N -> [G,N] + std::vector reshape_dims = lhs->shape().dimensions(); + auto num_groups = conv->feature_group_count(); + CHECK_EQ(input_batch % num_groups, 0) + << "Input batch should be an exact multiple of feature group count"; + reshape_dims[input_batch_dimension] = + reshape_dims[input_batch_dimension] / num_groups; + reshape_dims.insert(reshape_dims.begin() + input_batch_dimension, num_groups); + + HloComputation* c = conv->parent(); + HloInstruction* lhs_reshape_1 = + c->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs->shape().element_type(), reshape_dims), + lhs)); + + // Transpose G to the axis before C/G, For eg: [G, N, C/G, H, W] -> [N, G, + // C/G, H, W] + std::vector transpose_dims(lhs_reshape_1->shape().dimensions_size()); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + transpose_dims.erase(transpose_dims.begin() + input_batch_dimension); + transpose_dims.insert(transpose_dims.begin() + input_feature_dimension, + input_batch_dimension); + std::vector transpose_reshape_dims = + lhs_reshape_1->shape().dimensions(); + transpose_reshape_dims.erase(transpose_reshape_dims.begin() + + input_batch_dimension); + transpose_reshape_dims.insert( + transpose_reshape_dims.begin() + input_feature_dimension, num_groups); + + HloInstruction* lhs_transpose = + c->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(lhs_reshape_1->shape().element_type(), + transpose_reshape_dims), + lhs_reshape_1, transpose_dims)); + + // Merge [G,C/G] -> [C] + Shape new_shape = lhs_transpose->shape(); + new_shape.DeleteDimension(input_feature_dimension); + new_shape.set_dimensions(input_feature_dimension, + input_feature * conv->feature_group_count()); + HloInstruction* lhs_reshape_2 = c->AddInstruction( + HloInstruction::CreateReshape(new_shape, lhs_transpose)); + return std::make_tuple(true, backward_conv_window, backward_conv_dnums, + lhs_reshape_2); } // Try to match a backward input pattern that contains "conv". @@ -258,9 +320,11 @@ MatchBackwardInput(HloInstruction* conv) { const auto no_match_result = std::make_tuple(false, Window(), ConvolutionDimensionNumbers(), nullptr); - // TODO(b/119479517): Theoretically cuDNN supports grouped convolutions also - // for the backward input convolution, but at least for now with version 7.1.4 - // it is slower. This needs to be re-evaluated for future cuDNN versions. + // TODO: Theoretically cuDNN supports grouped convolutions also + // for the backward input convolution, but based on the cudnn's current state + // there is not much performance improvement when using the + // cudnn backward input API for grouped conv. + // This needs to be re-evaluated for future cuDNN versions. // Note that we already have the necessary code down below, the only thing to // enable it is to remove the following early return. if (conv->feature_group_count() > 1) { @@ -272,6 +336,22 @@ MatchBackwardInput(HloInstruction* conv) { HloInstruction* reverse_filter = conv->mutable_operand(1); ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers(); + // Match BackwardInput for a depthwise convolution and thunk it to forward + // convolution Output feature dimension and input feature dimension has been + // swapped in the bridge. Hence to get the actual input features we need to + // query the output feature dimension + auto kernel_out_feature_dim = dnums.kernel_output_feature_dimension(); + auto kernel_out_features = + reverse_filter->shape().dimensions(kernel_out_feature_dim); + + // For a depthwise convolution, the input features must be equal to the + // feature_group_count. We can leverage this property to match a depthwise + // convolution and thunk it to forward conv + if (conv->feature_group_count() > 1 && + kernel_out_features == conv->feature_group_count()) { + return no_match_result; + } + // We pattern-match to a backwards input conv if: // // - all spatial dims of the filter are reversed @@ -333,9 +413,8 @@ MatchBackwardInput(HloInstruction* conv) { Window new_window = old_window; for (size_t i = 0; i < input_spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. - // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc - // for how we convert backward input convolution to a variant of forward - // convolution. + // See the comment in tensorflow/core/kernels/conv_grad_ops.h for how we + // convert backward input convolution to a variant of forward convolution. // // The stride of the backward convolution // = the base dilation factor of the forward convolution @@ -429,11 +508,23 @@ MatchBackwardInput(HloInstruction* conv) { } // OK, it's a match! Switch the input feature dimension with the output - // feature dimension. This is the way cuDNN expects it to be. + // feature dimension. Also switch the output with the input. This is the way + // cuDNN expects it to be. + auto conv_dnums = conv->convolution_dimension_numbers(); dnums.set_kernel_input_feature_dimension( - conv->convolution_dimension_numbers().kernel_output_feature_dimension()); + conv_dnums.kernel_output_feature_dimension()); dnums.set_kernel_output_feature_dimension( - conv->convolution_dimension_numbers().kernel_input_feature_dimension()); + conv_dnums.kernel_input_feature_dimension()); + for (int i = 0; i < input_spatial_dims.size(); ++i) { + dnums.set_input_spatial_dimensions(i, + conv_dnums.output_spatial_dimensions(i)); + dnums.set_output_spatial_dimensions(i, + conv_dnums.input_spatial_dimensions(i)); + } + dnums.set_input_feature_dimension(conv_dnums.output_feature_dimension()); + dnums.set_input_batch_dimension(conv_dnums.output_batch_dimension()); + dnums.set_output_feature_dimension(conv_dnums.input_feature_dimension()); + dnums.set_output_batch_dimension(conv_dnums.input_batch_dimension()); // If we matched against a constant, we need to add a reverse op that can be // subsumed by the cuDNN call. algebraic-simplifier will later remove any @@ -469,7 +560,6 @@ MatchBackwardInput(HloInstruction* conv) { // dimensions, we need to divide the new 'kernel_input_feature_dimension' by // 'feature_group_count' and multiply the new // 'kernel_output_feature_dimension' by 'feature_group_count'. - Shape new_shape = rhs->shape(); int64 input_feature_dimension = dnums.kernel_input_feature_dimension(); int64 output_feature_dimension = dnums.kernel_output_feature_dimension(); @@ -477,13 +567,47 @@ MatchBackwardInput(HloInstruction* conv) { // feature dimensions, and we are guaranteed that the spatial dimensions are // adjacent. CHECK_EQ(std::abs(input_feature_dimension - output_feature_dimension), 1LL); - int64 input_features = new_shape.dimensions(input_feature_dimension); - int64 output_features = new_shape.dimensions(output_feature_dimension); - new_shape.set_dimensions(input_feature_dimension, - input_features / conv->feature_group_count()); - new_shape.set_dimensions(output_feature_dimension, - output_features * conv->feature_group_count()); + int64 input_features = rhs->shape().dimensions(input_feature_dimension); + int64 output_features = rhs->shape().dimensions(output_feature_dimension); + + // Reshape [H, W, ..., in_depth, out_depth / G] -> [H, W, ..., G, in_depth/G, + // out_depth / G] + std::vector reshape_dims = rhs->shape().dimensions(); + auto num_groups = conv->feature_group_count(); + CHECK_EQ(input_features % num_groups, 0) + << "Input feature count should be an exact multiple of feature group " + "count"; + reshape_dims[input_feature_dimension] = + reshape_dims[input_feature_dimension] / num_groups; + reshape_dims.insert(reshape_dims.begin() + input_feature_dimension, + num_groups); + HloComputation* c = conv->parent(); + rhs = c->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs->shape().element_type(), reshape_dims), rhs)); + + // Transpose [H, W, ..., G, in_depth/G, out_depth / G] -> [H, W, ..., + // in_depth/G, G, out_depth / G] + std::vector transpose_dims(rhs->shape().dimensions_size()); + std::iota(transpose_dims.begin(), transpose_dims.end(), 0); + transpose_dims.erase(transpose_dims.begin() + input_feature_dimension); + transpose_dims.insert(transpose_dims.begin() + output_feature_dimension, + input_feature_dimension); + std::vector transpose_reshape_dims = rhs->shape().dimensions(); + transpose_reshape_dims.erase(transpose_reshape_dims.begin() + + input_feature_dimension); + transpose_reshape_dims.insert( + transpose_reshape_dims.begin() + output_feature_dimension, num_groups); + rhs = c->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(rhs->shape().element_type(), transpose_reshape_dims), + rhs, transpose_dims)); + + // Reshape [H, W, ..., in_depth/G, G, out_depth / G] -> [H, W, ..., + // in_depth/G, out_depth] + Shape new_shape = rhs->shape(); + new_shape.DeleteDimension(output_feature_dimension); + new_shape.set_dimensions(output_feature_dimension, + output_features * num_groups); rhs = c->AddInstruction(HloInstruction::CreateReshape(new_shape, rhs)); return std::make_tuple(true, new_window, dnums, rhs); } @@ -503,14 +627,7 @@ StatusOr RunOnInstruction(HloInstruction* conv) { Window window; ConvolutionDimensionNumbers dnums; HloInstruction* rhs; - - std::tie(match, window, dnums) = MatchBackwardFilter(conv); - if (match) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), - conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums, conv->feature_group_count(), - conv->metadata()); - } + HloInstruction* lhs; std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { @@ -519,6 +636,13 @@ StatusOr RunOnInstruction(HloInstruction* conv) { conv->feature_group_count(), conv->metadata()); } + std::tie(match, window, dnums, lhs) = MatchBackwardFilter(conv); + if (match) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + lhs, conv->mutable_operand(1), window, dnums, + conv->feature_group_count(), conv->metadata()); + } + // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { return CreateCudnnConv(kCudnnConvForwardCallTarget, conv->shape(), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc index dbcdc2b075b..362d8d13aab 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter_test.cc @@ -135,6 +135,86 @@ TEST_F(CudnnConvRewriterTest, BackwardFilterConvolve) { << md_after_opt.DebugString() << " vs " << metadata.DebugString(); } +TEST_F(CudnnConvRewriterTest, BackwardFilterGroupConvolve) { + // In a nutshell, before pass: + // Input->batch_dim: 3 input_shape(3) = 4 + // Input->feature_dim: 0 input_shape(0) = 32 + // Kernel(gradient)->kernel_input_feature_dim (gradient_batch_dimension): 0 + // Kernel(gradient)->kernel_output_feature_dim (gradient_feature_dimension): 3 + // Output(dkernel)->output_batch_dim (dkernel_input_feature_dim): 2 + // Output(dkernel)->output_feature_dim (dkernel_output_feature_dim): 3 + + // After pass: All shapes and dimension layout is brought + // back to normal as would be acceptable by cudnn + // Input->batch_dim: 0 input_shape(0) = 8 + // Input->feature_dim: 3 input_shape(3) = 16 + // Kernel(gradient)->kernel_input_feature_dim (gradient_batch_dimension): 2 + // Kernel(gradient)->kernel_output_feature_dim (gradient_feature_dimension): 3 + // Output(dkernel)->output_batch_dim (dkernel_input_feature_dim): 0 + // Output(dkernel)->output_feature_dim (dkernel_output_feature_dim): 3 + HloComputation::Builder builder(TestName()); + HloInstruction* activations = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {32, 1, 3, 4}), "activations")); + HloInstruction* gradients = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {8, 1, 2, 16}), "gradients")); + Window conv_window = default_conv_window_; + conv_window.mutable_dimensions(1)->set_size(2); + conv_window.mutable_dimensions(1)->set_window_dilation(2); + auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/4, + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_) + .ConsumeValueOrDie(), + activations, gradients, /*feature_group_count=*/4, + /*batch_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); + OpMetadata metadata; + metadata.set_op_name("bar"); + conv->set_metadata(metadata); + auto module = CreateNewVerifiedModule(); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); + // Check that metadata was preserved. + const auto& md_after_opt = + entry_computation->root_instruction()->operand(0)->metadata(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(md_after_opt, metadata)) + << md_after_opt.DebugString() << " vs " << metadata.DebugString(); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); + const ConvolutionDimensionNumbers conv_dim = + custom_call->convolution_dimension_numbers(); + const auto lhs_a = custom_call->operand(0); + const auto input_shape = lhs_a->shape(); + // The input (lhs) batch_dim(dim 0 in the original NHWC layout) gets mapped to + // be the feature_dim(dim 3) with a value of N*g = 32 in tf2xla. As described + // in conv_grad_ops.h, this swap is required to implement backprop using fwd + // conv. After the pass the batch_dim gets remapped to dim 0. The batch_dim + // value gets scaled to N = N*g/g = 32/4 = 8 to be compatible with cudnn + EXPECT_EQ(0, conv_dim.input_batch_dimension()); + EXPECT_EQ(8, input_shape.dimensions(conv_dim.input_batch_dimension())); + // Similarly, the input (lhs) feature_dim(dim 3 in the original NHWC layout) + // gets mapped to be the batch_dim(dim 0) with a value of C/g = 4 in tf2xla. + // After the pass the batch_dim gets remapped to dim 0. The feature_dim value + // gets scaled to C = C/g*g = 4*4 = 16 to be compatible with cudnn + EXPECT_EQ(3, conv_dim.input_feature_dimension()); + EXPECT_EQ(16, input_shape.dimensions(conv_dim.input_feature_dimension())); + // Similarly, the feature and batch dims of the incoming gradients (used as + // rhs) and the in/out dims of the output of convolution i.e, dgrad have been + // been modified in tf2xla (as described in conv_grad_ops.h). This pass remaps + // everything back for the layout to be compatible with cudnn backprop APIs. + EXPECT_EQ(2, conv_dim.kernel_input_feature_dimension()); + EXPECT_EQ(3, conv_dim.kernel_output_feature_dimension()); + EXPECT_EQ(0, conv_dim.output_batch_dimension()); + EXPECT_EQ(3, conv_dim.output_feature_dimension()); +} + TEST_F(CudnnConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc index c2817e36466..2c380c9860e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.cc @@ -48,12 +48,10 @@ class ScratchBufAllocator : public se::ScratchAllocator { ~ScratchBufAllocator() override = default; - int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { - return scratch_.size(); - } + int64 GetMemoryLimitInBytes() override { return scratch_.size(); } se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override { + int64 byte_size) override { if (allocated_) { return se::port::InternalError( "Can't allocate twice from a ScratchBufAllocator."); @@ -73,31 +71,91 @@ class ScratchBufAllocator : public se::ScratchAllocator { bool allocated_ = false; }; -template -Status RunCudnnConvImpl(const CudnnConvParams& params, - se::ScratchAllocator* scratch_allocator, - se::Stream* stream, RunConvOptions options) { - auto input_buf = se::DeviceMemory(params.input_buf); - auto filter_buf = se::DeviceMemory(params.filter_buf); - auto output_buf = se::DeviceMemory(params.output_buf); - AlgorithmConfig algorithm = params.algorithm; +template +Status RunCudnnConvForward(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + AlgorithmConfig algorithm) { + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } + stream->ThenConvolveWithAlgorithm( + params.input_descriptor, input_buf, params.filter_descriptor, filter_buf, + params.conv_desc, params.output_descriptor, &output_buf, + scratch_allocator, algorithm, options.profile_result); + return Status::OK(); +} - if (options.algo_override) { - algorithm = AlgorithmConfig(*options.algo_override); +template +Status RunCudnnConvForwardActivation(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + AlgorithmConfig algorithm) { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count(params.output_descriptor.feature_map_count()) + .set_layout(params.output_descriptor.layout()); + + se::DeviceMemory side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; } + stream->ThenFusedConvolveWithAlgorithm( + params.input_descriptor, input_buf, params.conv_result_scale, + params.filter_descriptor, filter_buf, params.conv_desc, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory(params.fusion->bias_buf), params.fusion->mode, + params.output_descriptor, &output_buf, scratch_allocator, algorithm, + options.profile_result); + + return Status::OK(); +} + +// StreamExecutor supports various data types via overloading, and the support +// is maintained on-demand. To avoid calling into non-exist overloads, we have +// to carefully not call into them by using enable_if. +// TODO(timshen): Ideally, to avoid such complication in the runner, we can turn +// StreamExecutor overloadings to template functions, and for unsupported data +// types return runtime errors. +// This is the specialization for double, float, and half types. All kinds of +// convolutions are supported here. +template ::value>::type* = nullptr> +Status RunCudnnConvInternalImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + AlgorithmConfig algorithm) { switch (params.kind) { case CudnnConvKind::kForward: - if (params.conv_result_scale != 1) { - return InternalError( - "StreamExecutor doesn't support scaled convolution: %lf.", - params.conv_result_scale); - } - stream->ThenConvolveWithAlgorithm( - params.input_descriptor, input_buf, params.filter_descriptor, - filter_buf, params.conv_desc, params.output_descriptor, &output_buf, - scratch_allocator, algorithm, options.profile_result); - break; + return RunCudnnConvForward(params, scratch_allocator, stream, options, + input_buf, filter_buf, output_buf, algorithm); case CudnnConvKind::kBackwardInput: if (params.conv_result_scale != 1) { return InternalError( @@ -121,46 +179,70 @@ Status RunCudnnConvImpl(const CudnnConvParams& params, scratch_allocator, algorithm, options.profile_result); break; case CudnnConvKind::kForwardActivation: { - BatchDescriptor bias_desc; - bias_desc.set_count(1) - .set_height(1) - .set_width(1) - .set_feature_map_count(params.output_descriptor.feature_map_count()) - .set_layout(params.output_descriptor.layout()); - - se::DeviceMemory side_input(params.fusion->side_input_buf); - // If there is no side input, use output as the side input. - if (side_input.is_null()) { - if (params.fusion->side_input_scale != 0) { - return InternalError( - "Side input scale is not 0, yet no side input buffer is " - "provided"); - } - // Since side-input scale is 0, the values in the side input don't - // matter. The simplest thing to do would be to pass in a null buffer - // for the side input, but cudnn doesn't allow this. cudnn does promise - // that if side-input-scale is 0 the side input won't be read, so we - // just pass in the output buffer, since it's handy and has the correct - // size. - side_input = output_buf; - } - - stream->ThenFusedConvolveWithAlgorithm( - params.input_descriptor, input_buf, params.conv_result_scale, - params.filter_descriptor, filter_buf, params.conv_desc, side_input, - params.fusion->side_input_scale, bias_desc, - DeviceMemory(params.fusion->bias_buf), params.fusion->mode, - params.output_descriptor, &output_buf, scratch_allocator, algorithm, - options.profile_result); - break; + return RunCudnnConvForwardActivation( + params, scratch_allocator, stream, options, input_buf, filter_buf, + output_buf, algorithm); } } + return Status::OK(); +} + +// Specialization for integer types. Only two forward convolutions are allowed. +template ::value>::type* = + nullptr> +Status RunCudnnConvInternalImpl(CudnnConvParams params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, RunConvOptions options, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + AlgorithmConfig algorithm) { + switch (params.kind) { + case CudnnConvKind::kForward: + return RunCudnnConvForward(params, scratch_allocator, stream, options, + input_buf, filter_buf, output_buf, algorithm); + case CudnnConvKind::kForwardActivation: + return RunCudnnConvForwardActivation( + params, scratch_allocator, stream, options, input_buf, filter_buf, + output_buf, algorithm); + default: + return InternalError( + "Only convolution kinds kForward and kForwardActivation are " + "supported for integer types"); + } + return Status::OK(); +} + +template +Status RunCudnnConvImpl(const CudnnConvParams& params, + se::ScratchAllocator* scratch_allocator, + se::Stream* stream, RunConvOptions options) { + auto input_buf = se::DeviceMemory(params.input_buf); + auto filter_buf = se::DeviceMemory(params.filter_buf); + auto output_buf = se::DeviceMemory(params.output_buf); + AlgorithmConfig algorithm = params.algorithm; + + if (options.algo_override) { + algorithm = AlgorithmConfig(*options.algo_override); + } + + Status run_status = + RunCudnnConvInternalImpl( + params, scratch_allocator, stream, options, input_buf, filter_buf, + output_buf, algorithm); + + if (run_status != Status::OK()) { + return run_status; + } if (!stream->ok()) { return InternalError( - "Unable to launch convolution with type %s and algorithm (%d, %d)", + "Unable to launch convolution with type %s and algorithm (%d, %s)", CudnnConvKindToString(params.kind), algorithm.algorithm()->algo_id(), - algorithm.algorithm_no_scratch()->algo_id()); + algorithm.algorithm_no_scratch().has_value() + ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id()) + : "none"); } return Status::OK(); } @@ -372,18 +454,31 @@ Status RunCudnnConv(const HloCustomCallInstruction* conv, TF_ASSIGN_OR_RETURN(CudnnConvParams params, GetCudnnConvParams(conv, operand_buffers, result_buffer)); - PrimitiveType output_primitive_type = - conv->shape().tuple_shapes(0).element_type(); - switch (output_primitive_type) { + PrimitiveType input_primitive_type = conv->operand(0)->shape().element_type(); + switch (input_primitive_type) { case F16: - return RunCudnnConvImpl(params, scratch_allocator, stream, - options); + return RunCudnnConvImpl( + params, scratch_allocator, stream, options); case F32: - return RunCudnnConvImpl(params, scratch_allocator, stream, - options); + return RunCudnnConvImpl(params, scratch_allocator, + stream, options); case F64: - return RunCudnnConvImpl(params, scratch_allocator, stream, - options); + return RunCudnnConvImpl(params, scratch_allocator, + stream, options); + case S8: { + PrimitiveType output_primitive_type = + conv->shape().tuple_shapes(0).element_type(); + switch (output_primitive_type) { + case F32: + return RunCudnnConvImpl(params, scratch_allocator, + stream, options); + case S8: + return RunCudnnConvImpl(params, scratch_allocator, + stream, options); + default: + LOG(FATAL) << conv->ToString(); + } + } default: LOG(FATAL) << conv->ToString(); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc index dee257a5d97..aca7307e0c2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.cc @@ -223,6 +223,7 @@ StatusOr> TryRewriteToCudnnForwardRelu( } auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_feature_group_count(conv->feature_group_count()); new_conv->set_window(conv->window()); new_conv->set_convolution_dimension_numbers( conv->convolution_dimension_numbers()); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc index 7aa442d3bff..b621880f639 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc @@ -163,6 +163,26 @@ TEST_F(CudnnFusedConvRewriterTest, TestScaledConv) { })"); } +TEST_F(CudnnFusedConvRewriterTest, TestNoCrashOnInf) { + EXPECT_TRUE(RunAndCompare(R"( + HloModule Test + + ENTRY Test { + zero = f32[] constant(inf) + zeros = f32[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = f32[] constant(0.999994934) + + input = f32[1,17,9,9] parameter(0) + filter = f32[3,3,17,32] parameter(1) + + conv = f32[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = f32[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = f32[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = f32[1,32,9,9] maximum(zeros, scaled_conv) + })", + ErrorSpec{0.01})); +} + TEST_F(CudnnFusedConvRewriterTest, TestScaledConvAndSideInput) { // max(0, conv(x, w) + 0.899994934 * side_input); TestMatchWithAllTypes(R"( @@ -305,6 +325,30 @@ TEST_F(CudnnFusedConvRewriterTest, PreservesMetadata) { ::testing::ContainsRegex(R"(custom-call.*metadata=\{op_type="foo"\})")); } +TEST_F(CudnnFusedConvRewriterTest, TestPreservesFeatureGroupCount) { + // The convolution below would crash if feature_count is not preserved. + const char* kHloString = R"( + HloModule jaxpr_computation__6.19 + + primitive_computation__1.4 { + parameter.5 = f32[] parameter(0) + parameter.6 = f32[] parameter(1) + ROOT add.7 = f32[] add(parameter.5, parameter.6) + } + + ENTRY jaxpr_computation__7.8 { + parameter.11 = f32[2,64,64,53]{3,2,1,0} parameter(1) + parameter.10 = f32[3,3,1,53]{3,2,1,0} parameter(0) + convolution.12 = f32[2,64,64,53]{3,2,1,0} convolution(parameter.11, parameter.10), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=53 + constant.13 = f32[] constant(0) + broadcast.14 = f32[2,64,64,53]{3,2,1,0} broadcast(constant.13), dimensions={} + maximum.15 = f32[2,64,64,53]{3,2,1,0} maximum(convolution.12, broadcast.14) + ROOT reduce.17 = f32[] reduce(maximum.15, constant.13), dimensions={0,1,2,3}, to_apply=primitive_computation__1.4 + } + )"; + EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.01})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_test.cc b/tensorflow/compiler/xla/service/gpu/custom_call_test.cc index c04f6fb7bf5..53a3ca14400 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_test.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_test.cc @@ -90,67 +90,25 @@ void Callback_SubBuffers(CUstream stream, void** buffers, const char* /*opaque*/, size_t /*opaque_len*/) { // `buffers` is a flat array containing device pointers to the following. // - // 0: root tuple of param 0 - // 1: param 0 at tuple index {0}, shape f32[128] - // 2: param 0 at tuple index {1}, shape f32[256] - // 3: root tuple of param 1 - // 4: param 1 at tuple index {0}, shape f32[1024] - // 5: param 1 at tuple index {1}, shape f32[8] - // 6: root tuple of custom-call result - // 7: result at tuple index {0}, shape f32[8] - // 8: result at tuple index {1}, shape (f32[128], f32[256]) - // 9: result at tuple index {1, 0}, shape f32[128] - // 10: result at tuple index {1, 1}, shape f32[256] - // 11: result at tuple index {2}, shape f32[1024] + // 0: param 0 at tuple index {0}, shape f32[128] + // 1: param 0 at tuple index {1}, shape f32[256] + // 2: param 1 at tuple index {0}, shape f32[1024] + // 3: param 1 at tuple index {1}, shape f32[8] + // 4: result at tuple index {0}, shape f32[8] + // 5: result at tuple index {1, 0}, shape f32[128] + // 6: result at tuple index {1, 1}, shape f32[256] + // 7: result at tuple index {2}, shape f32[1024] // - // It's the contract of custom-call that the non-root pointers (i.e. - // everything other than indices 0, 3, and 6) may be null, if XLA is unable to - // analyze the program well enough to determine for sure what's in those - // buffers. For this simple example, all of the buffers should be non-null. - // Check the param 0 tuple, namely that - // - // (*buffers[0])[0] == buffers[1] and - // (*buffers[0])[1] == buffers[2]. - // - // because buffers contains pointers to device memory, we have to retrieve - // these values via cudaMemcpy. - void* p0[2]; - cudaMemcpy(p0, buffers[0], 2 * sizeof(void*), cudaMemcpyDeviceToHost); - ASSERT_EQ(p0[0], buffers[1]); - ASSERT_EQ(p0[1], buffers[2]); - - // Check the param 1 tuple, namely that - // - // (*buffers[3])[0] == buffers[4] - // (*buffers[3])[1] == buffers[5]. - void* p1[2]; - cudaMemcpy(p1, buffers[3], 2 * sizeof(void*), cudaMemcpyDeviceToHost); - ASSERT_EQ(p1[0], buffers[4]); - ASSERT_EQ(p1[1], buffers[5]); - - // We don't have an equivalent check for the output tuple (i.e. we don't check - // (*buffers[6])[0] == buffers[7]) because it's up to us to set the tuple - // as part of this custom-call. - - // Write the results. First set the root tuple output buffer to {b7, b8, - // b11}. - void* root[3] = {buffers[7], buffers[8], buffers[11]}; - cudaMemcpy(buffers[6], root, 3 * sizeof(void*), cudaMemcpyHostToDevice); - - // Now set the sub-tuple output buffer at index 8 to {b9, b10}. - void* sub_tuple[2] = {buffers[9], buffers[10]}; - cudaMemcpy(buffers[8], sub_tuple, 2 * sizeof(void*), cudaMemcpyDeviceToHost); - - // Now set output leaf buffers 7, 9, 10, and 11, copying data from the - // corresponding same-sized inputs. - cudaMemcpyAsync(buffers[7], buffers[5], 8 * sizeof(float), + // Set output leaf buffers, copying data from the corresponding same-sized + // inputs. + cudaMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(buffers[9], buffers[1], 128 * sizeof(float), + cudaMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(buffers[10], buffers[2], 256 * sizeof(float), + cudaMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float), cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(buffers[11], buffers[4], 1024 * sizeof(float), + cudaMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float), cudaMemcpyDeviceToDevice, stream); } XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, "CUDA"); @@ -185,5 +143,45 @@ TEST_F(CustomCallTest, SubBuffers) { EXPECT_THAT(result.data({2}), ::testing::Each(3)); } +void Callback_TupleSelect(CUstream stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) { + // Set the two output leaf buffers equal to the two input leaf buffers. + cudaMemcpyAsync(buffers[2], buffers[0], 10 * sizeof(float), + cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(buffers[3], buffers[1], 10 * sizeof(float), + cudaMemcpyDeviceToDevice, stream); +} +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_TupleSelect, "CUDA"); +// Tuple-shaped select is a case where XLA can't know all buffer assignments +// statically ahead of time and has to walk the on-device tuple sub-buffers. +TEST_F(CustomCallTest, TupleSelect) { + XlaBuilder b(TestName()); + auto tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {10}), + ShapeUtil::MakeShape(F32, {10}), + }); + auto p0 = AddParam(LiteralUtil::CreateR0(false), &b); + auto p1 = + AddParam(LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1(std::vector(10, 1.0f)), + LiteralUtil::CreateR1(std::vector(10, 2.0f))), + &b); + auto p2 = + AddParam(LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1(std::vector(10, 10.0f)), + LiteralUtil::CreateR1(std::vector(10, 20.0f))), + &b); + auto cc = CustomCall(&b, "Callback_TupleSelect", + /*operands=*/{Select(p0, p1, p2)}, tuple_shape, + /*opaque=*/""); + + // Do a tuple-select on the custom-call result to ensure that the custom-call + // sets its output tuple index buffers. + Select(p0, p1, cc); + TF_ASSERT_OK_AND_ASSIGN(auto result, ComputeAndTransfer(&b, {})); + EXPECT_THAT(result.data({0}), ::testing::Each(10)); + EXPECT_THAT(result.data({1}), ::testing::Each(20)); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index 5fba64e90ed..65673106391 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -48,8 +48,83 @@ CustomCallThunk::CustomCallThunk( instr->shape().ToString(), result_slices.shape().ToString()); } +// For each leaf in a preorder traversal of `slices`, appends its device address +// to `buffers`. +// +// In the common case, this is trivial; simply iterate over the ShapeTree and +// add every leaf to `buffers`. But under some circumstances XLA doesn't +// statically know the address of a leaf buffer and has to derive it by walking +// the on-device tuple. +static Status AppendBuffersFor(const ShapeTree& slices, + const BufferAllocations* buffer_allocations, + se::Stream* stream, + std::vector* buffers) { + // Buffer addresses we've retrieved by following device tuples. + ShapeTree retrieved_addrs(slices.shape()); + + // We make this lambda an std::function so it can capture itself. + std::function(const ShapeIndexView&)> get_addr_for = + [&](ShapeIndexView index) -> StatusOr { + auto slice = slices.element(index); + + // If we know the address of this sub-buffer statically, return it. + if (slice.allocation() != nullptr) { + return buffer_allocations->GetDeviceAddress(slice).opaque(); + } + // If we've already pulled the address for this sub-buffer down from the + // GPU, return it. + if (retrieved_addrs.element(index) != nullptr) { + return retrieved_addrs.element(index); + } + + // Recurse to get the address of the parent sub-buffer. + CHECK(!index.empty()) << "Address of tuple root cannot be unknown!"; + TF_ASSIGN_OR_RETURN(void* parent_buffer, get_addr_for(index.ConsumeBack())); + + // Pull down the entirety of parent_buffer from the GPU, getting the address + // we're interested in plus all of its siblings. (Perhaps only some of the + // siblings are unknown and we could get away without retrieving all of + // them. But in practice, getting them all in one fell swoop should be just + // as fast as getting just one.) + // + // TODO(jlebar): This is not as efficient as possible. In particular, at + // the expense of some complexity we could batch up multiple parallel D2H + // copies (say for multiple unrelated sub-buffers, maybe even across + // different parameters) and do just one BlockHostUntilDone. Hopefully the + // case when we have to do any copies at all is uncommon. + int64 num_siblings = + ShapeUtil::GetSubshape(slices.shape(), index.ConsumeBack()) + .tuple_shapes_size(); + std::vector sibling_addrs(num_siblings); + TF_RETURN_IF_ERROR( + stream + ->ThenMemcpy(sibling_addrs.data(), + se::DeviceMemoryBase(parent_buffer, sizeof(void*)), + num_siblings * sizeof(void*)) + .BlockHostUntilDone()); + + // Save the data we retrieved into retrieved_addrs. + for (int64 i = 0; i < num_siblings; ++i) { + ShapeIndex sibling_index(index.ConsumeBack()); + sibling_index.push_back(i); + *retrieved_addrs.mutable_element(sibling_index) = sibling_addrs[i]; + } + return sibling_addrs[index.back()]; + }; + + return slices.ForEachElementWithStatus( + [&](const ShapeIndex& index, const BufferAllocation::Slice&) { + if (slices.IsLeaf(index)) { + TF_ASSIGN_OR_RETURN(void* addr, get_addr_for(index)); + buffers->push_back(addr); + } + return Status::OK(); + }); +} + Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { // gpu_stream is CUstream or e.g. the equivalent type in ROCm. + se::Stream* stream = params.stream; auto gpu_stream = se::gpu::AsGpuStreamValue(params.stream); auto typed_call_target = reinterpret_cast buffers; - auto append_buffers = [&](const ShapeTree& slices) { - slices.ForEachElement([&](const ShapeIndex& /*index*/, - const BufferAllocation::Slice& slice) { - if (slice.allocation() == nullptr) { - buffers.push_back(nullptr); - } - buffers.push_back( - params.buffer_allocations->GetDeviceAddress(slice).opaque()); - }); - }; for (const auto& slices : operand_slices_) { - append_buffers(slices); + TF_RETURN_IF_ERROR( + AppendBuffersFor(slices, params.buffer_allocations, stream, &buffers)); } - append_buffers(result_slices_); + TF_RETURN_IF_ERROR(AppendBuffersFor(result_slices_, params.buffer_allocations, + stream, &buffers)); typed_call_target(gpu_stream, buffers.data(), opaque_.data(), opaque_.size()); - return Status::OK(); + + // If the custom-call returns a tuple, populate the result tuple index + // buffers. + return result_slices_.ForEachElementWithStatus( + [&](const ShapeIndex& index, const BufferAllocation::Slice& slice) { + const Shape& subshape = + ShapeUtil::GetSubshape(result_slices_.shape(), index); + auto n = subshape.tuple_shapes_size(); + if (!subshape.IsTuple() || n == 0) { + return Status::OK(); + } + auto tuple_ptrs = absl::make_unique(n); + ShapeIndex subindex(index); + for (int i = 0; i < n; ++i) { + subindex.push_back(i); + tuple_ptrs[i] = + params.buffer_allocations + ->GetDeviceAddress(result_slices_.element(subindex)) + .opaque(); + subindex.pop_back(); + } + SafeH2DMemcpy(se::DeviceMemory( + params.buffer_allocations->GetDeviceAddress(slice)), + std::move(tuple_ptrs), n, stream); + return Status::OK(); + }); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_blacklist.pbtxt b/tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_blacklist.pbtxt new file mode 100644 index 00000000000..5f22429962c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_blacklist.pbtxt @@ -0,0 +1,17 @@ +entries { + hlo: '(f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}"' + cc: {major: 7, minor: 0} + cudnn_version: {major: 7, minor: 6, patch: 0} + blas_version: "9000" + algos: [{}, {tensor_ops: true}, {id: 1}, {id:1, tensor_ops: true}] +} + +entries { + hlo: '(f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}"' + cc: {major: 7, minor: 0} + cudnn_version: {major: 7, minor: 6, patch: 2} + blas_version: "9000" + algos: [{}, {tensor_ops: true}, {id: 1}, {id:1, tensor_ops: true}] +} + + diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c0cd4addc7e..c6df786fb51 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" #include + #include #include @@ -144,7 +145,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { - return Unimplemented("Input type ≠ output type: %s ≠ %s", + return Unimplemented("Input type != output type: %s != %s", PrimitiveType_Name(input_type), PrimitiveType_Name(output_type)); } @@ -152,7 +153,7 @@ StatusOr GpuElementalIrEmitter::EmitMathCall( return EmitDeviceFunctionCall( callee_name, operands, input_types, output_type, - {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}); + {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind}, b_); } StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( @@ -269,8 +270,19 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, // Upcast F16 to F32 if necessary. llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); llvm::Value* input = FPCast(value, type); + + // If |value| >= kMaxValue, tanh() is set to -1.0 or 1.0. + constexpr double kMaxValue = 20.0; + auto max_value = llvm::ConstantFP::get(type, kMaxValue); + llvm::Value* abs_value = + llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {input}, {type}, b_); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return FPCast(fast_tanh, value->getType()); + auto one = llvm::ConstantFP::get(type, 1.0); + auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, + {one, input}, {type}, b_); + return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign), + value->getType()); } StatusOr GpuElementalIrEmitter::EmitComplexAbs( @@ -280,47 +292,16 @@ StatusOr GpuElementalIrEmitter::EmitComplexAbs( {prim_type, prim_type}, prim_type); } -llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( - const string& callee_name, absl::Span operands, - absl::Span input_types, PrimitiveType output_type, - absl::Span attributes) { - std::vector ir_input_types; - for (PrimitiveType input_type : input_types) { - ir_input_types.push_back( - llvm_ir::PrimitiveTypeToIrType(input_type, module_)); - } - llvm::FunctionType* callee_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type. - ir_input_types, // Parameter types. - false); // No variadic arguments. - - // Declares the callee if it is not declared already. - llvm::Function* callee = llvm::dyn_cast( - b_->GetInsertBlock() - ->getModule() - ->getOrInsertFunction(callee_name, callee_type) - .getCallee()); - - for (auto attribute : attributes) { - callee->addFnAttr(attribute); - } - - return Call(callee, llvm_ir::AsArrayRef(operands)); -} - llvm::Value* GpuElementalIrEmitter::EmitThreadId() { - llvm::Value* block_id = - IntCast(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = - IntCast(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = - IntCast(llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + llvm::Value* block_id = IntCast( + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = IntCast( + EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = IntCast( + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockDimx, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } @@ -408,7 +389,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( SDiv(input_multi_index[i], index_typed_const(window.dimensions(i).base_dilation())); - // We must check whether 0 ≤ input_multi_index[i] < bound, as + // We must check whether 0 <= input_multi_index[i] < bound, as // otherwise we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_multi_index[i] < bound, as a negative value wraps to a large diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index db4918c5890..c8a58a21980 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -100,13 +100,6 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* lhs_value, llvm::Value* rhs_value); - // Emits IR to call a device function named "callee_name" on the given - // operand. Returns the IR value that represents the return value. - llvm::Value* EmitDeviceFunctionCall( - const string& callee_name, absl::Span operands, - absl::Span input_type, PrimitiveType output_type, - absl::Span attributes); - // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the // return value of the function. diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index da90ba989dc..991a463f2a0 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -32,20 +32,20 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, se::DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { +int64 FftScratchAllocator::GetMemoryLimitInBytes() { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } StatusOr> FftScratchAllocator::AllocateBytes( - se::Stream* stream, int64 byte_size) { + int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { + if (byte_size > GetMemoryLimitInBytes()) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, absl::StrFormat( "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes(stream))); + byte_size, GetMemoryLimitInBytes())); } TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index be77df1eb77..95186c7f219 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -40,12 +40,12 @@ class FftScratchAllocator : public se::ScratchAllocator { FftScratchAllocator(int device_ordinal, se::DeviceMemoryAllocator* memory_allocator); - int64 GetMemoryLimitInBytes(se::Stream* stream) override; + int64 GetMemoryLimitInBytes() override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } se::port::StatusOr> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + int64 byte_size) override; private: const int device_ordinal_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 626bef76b98..98d8d00b62c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -65,6 +65,9 @@ static StatusOr> DoUncachedGemmAutotune( return InternalError("Failed to synchronize GPU for autotuning."); } + GemmBackendConfig backend_config = + gemm->backend_config().ValueOrDie(); + VLOG(3) << "Starting autotune of GemmThunk " << gemm->ToString(); std::vector algorithms; @@ -76,7 +79,7 @@ static StatusOr> DoUncachedGemmAutotune( for (se::blas::AlgorithmType algorithm : algorithms) { // Make sure the output buffer always has the same value if we use // the bias parameter. - if (gemm->backend_config().ValueOrDie().beta() != 0) { + if (backend_config.beta() != 0) { int64 rng_state = 0; InitializeFloatBuffer(stream, gemm->shape().element_type(), &rng_state, output_buffer); @@ -87,7 +90,8 @@ static StatusOr> DoUncachedGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - CHECK(RunGemm(gemm, lhs_buffer, rhs_buffer, output_buffer, stream, + CHECK(RunGemm(gemm, backend_config, lhs_buffer, rhs_buffer, output_buffer, + stream, /*implements_whole_instruction=*/true, /*profiler=*/nullptr, /*profile_result=*/&profile_result, algorithm) @@ -110,7 +114,7 @@ static StatusOr> DoUncachedGemmAutotune( TF_ASSIGN_OR_RETURN( se::cuda::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones(stream)); + allocator.CheckRedzones()); if (!rz_check_status.ok()) { result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); *result.mutable_failure()->mutable_msg() = @@ -235,17 +239,22 @@ static StatusOr> DoGemmAutotune( static StatusOr RunOnInstruction(HloInstruction* instr, se::StreamExecutor* executor, se::DeviceMemoryAllocator* allocator) { - se::Stream stream{executor}; - stream.Init(); - if (allocator == nullptr) { allocator = executor->GetAllocator(); } + absl::optional stream_opt; + se::Stream* stream = [&]() { + if (allocator->GetStream()) { + return allocator->GetStream(); + } + stream_opt.emplace(executor); + stream_opt->Init(); + return &stream_opt.value(); + }(); const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); se::cuda::RedzoneAllocator input_output_allocator( - executor->device_ordinal(), allocator, - PtxOptsFromConfig(hlo_module_config)); + stream, allocator, PtxOptsFromConfig(hlo_module_config)); BufferComparator comparator(instr->shape(), hlo_module_config); @@ -254,8 +263,8 @@ static StatusOr RunOnInstruction(HloInstruction* instr, [&](const HloInstruction* op) -> StatusOr { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(op->shape()))); - InitializeFloatBuffer(&stream, op->shape().element_type(), &rng_state, + ShapeUtil::ByteSizeOf(op->shape()))); + InitializeFloatBuffer(stream, op->shape().element_type(), &rng_state, buffer); return buffer; }; @@ -280,11 +289,11 @@ static StatusOr RunOnInstruction(HloInstruction* instr, const bool crash_on_checking_failure = debug_options.xla_gpu_crash_on_verification_failures(); - TF_ASSIGN_OR_RETURN(absl::optional gemm_algorithm, - DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, - output_buffer, reference_result_buffer, - &stream, crash_on_checking_failure, - input_output_allocator, comparator)); + TF_ASSIGN_OR_RETURN( + absl::optional gemm_algorithm, + DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer, + reference_result_buffer, stream, crash_on_checking_failure, + input_output_allocator, comparator)); // We update instruction->backend_config(); if no algorithms are supported, // a different API is used, which does not require specifying an algorithm. diff --git a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc index df7ee3cdc69..bdf697acfba 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_rewriter.cc @@ -32,23 +32,6 @@ namespace gpu { namespace m = match; -static complex128 GetScalarConstantAsComplex(const Literal &literal) { - switch (literal.shape().element_type()) { - case F16: - return {static_cast(literal.Get({})), 0}; - case F32: - return {literal.Get({}), 0}; - case F64: - return {literal.Get({}), 0}; - case C64: - return literal.Get({}); - case C128: - return literal.Get({}); - default: - LOG(FATAL) << "Unexpected type: " << literal.shape(); - } -} - // The rewriting proceeds in a bottom-up way: // // (kDot A B) is rewritten into a (kCustomCall:gemm A B) @@ -103,7 +86,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (config.beta() == 0.0 && existing_gemm->user_count() == 1) { complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()}; complex128 new_alpha = - GetScalarConstantAsComplex(alpha->literal()) * prev_alpha; + *alpha->literal().GetAsComplex128({}) * prev_alpha; config.set_alpha_real(new_alpha.real()); config.set_alpha_imag(new_alpha.imag()); TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config)); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index eddc2474830..d52e5410dab 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -37,12 +37,14 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice &lhs_buffer, const BufferAllocation::Slice &rhs_buffer, const BufferAllocation::Slice &output_buffer, bool implements_whole_instruction, - const HloInstruction *hlo_instruction) + const HloInstruction *hlo_instruction, + const GemmBackendConfig &backend_config) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), output_buffer_(output_buffer), - implements_whole_instruction_(implements_whole_instruction) {} + implements_whole_instruction_(implements_whole_instruction), + backend_config_(backend_config) {} Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { auto get_device_address = [&](const BufferAllocation::Slice &slice) { @@ -53,8 +55,9 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_); se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_); se::DeviceMemoryBase output_data = get_device_address(output_buffer_); - return RunGemm(hlo_instruction(), lhs_data, rhs_data, output_data, - params.stream, implements_whole_instruction_, params.profiler); + return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data, + output_data, params.stream, implements_whole_instruction_, + params.profiler); } // This struct contains the metadata of a matrix, e.g., its base address and @@ -152,8 +155,9 @@ static bool DoGemmWithAlgorithm( .ok(); } -Status RunGemm(const HloInstruction *gemm, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, +Status RunGemm(const HloInstruction *gemm, + const GemmBackendConfig &backend_config, + se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, se::Stream *stream, bool implements_whole_instruction, HloExecutionProfiler *profiler, @@ -162,8 +166,6 @@ Status RunGemm(const HloInstruction *gemm, se::DeviceMemoryBase lhs_buffer, VLOG(2) << "Executing a GemmThunk"; CHECK(IsCublasGemm(*gemm)); - TF_ASSIGN_OR_RETURN(GemmBackendConfig backend_config, - gemm->backend_config()); const Shape &output_shape = gemm->shape(); const HloInstruction *lhs = gemm->operand(0); const HloInstruction *rhs = gemm->operand(1); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index adf2fa853b7..b44cc40d295 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" @@ -42,7 +43,8 @@ class GemmThunk : public Thunk { const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, bool implements_whole_instruction, - const HloInstruction* hlo_instruction); + const HloInstruction* hlo_instruction, + const GemmBackendConfig& backend_config); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; @@ -54,23 +56,23 @@ class GemmThunk : public Thunk { const BufferAllocation::Slice rhs_buffer_; const BufferAllocation::Slice output_buffer_; bool implements_whole_instruction_; + GemmBackendConfig backend_config_; }; // Run the given GEMM instruction `gemm` subject to the configuration -// stored inside it's backend_config and the passed buffers. +// in `backend_config` and the passed buffers. // // `implements_whole_instruction` is used for the default profiler creation // if the `profiler` is not supplied. False value indicates that the created // profiler will not specifically profile the `gemm` instruction. // -// If `algorithm` is provided, it overrides the one specified in backend_config -// of gemm. -// +// If `algorithm` is provided, it overrides the one specified in +// `backend_config`. Status RunGemm( - const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - se::Stream* stream, bool implements_whole_instruction, - HloExecutionProfiler* profiler = nullptr, + const HloInstruction* gemm, const GemmBackendConfig& backend_config, + se::DeviceMemoryBase lhs_buffer, se::DeviceMemoryBase rhs_buffer, + se::DeviceMemoryBase output_buffer, se::Stream* stream, + bool implements_whole_instruction, HloExecutionProfiler* profiler = nullptr, se::blas::ProfileResult* profile_result = nullptr, absl::optional algorithm = absl::nullopt); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto index 6ed72437bec..35b5cfacb2d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto +++ b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto @@ -6,6 +6,7 @@ package xla.gpu; import "tensorflow/compiler/xla/service/hlo.proto"; import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/core/protobuf/autotuning.proto"; message ConvInstructionLog { xla.HloInstructionProto instruction = 1; @@ -13,3 +14,20 @@ message ConvInstructionLog { uint64 result_address = 3; repeated uint64 operand_addresses = 4; } + +message BlacklistedAlgorithm { + int64 id = 1; + bool tensor_ops = 2; +} + +message AlgorithmBlacklistEntry { + string hlo = 1; + tensorflow.ComputeCapability cc = 2; + tensorflow.CudnnVersion cudnn_version = 3; + string blas_version = 5; + repeated BlacklistedAlgorithm algos = 4; +} + +message AlgorithmBlacklist { + repeated AlgorithmBlacklistEntry entries = 1; +} diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc new file mode 100755 index 00000000000..de3b1efd03a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -0,0 +1,474 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" + +#include + +#include +#include +#include // NOLINT(build/c++11): only using std::call_once, not mutex. +#include + +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/conditional_simplifier.h" +#include "tensorflow/compiler/xla/service/depthwise_convolution_converter.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" +#include "tensorflow/compiler/xla/service/dump.h" +#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" +#include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" +#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_constant_folding.h" +#include "tensorflow/compiler/xla/service/hlo_cse.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_pass_fix.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/mem_wasted_on_passthrough_params.h" +#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" +#include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/rng_expander.h" +#include "tensorflow/compiler/xla/service/slice_sinker.h" +#include "tensorflow/compiler/xla/service/slow_operation_alarm.h" +#include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/stable_sort_expander.h" +#include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h" +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/profiler/lib/traceme.h" + +namespace xla { +namespace gpu { + +GpuCompiler::GpuCompiler(se::Platform::Id platform_id, + const char* target_triple, const char* data_layout) + : platform_id_(platform_id), + target_triple_(target_triple), + data_layout_(data_layout), + pointer_size_(llvm::DataLayout(data_layout) + .getPointerSize(0 /* default address space */)) {} + +// Runs optimization passes on the given HLO module. +Status GpuCompiler::OptimizeHloModule( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + { + HloPassPipeline pipeline("optimization"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + + // Expand random number generation. + pipeline.AddPass(); + + // Remove zero-sized HLO from the input so that other passes don't have to + // handle it. + pipeline.AddPass(); + + pipeline.AddPass(); + + pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( + &pipeline, hlo_module->config().debug_options(), + ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + + // TODO(b/64094172): make Call work on GPU instead of inlining. + pipeline.AddPass(); + auto cost_model = [](HloInstruction* conv) { + // We need a cost model for GPUs. Currently, do nothing. + return false; + }; + pipeline.AddPass(); + pipeline.AddPass(cost_model); + // Expand the sort op to support stable sorting if required. + pipeline.AddPass(); + // Convert BF16 operations to F32 operations so that the GPU backend can + // support BF16 operations without directly implementing a BF16 lowering for + // most ops. + pipeline.AddPass(BF16, F32); + + { + auto& pass = + pipeline.AddPass>("simplification"); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + + // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls + // where possible. Not every batchnorm op can be implemented as a call to + // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs. + if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { + pass.AddPass(); + } + pass.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + + pipeline.AddPass(); + + // BatchNormExpander can create zero-sized ops, so zero-sized HLO + // elimination has to come after that pass. + pipeline.AddPass(); + + AlgebraicSimplifierOptions options; + pass.AddPass(options); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(); + + // TODO(b/134075051): Re-enable after b/134075051 is fixed. + // pass.AddPass(); + + pass.AddPass(); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(); + } + + pipeline.AddPass( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return IsMatrixMultiplication(dot) + ? candidate_operands + : TransposeFolding::OperandIndices{}; + }, + TransposeFolding::NeverFoldTranspose); + pipeline.AddPass(/*is_layout_sensitive=*/false); + pipeline.AddPass(); + + // Run WhileLoopTripCountAnnotator at the end of the simplification + // pipeline, before layout assignment and fusion. This pass does some + // pattern-matching on while bodies/conditions, and this is where the HLO is + // "nicest". + // + // It's important that we don't make semantic changes (e.g. unrolling) to + // any `while` loops after this point, because otherwise the trip-count + // annotations added by this pass may not be correct after the + // modifications. + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + // Run target-specific HLO optimization passes for convolution + // canonicalization. + TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization( + hlo_module, stream_exec, device_allocator)); + + { + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); + pipeline.AddPass( + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + // Run target-specific HLO optimization passes after layout assignment. + TF_RETURN_IF_ERROR(OptimizeHloPostLayoutAssignment(hlo_module, stream_exec, + device_allocator)); + + { + HloPassFix fusion("fusion"); + // We try to split variadic ops with many parameters into several such ops + // to avoid exceeding the parameter space. + fusion.AddPass(); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + fusion.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + fusion.AddPass(/*may_duplicate=*/false); + fusion.AddPass(/*may_duplicate=*/true); + fusion.AddPass(); + fusion.AddPass(); + fusion.AddPass(/*is_layout_sensitive=*/true, + /*only_fusion_computations=*/true); + fusion.AddPass(); + TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); + + HloPassPipeline reduce_pipeline("reduce-precision"); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + reduce_pipeline.AddInvariantChecker( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + ReducePrecisionInsertion::AddPasses( + &reduce_pipeline, hlo_module->config().debug_options(), + ReducePrecisionInsertion::PassTiming::AFTER_FUSION); + StatusOr reduce_result = reduce_pipeline.Run(hlo_module); + TF_RETURN_IF_ERROR(reduce_result.status()); + + if (reduce_result.ValueOrDie()) { + // Do another fusion pass, with the expectation that we may be able to + // fuse the new ReducePrecision operations. + TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); + } + } + + return Status::OK(); +} + +// Modifies the given HLO module so that it will be accepted by IrEmitter. +// Unlike optimization passes, the passes are necessary for correctness. +Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { + // In some cases, we have to place the result of an instruction in a temporary + // buffer. For instance, the buffer that holds an external parameter is + // assumed immutable at this point, and should not be reused for output + // (b/27180329). Therefore, in that case, we set the output to be a copy of + // the parameter. + HloPassPipeline pipeline("GPU-ir-emit-prepare"); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + + // Copy insertion should be performed immediately before IR emission to avoid + // inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes an + // instruction which materializes a value). DCE must be run immediately before + // (and sometime after) copy insertion, to avoid dead code from interfering + // with the rewrites. + pipeline.AddPass(); + pipeline.AddPass(); + // The following pass LOGs memory waste. Add it when VLOGing is enabled only. + if (VLOG_IS_ON(2)) { + pipeline.AddPass(); + } + pipeline.AddPass(GetCanShareBuffer()); + pipeline.AddPass(); + return pipeline.Run(hlo_module).status(); +} + +StatusOr> GpuCompiler::RunHloPasses( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // We dump the post-optimization HLO in RunBackend so no need to dump it here. + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); + tensorflow::profiler::TraceMe activity( + [&] { return absl::StrCat("HLO Transforms:", module->name()); }, + tensorflow::profiler::TraceMeLevel::kInfo); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec, device_allocator)); + + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); + + return std::move(module); +} + +StatusOr> GpuCompiler::RunBackend( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); + auto slow_compile_alarm = SlowCompilationAlarm(); + + TF_RET_CHECK(stream_exec != nullptr); + + llvm::LLVMContext llvm_context; + std::string buffer; + llvm::raw_string_ostream error(buffer); + llvm::DiagnosticPrinterRawOStream printer(error); + auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info, + void* Context) { + auto printer = static_cast(Context); + diag_info.print(*printer); + }; + llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); + + llvm::Module llvm_module(module->name().c_str(), llvm_context); + // Set the target triple and the data layout. + llvm_module.setTargetTriple(target_triple_); + llvm_module.setDataLayout(data_layout_); + + // Determine the HLO schedule, which is an ordering of HLO instructions. This + // is used by buffer assignment to enable buffer reuse, and the same ordering + // must also be used to determine the thunk launch schedule. + std::unique_ptr stream_assignment = AssignStreams(*module); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); + + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr buffer_assignment, + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, GetCanShareBuffer())); + DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); + + IrEmitterContext ir_emitter_context( + module.get(), buffer_assignment.get(), stream_exec->platform(), + &stream_exec->GetDeviceDescription(), &llvm_module); + + HloComputation* entry_computation = module->entry_computation(); + IrEmitterUnnested ir_emitter(module->config(), entry_computation, + &ir_emitter_context); + + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); + } + + if (user_pre_optimization_hook_) { + user_pre_optimization_hook_(llvm_module); + } + string ir_module_string_before_opt; + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + if (embed_ir_in_executable) { + ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); + } + + llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false); + + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_to to get the IR. "; + } + + GpuVersion gpu_version = GetGpuVersion(stream_exec); + + using BackendCompileResult = std::pair>; + TF_ASSIGN_OR_RETURN(BackendCompileResult backend_result, + CompileTargetBinary(module.get(), &llvm_module, + gpu_version, stream_exec)); + + auto thunk_schedule = absl::make_unique( + ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), + hlo_schedule->ThunkLaunchOrder()); + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "thunk_schedule", + thunk_schedule->ToString()); + } + + std::unique_ptr profile_index_map; + std::unique_ptr profile_printer; + + if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { + HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + cost_analysis.set_bytes_per_second( + stream_exec->GetDeviceDescription().memory_bandwidth()); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); + VLOG(1) << "HLO memory read+written: " + << tensorflow::strings::HumanReadableNumBytes( + cost_analysis.bytes_accessed()); + if (module->config().hlo_profiling_enabled()) { + profile_index_map = absl::make_unique(*module); + profile_printer = CreateHloProfilePrinterData( + *profile_index_map, cost_analysis, entry_computation->name()); + } + } + + auto* gpu_executable = new GpuExecutable( + backend_result.first, backend_result.second, gpu_version, + std::move(thunk_schedule), std::move(module), + std::move(buffer_assignment), std::move(profile_printer), + std::move(profile_index_map)); + if (embed_ir_in_executable) { + DCHECK_NE("", ir_module_string_before_opt); + gpu_executable->set_ir_module_string(ir_module_string_before_opt); + } + return std::unique_ptr(gpu_executable); +} + +StatusOr>> +GpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, + const AotCompilationOptions& options) { + return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h new file mode 100644 index 00000000000..901d994d4ad --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" + +namespace xla { +namespace gpu { + +// The GPU compiler generates efficient GPU executables. +class GpuCompiler : public LLVMCompiler { + public: + GpuCompiler(se::Platform::Id platform_id, const char* target_triple, + const char* data_layout); + ~GpuCompiler() override {} + + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + + StatusOr> RunHloPasses( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + Status OptimizeHloModule(HloModule* hlo_module, + se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator); + + virtual Status OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) = 0; + + virtual Status OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) = 0; + + virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() { + return + [](const HloInstruction*, const HloInstruction*, + const ShapeIndex&) -> absl::optional { return absl::nullopt; }; + } + + virtual GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) = 0; + + virtual StatusOr>> + CompileTargetBinary(const HloModule* hlo_module, llvm::Module* llvm_module, + GpuVersion gpu_version, + se::StreamExecutor* stream_exec) = 0; + + Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); + + StatusOr> RunBackend( + std::unique_ptr module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) override; + + StatusOr>> + CompileAheadOfTime(std::unique_ptr module_group, + AotCompilationOptions const& options) override; + + se::Platform::Id PlatformId() const override { return platform_id_; } + + HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { + // Capture just the pointer size, not the entire GpuCompiler object. + int64 pointer_size = pointer_size_; + return [pointer_size](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + } + + private: + se::Platform::Id platform_id_; + + // The triple that represents our target. + const char* target_triple_; + + // The data layout of the emitted module. + const char* data_layout_; + + // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. + const int64 pointer_size_; + + TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index e4942bd76a6..abf2cd1f23f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -81,11 +81,12 @@ void GpuExecutable::ComputeThunkAnnotations() { for (Thunk* thunk : thunk_schedule_->TotalOrder()) { const HloInstruction* hlo = thunk->hlo_instruction(); CHECK(hlo); - thunk_annotations_[thunk] = absl::StrFormat( - "%s:#tf_op=%s,hlo_op=%s,hlo_module=%s#", - hlo->ToStringWithCanonicalNameMap(HloPrintOptions::Canonical(), - &canonical_name_map), - hlo->metadata().op_name(), hlo->name(), hlo->GetModule()->name()); + thunk_annotations_[thunk] = + absl::StrFormat("%s:#tf_op=%s:%s,hlo_op=%s,hlo_module=%s#", + hlo->ToStringWithCanonicalNameMap( + HloPrintOptions::Canonical(), &canonical_name_map), + hlo->metadata().op_name(), hlo->metadata().op_type(), + hlo->name(), hlo->GetModule()->name()); } } @@ -195,10 +196,11 @@ Status GpuExecutable::ExecuteThunks( } main_stream->ThenWaitFor(&sub_streams); - // Make sure kernels are completed before deallocating temporary buffers. + // Make sure kernels are completed before deallocating temporary buffers or + // the profiler state. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. - if (block_host_until_done) { + if (do_profile || block_host_until_done) { Status block_status = main_stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError( @@ -207,17 +209,20 @@ Status GpuExecutable::ExecuteThunks( } } + // FinishExecution() blocks until main_stream has completed if profiling is + // enabled; we therefore do not need to defer profile collection onto a + // stream. profiler.FinishExecution(); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - { - tensorflow::mutex_lock lock(mutex_); + if (run_options->run_options().execution_profile()) { + ExecutionProfile* profile = run_options->run_options().execution_profile(); const double nanoseconds = (end_micros - start_micros) * 1000.0; - execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); // If hlo profiling was disabled then the cycle count is left empty. if (do_profile) { - execution_profile_.set_compute_cycle_count( + profile->set_compute_cycle_count( hlo_execution_profile->total_cycles_executed( *module().entry_computation())); } @@ -241,8 +246,14 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { module_spec.AddCudaPtxInMemory(text().c_str()); absl::flat_hash_map globals; + if (executor->platform_kind() == se::PlatformKind::kCuda && + module_spec.cuda_ptx_in_memory() == nullptr) { + // No custom PTX => no globals. + return &module_globals_.emplace(executor, std::move(globals)).first->second; + } + se::ModuleHandle module_handle; - executor->LoadModule(module_spec, &module_handle); + TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle)); for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); ++i) { @@ -402,25 +413,16 @@ StatusOr GpuExecutable::Execute( return std::move(shaped_buffer); } -StatusOr GpuExecutable::ExecuteOnStream( +StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { - // TODO(b/134086343): ExecuteOnStream should not be async according to the - // documentation, instead ExecuteAsyncOnStream should be used. - return Execute(run_options, arguments, hlo_execution_profile, - /*block_host_until_done=*/ - !run_options->allocator()->AllowsAsynchronousDeallocation()); -} - -StatusOr GpuExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) { se::DeviceMemoryAllocator* memory_allocator = run_options->allocator(); // Force synchronous execution if the allocator requires it. bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); - return Execute(run_options, arguments, nullptr, block_host_until_done); + return Execute(run_options, arguments, hlo_execution_profile, + block_host_until_done); } const InstructionValueSet& GpuExecutable::GetRootValueSet() const { @@ -428,5 +430,14 @@ const InstructionValueSet& GpuExecutable::GetRootValueSet() const { module().entry_computation()->root_instruction()); } +int64 GpuExecutable::SizeOfGeneratedCodeInBytes() { + // Non-empty PTX but empty cubin: compilation must have failed, return + // "unknown". + if (binary().empty() && !text_.empty()) { + return -1; + } + return binary().size(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 5f9fe3e71ef..0175e31568c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -61,6 +61,8 @@ class GpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~GpuExecutable() override; + int64 SizeOfGeneratedCodeInBytes() override; + // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -78,17 +80,13 @@ class GpuExecutable : public Executable { // compilation is left up to the GPU driver. const std::vector& binary() const { return binary_; } - // ExecuteOnStream will fail if the compute capability of the stream doesn't - // match the compute capability passed to this object's constructor. - StatusOr ExecuteOnStream( + // ExecuteAsyncOnStream will fail if the compute capability of the stream + // doesn't match the compute capability passed to this object's constructor. + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) override; - std::shared_ptr GetBufferAssignment() const { return assignment_; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 2d266b9bc73..c5c79f63e81 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include +#include #include #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape.h" @@ -26,8 +28,8 @@ limitations under the License. namespace xla { namespace gpu { - namespace { + void AppendParams(const HloInstruction& instr, std::vector* params) { if (instr.opcode() == HloOpcode::kFusion) { @@ -39,6 +41,25 @@ void AppendParams(const HloInstruction& instr, } } } + +bool CodegensIntoLoop(const HloInstruction& instr) { + CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused."; + if (instr.opcode() == HloOpcode::kReduce && + !IsReductionFromOrToContiguousDimensions(instr)) { + return true; + } + // Reduce window codegens into loop only when windows overlap, i.e. stride is + // less than window size. + if (instr.opcode() == HloOpcode::kReduceWindow) { + for (const auto& dim : instr.window().dimensions()) { + if (dim.size() > dim.stride()) { + return true; + } + } + } + return false; +} + } // namespace bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer, @@ -202,19 +223,16 @@ bool IsProducerConsumerFusible(const HloInstruction& producer, if (!IsLoopFusible(producer) || !IsFusible(consumer)) { return false; } - // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { return false; } - // Do not fuse into reduce input fusions if the resulting kernel would suffer // from poor data locality (due to unfriendly input layouts). if (IsInputFusibleReduction(consumer) && !LayoutsAreReduceInputFusionFriendly(producer, consumer)) { return false; } - // We can't fuse library calls, so if a user of such an op could become a // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for // further rationale. @@ -222,7 +240,6 @@ bool IsProducerConsumerFusible(const HloInstruction& producer, ImplementedAsLibraryCall(*producer.operand(0))) { return false; } - // Fuse scalar constants into loop fusion nodes. This reduces the number of // parameters and makes matching scalar broadcasts easier. // @@ -235,7 +252,6 @@ bool IsProducerConsumerFusible(const HloInstruction& producer, return ShapeUtil::IsEffectiveScalar(producer.shape()) && consumer.opcode() == HloOpcode::kFusion; } - return true; } @@ -249,15 +265,12 @@ bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer, if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) { return false; } - if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) { return false; } - if (!LayoutsAreReduceInputFusionFriendly(producer, consumer)) { return false; } - return true; } @@ -323,6 +336,71 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion; } +bool CreatesNestedLoop(const HloInstruction& producer, + const HloInstruction& consumer) { + // If producer does not have an instruction that codegens a loop then there is + // nothing to do. + auto producer_has_loop_codegen = [&](const HloInstruction& instr) { + if (producer.opcode() != HloOpcode::kFusion) { + return CodegensIntoLoop(producer); + } + for (const auto& instr : producer.fused_instructions()) { + if (CodegensIntoLoop(*instr)) { + return true; + } + } + return false; + }; + if (!producer_has_loop_codegen(producer)) { + return false; + } + + // If consumer is a non-fusion instruction then we have to check if it + // generates a loop. + if (consumer.opcode() != HloOpcode::kFusion) { + return CodegensIntoLoop(consumer); + } + + // If consumer is a fusion then we have to check if the output of producer is + // used directly or indirectly as an input to an HLO instruction that + // generates a loop, i.e. there is a path in the graph from an operand + // corresponding to the producer to an HLO instruction generating a loop in + // the consumer. + for (const HloInstruction* operand : consumer.operands()) { + if (operand != &producer) { + continue; + } + + const HloInstruction* root = + consumer.fused_instructions_computation()->parameter_instruction( + consumer.operand_index(operand)); + + std::stack dfs; + dfs.push(root); + absl::flat_hash_set visited; + while (!dfs.empty()) { + const HloInstruction* cur = dfs.top(); + dfs.pop(); + + if (visited.contains(cur)) { + continue; + } + visited.insert(cur); + + if (CodegensIntoLoop(*cur)) { + return true; + } + for (const auto& user : cur->users()) { + if (visited.contains(user)) { + continue; + } + dfs.push(user); + } + } + } + return false; +} + bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { // We can fuse reduces and loop fusions. Elementwise instructions can be fused // with any other instruction. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index 4956bf096a0..145975e6f49 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -67,6 +67,11 @@ bool IsInputFusibleScatter(const HloInstruction& instr); bool FusionWouldBeTooLarge(const HloInstruction& instr1, const HloInstruction& instr2); +// Check if fusing producer and consumer will generate a nested loop, e.g. both +// producer and consumer are `reduce-window` HLO instructions. +bool CreatesNestedLoop(const HloInstruction& producer, + const HloInstruction& consumer); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. // This function works for both, sibling and producer-consumer multi-output diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 2879acecbce..550f4662b55 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -166,7 +166,7 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( // instr->operand(2), if exists, is the bias buffer. There is no need to // assign layout to it, as it has only one dimension. - // instr->opernad(3), if exists, is the side input buffer. + // instr->operand(3), if exists, is the side input buffer. if (instr->operand_count() == 4) { if (kind != CudnnConvKind::kForwardActivation) { return InternalError( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc new file mode 100644 index 00000000000..013fffe4fa8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc @@ -0,0 +1,88 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" + +namespace xla { +namespace gpu { + +constexpr absl::string_view kDefaultBlacklist = R"pb( + entries { + hlo: "(f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\"" + cc { major: 7 } + cudnn_version { major: 7 minor: 6 patch: 2 } + blas_version: "10201" + algos { id: 1 tensor_ops: true } + } + entries { + hlo: "(f16[7,7,4,64]{2,1,0,3}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[256,112,112,64]{3,2,1,0}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convBackwardFilter\", backend_config=\"{conv_result_scale:1}\"" + cc { major: 7 } + cudnn_version { major: 7 minor: 6 patch: 2 } + blas_version: "10201" + algos { id: 1 tensor_ops: true } + })pb"; + +absl::Span +GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, + tensorflow::CudnnVersion cudnn_version, + absl::string_view blas_version, + absl::string_view hlo) { + // Key is the tuple of canonicalized hlo, compute capability major/minor, + // cudnn version major/minor/patch, blas version. + using MapType = absl::flat_hash_map< + std::tuple, + std::vector>; + + static MapType* blacklist = [] { + MapType* list = new MapType(); + AlgorithmBlacklist proto; + std::string file_path = + GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path(); + if (!file_path.empty()) { + TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), + file_path, &proto)); + } else { + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + std::string(kDefaultBlacklist), &proto)); + } + for (const auto& entry : proto.entries()) { + for (const auto& algo : entry.algos()) { + (*list)[std::make_tuple( + std::string(entry.hlo()), entry.cc().major(), + entry.cc().minor(), entry.cudnn_version().major(), + entry.cudnn_version().minor(), + entry.cudnn_version().patch(), entry.blas_version())] + .push_back({algo.id(), algo.tensor_ops()}); + } + } + return list; + }(); + + auto iter = blacklist->find(std::make_tuple( + std::string(hlo), cc.major(), cc.minor(), cudnn_version.major(), + cudnn_version.minor(), cudnn_version.patch(), std::string(blas_version))); + if (iter != blacklist->end()) { + return iter->second; + } + return {}; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h new file mode 100644 index 00000000000..0120879e9d7 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/protobuf/autotuning.pb.h" + +namespace xla { +namespace gpu { + +absl::Span +GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, + tensorflow::CudnnVersion cudnn_version, + absl::string_view blas_version, + absl::string_view hlo); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc new file mode 100644 index 00000000000..2f2782bd4dc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/dnn.h" + +namespace xla { +namespace gpu { +namespace { + +class BlacklistTest : public testing::Test { + protected: + BlacklistTest() { + setenv("XLA_FLAGS", + absl::StrCat( + "--xla_gpu_algorithm_blacklist_path=", + tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "compiler", "xla", + "service", "gpu", "data", "hlo_algorithm_blacklist.pbtxt")) + .data(), + 0); + } +}; + +TEST_F(BlacklistTest, DefaultTest) { + tensorflow::ComputeCapability cc; + cc.set_major(7); + cc.set_minor(0); + tensorflow::CudnnVersion cudnn_version; + cudnn_version.set_major(7); + cudnn_version.set_minor(6); + cudnn_version.set_patch(2); + auto list = GetBlacklistedConvAlgorithms( + cc, cudnn_version, /*blas_version=*/"9000", + R"((f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}")"); + ASSERT_EQ(4, list.size()); + EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(0, false), list[0]); + EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(0, true), list[1]); + EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, false), list[2]); + EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, true), list[3]); +} + +TEST_F(BlacklistTest, NegativeTest) { + tensorflow::ComputeCapability cc; + cc.set_major(7); + cc.set_minor(0); + tensorflow::CudnnVersion cudnn_version; + cudnn_version.set_major(7); + cudnn_version.set_minor(6); + cudnn_version.set_minor(2); + auto list = + GetBlacklistedConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)"); + ASSERT_EQ(0, list.size()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 404d3347772..78f8e22a857 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -220,9 +220,9 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { } // For column reduction, the tile block is tize_size_y x tile_size_x, and we - // are reducing along tile_size_y. Both tile_size_x and tile_size_y need to be + // are reducing along tile_size_y. Only tile_size_y needs to be // large enough to make the tiling implementation efficient. - return dims_in_elem[2] >= kWarpSize && dims_in_elem[1] >= kWarpSize; + return dims_in_elem[1] >= kWarpSize; } std::pair GetReductionKindAndContiguousComponents( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index f380aee9d3c..16dc9cd284f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -134,14 +134,6 @@ class IrEmitter : public DfsHloVisitorWithDefault, std::vector ConstructIrArrayForOutputs( const HloInstruction& hlo); - // A convenient helper for calling BufferAssignment::GetUniqueSlice. - BufferAllocation::Slice GetAllocationSlice( - const HloInstruction& hlo, const ShapeIndex& index = {}) const { - return ir_emitter_context_->buffer_assignment() - .GetUniqueSlice(&hlo, index) - .ConsumeValueOrDie(); - } - // Emit a singlethreaded or multithreaded loop that computes every element in // the result of the given HLO instruction. This produces a series of nested // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 168156edf8e..0435daee143 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" @@ -37,37 +38,28 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" -#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h" #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" -#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" -#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" -#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -98,10 +90,6 @@ namespace xla { namespace gpu { using llvm_ir::KernelMappingScheme; -using EmitElementFunction = - std::function; - namespace { using absl::InlinedVector; @@ -358,238 +346,15 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { } Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { - // A CustomCall on the GPU backend can either be a custom-call to a - // user-supplied kernel, or a call into a library like cudnn. - - // Lower custom-calls to cudnn batchnorm ops to specialized thunks. It's part - // of the contract of these cudnn batchnorm calls that the epsilon and - // feature_index operands be constants. - if (custom_call->custom_call_target() == - kCudnnBatchNormForwardInferenceCallTarget) { - const HloInstruction* epsilon = custom_call->operand(5); - CHECK(epsilon->IsConstant()); - float epsilon_value = epsilon->literal().Get({}); - - const HloInstruction* feature_index = custom_call->operand(6); - CHECK(feature_index->IsConstant()); - int64 feature_index_value = feature_index->literal().Get({}); - - AddThunkToThunkSequence( - absl::make_unique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*offset=*/GetAllocationSlice(*custom_call->operand(2)), - /*mean=*/GetAllocationSlice(*custom_call->operand(3)), - /*variance=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); - return Status::OK(); - } - - if (custom_call->custom_call_target() == - kCudnnBatchNormForwardTrainingCallTarget) { - const HloInstruction* epsilon = custom_call->operand(3); - CHECK(epsilon->IsConstant()); - float epsilon_value = epsilon->literal().Get({}); - - const HloInstruction* feature_index = custom_call->operand(4); - CHECK(feature_index->IsConstant()); - int64 feature_index_value = feature_index->literal().Get({}); - - // BatchNormTraining returns a tuple of three elements: data, calculated - // mean, and calculated 1/sqrt(variance + epsilon). - const auto& assn = ir_emitter_context_->buffer_assignment(); - auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); - auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - AddThunkToThunkSequence( - absl::make_unique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*offset=*/GetAllocationSlice(*custom_call->operand(2)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_data=*/output_data, - /*output_mean=*/output_mean, - /*output_inv_stddev=*/output_inv_stddev, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); - return Status::OK(); - } - - if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) { - const HloInstruction* epsilon = custom_call->operand(5); - CHECK(epsilon->IsConstant()); - float epsilon_value = epsilon->literal().Get({}); - - const HloInstruction* feature_index = custom_call->operand(6); - CHECK(feature_index->IsConstant()); - int64 feature_index_value = feature_index->literal().Get({}); - - // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale, - // grad_offset. - const auto& assn = ir_emitter_context_->buffer_assignment(); - auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); - auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - auto output_grad_offset = - assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - AddThunkToThunkSequence(absl::make_unique( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); - return Status::OK(); - } - - if (IsCustomCallToDnnConvolution(*custom_call)) { - const auto& assn = ir_emitter_context_->buffer_assignment(); - std::vector operand_slices; - operand_slices.reserve(custom_call->operand_count()); - for (const auto* operand : custom_call->operands()) { - operand_slices.push_back(GetAllocationSlice(*operand)); - } - auto tuple_result_slice = GetAllocationSlice(*custom_call); - auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); - auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - - AddThunkToThunkSequence(absl::make_unique( - Cast(custom_call), std::move(operand_slices), - conv_result_slice, scratch_slice, tuple_result_slice)); - return Status::OK(); - } - - if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) { - TF_ASSIGN_OR_RETURN(CholeskyOptions options, - custom_call->backend_config()); - - const Shape& shape = custom_call->operand(0)->shape(); - int ndim = shape.dimensions_size(); - CHECK_GE(ndim, 2); - int64 n = shape.dimensions(ndim - 1); - - const auto& dims = shape.dimensions(); - int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1}, - [](int64 a, int64 b) { return a * b; }); - - auto operand_buffer = GetAllocationSlice(*custom_call->operand(0)); - - const auto& assn = ir_emitter_context_->buffer_assignment(); - auto a_buffer = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); - auto workspace_buffer = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - auto info_buffer = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - - std::vector> thunks; - - if (operand_buffer != a_buffer) { - thunks.push_back(absl::make_unique( - /*source_address=*/operand_buffer, - /*destination_buffer=*/a_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call)); - } - - thunks.push_back(absl::make_unique( - options, a_buffer, workspace_buffer, info_buffer, - custom_call->operand(0)->shape().element_type(), batch_size, n, - custom_call)); - - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); - } else { - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), custom_call)); - } - - return Status::OK(); - } - - if (IsCublasGemm(*custom_call)) { - AddThunkToThunkSequence(BuildGemmThunk(custom_call)); - return Status::OK(); - } - - if (void* call_target = CustomCallTargetRegistry::Global()->Lookup( - custom_call->custom_call_target(), - ir_emitter_context_->platform()->Name())) { - const auto& assn = ir_emitter_context_->buffer_assignment(); - auto get_slices_for_instr = [&](const HloInstruction* instr) { - ShapeTree slices(instr->shape()); - slices.ForEachMutableElement([&](const ShapeIndex& index, - BufferAllocation::Slice* slice) { - StatusOr s = assn.GetUniqueSlice(instr, index); - if (s.ok()) { - *slice = s.ValueOrDie(); - } - }); - return slices; - }; - std::vector> operand_slices; - for (const auto* operand : custom_call->operands()) { - operand_slices.push_back(get_slices_for_instr(operand)); - } - ShapeTree result_slices = - get_slices_for_instr(custom_call); - AddThunkToThunkSequence(absl::make_unique( - call_target, std::move(operand_slices), std::move(result_slices), - Cast(custom_call)->opaque(), custom_call)); - return Status::OK(); - } - - return Unimplemented("No registered implementation for custom call to \"%s\"", - custom_call->custom_call_target()); + return ThunkEmitter(this).HandleCustomCall(custom_call); } Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { - TF_RET_CHECK( - LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); - TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); - AddThunkToThunkSequence(BuildFftThunk(fft)); - return Status::OK(); + return ThunkEmitter(this).HandleFft(fft); } Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { - auto has_fortran_layout = [](const Layout& layout) { - int n = layout.minor_to_major_size(); - return layout.minor_to_major(0) == n - 2 && - layout.minor_to_major(1) == n - 1; - }; - TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout())); - TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout())); - TF_RET_CHECK(has_fortran_layout(hlo->shape().layout())); - - std::vector> thunks; - - // Triangular solve is in-place on 'b', so copy 'b' to the output if they - // aren't the same buffer. - auto operand_buffer = GetAllocationSlice(*hlo->operand(1)); - auto destination_buffer = GetAllocationSlice(*hlo); - if (operand_buffer != destination_buffer) { - thunks.push_back(absl::make_unique( - /*source_address=*/operand_buffer, - /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); - } - - thunks.push_back(BuildTriangularSolveThunk(hlo)); - - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); - } else { - AddThunkToThunkSequence( - absl::make_unique(std::move(thunks), hlo)); - } - return Status::OK(); + return ThunkEmitter(this).HandleTriangularSolve(hlo); } Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { @@ -605,7 +370,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { int unroll_factor = ComputeMaxUnrollFactor(fusion); thunks.push_back(BuildKernelThunk( fusion, /*implements_whole_instruction=*/false, unroll_factor)); - GpuElementalIrEmitter operand_elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); @@ -710,7 +474,16 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (LayoutUtil::Equal(copy->operand(0)->shape().layout(), copy->shape().layout()) && buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) { - AddThunkToThunkSequence(BuildDeviceToDeviceCopyThunk(copy)); + // Copy the operand into the output if it's not the same buffer already. + auto operand_buffer = GetAllocationSlice(*copy->operand(0)); + auto destination_buffer = GetAllocationSlice(*copy); + if (operand_buffer != destination_buffer) { + AddThunkToThunkSequence(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ + ByteSizeOf(copy->operand(0)->shape()), copy)); + } return Status::OK(); } if (CheckAndEmitHloWithTile021(copy)) { @@ -1048,7 +821,8 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { thunks.push_back(absl::make_unique( /*source_address=*/operand_buffer, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter)); + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), + /*hlo_instruction=*/nullptr)); } thunks.push_back( @@ -1486,17 +1260,15 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { return Status::OK(); } -Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { - return Status::OK(); -} - -Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { - AddThunkToThunkSequence(BuildInfeedThunk(infeed)); - return Status::OK(); +Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) { + return ThunkEmitter(this).HandleInfeed(xla_infeed); } Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { - AddThunkToThunkSequence(BuildOutfeedThunk(outfeed)); + return ThunkEmitter(this).HandleOutfeed(outfeed); +} + +Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } @@ -1720,131 +1492,6 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( implements_whole_instruction ? inst : nullptr, unroll_factor); } -std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( - const HloInstruction* inst) { - const HloInstruction* operand = inst->operand(0); - CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return absl::make_unique( - /*source_address=*/operand->literal().untyped_data(), - /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ - llvm_ir::ByteSizeOf(operand->shape(), - ir_emitter_context_->llvm_module()->getDataLayout()), - inst); -} - -std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( - const HloInstruction* inst) { - const HloInstruction* operand = inst->operand(0); - return absl::make_unique( - /*source_address=*/GetAllocationSlice(*operand), - /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ - llvm_ir::ByteSizeOf(operand->shape(), - ir_emitter_context_->llvm_module()->getDataLayout()), - inst); -} - -std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( - const HloInstruction* inst) { - CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); - - ShapeTree slices(inst->shape()); - slices.ForEachMutableElement( - [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { - *slice = ir_emitter_context_->buffer_assignment() - .GetUniqueSlice(inst, index) - .ConsumeValueOrDie(); - }); - return absl::make_unique(slices, inst); -} - -std::unique_ptr IrEmitterUnnested::BuildOutfeedThunk( - const HloInstruction* inst) { - CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); - - ShapeTree slices(inst->operand(0)->shape()); - slices.ForEachMutableElement( - [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { - auto status_or_slice = - ir_emitter_context_->buffer_assignment().GetUniqueSlice( - inst->operand(0), index); - if (status_or_slice.ok()) { - *slice = status_or_slice.ConsumeValueOrDie(); - } - }); - return absl::make_unique(std::move(slices), inst); -} - -std::unique_ptr IrEmitterUnnested::BuildGemmThunk( - const HloInstruction* inst) { - auto config_or = inst->backend_config(); - GemmBackendConfig gemm_config = std::move(config_or.ValueOrDie()); - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - - // The bias is passed inside the output buffer. If those buffers are shared - // we can just use it, otherwise copy the bias values into the output buffer - // first. - if (gemm_config.beta() != 0.0) { - const HloInstruction* bias = inst->operand(2); - CHECK_EQ(bias->shape(), inst->shape()); - if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { - std::vector> thunks; - thunks.push_back(absl::make_unique( - /*source_buffer=*/GetAllocationSlice(*bias), - /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr)); - thunks.push_back(absl::make_unique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/false, inst)); - return absl::make_unique(std::move(thunks), inst); - } - } - - return absl::make_unique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - /*implements_whole_instruction=*/true, inst); -} - -std::unique_ptr IrEmitterUnnested::BuildFftThunk( - const HloInstruction* inst) { - const HloInstruction* operand = inst->operand(0); - return absl::make_unique( - inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); -} - -std::unique_ptr IrEmitterUnnested::BuildTriangularSolveThunk( - const HloInstruction* inst) { - const HloInstruction* a = inst->operand(0); - const HloInstruction* b = inst->operand(1); - int64 m = b->shape().dimensions(b->shape().rank() - 2); - int64 n = b->shape().dimensions(b->shape().rank() - 1); - int64 batch_size = std::accumulate( - b->shape().dimensions().begin(), b->shape().dimensions().end() - 2, - int64{1}, [](int64 a, int64 b) { return a * b; }); - int64 elem_size = - ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type()); - int64 a_batch_stride = inst->triangular_solve_options().left_side() - ? m * m * elem_size - : n * n * elem_size; - int64 b_batch_stride = m * n * elem_size; - return absl::make_unique( - inst->triangular_solve_options(), - /*a_input_buffer=*/GetAllocationSlice(*a), - /*b_input_buffer=*/GetAllocationSlice(*inst), - inst->shape().element_type(), batch_size, m, n, a_batch_stride, - b_batch_stride, inst); -} - StatusOr> IrEmitterUnnested::BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); @@ -2200,41 +1847,6 @@ Status IrEmitterUnnested::EmitTargetElementLoop( return emit_status; } -std::vector IrEmitterUnnested::ConstructIrArrayForInputs( - const HloInstruction& hlo) { - std::vector param_arrays; - param_arrays.reserve(hlo.operands().size()); - for (const HloInstruction* param : hlo.operands()) { - param_arrays.push_back(GetIrArray(*param, hlo)); - } - return param_arrays; -} - -int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape( - const HloInstruction& hlo, const std::vector& param_arrays, - const std::vector& param_buffers, - absl::Span reduced_output_dims, - std::vector* param_reduced_shapes, - std::vector* param_in_reduced_shape_arrays) { - int64 num_params = hlo.operands().size(); - param_in_reduced_shape_arrays->reserve(num_params); - param_reduced_shapes->reserve(num_params); - for (int64 id = 0; id < num_params; ++id) { - if (param_buffers[id] == nullptr) { - param_reduced_shapes->push_back(Shape()); - param_in_reduced_shape_arrays->push_back(IrArray()); - continue; - } - const HloInstruction* param = hlo.operand(id); - param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout( - param->shape().element_type(), - Permute({0, 2, 1}, reduced_output_dims))); - param_in_reduced_shape_arrays->push_back( - param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_)); - } - return num_params; -} - namespace { std::tuple GetStartOffsetAndStepForX( @@ -2254,12 +1866,12 @@ std::tuple GetStartOffsetAndStepForX( return std::make_tuple(start_offset_x, step_x); } -void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - const string& loop_name, KernelSupportLibrary* ksl, - llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Type* index_ty, - const EmitElementFunction& emit_elem_function) { +void EmitFullElementalTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Type* index_ty, + const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); @@ -2292,14 +1904,13 @@ void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme, }); } -void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme, - const IrArray::Index& tile_origin_index, - const string& loop_name, - KernelSupportLibrary* ksl, - llvm::IRBuilder<>* builder, llvm::Value* y, - llvm::Value* x, llvm::Value* tile_height, - llvm::Value* tile_width, llvm::Type* index_ty, - const EmitElementFunction& emit_elem_function) { +void EmitPartialElementalTile( + const KernelMappingScheme* mapping_scheme, + const IrArray::Index& tile_origin_index, const string& loop_name, + KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, + llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, + llvm::Type* index_ty, + const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX(); int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY(); int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); @@ -2361,7 +1972,7 @@ void EmitTiledElementalCodeWithBoundsCheck( const IrArray::Index& tile_origin_index, const string& loop_name, KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y, llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width, - const EmitElementFunction& emit_elem_function) { + const IrEmitterUnnested::EmitElementFunction& emit_elem_function) { int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX(); int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY(); llvm::Type* index_ty = tile_width->getType(); @@ -2397,13 +2008,11 @@ void EmitTiledElementalCodeWithBoundsCheck( void IrEmitterUnnested::EmitTileElementForCopy( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc, int64 /*x_iter_num*/) { - llvm_ir::TiledParameterInfo* tiled_param_info = - kernel_info->GetTiledParameterInfo(); + llvm::Value* x_loc, int64 /*x_iter_num*/, + absl::Span param_shmem_buffers) { // TODO(jlebar): Add AA metadata to this load. llvm::Instruction* load_from_shmem_buffer = - Load(GEP(tiled_param_info->GetBufferForParameter(0), - {b_.getInt64(0), x_loc, y_loc}), + Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}), "output_element"); llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo); Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( @@ -2427,17 +2036,15 @@ void IrEmitterUnnested::EmitTileElementForCopy( void IrEmitterUnnested::EmitTileElementForFusion( HloInstruction* hlo, const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc, int64 /*x_iter_num*/) { - llvm_ir::TiledParameterInfo* tiled_param_info = - kernel_info->GetTiledParameterInfo(); + llvm::Value* x_loc, int64 /*x_iter_num*/, + absl::Span param_shmem_buffers) { std::vector output_arrays = ConstructIrArrayForOutputs(*hlo); GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GetNestedComputer()); FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo), - &elem_emitter); - tiled_param_info->set_y(y_loc); - tiled_param_info->set_x(x_loc); - fused_emitter.SetTiledParameterInfo(tiled_param_info); + &elem_emitter, x_loc, y_loc, + param_shmem_buffers); + TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter)); IrArray::Index untiled_index = kernel_info->GetKernelMappingScheme()->GetUnnormalizedIndex( @@ -2501,19 +2108,6 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { return reduction_input_addresses_; } - InlinedVector* GetMutableReducers() { return &reducers_; } - const InlinedVector& GetReducers() const { - return reducers_; - } - int GetNumberOfReduces() const { return reducers_.size(); } - - InlinedVector* GetMutableReductionOutputShapeIndices() { - return &reduction_output_shape_indices_; - } - absl::Span GetReductionOutputShapeIndices() const { - return reduction_output_shape_indices_; - } - bool IsRowReduction() const { return is_row_reduction_; } // Return the dimension that is being reduced between DimX and DimY. @@ -2560,8 +2154,6 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { private: AddressVector partial_result_addresses_; AddressVector reduction_input_addresses_; - InlinedVector reducers_; - InlinedVector reduction_output_shape_indices_; // The address of the memory that stores the linear index of the current // output, assuming that the output doesn't change the layout of the kept // elements in the reduction input. @@ -2570,48 +2162,10 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo { bool is_row_reduction_; }; -namespace { -// Returns a group of instructions that generate the output for the kernel -// containing the given HLO instruction. The result may be an unnested kReduce -// HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple -// for a multiple output fusion. -absl::Span GetOutputInstructions( - HloInstruction* const* reduce_or_tuple_pointer) { - HloOpcode opcode = (*reduce_or_tuple_pointer)->opcode(); - CHECK(opcode == HloOpcode::kReduce || opcode == HloOpcode::kTuple); - return opcode == HloOpcode::kTuple - ? (*reduce_or_tuple_pointer)->operands() - : absl::Span(reduce_or_tuple_pointer, 1); -} - -const HloInstruction* GetFirstReduceInstruction( - absl::Span instructions) { - auto first_reduce_iter = - absl::c_find_if(instructions, [](const HloInstruction* inst) { - return IsReductionFromOrToContiguousDimensions(*inst); - }); - CHECK_NE(first_reduce_iter, instructions.end()); - return *first_reduce_iter; -} - -}; // namespace - void IrEmitterUnnested::EmitPrologueForOneReduction( HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, - KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter, - ShapeIndex output_shape_index) { - ReductionCodegenInfo* reduction_info = - static_cast(kernel_info); - - InlinedVector* reducers = - reduction_info->GetMutableReducers(); - CHECK(IsReductionFromOrToContiguousDimensions(*reduce_inst)); - reducers->push_back(reduce_inst->to_apply()); - - InlinedVector* reduction_output_shape_indices = - reduction_info->GetMutableReductionOutputShapeIndices(); - reduction_output_shape_indices->push_back(std::move(output_shape_index)); - + KernelCodegenInfo* kernel_info, GpuElementalIrEmitter* elemental_emitter) { + auto reduction_info = static_cast(kernel_info); AddressVector* reduction_input_addresses = reduction_info->GetMutableReductionInputAddresses(); llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( @@ -2652,38 +2206,23 @@ void IrEmitterUnnested::EmitPrologueForOneReduction( } void IrEmitterUnnested::EmitPrologueForReduction( - HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info, + absl::Span reduce_instructions) { VLOG(10) << "Emit prologue for reduction " << unnested_hlo->ToString(); - // Find the unnested kReduce or the tuple that contains a list of kReduce. - HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion - ? unnested_hlo->fused_expression_root() - : unnested_hlo; - absl::Span output_instructions = - GetOutputInstructions(&reduce_or_tuple); - ReductionCodegenInfo* reduction_info = - static_cast(kernel_info); + auto reduction_info = static_cast(kernel_info); GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &b_, GetNestedComputer()); const HloInstruction* first_reduce = nullptr; - for (int i = 0, e = output_instructions.size(); i != e; ++i) { - if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { - continue; - } - HloInstruction* reduce_inst = output_instructions[i]; + for (int i = 0; i < reduce_instructions.size(); i++) { + HloInstruction* reduce_inst = reduce_instructions[i]; if (first_reduce == nullptr) { first_reduce = reduce_inst; } else { CHECK(first_reduce->dimensions() == reduce_inst->dimensions()); } - ShapeIndex output_shape_index; - if (reduce_or_tuple->opcode() == HloOpcode::kTuple) { - output_shape_index = {i}; - } - EmitPrologueForOneReduction(unnested_hlo, reduce_inst, i, kernel_info, - &elemental_emitter, - std::move(output_shape_index)); + &elemental_emitter); } int num_partial_results = reduction_info->GetNumberOfPartialResults(); @@ -2733,17 +2272,14 @@ void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces( } void IrEmitterUnnested::EmitEpilogueForReduction( - HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info) { - ReductionCodegenInfo* reduction_info = - static_cast(kernel_info); - int num_reduces = reduction_info->GetNumberOfReduces(); + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info, + absl::Span reduce_instructions, + absl::Span reduction_output_shape_indices, + absl::Span reducers) { + auto reduction_info = static_cast(kernel_info); + int num_reduces = reducers.size(); absl::Span partial_result_addresses = reduction_info->GetPartialResultAddresses(); - const InlinedVector& reducers = - reduction_info->GetReducers(); - absl::Span reduction_output_shape_indices = - reduction_info->GetReductionOutputShapeIndices(); - if (reduction_info->IsRowReduction()) { EmitFullWarpShuffleDownLoopForAllReduces(reducers, partial_result_addresses); @@ -2763,16 +2299,6 @@ void IrEmitterUnnested::EmitEpilogueForReduction( llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_); } - HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion - ? unnested_hlo->fused_expression_root() - : unnested_hlo; - std::vector reduce_instructions; - absl::c_for_each(GetOutputInstructions(&reduce_or_tuple), - [&](const HloInstruction* instr) { - if (IsReductionFromOrToContiguousDimensions(*instr)) { - reduce_instructions.push_back(instr); - } - }); int num_partial_results = reduction_info->GetNumberOfPartialResults(); // Emit an atomic operation that accumulates the partial reduction to the @@ -2837,21 +2363,16 @@ void IrEmitterUnnested::EmitEpilogueForReduction( } void IrEmitterUnnested::EmitTileElementForReduction( - HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc, int64 x_iter_num) { + HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, + absl::Span output_instructions, + const llvm_ir::IrArray::Index& index, const KernelCodegenInfo* kernel_info, + absl::Span reducers, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion ? unnested_hlo->fused_expression_root() : unnested_hlo; - llvm_ir::TiledParameterInfo* tiled_param_info = - kernel_info->GetTiledParameterInfo(); - tiled_param_info->set_y(y_loc); - tiled_param_info->set_x(x_loc); - // Record the untransposed output linear address for the reduction. - const ReductionCodegenInfo* reduction_info = - dynamic_cast(kernel_info); + auto reduction_info = dynamic_cast(kernel_info); int partial_result_index = reduction_info->IsRowReduction() ? 0 : x_iter_num; Store(reduction_info->GetUntransposedOutputLinearAddress(&b_, index), InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(), @@ -2871,12 +2392,9 @@ void IrEmitterUnnested::EmitTileElementForReduction( GetNestedComputer()); FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo), &elem_emitter); - absl::Span output_instructions = - GetOutputInstructions(&reduce_or_tuple); // Construct the ElementGenerator for each reduction and extra output in the // the group of output instructions. if (unnested_hlo->opcode() == HloOpcode::kFusion) { - fused_emitter.SetTiledParameterInfo(tiled_param_info); TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter)); for (int i = 0, e = output_instructions.size(); i != e; ++i) { @@ -2899,8 +2417,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( }); } - Shape reduction_operand_shape = - GetFirstReduceInstruction(output_instructions)->operand(0)->shape(); IrArray::Index input_index = reduction_info->GetKernelMappingScheme()->GetUnnormalizedIndex( index, reduction_operand_shape); @@ -2915,9 +2431,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( reduction_info->GetPartialResultAddresses(); absl::Span reduction_input_addresses = reduction_info->GetReductionInputAddresses(); - const InlinedVector& reducers = - reduction_info->GetReducers(); - // Emit code to generate the input and perform the reduction computation for // each reduction instruction. for (int i = 0; i != reducers.size(); ++i) { @@ -2942,10 +2455,11 @@ void IrEmitterUnnested::EmitTileElementForReduction( } // Emits a kernel for the hlo instruction using the given tiling scheme. -void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, - KernelCodegenInfo* kernel_info, - KernelSupportLibrary* ksl, - llvm::Type* index_ty) { +void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info, + KernelSupportLibrary* ksl, llvm::Value* y, + llvm::Value* x, + TileElementGenerator tile_generator) { + llvm::Type* index_ty = kernel_info->GetIndexType(); KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); absl::Span dims_in_tile = mapping_scheme->GetDimensionsInTiles(); absl::Span dims_in_block = @@ -2990,11 +2504,9 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, absl::Span reduced_dims = mapping_scheme->GetDimensionsInElements(); - const bool block_contains_multi_tiles = - mapping_scheme->GetNumberOfTilesInOneBlock() > 1; // Emit the tile with a given tile_index, by calculating the tight bounds for - // each dimension of the tile and then calling emit_one_tile. + // each dimension of the tile and then calling tile_generator. auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) { std::vector output_tile_bounds(3); for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot; @@ -3012,7 +2524,8 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, IrArray::Index tile_origin = mapping_scheme->GetElementIndexForTileOrigin(tile_index); - emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles); + tile_generator(y, x, tile_origin, "output", output_tile_bounds[1], + output_tile_bounds[2], ksl); }; const IrArray::Index starting_block = @@ -3036,79 +2549,34 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile, // Emits a kernel for the hlo instruction using the given kernel mapping scheme. // +// The emitted code is written into the member variable b_, which corresponds to +// the kernel thunk currently being constructed (previous call to +// BuildKernelThunk). +// // unnested_hlo: The unnested hlo instruction for which the kernel is generated. // Currently, these hlo instructions are supported: kLoop fusion, kCopy. -// tiled_param_ids: The IDs for the parameters that are 0-2-1 transpose of -// other tensors with the same dimensions and are safe to be tranposed via -// the shared memory tranpose implementation. // mapping_scheme: The tiling scheme to use. // kernel_generator: Contains function objects for code generation, such as // element generator, block prologue and epilogue generators. // kernel_info: Represent other information to support the code generation // of the tiled kernel for the hlo. -LaunchDimensions IrEmitterUnnested::EmitKernel( - HloInstruction* unnested_hlo, absl::Span tiled_param_ids, - const KernelCodeGenerator& kernel_generator, - KernelCodegenInfo* kernel_info) { +void IrEmitterUnnested::EmitKernel( + HloInstruction* unnested_hlo, Thunk* kernel_thunk, + KernelCodegenInfo* kernel_info, TileElementGenerator tile_element_generator, + BlockPrologueGenerator block_prologue_generator, + BlockEpilogueGenerator block_epilogue_generator) { KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme(); - - std::vector param_arrays = ConstructIrArrayForInputs(*unnested_hlo); - int64 num_params = param_arrays.size(); - // Allocate shared memory buffers to store the tiled inputs. - std::vector param_shmem_buffers(num_params, nullptr); - for (int64 id : tiled_param_ids) { - const HloInstruction* param = unnested_hlo->operand(id); - param_shmem_buffers[id] = - mapping_scheme->GetSharedMemoryBufferForElementType( - llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(), - module_), - IrName(unnested_hlo, StrCat("tile", id))); - VLOG(3) << "Added shmem buffer for parameter " << id << ": " - << llvm_ir::DumpToString(*param_shmem_buffers[id]); - } - - const ReductionCodegenInfo* reduction_info = - dynamic_cast(kernel_info); - bool is_column_reduction = - (reduction_info && !reduction_info->IsRowReduction()); - - LaunchDimensions launch_dimensions = - LaunchDimensions(mapping_scheme->GetNumberOfBlocks(), - mapping_scheme->GetThreadsPerBlock()); + LaunchDimensions launch_dimensions(mapping_scheme->GetNumberOfBlocks(), + mapping_scheme->GetThreadsPerBlock()); // TODO(b/110211620): Enable int32 index type for column reduction. + auto reduction_info = dynamic_cast(kernel_info); llvm::Type* index_ty = - is_column_reduction + (reduction_info && !reduction_info->IsRowReduction()) ? b_.getInt64Ty() : GetIndexTypeForKernel(unnested_hlo, launch_dimensions.launch_bound(), &b_); - - auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - // For multioutput fusion, one thread needs to output a tuple with pointers to - // all the individual outputs. We could do this at any point in the kernel, - // but we do it at the beginning in the hopes of reducing register pressure, - // since we touch threadIdx.x and blockIdx.x at the beginning of the kernel - // *anyway*. - if (!reduction_info && unnested_hlo->IsMultiOutputFusion()) { - KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { - llvm_ir::EmitTuple(GetIrArray(*unnested_hlo, *unnested_hlo), - ConstructIrArrayForOutputs(*unnested_hlo), &b_); - }); - } - - // For each tiled parameter, cast its input IrArray to the corresponding - // reduced shape and keep the reduced shape live during IR emission. - std::vector param_in_reduced_shape_arrays; - std::vector param_reduced_shapes; - absl::Span reduced_dims = - mapping_scheme->GetDimensionsInElements(); - int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape( - *unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims, - ¶m_reduced_shapes, ¶m_in_reduced_shape_arrays); - DCHECK_EQ(num_shapes, num_params); + kernel_info->SetIndexType(index_ty); // Calculate the starting element coordinate within a tile for the current // thread, (y, x) from thread_id. @@ -3119,102 +2587,20 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( kernel_info->SetLaneId( mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x : nullptr); - kernel_info->SetIndexType(index_ty); - KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); - // Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck. - auto emit_tiled_elemental_code_with_bounds_check = - [&](const IrArray::Index& index, const string& loop_name, - llvm::Value* tile_height, llvm::Value* tile_width, - const EmitElementFunction& emit_elem_function) { - EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name, - &ksl, &b_, y, x, tile_height, - tile_width, emit_elem_function); - }; - auto emit_one_tile = [&](const IrArray::Index& output_tile_origin, - absl::Span output_tile_bounds, - bool block_contains_multi_tiles) { - // Calculate the input tile origin from the output tile origin. - const IrArray::Index input_tile_origin( - Permute({0, 2, 1}, output_tile_origin.multidim()), - Permute({0, 2, 1}, output_tile_origin.dims()), - output_tile_origin.GetType()); - - // If shared memory transpose is needed, wait for all threads to reach this - // point, lest we copy a value from tile to output before the other thread - // copies it from input to tile. This is `__syncthreads` in CUDA. - if (!tiled_param_ids.empty()) { - // Copy input parameter values to shared memory buffers: - // tile[y, x] = input[index] - // Note that tile_width and tile_height are flipped here because we are - // reading a transposed tile. - emit_tiled_elemental_code_with_bounds_check( - input_tile_origin, "input", output_tile_bounds[2], - output_tile_bounds[1], - [&](const IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc, int64 /*x_iter_num*/) { - for (int64 id : tiled_param_ids) { - IrArray& input_in_logical_shape = - param_in_reduced_shape_arrays[id]; - llvm::Value* shmem_buffer = param_shmem_buffers[id]; - // TODO(jlebar): Add AA metadata to this store. Tile buffers are - // global variables, so LLVM can't infer much about it. - Store(input_in_logical_shape.EmitReadArrayElement( - index, &b_, "input_element"), - GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc})); - } - }); - - // Wait for all threads to reach this point using `__syncthreads` in CUDA. - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); - } - - llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x); - kernel_info->SetTiledParamInfo(&tiled_param_info); - - // Write to output[index] by emitting code like normal, except that values - // for the tiled parameters are read from the shmem buffers. - emit_tiled_elemental_code_with_bounds_check( - output_tile_origin, "output", output_tile_bounds[1], - output_tile_bounds[2], - [&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, - int64 x_iter_num) { - kernel_generator.GetTileElementGenerator()( - unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num); - }); - - // If a tile block contains multiple tiles and shared memory buffers are - // used, we need to wait for all threads to finish using the shared memory - // buffer for the current tile before we move on to process the next tile - // and overwrite the shared memory buffers. - if (block_contains_multi_tiles && !tiled_param_ids.empty()) { - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); - } - }; - - const BlockPrologueGenerator& block_prologue_generator = - kernel_generator.GetBlockPrologueGenerator(); - if (block_prologue_generator) { - block_prologue_generator(unnested_hlo, kernel_info); - } - - EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty); - - const BlockEpilogueGenerator& block_epilogue_generator = - kernel_generator.GetBlockEpilogueGenerator(); - if (block_epilogue_generator) { - block_epilogue_generator(unnested_hlo, kernel_info); - } - - return launch_dimensions; + block_prologue_generator(unnested_hlo, kernel_info); + EmitBlock(kernel_info, &ksl, y, x, tile_element_generator); + block_epilogue_generator(unnested_hlo, kernel_info); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk, + ir_emitter_context_->llvm_module()); } // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose // algorithm to improve the memory access patterns for the input parameters // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller // is responsible for making sure that it is safe to apply the shared memory -// tranpose on the input parameters. +// transpose on the input parameters. // // // For the purpose of tiling, the output tensors have a logical shape of three @@ -3234,37 +2620,136 @@ LaunchDimensions IrEmitterUnnested::EmitKernel( // // TODO(b/33320379): Here each block transposes 1 tile. It may be more // efficient to launch fewer blocks so each transposes many tiles. -LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( - HloInstruction* hlo, absl::Span reduced_output_dims, +void IrEmitterUnnested::EmitHlo021Tile( + HloInstruction* hlo, Thunk* kernel_thunk, + absl::Span reduced_output_dims, absl::Span tiled_param_ids) { constexpr int kNumRows = 4; KernelMappingScheme mapping_scheme( reduced_output_dims, /*tile_size_y=*/kWarpSize, - /*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1}, + /*tile_size_x=*/kWarpSize, /*block_size_z=*/1, /*num_threads_y=*/kNumRows, - /*num_threads_x=*/kWarpSize, &b_); - TileElementGenerator element_generator; - if (hlo->opcode() == HloOpcode::kCopy) { - element_generator = [&](HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc, - int64 x_iter_num) { - EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num); - }; - } else { - DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion); - element_generator = - [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, - llvm::Value* x_loc, int64 x_iter_num) { - EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc, - x_iter_num); - }; - } + /*num_threads_x=*/kWarpSize, /*is_dilated_x=*/false, &b_); KernelCodegenInfo kernel_info(&mapping_scheme); - KernelCodeGenerator kernel_generator(std::move(element_generator)); - return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info); + + std::vector param_arrays; + + // For each tiled parameter, cast its input IrArray to the corresponding + // reduced shape and keep the reduced shape live during IR emission. + std::vector param_in_reduced_shape_arrays; + std::vector param_shmem_buffers(hlo->operand_count(), nullptr); + + for (int64 id = 0; id < hlo->operand_count(); id++) { + const HloInstruction* param = hlo->operand(id); + param_arrays.push_back(GetIrArray(*param, *hlo)); + + if (absl::c_linear_search(tiled_param_ids, id)) { + param_shmem_buffers[id] = + mapping_scheme.GetSharedMemoryBufferForElementType( + llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(), + module_), + IrName(hlo, StrCat("tile", id))); + VLOG(3) << "Added shmem buffer for parameter " << id << ": " + << llvm_ir::DumpToString(*param_shmem_buffers[id]); + Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( + param->shape().element_type(), + Permute({0, 2, 1}, reduced_output_dims)); + param_in_reduced_shape_arrays.push_back( + param_arrays[id].CastToShape(reduced_shape, &b_)); + } else { + param_in_reduced_shape_arrays.push_back(IrArray()); + } + } + + EmitElementFunction element_generator = + [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + if (hlo->opcode() == HloOpcode::kCopy) { + EmitTileElementForCopy(hlo, index, &kernel_info, y_loc, x_loc, + x_iter_num, param_shmem_buffers); + } else { + CHECK_EQ(hlo->opcode(), HloOpcode::kFusion); + EmitTileElementForFusion(hlo, index, &kernel_info, y_loc, x_loc, + x_iter_num, param_shmem_buffers); + } + }; + + TileElementGenerator tile_generator = + [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index, + const string& loop_name, llvm::Value* tile_height, + llvm::Value* tile_width, KernelSupportLibrary* ksl) { + // If shared memory transpose is needed, wait for all threads to reach + // this point, lest we copy a value from tile to output before the other + // thread copies it from input to tile. This is `__syncthreads` in CUDA. + if (!tiled_param_ids.empty()) { + // Calculate the input tile origin from the output tile origin. + const IrArray::Index input_tile_origin( + Permute({0, 2, 1}, index.multidim()), + Permute({0, 2, 1}, index.dims()), index.GetType()); + + // Copy input parameter values to shared memory buffers: + // tile[y, x] = input[index] + // Note that tile_width and tile_height are flipped here because we + // are reading a transposed tile. + EmitTiledElementalCodeWithBoundsCheck( + &mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x, + tile_width, tile_height, + [&](const IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc, int64 /*x_iter_num*/) { + for (int64 id : tiled_param_ids) { + IrArray& input_in_logical_shape = + param_in_reduced_shape_arrays[id]; + + llvm::Value* shmem_buffer = param_shmem_buffers[id]; + llvm::Value* zero = + llvm::ConstantInt::get(kernel_info.GetIndexType(), 0); + // TODO(jlebar): Add AA metadata to this store. Tile buffers + // are global variables, so LLVM can't infer much about it. + Store(input_in_logical_shape.EmitReadArrayElement( + index, &b_, "input_element"), + GEP(shmem_buffer, {zero, y_loc, x_loc})); + } + }); + + // Wait for all threads to reach this point using `__syncthreads` in + // CUDA. + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); + } + + EmitTiledElementalCodeWithBoundsCheck(&mapping_scheme, index, loop_name, + ksl, &b_, y, x, tile_height, + tile_width, element_generator); + bool block_contains_multi_tiles = + mapping_scheme.GetNumberOfTilesInOneBlock() > 1; + + // If a tile block contains multiple tiles and shared memory buffers are + // used, we need to wait for all threads to finish using the shared + // memory buffer for the current tile before we move on to process the + // next tile and overwrite the shared memory buffers. + if (block_contains_multi_tiles && !tiled_param_ids.empty()) { + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_); + } + }; + + BlockPrologueGenerator hlo021_prologue = [&](HloInstruction* hlo, + KernelCodegenInfo* kernel_info) { + // For multioutput fusion, one thread needs to output a tuple + // with pointers to all the individual outputs. We could do this + // at any point in the kernel, but we do it at the beginning in + // the hopes of reducing register pressure, since we touch + // threadIdx.x and blockIdx.x at the beginning of the kernel + // *anyway*. + if (hlo->IsMultiOutputFusion()) { + KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { + llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), + ConstructIrArrayForOutputs(*hlo), &b_); + }); + } + }; + BlockEpilogueGenerator epilogue_generator = [](HloInstruction*, + KernelCodegenInfo*) {}; + EmitKernel(hlo, kernel_thunk, &kernel_info, tile_generator, hlo021_prologue, + epilogue_generator); } namespace { @@ -3282,7 +2767,7 @@ namespace { // the preload tile. If this is not true, we can't use a shmem transpose for P. // // If the computation of output element [z, y, x] only requires the element of -// P with the same indices, the shmem tranpose implementation can be applied +// P with the same indices, the shmem transpose implementation can be applied // to P safely. This is a sufficient but not necessary condition. We check all // the transitive users of P to see if we can find a user that may cause an // exception to the situation. If such a user is not found, we conclude that P @@ -3302,7 +2787,7 @@ namespace { // block. // // TODO(bixia): In order to extend this for kInput fusion, that is reduction -// with tranpose, we only need to end the use-chain checking with the input of +// with transpose, we only need to end the use-chain checking with the input of // a reduce operations. In this case, the above description on "output" apply // to the result of such a use-chain, which provides the input to the reduce // operation. @@ -3334,9 +2819,9 @@ bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) { } } -// Given a group of input parameters that are 0-2-1 tranpose of the outputs of +// Given a group of input parameters that are 0-2-1 transpose of the outputs of // a fusion kernel, returns the input parameters that are safe for the shared -// memory tranpose implementation. +// memory transpose implementation. // // When a tile based shared memory transpose is used to implement an input with // 0-2-1 transpose, we preload a tile of the input elements @@ -3354,8 +2839,7 @@ std::vector FilterInputsForShmemTranspose(const HloInstruction* fusion, if (IsInstructionSafeForShmemTranspose(input)) { filtered_input_ids.push_back(input_ids[i]); } else { - VLOG(10) << "Input not safe for shmem transpose " << input->ToString() - << "\n"; + VLOG(10) << "Input not safe for shmem transpose " << input->ToString(); } } return filtered_input_ids; @@ -3446,15 +2930,15 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { } } + if (params_012.empty()) { + return false; + } + VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString(); std::unique_ptr kernel_thunk = BuildKernelThunk(hlo, /*implements_whole_instruction=*/true); - const LaunchDimensions launch_dimensions = - EmitHlo021Tile(hlo, *reduced_dims_021, params_012); - UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), - ir_emitter_context_->llvm_module()); + EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012); AddThunkToThunkSequence(std::move(kernel_thunk)); - return true; } @@ -3578,7 +3062,7 @@ bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo, } // namespace -std::tuple +std::pair IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) { const Shape& input_shape = first_reduce->operand(0)->shape(); @@ -3637,12 +3121,10 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind( tile_size_y = kNumElementsPerPartialSum; } - DimensionVector req_block_sizes{block_size_z, 1, 1}; llvm_ir::KernelMappingScheme mapping_scheme( - dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y, - num_threads_x, &b_); - mapping_scheme.SetDilatedX(dilated_x); - return std::make_tuple(mapping_scheme, is_row_reduction); + dims_in_elem, tile_size_y, tile_size_x, block_size_z, num_threads_y, + num_threads_x, dilated_x, &b_); + return std::make_pair(mapping_scheme, is_row_reduction); } Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( @@ -3652,11 +3134,36 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion ? unnested_hlo->fused_expression_root() : unnested_hlo; - absl::Span output_instructions = - GetOutputInstructions(&reduce_or_tuple); - const HloInstruction* first_reduce = - GetFirstReduceInstruction(output_instructions); + // A group of instructions that generate the output for the kernel + // containing the given HLO instruction. The result may be an unnested kReduce + // HLO, a nested kReduce HLO of a kInput fusion, or the operands of the tuple + // for a multiple output fusion. + bool returns_tuple = false; + auto output_instructions = ([&]() -> absl::Span { + if (reduce_or_tuple->opcode() == HloOpcode::kReduce) { + return absl::Span(&reduce_or_tuple, 1); + } + CHECK(reduce_or_tuple->opcode() == HloOpcode::kTuple); + returns_tuple = true; + return reduce_or_tuple->operands(); + })(); + std::vector reduce_instructions; + InlinedVector reduction_output_shape_indices; + InlinedVector reducers; + for (int i = 0; i < output_instructions.size(); i++) { + HloInstruction* output_instruction = output_instructions[i]; + if (IsReductionFromOrToContiguousDimensions(*output_instruction)) { + reduce_instructions.push_back(output_instruction); + ShapeIndex idx; + if (returns_tuple) { + idx = {i}; + } + reduction_output_shape_indices.push_back(idx); + reducers.push_back(output_instruction->to_apply()); + } + } + const HloInstruction* first_reduce = reduce_instructions.at(0); if (output_instructions.size() > 1) { TF_RETURN_IF_ERROR( AreFusedReductionOutputsConsistent(output_instructions, first_reduce)); @@ -3688,35 +3195,41 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( "doesn't set the input layout of " << first_reduce->ToString(); - bool is_row_reduction; - llvm_ir::KernelMappingScheme mapping_scheme; - std::tie(mapping_scheme, is_row_reduction) = + auto mapping_scheme_pair = ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce); - ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction); - KernelCodeGenerator kernel_generator( - /*tile_element_generator=*/ - [&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + bool is_row_reduction = mapping_scheme_pair.second; + ReductionCodegenInfo reduction_info(&mapping_scheme_pair.first, + is_row_reduction); + EmitElementFunction emit_reduction_tile = + [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { - EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc, - x_iter_num); + EmitTileElementForReduction(unnested_hlo, input_shape, + output_instructions, index, &reduction_info, + reducers, x_iter_num); + }; + + EmitKernel( + unnested_hlo, kernel_thunk.get(), &reduction_info, + /*tile_element_generator=*/ + [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index, + const string& loop_name, llvm::Value* tile_height, + llvm::Value* tile_width, KernelSupportLibrary* ksl) { + EmitTiledElementalCodeWithBoundsCheck( + &mapping_scheme_pair.first, index, loop_name, ksl, &b_, y, x, + tile_height, tile_width, emit_reduction_tile); }, /*block_prologue_generator=*/ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { - EmitPrologueForReduction(hlo, kernel_info); + EmitPrologueForReduction(hlo, kernel_info, reduce_instructions); }, - /*block_epilogue_generator*/ + /*block_epilogue_generator=*/ [&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) { - EmitEpilogueForReduction(hlo, kernel_info); + EmitEpilogueForReduction(hlo, kernel_info, reduce_instructions, + reduction_output_shape_indices, reducers); }); - LaunchDimensions launch_dimensions = - EmitKernel(unnested_hlo, {}, kernel_generator, &reduction_info); - UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), - ir_emitter_context_->llvm_module()); - thunks.push_back(std::move(kernel_thunk)); - std::unique_ptr sequential_thunk = + auto sequential_thunk = absl::make_unique(std::move(thunks), unnested_hlo); AddThunkToThunkSequence(std::move(sequential_thunk)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index e5177c28484..efc3f8f3ff6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h" @@ -47,16 +49,9 @@ namespace gpu { // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not // really an IrEmitter, but is more an "IR generator generator".) // -class IrEmitterUnnested : public IrEmitter { +class IrEmitterUnnested : public IrEmitter, + private ThunkEmitter::EmissionContext { public: - // Parameter block_contains_multi_tiles indicates whether a tile block - // consists of multiple tiles or not. If the tile block contains only one - // tile, there is no need to use atomic operation to accumulate a local result - // to a global result to implement reduction. - using TileGenerator = - std::function output_tile_bounds, - bool block_contains_multi_tiles)>; // KernelCodegenInfo records the common information to support the code // generation for a kernel to process tensor elements by blocks. A block of // tensor elements may contain one or multiple tiles. The code generators that @@ -68,29 +63,21 @@ class IrEmitterUnnested : public IrEmitter { public: explicit KernelCodegenInfo(llvm_ir::KernelMappingScheme* mapping_scheme) : mapping_scheme_(mapping_scheme), - tiled_param_info_(nullptr), lane_id_(nullptr), index_ty_(nullptr) {} virtual ~KernelCodegenInfo() {} void SetLaneId(llvm::Value* v) { lane_id_ = v; } void SetIndexType(llvm::Type* t) { index_ty_ = t; } - void SetTiledParamInfo(llvm_ir::TiledParameterInfo* tiled_param_info) { - tiled_param_info_ = tiled_param_info; - } llvm::Value* GetLaneId() const { return lane_id_; } llvm_ir::KernelMappingScheme* GetKernelMappingScheme() const { return mapping_scheme_; } - llvm_ir::TiledParameterInfo* GetTiledParameterInfo() const { - return tiled_param_info_; - } llvm::Type* GetIndexType() const { return index_ty_; } protected: llvm_ir::KernelMappingScheme* mapping_scheme_; - llvm_ir::TiledParameterInfo* tiled_param_info_; llvm::Value* lane_id_; llvm::Type* index_ty_; }; @@ -101,6 +88,7 @@ class IrEmitterUnnested : public IrEmitter { // A function object to finalize the code generation for a tile block. using BlockEpilogueGenerator = std::function; + // A function object to generate code to process one element in a tile. // // hlo: the instruction for which the code is generated for. @@ -110,38 +98,14 @@ class IrEmitterUnnested : public IrEmitter { // kernel_info: Other information to support the kernel code generation. // x_iter_num: When a thread process N elements in the X dimension, x_iter_num // has a value of 0..N-1 to identify the element being process. - using TileElementGenerator = std::function; - // KernelCodeGenerator records the code generator objects that generate code - // for tile elements or tile block prologue/epilogue. - class KernelCodeGenerator { - public: - explicit KernelCodeGenerator( - TileElementGenerator tile_element_generator, - BlockPrologueGenerator block_prologue_generator = {}, - BlockEpilogueGenerator block_epilogue_generator = {}) - : tile_element_generator_(std::move(tile_element_generator)), - block_prologue_generator_(std::move(block_prologue_generator)), - block_epilogue_generator_(std::move(block_epilogue_generator)) {} - - const TileElementGenerator& GetTileElementGenerator() const { - return tile_element_generator_; - } - const BlockPrologueGenerator& GetBlockPrologueGenerator() const { - return block_prologue_generator_; - } - const BlockEpilogueGenerator& GetBlockEpilogueGenerator() const { - return block_epilogue_generator_; - } - - private: - TileElementGenerator tile_element_generator_; - BlockPrologueGenerator block_prologue_generator_; - BlockEpilogueGenerator block_epilogue_generator_; - }; + using TileElementGenerator = std::function; IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, @@ -157,7 +121,8 @@ class IrEmitterUnnested : public IrEmitter { Status DefaultAction(HloInstruction* hlo) override; // IrEmitterUnnested handles the following instructions differently from - // IrEmitter. + // IrEmitter. It also mixes in some special handling for custom kernels + // via the ThunkEmitter. Status HandleCopy(HloInstruction* copy) override; Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; @@ -199,10 +164,30 @@ class IrEmitterUnnested : public IrEmitter { private: // Add a owning Thunk object to the thunk sequence. - void AddThunkToThunkSequence(std::unique_ptr thunk) { + void AddThunkToThunkSequence(std::unique_ptr thunk) override { thunk_sequence_->emplace_back(std::move(thunk)); } + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + StatusOr MaybeGetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index) const override { + return ir_emitter_context_->buffer_assignment().GetUniqueSlice(&hlo, index); + } + + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return MaybeGetAllocationSlice(hlo, index).ConsumeValueOrDie(); + } + + int64 ByteSizeOf(const Shape& shape) const override { + return llvm_ir::ByteSizeOf( + shape, ir_emitter_context_->llvm_module()->getDataLayout()); + } + + const se::Platform* platform() const override { + return ir_emitter_context_->platform(); + } + // Builds the prototype of the IR kernel for `inst` and adds it to the module. // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( @@ -227,7 +212,7 @@ class IrEmitterUnnested : public IrEmitter { // and first_reduce are the same instruction. For a kInput fusion, // unnested_hlo is the fusion instruction while first_reduce is the first // reduce op. - std::tuple + std::pair ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo, const HloInstruction* first_reduce); @@ -242,76 +227,72 @@ class IrEmitterUnnested : public IrEmitter { // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // for the hlo instruction. bool CheckAndEmitHloWithTile021(HloInstruction* hlo); + // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and - // returns the launch dimensions for the kernel. This is a helper to support + // sets the corresponding launch dimensions. This is a helper to support // the implementation of CheckAndEmitHloWithTile021. - LaunchDimensions EmitHlo021Tile(HloInstruction* hlo, - absl::Span reduced_output_dims, - absl::Span tiled_param_ids); - // Emits a kernel for an unnested HLO instruction. - LaunchDimensions EmitKernel(HloInstruction* unnested_hlo, - absl::Span param_ids, - const KernelCodeGenerator& kernel_generator, - KernelCodegenInfo* kernel_info); - void EmitBlock(const TileGenerator& emit_one_tile, - KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, - llvm::Type* index_ty); + void EmitHlo021Tile(HloInstruction* hlo, Thunk* kernel_thunk, + absl::Span reduced_output_dims, + absl::Span tiled_param_ids); + + // Emits a kernel for an unnested HLO instruction, set the `kernel_thunk` + // launch dimensions. + void EmitKernel(HloInstruction* unnested_hlo, Thunk* kernel_thunk, + KernelCodegenInfo* kernel_info, + TileElementGenerator tile_element_generator, + BlockPrologueGenerator block_prologue_generator, + BlockEpilogueGenerator block_epilogue_generator); + + void EmitBlock(KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl, + llvm::Value* y, llvm::Value* x, + TileElementGenerator tile_generator); + // Emits code to process a tensor element in a tile for the given kCopy HLO // that performs a 0-2-1 transpose. - void EmitTileElementForCopy(HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc, - int64 x_iter_num); + void EmitTileElementForCopy( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num, + absl::Span param_shmem_buffers); + // Emits code to process a tensor element in a tile for the given kLoop fusion // HLO containing parameters that are 0-2-1 transpose of its outputs. - void EmitTileElementForFusion(HloInstruction* hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc, - int64 x_iter_num); + void EmitTileElementForFusion( + HloInstruction* hlo, const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num, + absl::Span param_shmem_buffers); + // Emits code to process a tensor element in a tile for the given input hlo // that is either a unnested kReduce or a kInput fusion. - void EmitTileElementForReduction(HloInstruction* unnested_hlo, - const llvm_ir::IrArray::Index& index, - const KernelCodegenInfo* kernel_info, - llvm::Value* y_loc, llvm::Value* x_loc, - int64 x_iter_num); + void EmitTileElementForReduction( + HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, + absl::Span output_instructions, + const llvm_ir::IrArray::Index& index, + const KernelCodegenInfo* kernel_info, + absl::Span reducers, int64 x_iter_num); + // Prepares for the code generation for a tile block of a reduction kernel. - void EmitPrologueForReduction(HloInstruction* unnested_hlo, - KernelCodegenInfo* kernel_info); + void EmitPrologueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info, + absl::Span reduce_instructions); + void EmitPrologueForOneReduction(HloInstruction* unnested_hlo, HloInstruction* reduce_inst, int reduce_idx, KernelCodegenInfo* kernel_info, - GpuElementalIrEmitter* elemental_emitter, - ShapeIndex output_shape_index); + GpuElementalIrEmitter* elemental_emitter); // Wraps up the code generation for a tile block of a reduction kernel. - void EmitEpilogueForReduction(HloInstruction* unnested_hlo, - KernelCodegenInfo* kernel_info); + void EmitEpilogueForReduction( + HloInstruction* unnested_hlo, KernelCodegenInfo* kernel_info, + absl::Span reduce_instructions, + absl::Span reduction_output_shape_indices, + absl::Span reducers); // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( absl::Span reducers, absl::Span partial_result_addresses); - // Generates the IrArray for each input of an hlo and returns a vector that - // constains such IrArrays. - std::vector ConstructIrArrayForInputs( - const HloInstruction& hlo); - - // For each input of the `hlo` instruction, checks its value in - // `param_buffers` to find out whether the input has a reduced shape. If the - // input has a reduced shape, constructs the reduced shape for the input and - // casts the original input IrArray in `param_arrays` to the reduced shape. - // Return the total number of inputs. - int ConstructInputReducedShapeAndCastInputIrArrayToShape( - const HloInstruction& hlo, - const std::vector& param_arrays, - const std::vector& param_buffers, - absl::Span reduced_output_dims, - std::vector* param_reduced_shapes, - std::vector* param_in_reduced_shape_arrays); - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. The kernel implementation will be unrolled if unroll_factor @@ -322,39 +303,11 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction* inst, bool implements_whole_instruction, int unroll_factor = 1); - // Returns a FftThunk that calls cuFFT to implement `inst`. - std::unique_ptr BuildFftThunk(const HloInstruction* inst); - - // Returns a CholeskyThunk that calls cuSolver to implement `inst`. - std::unique_ptr BuildCholeskyThunk(const HloInstruction* inst); - - // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. - std::unique_ptr BuildTriangularSolveThunk(const HloInstruction* inst); - - // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs - // to make sure `inst` outlives the lifetime of the returned Thunk object. - std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( HloInstruction* hlo, const ShapeIndex& index = {}); - // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); - - // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildDeviceToDeviceCopyThunk( - const HloInstruction* inst); - - // Returns an InfeedThunk that performs a host-to-device memcpy to implement - // `inst`. - std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); - - // Returns an OutfeedThunk that performs a device-to-host memcpy to implement - // `inst`. - std::unique_ptr BuildOutfeedThunk(const HloInstruction* inst); - // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 2f73fd0b3d4..db26d36c71a 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -16,12 +16,12 @@ cc_library( name = "llvm_gpu_backend", srcs = [ "dump_ir_pass.cc", - "nvptx_backend_lib.cc", + "gpu_backend_lib.cc", "utils.cc", ], hdrs = [ "dump_ir_pass.h", - "nvptx_backend_lib.h", + "gpu_backend_lib.h", "utils.h", ], deps = [ @@ -30,6 +30,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service/gpu:gpu_types", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc similarity index 54% rename from tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc rename to tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 9f52f09004b..84616f3a37b 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include #include #include #include @@ -40,6 +41,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/FormattedStream.h" +#include "llvm/Support/Program.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/ToolOutputFile.h" @@ -65,6 +67,9 @@ namespace xla { namespace gpu { namespace { +// Inline threshold value to use in LLVM AMDGPU backend. +const int kAMDGPUInlineThreshold = 0x100000; + // Default inline threshold value to use in llvm. const int kDefaultInlineThreshold = 1100; @@ -124,7 +129,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { // Returns the TargetMachine, given a triple. std::unique_ptr GetTargetMachine( llvm::Triple triple, absl::string_view cpu_name, - const HloModuleConfig& hlo_module_config) { + const HloModuleConfig& hlo_module_config, absl::string_view feature_str) { std::string error; const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); if (target == nullptr) { @@ -155,8 +160,9 @@ std::unique_ptr GetTargetMachine( codegen_opt_level = CodeGenOpt::None; } return absl::WrapUnique(target->createTargetMachine( - triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, - getRelocModel(), getCodeModel(), codegen_opt_level)); + triple.str(), llvm_ir::AsStringRef(cpu_name), + llvm_ir::AsStringRef(feature_str), target_options, getRelocModel(), + getCodeModel(), codegen_opt_level)); } // Adds the standard LLVM optimization passes, based on the speed optimization @@ -166,13 +172,14 @@ std::unique_ptr GetTargetMachine( void AddOptimizationPasses(unsigned opt_level, unsigned size_level, llvm::TargetMachine* target_machine, llvm::legacy::PassManagerBase* module_passes, - llvm::legacy::FunctionPassManager* function_passes) { + llvm::legacy::FunctionPassManager* function_passes, + int inline_threshold) { PassManagerBuilder builder; builder.OptLevel = opt_level; builder.SizeLevel = size_level; if (opt_level > 1) { - builder.Inliner = llvm::createFunctionInliningPass(kDefaultInlineThreshold); + builder.Inliner = llvm::createFunctionInliningPass(inline_threshold); } else { // Only inline functions marked with "alwaysinline". builder.Inliner = llvm::createAlwaysInlinerLegacyPass(); @@ -240,13 +247,13 @@ void FeedLLVMWithFlags(const std::vector& cl_opts) { llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); } -// Returns whether the module could use any libdevice functions. This function -// may have false positives -- the module might not use libdevice even if this -// function returns true. -bool CouldNeedLibdevice(const llvm::Module& module) { +// Returns whether the module could use any device bitcode library functions. +// This function may have false positives -- the module might not use libdevice +// on NVPTX or ROCm-Device-Libs on AMDGPU even if this function returns true. +bool CouldNeedDeviceBitcode(const llvm::Module& module) { for (const llvm::Function& function : module.functions()) { // This is a conservative approximation -- not all such functions are in - // libdevice. + // libdevice or ROCm-Device-Libs. if (!function.isIntrinsic() && function.isDeclaration()) { return true; } @@ -254,11 +261,41 @@ bool CouldNeedLibdevice(const llvm::Module& module) { return false; } +// Links the module with a vector of path to bitcode modules. +// The caller must guarantee that the paths exist. +Status LinkWithBitcodeVector(llvm::Module* module, + const std::vector& bitcode_path_vector) { + llvm::Linker linker(*module); + + for (auto& bitcode_path : bitcode_path_vector) { + if (!tensorflow::Env::Default()->FileExists(bitcode_path).ok()) { + LOG(ERROR) << "bitcode module is required by this HLO module but was " + "not found at " + << bitcode_path; + return xla::InternalError("bitcode module not found at %s", bitcode_path); + } + + std::unique_ptr bitcode_module = + LoadIRModule(bitcode_path, &module->getContext()); + if (linker.linkInModule( + std::move(bitcode_module), llvm::Linker::Flags::LinkOnlyNeeded, + [](Module& M, const StringSet<>& GVS) { + internalizeModule(M, [&GVS](const GlobalValue& GV) { + return !GV.hasName() || (GVS.count(GV.getName()) == 0); + }); + })) { + return xla::InternalError("Error linking bitcode module from %s", + bitcode_path); + } + } + return Status::OK(); +} + // Links libdevice into the given module if the module needs libdevice. Status LinkLibdeviceIfNecessary(llvm::Module* module, std::pair compute_capability, const string& libdevice_dir_path) { - if (!CouldNeedLibdevice(*module)) { + if (!CouldNeedDeviceBitcode(*module)) { return Status::OK(); } @@ -274,38 +311,20 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, } VLOG(1) << "Linking with libdevice from: " << libdevice_path; - std::unique_ptr libdevice_module = - LoadIRModule(libdevice_path, &module->getContext()); - - llvm::Linker linker(*module); - if (linker.linkInModule( - std::move(libdevice_module), llvm::Linker::Flags::LinkOnlyNeeded, - [](Module& M, const StringSet<>& GVS) { - internalizeModule(M, [&GVS](const GlobalValue& GV) { - return !GV.hasName() || (GVS.count(GV.getName()) == 0); - }); - })) { - return xla::InternalError("Error linking libdevice from %s", - libdevice_path); - } - return Status::OK(); + return LinkWithBitcodeVector(module, {libdevice_path}); } -StatusOr CompileModuleToPtx(llvm::Module* module, - std::pair compute_capability, - const HloModuleConfig& hlo_module_config, - const string& libdevice_dir_path) { - // If the module has no functions or globals, there's nothing to compile. Just - // return an empty string. - if (module->empty() && module->global_empty()) { - VLOG(2) << "Module '" << module->getName().str() - << "' is empty. Skipping compilation."; - return string(); - } +Status NVPTXTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, + const string& device_bitcode_dir_path) { // Link the input module with libdevice, to pull in implementations of some // builtins. - TF_RETURN_IF_ERROR( - LinkLibdeviceIfNecessary(module, compute_capability, libdevice_dir_path)); + auto compute_capability = absl::get_if>(&gpu_version); + if (!compute_capability) { + return xla::InternalError("Incompatible compute capability was specified."); + } + TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, *compute_capability, + device_bitcode_dir_path)); // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass // can access it. @@ -319,6 +338,31 @@ StatusOr CompileModuleToPtx(llvm::Module* module, } } + return Status::OK(); +} + +std::unique_ptr NVPTXGetTargetMachine( + llvm::Triple target_triple, std::pair compute_capability, + const HloModuleConfig& hlo_module_config) { + // Figure out the exact name of the processor as known to the NVPTX backend + // from the gpu_architecture flag. + return GetTargetMachine(target_triple, GetSmName(compute_capability), + hlo_module_config, "+ptx60"); +} + +using TargetModuleLinker = std::function; + +Status LinkAndOptimizeModule(llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, + const string& device_bitcode_dir_path, + TargetModuleLinker module_linker, + llvm::Triple default_target_triple, + llvm::TargetMachine* target_machine, + int inline_threshold) { + TF_RETURN_IF_ERROR(module_linker(module, gpu_version, hlo_module_config, + device_bitcode_dir_path)); + IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false); // Add an appropriate TargetLibraryInfo pass for the module's triple. @@ -332,13 +376,9 @@ StatusOr CompileModuleToPtx(llvm::Module* module, llvm::Triple target_triple = llvm::Triple(module->getTargetTriple()); if (target_triple.getArch() == llvm::Triple::UnknownArch) { LOG(WARNING) << "target triple not found in the module"; - target_triple = llvm::Triple("nvptx64-unknown-unknown"); + target_triple = default_target_triple; } - // Figure out the exact name of the processor as known to the NVPTX backend - // from the gpu_architecture flag. - std::unique_ptr target_machine = GetTargetMachine( - target_triple, GetSmName(compute_capability), hlo_module_config); module_passes.add(llvm::createTargetTransformInfoWrapperPass( target_machine->getTargetIRAnalysis())); @@ -365,9 +405,10 @@ StatusOr CompileModuleToPtx(llvm::Module* module, LOG(ERROR) << std::string(80, '*'); } + // Add optimization passes, and set inliner threshold. AddOptimizationPasses(opt_level, - /*size_level=*/0, target_machine.get(), &module_passes, - &function_passes); + /*size_level=*/0, target_machine, &module_passes, + &function_passes, inline_threshold); // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. @@ -394,13 +435,12 @@ StatusOr CompileModuleToPtx(llvm::Module* module, function_passes.doFinalization(); module_passes.run(*module); - // Finally, produce PTX. - return EmitModuleToPTX(module, target_machine.get()); + return Status::OK(); } // One-time module initializer. // Must be called only once -- DO NOT CALL DIRECTLY. -void GPUBackendInit(const HloModuleConfig& hlo_module_config) { +void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) { // Feed all customized flags here, so we can override them with llvm_cl_opts // without redeploy the compiler for development purpose. @@ -446,25 +486,267 @@ void GPUBackendInit(const HloModuleConfig& hlo_module_config) { } // namespace -StatusOr CompileToPtx(llvm::Module* module, - std::pair compute_capability, +namespace nvptx { + +StatusOr CompileToPtx(llvm::Module* module, GpuVersion gpu_version, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { static std::once_flag backend_init_flag; - std::call_once(backend_init_flag, GPUBackendInit, hlo_module_config); + std::call_once(backend_init_flag, NVPTXBackendInit, hlo_module_config); string ptx; + std::unique_ptr target_machine; { tensorflow::profiler::TraceMe activity( [&] { return absl::StrCat("Compiling IR:", module->getName().str()); }, tensorflow::profiler::TraceMeLevel::kInfo); XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); - TF_ASSIGN_OR_RETURN( - ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, - libdevice_dir_path)); + + // If the module has no functions or globals, there's nothing to compile. + // Just return an empty string. + if (module->empty() && module->global_empty()) { + VLOG(2) << "Module '" << module->getName().str() + << "' is empty. Skipping compilation."; + return string(); + } + + auto compute_capability = absl::get_if>(&gpu_version); + if (!compute_capability) { + return xla::InternalError( + "Incompatible compute capability was specified."); + } + + llvm::Triple default_target_triple("nvptx64-unknown-unknown"); + // Construct LLVM TargetMachine for NVPTX. + std::unique_ptr target_machine = NVPTXGetTargetMachine( + default_target_triple, *compute_capability, hlo_module_config); + + // Link with libdeivce, and optimize the LLVM module. + TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + module, gpu_version, hlo_module_config, libdevice_dir_path, + NVPTXTargetModuleLinker, default_target_triple, target_machine.get(), + kDefaultInlineThreshold)); + + // Lower optimized LLVM module to PTX. + ptx = EmitModuleToPTX(module, target_machine.get()); } return ptx; } +} // namespace nvptx + +namespace { + +// Gets the ROCm-Device-Libs filenames for a particular AMDGPU version. +static std::vector GetROCDLPaths(int amdgpu_version, + const string& rocdl_dir_path) { + // AMDGPU version-neutral bitcodes. + static std::vector* rocdl_filenames = new std::vector( + {"hc.amdgcn.bc", "opencl.amdgcn.bc", "ocml.amdgcn.bc", "ockl.amdgcn.bc", + "oclc_finite_only_off.amdgcn.bc", "oclc_daz_opt_off.amdgcn.bc", + "oclc_correctly_rounded_sqrt_on.amdgcn.bc", + "oclc_unsafe_math_off.amdgcn.bc"}); + + // Construct full path to ROCDL bitcode libraries. + std::vector result; + for (auto& filename : *rocdl_filenames) { + result.push_back(tensorflow::io::JoinPath(rocdl_dir_path, filename)); + } + + // Add AMDGPU version-specific bitcodes. + result.push_back(tensorflow::io::JoinPath( + rocdl_dir_path, + absl::StrCat("oclc_isa_version_", amdgpu_version, ".amdgcn.bc"))); + return result; +} + +// Emits the given module to HSA Code Object. target_machine is an initialized +// TargetMachine for the AMDGPU target. +StatusOr> EmitModuleToHsaco( + Module* module, llvm::TargetMachine* target_machine) { + auto* env = tensorflow::Env::Default(); + std::vector tempdir_vector; + env->GetLocalTempDirectories(&tempdir_vector); + if (tempdir_vector.empty()) { + return xla::InternalError( + "Unable to locate a temporary directory for compile-time artifacts."); + } + std::string tempdir_name = tempdir_vector.front(); + VLOG(1) << "Compile-time artifacts located at: " << tempdir_name; + + // Prepare filenames for all stages of compilation: + // IR, binary ISA, and HSACO. + std::string ir_filename = absl::StrCat(module->getModuleIdentifier(), ".ll"); + std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename); + + std::string isabin_filename = + absl::StrCat(module->getModuleIdentifier(), ".o"); + std::string isabin_path = + tensorflow::io::JoinPath(tempdir_name, isabin_filename); + + std::string hsaco_filename = + absl::StrCat(module->getModuleIdentifier(), ".hsaco"); + std::string hsaco_path = + tensorflow::io::JoinPath(tempdir_name, hsaco_filename); + + std::error_code ec; + + // Dump LLVM IR. + std::unique_ptr ir_fs( + new llvm::raw_fd_ostream(ir_path, ec, llvm::sys::fs::F_None)); + module->print(*ir_fs, nullptr); + ir_fs->flush(); + + // Emit GCN ISA binary. + // The extension is stripped by IrDumpingPassManager, so we need to + // get creative to add a suffix. + std::string module_id = module->getModuleIdentifier(); + IrDumpingPassManager codegen_passes( + ReplaceFilenameExtension(tensorflow::io::Basename(module_id), + "-amdgpu.dummy"), + "", false); + codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( + llvm::Triple(module->getTargetTriple()))); + llvm::SmallVector stream; + llvm::raw_svector_ostream pstream(stream); + std::unique_ptr isabin_fs( + new llvm::raw_fd_ostream(isabin_path, ec, llvm::sys::fs::F_Text)); + module->setDataLayout(target_machine->createDataLayout()); + target_machine->addPassesToEmitFile(codegen_passes, *isabin_fs, nullptr, + llvm::TargetMachine::CGFT_ObjectFile); + codegen_passes.run(*module); + isabin_fs->flush(); + + // Locate lld. + // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after + // ROCm-Device-Libs PR. + std::string lld_path = tensorflow::io::JoinPath("/opt/rocm", "hcc/bin"); + auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); + if (!lld_program) { + return xla::InternalError("unable to find ld.lld in PATH: %s", + lld_program.getError().message()); + } + std::vector lld_args{ + llvm_ir::AsStringRef("ld.lld"), + llvm_ir::AsStringRef("-flavor"), + llvm_ir::AsStringRef("gnu"), + llvm_ir::AsStringRef("-shared"), + llvm_ir::AsStringRef(isabin_path), + llvm_ir::AsStringRef("-o"), + llvm_ir::AsStringRef(hsaco_path), + }; + + std::string error_message; + int lld_result = + llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), + llvm::None, {}, 0, 0, &error_message); + + if (lld_result) { + return xla::InternalError("ld.lld execute fail: %s", error_message); + } + + // Read HSACO. + std::ifstream hsaco_file(hsaco_path, std::ios::binary | std::ios::ate); + std::ifstream::pos_type hsaco_file_size = hsaco_file.tellg(); + + std::vector hsaco(hsaco_file_size); + hsaco_file.seekg(0, std::ios::beg); + hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); + return hsaco; +} + +// Links ROCm-Device-Libs into the given module if the module needs it. +Status LinkROCDLIfNecessary(llvm::Module* module, int amdgpu_version, + const string& rocdl_dir_path) { + if (!CouldNeedDeviceBitcode(*module)) { + return Status::OK(); + } + + return LinkWithBitcodeVector(module, + GetROCDLPaths(amdgpu_version, rocdl_dir_path)); +} + +Status AMDGPUTargetModuleLinker(llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, + const string& device_bitcode_dir_path) { + // Link the input module with ROCDL. + auto amdgpu_version = absl::get_if(&gpu_version); + if (!amdgpu_version) { + return xla::InternalError( + "Incompatible AMD GCN ISA version was specified."); + } + TF_RETURN_IF_ERROR( + LinkROCDLIfNecessary(module, *amdgpu_version, device_bitcode_dir_path)); + + return Status::OK(); +} + +std::unique_ptr AMDGPUGetTargetMachine( + llvm::Triple target_triple, int amdgpu_version, + const HloModuleConfig& hlo_module_config) { + return GetTargetMachine(target_triple, absl::StrCat("gfx", amdgpu_version), + hlo_module_config, "-code-object-v3"); +} + +void AMDGPUBackendInit(const HloModuleConfig& hlo_module_config) { + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); + + // Initialize the AMDGPU target; it's the only target we link with, so call + // its specific initialization functions instead of the catch-all + // InitializeAll*. +#if TENSORFLOW_USE_ROCM + LLVMInitializeAMDGPUTarget(); + LLVMInitializeAMDGPUTargetInfo(); + LLVMInitializeAMDGPUTargetMC(); + LLVMInitializeAMDGPUAsmPrinter(); +#endif + + llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); + InitializePasses(registry); +} + +} // namespace + +namespace amdgpu { +StatusOr> CompileToHsaco( + llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path) { + static std::once_flag backend_init_flag; + std::call_once(backend_init_flag, AMDGPUBackendInit, hlo_module_config); + + std::vector hsaco; + std::unique_ptr target_machine; + { + tensorflow::profiler::TraceMe activity( + [&] { return absl::StrCat("Compiling IR", module->getName().str()); }, + tensorflow::profiler::TraceMeLevel::kInfo); + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); + + auto amdgpu_version = absl::get_if(&gpu_version); + if (!amdgpu_version) { + return xla::InternalError( + "Incompatible AMD GCN ISA version was specified."); + } + + llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); + // Construct LLVM TargetMachine for AMDGPU. + std::unique_ptr target_machine = + AMDGPUGetTargetMachine(default_target_triple, *amdgpu_version, + hlo_module_config); + + // Link with ROCm-Device-Libs, and optimize the LLVM module. + TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + module, gpu_version, hlo_module_config, rocdl_dir_path, + AMDGPUTargetModuleLinker, default_target_triple, target_machine.get(), + kAMDGPUInlineThreshold)); + + // Lower optimized LLVM module to HSA code object. + TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get())); + } + return hsaco; +} + +} // namespace amdgpu + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h similarity index 67% rename from tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h rename to tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 9654175bfaf..526621de7a5 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -14,14 +14,15 @@ limitations under the License. ==============================================================================*/ // LLVM-based compiler backend. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ #include #include #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -29,6 +30,7 @@ limitations under the License. namespace xla { namespace gpu { +namespace nvptx { // Compiles the argument module and returns it. libdevice_dir_path is the parent // directory of the libdevice bitcode libraries. The contents of the module may // be changed. @@ -36,12 +38,21 @@ namespace gpu { // The Compile.* interfaces each create their own llvm::LLVMContext objects for // thread safety, but note that LLVM's multithreaded support is very // preliminary; multithreaded use is not recommended at this time. -StatusOr CompileToPtx(llvm::Module* module, - std::pair compute_capability, +StatusOr CompileToPtx(llvm::Module* module, GpuVersion gpu_version, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path); +} // namespace nvptx + +namespace amdgpu { +// Compiles the argument module and returns it with LLVM AMDGPU backend. +// rocdl_dir_path is the parent directory of ROCm-Device-Libs bitcode libraries. +// The contents of the module may be changed. +StatusOr> CompileToHsaco( + llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, const string& rocdl_dir_path); +} // namespace amdgpu } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_BACKEND_LIB_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index 536b11a00a9..9c86f7cd2a2 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -17,12 +17,7 @@ limitations under the License. #include -#include -#include -#include #include -#include -#include #include #include "absl/algorithm/container.h" @@ -55,17 +50,15 @@ int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, HloInstruction* instr2) { absl::flat_hash_set in_list; for (auto instr : instr1->operands()) { - if (!IsProfitableOperand(instr)) { - continue; + if (IsProfitableOperand(instr)) { + in_list.insert(instr); } - in_list.insert(instr); } int64 profit = 0; for (auto instr : instr2->operands()) { - if (!IsProfitableOperand(instr) || !in_list.contains(instr)) { - continue; + if (IsProfitableOperand(instr) && in_list.contains(instr)) { + profit += ShapeUtil::ByteSizeOf(instr->shape()); } - profit += ShapeUtil::ByteSizeOf(instr->shape()); } VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() << ", the profit is =" << profit; @@ -77,7 +70,6 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, if (!MultiOutputFusion::LegalToFuse(instr1, instr2)) { return false; } - // If we're fusing fusions only do it if the fusion kind matches. Loop fusions // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions @@ -91,7 +83,6 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, instr1->IsLoopFusion())) { return false; } - // The emitter only supports in-place DUS for fusions with a single DUS at the // root. Don't sibling fuse DUS for now. // TODO(b/119178699): Multi-output fusing DUS can improve performance if we @@ -103,15 +94,15 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, HloOpcode::kDynamicUpdateSlice)) { return false; } - // Do this check last, as it may be expensive. return !FusionWouldBeTooLarge(*instr1, *instr2); } namespace { + // We prefer multi-output fusions over other fusions over unfused ops, because // we want to preserve fusion opportunities if possible. -HloInstruction* GetPreferredFusionCandidate( +HloInstruction* SelectPreferredFusionCandidate( const std::vector candidates) { for (auto* candidate : candidates) { if (candidate->IsMultiOutputFusion()) { @@ -123,8 +114,54 @@ HloInstruction* GetPreferredFusionCandidate( return candidate; } } - return candidates.empty() ? nullptr : candidates[0]; + return candidates.empty() ? nullptr : candidates.front(); } + +std::vector GetProducerConsumerMultiOutputFusionCandidates( + const HloInstruction* producer, const HloReachabilityMap& reachability) { + std::vector fusion_candidates; + for (HloInstruction* consumer : producer->users()) { + VLOG(3) << "Looking at producer " << producer->name() + << " and its consumer " << consumer->name(); + if (!IsInputFusibleReduction(*consumer)) { + VLOG(3) << "Consumer " << consumer->name() + << " is not an input-fusible reduction.."; + continue; + } + if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) { + VLOG(3) << producer->name() << " and " << consumer->name() + << " are not fusible."; + continue; + } + // Do not fuse a producer if the other operands of the fusion are + // reachable from the producer, this would create a cycle. + auto operand_reachable_from_producer = [&](const HloInstruction* operand) { + // If a get-tuple-elment instruction is not in the reachability + // map, it has been created by fusion in this pass. Simply move + // on to its operand, which is in the reachability map. + if (!reachability.IsPresent(operand) && + operand->opcode() == HloOpcode::kGetTupleElement) { + operand = operand->operand(0); + } + CHECK(reachability.IsPresent(operand) && reachability.IsPresent(producer)) + << "Reachability map is incomplete. This should never " + "happen."; + return producer != operand && reachability.IsReachable(producer, operand); + }; + if (absl::c_any_of(consumer->operands(), operand_reachable_from_producer)) { + VLOG(3) << producer->name() << " would introduce a cycle when fused."; + continue; + } + if (FusionWouldBeTooLarge(*producer, *consumer)) { + VLOG(3) << producer->name() << " and " << consumer->name() + << " would be too large of a fusion."; + continue; + } + fusion_candidates.push_back(consumer); + } + return fusion_candidates; +} + } // namespace bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { @@ -144,86 +181,43 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { VLOG(3) << producer->name() << " is a constant."; continue; } - std::vector fusion_candidates; - for (HloInstruction* consumer : producer->users()) { - VLOG(3) << "Looking at producer " << producer->name() - << " and its consumer " << consumer->name(); - // TODO(b/136623068): Use IsFusibleAsMultiOutputFusionRoot(...) to lift - // the restriction to input-fusible reductions. - if (!IsInputFusibleReduction(*consumer)) { - VLOG(3) << "Consumer " << consumer->name() - << " is not an input-fusible reduction."; - continue; - } - if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) { - VLOG(3) << producer->name() << " and " << consumer->name() - << " are not fusible."; - continue; - } - // Do not fuse a producer if the other operands of the fusion are - // reachable from the producer, this would create a cycle. - if (absl::c_any_of( - consumer->operands(), [&](const HloInstruction* operand) { - // If a get-tuple-elment instruction is not in the reachability - // map, it has been created by fusion in this pass. Simply move - // on to its operand, which is in the reachability map. - if (!reachability()->IsPresent(operand) && - operand->opcode() == HloOpcode::kGetTupleElement) { - operand = operand->operand(0); - } - CHECK(reachability()->IsPresent(operand) && - reachability()->IsPresent(producer)) - << "Reachability map is incomplete. This should never " - "happen."; - return producer != operand && - reachability()->IsReachable(producer, operand); - })) { - VLOG(3) << producer->name() << " would introduce a cycle when fused."; - continue; - } - if (FusionWouldBeTooLarge(*producer, *consumer)) { - VLOG(3) << producer->name() << " and " << consumer->name() - << " would be too large of a fusion."; - continue; - } - fusion_candidates.push_back(consumer); + const auto candidates = GetProducerConsumerMultiOutputFusionCandidates( + producer, *reachability()); + auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates); + if (consumer_for_fusion == nullptr) { + continue; } - auto* consumer_for_fusion = GetPreferredFusionCandidate(fusion_candidates); - if (consumer_for_fusion != nullptr) { - changed = true; - if (consumer_for_fusion->opcode() == HloOpcode::kFusion) { - VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " - << consumer_for_fusion->name(); - if (producer->opcode() == HloOpcode::kFusion) { - consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer); - } else { - consumer_for_fusion->FuseInstructionIntoMultiOutput(producer); - CHECK_EQ(0, producer->user_count()); - TF_CHECK_OK(computation()->RemoveInstruction(producer)); - } + changed = true; + if (consumer_for_fusion->opcode() == HloOpcode::kFusion) { + VLOG(2) << "Fuse producer " << producer->name() << " into its consumer " + << consumer_for_fusion->name(); + if (producer->opcode() == HloOpcode::kFusion) { + consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer); } else { - HloInstruction* input_fusion = - computation()->AddInstruction(HloInstruction::CreateFusion( - consumer_for_fusion->shape(), - ChooseFusionKind(*producer, *consumer_for_fusion), - consumer_for_fusion)); - VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " - << consumer_for_fusion->name() << " into " - << input_fusion->name(); - reachability()->Replace(consumer_for_fusion, input_fusion); - TF_CHECK_OK(computation()->ReplaceInstruction(consumer_for_fusion, - input_fusion)); - if (producer->opcode() == HloOpcode::kFusion) { - input_fusion->MergeFusionInstructionIntoMultiOutput(producer); - } else { - input_fusion->FuseInstructionIntoMultiOutput(producer); - CHECK_EQ(0, producer->user_count()); - TF_CHECK_OK(computation()->RemoveInstruction(producer)); - } + consumer_for_fusion->FuseInstructionIntoMultiOutput(producer); + CHECK_EQ(0, producer->user_count()); + TF_CHECK_OK(computation()->RemoveInstruction(producer)); } + continue; + } + HloInstruction* input_fusion = + computation()->AddInstruction(HloInstruction::CreateFusion( + consumer_for_fusion->shape(), + ChooseFusionKind(*producer, *consumer_for_fusion), + consumer_for_fusion)); + VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " + << consumer_for_fusion->name() << " into " << input_fusion->name(); + reachability()->Replace(consumer_for_fusion, input_fusion); + TF_CHECK_OK( + computation()->ReplaceInstruction(consumer_for_fusion, input_fusion)); + if (producer->opcode() == HloOpcode::kFusion) { + input_fusion->MergeFusionInstructionIntoMultiOutput(producer); + } else { + input_fusion->FuseInstructionIntoMultiOutput(producer); + CHECK_EQ(0, producer->user_count()); + TF_CHECK_OK(computation()->RemoveInstruction(producer)); } } - return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc old mode 100644 new mode 100755 index 20b3d64c417..2f2a2efcecb --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -17,100 +17,35 @@ limitations under the License. #include -#include -#include -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include +#include -#include "absl/memory/memory.h" -#include "absl/strings/numbers.h" -#include "absl/strings/str_cat.h" -#include "llvm/IR/DiagnosticInfo.h" -#include "llvm/IR/DiagnosticPrinter.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Verifier.h" -#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/convolution_group_converter.h" -#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" -#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" -#include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_pad_for_tensor_cores.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" -#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" -#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h" -#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" -#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" -#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" -#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" -#include "tensorflow/compiler/xla/service/gpu/variadic_op_splitter.h" -#include "tensorflow/compiler/xla/service/hlo.pb.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/compiler/xla/service/hlo_proto_util.h" -#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/mem_wasted_on_passthrough_params.h" -#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" -#include "tensorflow/compiler/xla/service/reshape_mover.h" -#include "tensorflow/compiler/xla/service/rng_expander.h" -#include "tensorflow/compiler/xla/service/slice_sinker.h" -#include "tensorflow/compiler/xla/service/sort_simplifier.h" -#include "tensorflow/compiler/xla/service/stable_sort_expander.h" -#include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" -#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" -#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" -#include "tensorflow/compiler/xla/service/while_loop_trip_count_annotator.h" -#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" @@ -165,6 +100,109 @@ string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) { return "."; } +} // namespace + +Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + // Convert convolutions into CustomCalls to cudnn, then canonicalize them + // (CudnnConvPaddingLegalization). Also expand cuSolver calls. + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); + if (IsVoltaOrLater(*stream_exec)) { + pipeline.AddPass(); + // CudnnConvPadForTensorCores leaves behind unnecessary + // tuple/get-tuple-element pairs that TupleSimplifier fixes. + pipeline.AddPass(); + } + + // tf2xla bridge, DepthwiseConvolutionConverter and CudnnConvRewriter + // introduces reshapes and transposes that can be eliminated using + // AlgebraicSimplifier + { + auto& pass = pipeline.AddPass>( + "algebraic_simplification_post_conv_rewriter"); + pass.AddInvariantChecker(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); + + AlgebraicSimplifierOptions options; + pass.AddPass(options); + } + + // CudnnConvRewriter, CudnnConvPaddingLegalization and + // CudnnConvPadForTensorCores may add instructions which can be simplified + // by constant folding. + pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + +Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* device_allocator) { + HloPassPipeline pipeline("post-layout_assignment"); + /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after + * fixing the ticket. */ + pipeline.AddInvariantChecker( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + pipeline.AddPass>(options); + + // Rewrite GEMMs into custom calls. + pipeline.AddPass(); + + // Choose the fastest algorithm for each conv. + // + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: + // + // customcall = (f32[...], f32[0]) + // return gte(customcall, 0) + // + // The algorithm picker then chooses the best algorithm, and potentially + // increases the scratch space. It replaces customcall with new_tuple, + // giving us the following: + // + // new_customcall = (f32[...], f32[N]) + // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) + // return gte(new_tuple, 0) + // + // The new tuple and gte instructions then be simplified away, because + // nobody is expected to use the scratch value. + // + // However, if we were to run CudnnConvAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. + pipeline.AddPass(stream_exec, device_allocator); + + // Find the fastest algorithm for GEMMs. + pipeline.AddPass(stream_exec, device_allocator); + + // Clean up new_tuple described above. + pipeline.AddPass(); + + pipeline.AddPass(/*is_layout_sensitive=*/true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + + return Status::OK(); +} + +namespace { absl::optional CanShareBufferHint(const HloInstruction* user, const HloInstruction* operand, const ShapeIndex& user_index) { @@ -222,387 +260,71 @@ void WarnIfBadDriverJITVersion() { }); } +// Try to load ptx from files defined in the FLAGS. If successful, return true. +bool MaybeLoadPtxFromFile(const HloModule* module, std::string* ptx) { + // If the xla_gpu_ptx_file options is set, be explicit when a file is used + // and warn when a file is not used to ease catching typo in filename. + std::string prefix = xla::FilenameFor(*module, *ptx); + std::string matched_filename; + for (const string filename : + module->config().debug_options().xla_gpu_ptx_file()) { + // To ease comparing many PTX versions, accept different suffixes then + // the original filename. + if (absl::StartsWith(filename, prefix)) { + matched_filename = filename; + VLOG(0) << "RunBackend() - Will load PTX from file: " << filename; + break; + } + } + if (module->config().debug_options().xla_gpu_ptx_file().size() > 0 && + matched_filename.empty()) { + VLOG(0) << "RunBackend() - For module with prefix '" << prefix + << "', we did not found a PTX file to load."; + } + + if (!matched_filename.empty()) { + std::ifstream ifs(matched_filename, std::ifstream::in); + *ptx = std::string(std::istreambuf_iterator(ifs), + std::istreambuf_iterator()); + CHECK(!ptx->empty()) << "Empty or non existing PTX file: " + << matched_filename; + return true; + } + return false; +} + } // namespace -// Runs optimization passes on the given HLO module. -Status impl::OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - { - HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); - - // Expand random number generation. - pipeline.AddPass(); - - // Remove zero-sized HLO from the input so that other passes don't have to - // handle it. - pipeline.AddPass(); - - pipeline.AddPass(); - - pipeline.AddPass(); - pipeline.AddPass(); - ReducePrecisionInsertion::AddPasses( - &pipeline, hlo_module->config().debug_options(), - ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - - // TODO(b/64094172): make Call work on GPU instead of inlining. - pipeline.AddPass(); - auto cost_model = [](HloInstruction* conv) { - // We need a cost model for GPUs. Currently, do nothing. - return false; - }; - pipeline.AddPass(); - pipeline.AddPass( - cost_model, - /*convert_batch_groups_only=*/true); - // Expand the sort op to support stable sorting if required. - pipeline.AddPass(); - // Convert BF16 operations to F32 operations so that the GPU backend can - // support BF16 operations without directly implementing a BF16 lowering for - // most ops. - pipeline.AddPass(BF16, F32); - - { - auto& pass = - pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); - - // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls - // where possible. Not every batchnorm op can be implemented as a call to - // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs. - if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { - pass.AddPass(); - } - pass.AddPass( - /*rewrite_training_op=*/true, - /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true); - - pipeline.AddPass(); - - // BatchNormExpander can create zero-sized ops, so zero-sized HLO - // elimination has to come after that pass. - pipeline.AddPass(); - - AlgebraicSimplifierOptions options; - pass.AddPass(options); - pass.AddPass(); - pass.AddPass(); - pass.AddPass(); - pass.AddPass(); - - // TODO(b/134075051): Re-enable after b/134075051 is fixed. - // pass.AddPass(); - - pass.AddPass(); - pass.AddPass(); - pass.AddPass(); - pass.AddPass(); - } - - pipeline.AddPass( - [](const HloInstruction& dot, - const TransposeFolding::OperandIndices& candidate_operands) { - return IsMatrixMultiplication(dot) - ? candidate_operands - : TransposeFolding::OperandIndices{}; - }, - TransposeFolding::NeverFoldTranspose); - pipeline.AddPass(/*is_layout_sensitive=*/false); - pipeline.AddPass(); - - // Run WhileLoopTripCountAnnotator at the end of the simplification - // pipeline, before layout assignment and fusion. This pass does some - // pattern-matching on while bodies/conditions, and this is where the HLO is - // "nicest". - // - // It's important that we don't make semantic changes (e.g. unrolling) to - // any `while` loops after this point, because otherwise the trip-count - // annotations added by this pass may not be correct after the - // modifications. - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (CudnnConvPaddingLegalization). Also expand cuSolver calls. - HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker(/*layout_sensitive=*/false, - /*allow_mixed_precision=*/false); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - if (IsVoltaOrLater(*stream_exec)) { - pipeline.AddPass(); - // CudnnConvPadForTensorCores leaves behind unnecessary - // tuple/get-tuple-element pairs that TupleSimplifier fixes. - pipeline.AddPass(); - } - // CudnnConvRewriter, CudnnConvPaddingLegalization and - // CudnnConvPadForTensorCores may add instructions which can be simplified - // by constant folding. - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - // Run layout assignment in a separate pipeline from - // "post-layout-assignment" because we want everything after layout - // assignment to have a layout-sensitive invariant-checker, but - // HloPassPipeline also runs its invariant checker before any passes are - // run, meaning, the pipeline that contains layout assignment cannot contain - // a layout-sensitive verifier! - HloPassPipeline pipeline("layout assignment"); - pipeline.AddPass( - hlo_module->mutable_entry_computation_layout(), - LayoutAssignment::InstructionCanChangeLayout, stream_exec); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassPipeline pipeline("post-layout_assignment"); - /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after - * fixing the ticket. */ - pipeline.AddInvariantChecker( - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, - LayoutAssignment::InstructionCanChangeLayout); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - AlgebraicSimplifierOptions options; - options.set_is_layout_sensitive(true); - pipeline.AddPass>(options); - - // Rewrite GEMMs into custom calls. - pipeline.AddPass(); - - // Choose the fastest algorithm for each conv. - // - // We pick the algorithm before fusion so we can generate better HLO. After - // CudnnConvRewriter, our convolutions are CustomCalls which return a - // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of - // scratch: - // - // customcall = (f32[...], f32[0]) - // return gte(customcall, 0) - // - // The algorithm picker then chooses the best algorithm, and potentially - // increases the scratch space. It replaces customcall with new_tuple, - // giving us the following: - // - // new_customcall = (f32[...], f32[N]) - // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) - // return gte(new_tuple, 0) - // - // The new tuple and gte instructions then be simplified away, because - // nobody is expected to use the scratch value. - // - // However, if we were to run CudnnConvAlgorithmPicker after fusion - // the gte(customcall, 0) would probably already be into a fusion node. We - // can't simplify across HloComputation boundaries, so in this case we - // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass(stream_exec, device_allocator); - - // Find the fastest algorithm for GEMMs. - pipeline.AddPass(stream_exec, device_allocator); - - // Clean up new_tuple described above. - pipeline.AddPass(); - - pipeline.AddPass(/*is_layout_sensitive=*/true); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassFix fusion("fusion"); - // We try to split variadic ops with many parameters into several such ops - // to avoid exceeding the parameter space. - fusion.AddPass(); - /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after - * fixing the ticket. */ - fusion.AddInvariantChecker( - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, - LayoutAssignment::InstructionCanChangeLayout); - fusion.AddPass(/*may_duplicate=*/false); - fusion.AddPass(/*may_duplicate=*/true); - fusion.AddPass(); - fusion.AddPass(); - fusion.AddPass(/*is_layout_sensitive=*/true, - /*only_fusion_computations=*/true); - fusion.AddPass(); - TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); - - HloPassPipeline reduce_pipeline("reduce-precision"); - /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after - * fixing the ticket. */ - reduce_pipeline.AddInvariantChecker( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, - LayoutAssignment::InstructionCanChangeLayout); - ReducePrecisionInsertion::AddPasses( - &reduce_pipeline, hlo_module->config().debug_options(), - ReducePrecisionInsertion::PassTiming::AFTER_FUSION); - StatusOr reduce_result = reduce_pipeline.Run(hlo_module); - TF_RETURN_IF_ERROR(reduce_result.status()); - - if (reduce_result.ValueOrDie()) { - // Do another fusion pass, with the expectation that we may be able to - // fuse the new ReducePrecision operations. - TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); - } - } - - return Status::OK(); -} - -// Modifies the given HLO module so that it will be accepted by IrEmitter. -// Unlike optimization passes, the passes are necessary for correctness. -Status impl::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { - // In some cases, we have to place the result of an instruction in a temporary - // buffer. For instance, the buffer that holds an external parameter is - // assumed immutable at this point, and should not be reused for output - // (b/27180329). Therefore, in that case, we set the output to be a copy of - // the parameter. - HloPassPipeline pipeline("GPU-ir-emit-prepare"); - /* TODO(b/117531509): Use LayoutAssignment::InstructionCanChangeLayout after - * fixing the ticket. */ - pipeline.AddInvariantChecker( - /*layout_sensitive=*/true, - /*allow_mixed_precision=*/false, - LayoutAssignment::InstructionCanChangeLayout); - - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. - pipeline.AddPass(); - pipeline.AddPass(); - // The following pass LOGs memory waste. Add it when VLOGing is enabled only. - if (VLOG_IS_ON(2)) { - pipeline.AddPass(); - } - pipeline.AddPass(&CanShareBufferHint); - pipeline.AddPass(); - return pipeline.Run(hlo_module).status(); -} - NVPTXCompiler::NVPTXCompiler() - : pointer_size_(llvm::DataLayout(nvptx::kDataLayout) - .getPointerSize(0 /* default address space */)) {} + : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::kTargetTriple, + nvptx::kDataLayout) {} -StatusOr> NVPTXCompiler::RunHloPasses( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - // We dump the post-optimization HLO in RunBackend so no need to dump it here. - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); - tensorflow::profiler::TraceMe activity( - [&] { return absl::StrCat("HLO Transforms:", module->name()); }, - tensorflow::profiler::TraceMeLevel::kInfo); - TF_RETURN_IF_ERROR( - impl::OptimizeHloModule(module.get(), stream_exec, device_allocator)); - - TF_RETURN_IF_ERROR(impl::PrepareHloModuleForIrEmitting(module.get())); - - return std::move(module); +HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() { + return &CanShareBufferHint; } -StatusOr> NVPTXCompiler::RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend"); - - TF_RET_CHECK(stream_exec != nullptr); - - llvm::LLVMContext llvm_context; - std::string buffer; - llvm::raw_string_ostream error(buffer); - llvm::DiagnosticPrinterRawOStream printer(error); - auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info, - void* Context) { - auto printer = static_cast(Context); - diag_info.print(*printer); - }; - llvm_context.setDiagnosticHandlerCallBack(DiagnosticHandler, &printer); - - llvm::Module llvm_module(module->name().c_str(), llvm_context); - // Set the target triple and the data layout. - llvm_module.setTargetTriple(nvptx::kTargetTriple); - llvm_module.setDataLayout(nvptx::kDataLayout); - - // Determine the HLO schedule, which is an ordering of HLO instructions. This - // is used by buffer assignment to enable buffer reuse, and the same ordering - // must also be used to determine the thunk launch schedule. - std::unique_ptr stream_assignment = AssignStreams(*module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr buffer_assignment, - BufferAssigner::Run( - module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), - /*color_alignment=*/ - [](LogicalBuffer::Color) { return kXlaAllocatedBufferAlignBytes; }, - /*allocate_buffers_for_constants=*/true, - /*colorer=*/BufferAssigner::DefaultColorer(), - /*must_not_live_out=*/{}, &CanShareBufferHint)); - DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); - - IrEmitterContext ir_emitter_context( - module.get(), buffer_assignment.get(), stream_exec->platform(), - &stream_exec->GetDeviceDescription(), &llvm_module); - - HloComputation* entry_computation = module->entry_computation(); - IrEmitterUnnested ir_emitter(module->config(), entry_computation, - &ir_emitter_context); - - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); - - { - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - IR emission"); - TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); +GpuVersion NVPTXCompiler::GetGpuVersion(se::StreamExecutor* stream_exec) { + int cc_major, cc_minor; + if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor)) { + LOG(WARNING) + << "Couldn't get compute capability for device; assuming sm_20."; + cc_major = 2; + cc_minor = 0; } - if (user_pre_optimization_hook_) { - user_pre_optimization_hook_(llvm_module); - } - string ir_module_string_before_opt; - const bool embed_ir_in_executable = - module->config().debug_options().xla_embed_ir_in_executable(); - if (embed_ir_in_executable) { - ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); - } + return std::make_pair(cc_major, cc_minor); +} - llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/false); +StatusOr>> +NVPTXCompiler::CompileTargetBinary(const HloModule* module, + llvm::Module* llvm_module, + GpuVersion gpu_version, + se::StreamExecutor* stream_exec) { + std::pair compute_capability = + absl::get>(gpu_version); - { - XLA_SCOPED_LOGGING_TIMER( - "NVPTXCompiler::RunBackend - Running LLVM verifier"); - - std::string err; - llvm::raw_string_ostream err_stream(err); - - // verifyModule() returns true if the module is broken. - TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) - << "Invalid LLVM IR before optimizations:\n" - << err_stream.str() - << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " - "Rerun with --xla_dump_to to get the IR. "; - } - - string libdevice_dir; + std::string libdevice_dir; { tensorflow::mutex_lock lock(mutex_); @@ -616,70 +338,31 @@ StatusOr> NVPTXCompiler::RunBackend( } VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n"; - int cc_major, cc_minor; - if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor)) { - LOG(WARNING) - << "Couldn't get compute capability for device; assuming sm_20."; - cc_major = 2; - cc_minor = 0; - } - string ptx; - { - XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunBackend - CompileToPtx"); - TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, - module->config(), libdevice_dir)); + if (!MaybeLoadPtxFromFile(module, &ptx)) { + XLA_SCOPED_LOGGING_TIMER( + "NVPTXCompiler::CompileTargetBinary - CompileToPtx"); + TF_ASSIGN_OR_RETURN( + ptx, nvptx::CompileToPtx(llvm_module, gpu_version, module->config(), + libdevice_dir)); } - llvm_ir::DumpIrIfEnabled(*module, llvm_module, /*optimized=*/true); + llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/true); if (user_post_optimization_hook_) { - user_post_optimization_hook_(llvm_module); + user_post_optimization_hook_(*llvm_module); } // Write PTX to IR dump directory, if IR dumping was requested. if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "ptx", ptx); } - const std::vector cubin = CompilePtxOrGetCachedResult( - stream_exec, ptx, cc_major, cc_minor, module->config()); + std::vector cubin = + CompilePtxOrGetCachedResult(stream_exec, ptx, compute_capability.first, + compute_capability.second, module->config()); - auto thunk_schedule = absl::make_unique( - ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); - if (DumpingEnabledForHloModule(*module)) { - DumpToFileInDirOrStdout(*module, "thunk_schedule", - thunk_schedule->ToString()); - } - - std::unique_ptr profile_index_map; - std::unique_ptr profile_printer; - - if (module->config().hlo_profiling_enabled() || VLOG_IS_ON(1)) { - HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); - cost_analysis.set_bytes_per_second( - stream_exec->GetDeviceDescription().memory_bandwidth()); - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - VLOG(1) << "HLO memory read+written: " - << tensorflow::strings::HumanReadableNumBytes( - cost_analysis.bytes_accessed()); - if (module->config().hlo_profiling_enabled()) { - profile_index_map = absl::make_unique(*module); - profile_printer = CreateHloProfilePrinterData( - *profile_index_map, cost_analysis, entry_computation->name()); - } - } - - auto* gpu_executable = new GpuExecutable( - ptx, cubin, std::make_pair(cc_major, cc_minor), std::move(thunk_schedule), - std::move(module), std::move(buffer_assignment), - std::move(profile_printer), std::move(profile_index_map)); - if (embed_ir_in_executable) { - DCHECK_NE("", ir_module_string_before_opt); - gpu_executable->set_ir_module_string(ir_module_string_before_opt); - } - return std::unique_ptr(gpu_executable); + return std::pair>(std::move(ptx), + std::move(cubin)); } std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( @@ -761,16 +444,5 @@ std::vector NVPTXCompiler::CompilePtxOrGetCachedResult( return cache_value->cubin_data; } -StatusOr>> -NVPTXCompiler::CompileAheadOfTime(std::unique_ptr module_group, - const AotCompilationOptions& options) { - return Unimplemented( - "not yet implemented: NVPTXCompiler::CompileAheadOfTime"); -} - -se::Platform::Id NVPTXCompiler::PlatformId() const { - return se::cuda::kCudaPlatformId; -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 980c00ac7da..a7b38afb8ec 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -22,72 +22,37 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/types/optional.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/stream_executor/stream_executor_pimpl.h" namespace xla { namespace gpu { -// Temporarily expose the optimization pipeline for the GPU backend for reuse -// in the MLIR GPU backend. -// TODO(b/137624192): Remove once MLIR backend uses tailored optimizations. -namespace impl { - -Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* device_allocator); -Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); - -} // namespace impl - -// The GPU compiler generates efficient GPU executables. -class NVPTXCompiler : public LLVMCompiler { +// NVPTXCompiler generates efficient GPU executables for NVPTX target. +class NVPTXCompiler : public GpuCompiler { public: NVPTXCompiler(); ~NVPTXCompiler() override {} - // Bring in - // StatusOr>> Compile( - // std::vector> modules, - // std::vector> - // stream_execs) - using LLVMCompiler::Compile; - - StatusOr> RunHloPasses( - std::unique_ptr module, se::StreamExecutor* stream_exec, + Status OptimizeHloConvolutionCanonicalization( + HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - StatusOr> RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec, + Status OptimizeHloPostLayoutAssignment( + HloModule* hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - StatusOr>> - CompileAheadOfTime(std::unique_ptr module_group, - AotCompilationOptions const& options) override; + HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() override; - se::Platform::Id PlatformId() const override; + GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) override; - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - // Capture just the pointer size, not the entire NVPTXCompiler object. - int64 pointer_size = pointer_size_; - return [pointer_size](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, pointer_size); - }; - } + StatusOr>> CompileTargetBinary( + const HloModule* hlo_module, llvm::Module* llvm_module, + GpuVersion gpu_version, se::StreamExecutor* stream_exec) override; private: - // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. - const int64 pointer_size_; - tensorflow::mutex mutex_; // When compiling an HLO module, we need to find a path to the nvvm libdevice diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index cb012649200..f9937ba77de 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -22,6 +22,7 @@ limitations under the License. // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -72,8 +73,8 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, VLOG(3) << "EmitIndexAndSetExitBasicBlock unroll_factor " << unroll_factor_; CHECK_NE(index_type, nullptr); std::vector array_indices; - llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_); + llvm::Value* block_id = + EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), static_cast(block_id)); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); @@ -82,8 +83,8 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x" // // %ntid.x is currently specified as 1024. - llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_); + llvm::Value* thread_id = + EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), static_cast(thread_id)); thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 10bc82488ff..2276807d74f 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -82,6 +82,11 @@ LaunchDimensions CalculateLaunchDimensions( // TODO(jlebar): Investigate this further, and tune this heuristic so we can // run faster on the few benchmarks where smaller block size helps. int64 threads_per_block = ThreadsPerBlockLimit(device_desc); + // We unroll kernels to make use of vectorized loads/stores. This means we + // need more registers to hold intermediate values. Reduce the number of + // blocks per thread to increase the number of registers available to ptxas. + // Make sure we still have a multiple of 32. + threads_per_block = RoundUpToNearest(threads_per_block / unroll_factor, 32LL); if (num_elements < threads_per_block) { threads_per_block = num_elements; VLOG(2) << "Update # of threads per block to the element count (" diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc b/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc deleted file mode 100644 index 5793051771f..00000000000 --- a/tensorflow/compiler/xla/service/gpu/scratch_allocator.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" - -namespace xla { -namespace gpu { - -StatusOr> ScratchAllocator::AllocateBytes( - se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - absl::StrFormat( - "Allocating %d bytes exceeds the memory limit of %d bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false)); - total_allocated_bytes_ += byte_size; - - se::DeviceMemoryBase buffer_addr = *allocated_buffer; - allocated_buffers_.push_back(std::move(allocated_buffer)); - return se::DeviceMemory(buffer_addr); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h b/tensorflow/compiler/xla/service/gpu/scratch_allocator.h deleted file mode 100644 index 9654237956a..00000000000 --- a/tensorflow/compiler/xla/service/gpu/scratch_allocator.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ - -#include - -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/stream_executor/device_memory_allocator.h" - -namespace xla { -namespace gpu { - -class ScratchAllocator : public se::ScratchAllocator { - public: - ScratchAllocator(int device_ordinal, - se::DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - - int64 GetMemoryLimitInBytes(se::Stream* stream) override { - return 1LL << 32; // 4GB. TODO(jlebar): Tune this? - } - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - StatusOr> AllocateBytes(se::Stream* stream, - int64 byte_size) override; - - template - StatusOr> Allocate(se::Stream* stream, - int64 num_elements) { - TF_ASSIGN_OR_RETURN(se::DeviceMemory bytes, - AllocateBytes(stream, num_elements * sizeof(T))); - return se::DeviceMemory(bytes); - } - - private: - const int device_ordinal_; - se::DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SCRATCH_ALLOCATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 1cdf9752390..117931e3398 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" @@ -201,10 +202,7 @@ StatusOr> CreateKernel( } auto kernel_base = absl::make_unique(stream_exec); - if (!stream_exec->GetKernel(loader_spec, kernel_base.get())) { - return InternalError("Unable to load kernel '%s'", kernel_name); - } - + TF_RETURN_IF_ERROR(stream_exec->GetKernel(loader_spec, kernel_base.get())); return std::move(kernel_base); } @@ -217,13 +215,9 @@ Status ExecuteKernelOnStream(const se::KernelBase& kernel, for (const se::DeviceMemoryBase& buf : args) { kernel_args->add_device_memory_argument(buf); } - - if (!stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), - se::BlockDim(block_count), kernel, - *kernel_args)) { - return InternalError("Unable to launch kernel"); - } - return Status::OK(); + return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), + se::BlockDim(block_count), kernel, + *kernel_args); } se::cuda::PtxCompilationOptions PtxOptsFromConfig( diff --git a/tensorflow/compiler/xla/service/gpu/target_util.cc b/tensorflow/compiler/xla/service/gpu/target_util.cc index 31f989bd58c..48c703183fc 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.cc +++ b/tensorflow/compiler/xla/service/gpu/target_util.cc @@ -29,9 +29,14 @@ namespace { using absl::StrCat; // Wrapper structure for carrying llvm intrinsic ids for NVPTX/AMDGPU platforms. +// On AMDGPU, some of these operations are made as device functions instead of +// intrinsics. Therefore a variant type is used to wrap the lambda to call +// those device functions. struct TargetIntrinsics { llvm::Intrinsic::ID nvptx_intrinsic; - llvm::Intrinsic::ID amdgpu_intrinsic; + absl::variant*)>> + amdgpu_intrinsic_or_function; }; // Gets the llvm intrinsic ids on different platforms (NVPTX, AMDGPU) @@ -66,6 +71,30 @@ struct TargetIntrinsics GetIntrinsic(TargetIntrinsicID intrin) { return {llvm::Intrinsic::nvvm_barrier0, llvm::Intrinsic::amdgcn_s_barrier}; } + case TargetIntrinsicID::kBlockDimx: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall("__ockl_get_local_size", + {b_->getInt32(0)}, {U32}, U64, {}, + b_); + }}; + } + case TargetIntrinsicID::kBlockDimy: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_y, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall("__ockl_get_local_size", + {b_->getInt32(1)}, {U32}, U64, {}, + b_); + }}; + } + case TargetIntrinsicID::kBlockDimz: { + return {llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_z, + [](llvm::IRBuilder<>* b_) -> llvm::CallInst* { + return EmitDeviceFunctionCall("__ockl_get_local_size", + {b_->getInt32(2)}, {U32}, U64, {}, + b_); + }}; + } } } @@ -156,6 +185,36 @@ string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, } } +llvm::CallInst* EmitDeviceFunctionCall( + const string& callee_name, absl::Span operands, + absl::Span input_types, PrimitiveType output_type, + absl::Span attributes, + llvm::IRBuilder<>* b) { + std::vector ir_input_types; + llvm::Module* module = b->GetInsertBlock()->getModule(); + for (PrimitiveType input_type : input_types) { + ir_input_types.push_back( + llvm_ir::PrimitiveTypeToIrType(input_type, module)); + } + llvm::FunctionType* callee_type = llvm::FunctionType::get( + llvm_ir::PrimitiveTypeToIrType(output_type, module), // Return type. + ir_input_types, // Parameter types. + false); // No variadic arguments. + + // Declares the callee if it is not declared already. + llvm::Function* callee = llvm::dyn_cast( + b->GetInsertBlock() + ->getModule() + ->getOrInsertFunction(callee_name, callee_type) + .getCallee()); + + for (auto attribute : attributes) { + callee->addFnAttr(attribute); + } + + return b->CreateCall(callee, llvm_ir::AsArrayRef(operands)); +} + llvm::CallInst* EmitCallToTargetIntrinsic( TargetIntrinsicID intrinsic_id, absl::Span operands, absl::Span overloaded_types, llvm::IRBuilder<>* b) { @@ -166,7 +225,17 @@ llvm::CallInst* EmitCallToTargetIntrinsic( if (target_triple.isNVPTX()) { llvm_intrinsic_id = gpu_intrinsic_id.nvptx_intrinsic; } else if (target_triple.getArch() == llvm::Triple::amdgcn) { - llvm_intrinsic_id = gpu_intrinsic_id.amdgpu_intrinsic; + llvm::Intrinsic::ID* llvm_intrinsic_id_ptr = + absl::get_if( + &gpu_intrinsic_id.amdgpu_intrinsic_or_function); + if (llvm_intrinsic_id_ptr) { + llvm_intrinsic_id = *llvm_intrinsic_id_ptr; + } else { + std::function*)>* builder_func = + absl::get_if*)>>( + &gpu_intrinsic_id.amdgpu_intrinsic_or_function); + return (*builder_func)(b); + } } else { LOG(FATAL) << "Invalid triple " << target_triple.str(); } diff --git a/tensorflow/compiler/xla/service/gpu/target_util.h b/tensorflow/compiler/xla/service/gpu/target_util.h index d50529e395e..4355ed21136 100644 --- a/tensorflow/compiler/xla/service/gpu/target_util.h +++ b/tensorflow/compiler/xla/service/gpu/target_util.h @@ -39,6 +39,9 @@ enum class TargetIntrinsicID { kBlockIdy, kBlockIdz, kBarrierId, + kBlockDimx, + kBlockDimy, + kBlockDimz, }; // Enumeration to get target specific device math function. @@ -59,8 +62,15 @@ enum class TargetDeviceFunctionID { kHypot }; -// Emits a call to the specified target intrinsic with the given operands. +// Emits IR to call a device function named "callee_name" on the given +// operand. Returns the IR value that represents the return value. +llvm::CallInst* EmitDeviceFunctionCall( + const std::string& callee_name, absl::Span operands, + absl::Span input_type, PrimitiveType output_type, + absl::Span attributes, + llvm::IRBuilder<>* b); +// Emits a call to the specified target intrinsic with the given operands. // Overloaded intrinsics (for example, "minnum") must include a type // in overloaded_types for each overloaded type. Typically, overloaded // intrinsics have only a single overloaded type. diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a9b52d985af..67051b153b1 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -7,7 +7,7 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 83fb6ebb443..7491949fa59 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -52,28 +52,5 @@ void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr hlo_module, EXPECT_TRUE(filecheck_result.ValueOrDie()); } -void GpuCodegenTest::MatchOptimizedHlo(absl::string_view hlo, - absl::string_view pattern, - bool print_operand_shape) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - GetOptimizedModule(hlo)); - HloPrintOptions print_opts; - print_opts.set_print_operand_shape(print_operand_shape); - StatusOr filecheck_result = - RunFileCheck(optimized_module->ToString(print_opts), pattern); - TF_ASSERT_OK(filecheck_result.status()); - EXPECT_TRUE(filecheck_result.ValueOrDie()); -} - -StatusOr> GpuCodegenTest::GetOptimizedModule( - absl::string_view hlo) { - HloModuleConfig config; - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo, config)); - return backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor(), - backend().default_stream_executor()->GetAllocator()); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h index c3c6586d12a..59fba6325ec 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h @@ -34,21 +34,6 @@ class GpuCodegenTest : public LlvmIrGenTestBase { // FileCheck pattern. (See http://llvm.org/docs/CommandGuide/FileCheck.html). void CompileAndVerifyPtx(std::unique_ptr hlo_module, absl::string_view pattern); - - // Compiles the given `hlo` with optimizations, and verifies that optimized - // HLO matches the given FileCheck pattern. - void MatchOptimizedHlo(absl::string_view hlo, absl::string_view pattern, - bool print_operand_shape = false); - - // LikeMatchOptimizedHlo, but checks operand shapes as well. - void MatchOptimizedHloWithShapes(absl::string_view hlo, - absl::string_view pattern) { - MatchOptimizedHlo(hlo, pattern, /*print_operand_shape=*/true); - } - - // Compiles and returns module with optimizations from a given HLO. - StatusOr> GetOptimizedModule( - absl::string_view hlo); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index a12932f573b..92bb84065a2 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -99,6 +99,22 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { /*match_optimized_ir=*/true); } +TEST_F(GpuKernelTilingTest, UnnestedTransposeC128TypeRun) { + const char *const kHloString = R"( + HloModule unnested_transpose_3 + + ENTRY unnested_transpose_3 { + para0 = c128[65,65]{1,0} parameter(0) + ROOT copy1 = c128[65,65]{0,1} copy(para0) + })"; + + // With the current implementation for the available hardwares, we bail out + // from the tiled transpose implementation at the last minute. Instead of + // checking the transpose is not tiled, we only check the module compiled and + // run in this test. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.0})); +} + TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { const char *const kHloString = R"( HloModule multiple_output_fusion_1 @@ -520,6 +536,51 @@ TEST_F(GpuKernelTilingTest, EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); } +TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) { + const char *const kHloString = R"( + HloModule Test + + scalar_add_computation.1 { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1) + } + ENTRY Test { + param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3) + constant_661 = f16[] constant(0) + broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={} + compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT + param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2) + select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695) + convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40) + param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1) + copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809) + convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335) + param_0.668 = f32[2]{0} parameter(0) + broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1} + subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687) + multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136) + constant_485 = f32[] constant(0) + reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 + reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 + ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1) + })"; + + // Check that no loop is generated for reduction. + auto hlo_module = + ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) + .ValueOrDie(); + CompileAndVerifyIr(std::move(hlo_module), + R"( +; CHECK-LABEL: define void @fusion +; CHECK-NOT: reduce.0.loop_header +; CHECK: } +)", + /*match_optimized_ir=*/true); + // Check that the kernel runs correctly. + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +} + TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) { const char *const kHloString = R"( HloModule reduction diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc new file mode 100644 index 00000000000..13d32672a95 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -0,0 +1,373 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" + +#include "tensorflow/compiler/xla/service/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +namespace xla { +namespace gpu { + +std::unique_ptr ThunkEmitter::BuildFftThunk(const HloInstruction* inst) { + const HloInstruction* operand = inst->operand(0); + return absl::make_unique( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); +} + +std::unique_ptr ThunkEmitter::BuildTriangularSolveThunk( + const HloInstruction* inst) { + const HloInstruction* a = inst->operand(0); + const HloInstruction* b = inst->operand(1); + int64 m = b->shape().dimensions(b->shape().rank() - 2); + int64 n = b->shape().dimensions(b->shape().rank() - 1); + int64 batch_size = std::accumulate( + b->shape().dimensions().begin(), b->shape().dimensions().end() - 2, + int64{1}, [](int64 a, int64 b) { return a * b; }); + int64 elem_size = + ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type()); + int64 a_batch_stride = inst->triangular_solve_options().left_side() + ? m * m * elem_size + : n * n * elem_size; + int64 b_batch_stride = m * n * elem_size; + return absl::make_unique( + inst->triangular_solve_options(), + /*a_input_buffer=*/GetAllocationSlice(*a), + /*b_input_buffer=*/GetAllocationSlice(*inst), + inst->shape().element_type(), batch_size, m, n, a_batch_stride, + b_batch_stride, inst); +} + +std::unique_ptr ThunkEmitter::BuildGemmThunk( + const HloInstruction* inst) { + auto config_or = inst->backend_config(); + GemmBackendConfig gemm_config = std::move(config_or.ValueOrDie()); + const HloInstruction* lhs = inst->operand(0); + const HloInstruction* rhs = inst->operand(1); + + // The bias is passed inside the output buffer. If those buffers are shared + // we can just use it, otherwise copy the bias values into the output buffer + // first. + if (gemm_config.beta() != 0.0) { + const HloInstruction* bias = inst->operand(2); + CHECK_EQ(bias->shape(), inst->shape()); + if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) { + std::vector> thunks; + thunks.push_back(absl::make_unique( + /*source_buffer=*/GetAllocationSlice(*bias), + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape()), nullptr)); + thunks.push_back(absl::make_unique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + /*implements_whole_instruction=*/false, inst, + std::move(gemm_config))); + return absl::make_unique(std::move(thunks), inst); + } + } + + return absl::make_unique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + /*implements_whole_instruction=*/true, inst, std::move(gemm_config)); +} + +std::unique_ptr ThunkEmitter::BuildInfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); + + ShapeTree slices(inst->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + *slice = GetAllocationSlice(*inst, index); + }); + return absl::make_unique(slices, inst); +} + +std::unique_ptr ThunkEmitter::BuildOutfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); + + ShapeTree slices(inst->operand(0)->shape()); + slices.ForEachMutableElement([&](const ShapeIndex& index, + BufferAllocation::Slice* slice) { + auto status_or_slice = MaybeGetAllocationSlice(*inst->operand(0), index); + if (status_or_slice.ok()) { + *slice = status_or_slice.ValueOrDie(); + } + }); + return absl::make_unique(std::move(slices), inst); +} + +Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) { + // A CustomCall on the GPU backend can either be a custom-call to a + // user-supplied kernel, or a call into a library like cudnn. + + // Lower custom-calls to cudnn batchnorm ops to specialized thunks. It's part + // of the contract of these cudnn batchnorm calls that the epsilon and + // feature_index operands be constants. + if (custom_call->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget) { + const HloInstruction* epsilon = custom_call->operand(5); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(6); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + AddThunkToThunkSequence( + absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*offset=*/GetAllocationSlice(*custom_call->operand(2)), + /*mean=*/GetAllocationSlice(*custom_call->operand(3)), + /*variance=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget) { + const HloInstruction* epsilon = custom_call->operand(3); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(4); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + // BatchNormTraining returns a tuple of three elements: data, calculated + // mean, and calculated 1/sqrt(variance + epsilon). + auto output_data = GetAllocationSlice(*custom_call, {0}); + auto output_mean = GetAllocationSlice(*custom_call, {1}); + auto output_inv_stddev = GetAllocationSlice(*custom_call, {2}); + AddThunkToThunkSequence( + absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*offset=*/GetAllocationSlice(*custom_call->operand(2)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_data=*/output_data, + /*output_mean=*/output_mean, + /*output_inv_stddev=*/output_inv_stddev, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) { + const HloInstruction* epsilon = custom_call->operand(5); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(6); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale, + // grad_offset. + auto output_grad_data = GetAllocationSlice(*custom_call, {0}); + auto output_grad_scale = GetAllocationSlice(*custom_call, {1}); + auto output_grad_offset = GetAllocationSlice(*custom_call, {2}); + AddThunkToThunkSequence(absl::make_unique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (IsCustomCallToDnnConvolution(*custom_call)) { + std::vector operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } + auto tuple_result_slice = GetAllocationSlice(*custom_call); + auto conv_result_slice = GetAllocationSlice(*custom_call, {0}); + auto scratch_slice = GetAllocationSlice(*custom_call, {1}); + + AddThunkToThunkSequence(absl::make_unique( + Cast(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == kCusolverCholeskyCallTarget) { + TF_ASSIGN_OR_RETURN(CholeskyOptions options, + custom_call->backend_config()); + + const Shape& shape = custom_call->operand(0)->shape(); + int ndim = shape.dimensions_size(); + CHECK_GE(ndim, 2); + int64 n = shape.dimensions(ndim - 1); + + const auto& dims = shape.dimensions(); + int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1}, + [](int64 a, int64 b) { return a * b; }); + + auto operand_buffer = GetAllocationSlice(*custom_call->operand(0)); + + auto a_buffer = GetAllocationSlice(*custom_call, {0}); + auto workspace_buffer = GetAllocationSlice(*custom_call, {1}); + auto info_buffer = GetAllocationSlice(*custom_call, {2}); + + std::vector> thunks; + + if (operand_buffer != a_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/a_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(shape), custom_call)); + } + + thunks.push_back(absl::make_unique( + options, a_buffer, workspace_buffer, info_buffer, + custom_call->operand(0)->shape().element_type(), batch_size, n, + custom_call)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), custom_call)); + } + + return Status::OK(); + } + + if (IsCublasGemm(*custom_call)) { + AddThunkToThunkSequence(BuildGemmThunk(custom_call)); + return Status::OK(); + } + + if (void* call_target = CustomCallTargetRegistry::Global()->Lookup( + custom_call->custom_call_target(), platform()->Name())) { + auto get_slices_for_instr = [&](const HloInstruction* instr) { + ShapeTree slices(instr->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + StatusOr s = + MaybeGetAllocationSlice(*instr, index); + if (s.ok()) { + *slice = s.ValueOrDie(); + } + }); + return slices; + }; + std::vector> operand_slices; + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(get_slices_for_instr(operand)); + } + ShapeTree result_slices = + get_slices_for_instr(custom_call); + AddThunkToThunkSequence(absl::make_unique( + call_target, std::move(operand_slices), std::move(result_slices), + Cast(custom_call)->opaque(), custom_call)); + return Status::OK(); + } + + return Unimplemented("No registered implementation for custom call to \"%s\"", + custom_call->custom_call_target()); +} + +Status ThunkEmitter::HandleFft(HloInstruction* fft) { + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + AddThunkToThunkSequence(BuildFftThunk(fft)); + return Status::OK(); +} + +Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { + auto has_fortran_layout = [](const Layout& layout) { + int n = layout.minor_to_major_size(); + return layout.minor_to_major(0) == n - 2 && + layout.minor_to_major(1) == n - 1; + }; + TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout())); + TF_RET_CHECK(has_fortran_layout(hlo->shape().layout())); + + std::vector> thunks; + + // Triangular solve is in-place on 'b', so copy 'b' to the output if they + // aren't the same buffer. + auto operand_buffer = GetAllocationSlice(*hlo->operand(1)); + auto destination_buffer = GetAllocationSlice(*hlo); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()), hlo)); + } + + thunks.push_back(BuildTriangularSolveThunk(hlo)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence( + absl::make_unique(std::move(thunks), hlo)); + } + return Status::OK(); +} + +Status ThunkEmitter::HandleInfeed(HloInstruction* infeed) { + AddThunkToThunkSequence(BuildInfeedThunk(infeed)); + return Status::OK(); +} + +Status ThunkEmitter::HandleOutfeed(HloInstruction* outfeed) { + AddThunkToThunkSequence(BuildOutfeedThunk(outfeed)); + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h new file mode 100644 index 00000000000..55d92c74794 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h @@ -0,0 +1,97 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_EMITTER_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +// Implements handling of GPU execution for HLO operations that are handed off +// to specialzied thunks that do not require code generation. Intended to be +// mixed into GPU emitters. +class ThunkEmitter { + public: + class EmissionContext { + public: + virtual void AddThunkToThunkSequence(std::unique_ptr thunk) = 0; + virtual StatusOr MaybeGetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index) const = 0; + virtual int64 ByteSizeOf(const Shape& shape) const = 0; + virtual const se::Platform* platform() const = 0; + + virtual ~EmissionContext() = default; + }; + + explicit ThunkEmitter(EmissionContext* context) : context_(context) {} + + Status HandleCustomCall(HloInstruction* custom_call); + Status HandleFft(HloInstruction* fft); + Status HandleTriangularSolve(HloInstruction* hlo); + Status HandleInfeed(HloInstruction* xla_infeed); + Status HandleOutfeed(HloInstruction* outfeed); + + private: + EmissionContext* context_; + + void AddThunkToThunkSequence(std::unique_ptr thunk) { + return context_->AddThunkToThunkSequence(std::move(thunk)); + } + + StatusOr MaybeGetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index) const { + return context_->MaybeGetAllocationSlice(hlo, index); + } + + int64 ByteSizeOf(const Shape& shape) { return context_->ByteSizeOf(shape); } + + const se::Platform* platform() const { return context_->platform(); } + + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return MaybeGetAllocationSlice(hlo, index).ValueOrDie(); + } + + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr BuildFftThunk(const HloInstruction* inst); + + // Returns a CholeskyThunk that calls cuSolver to implement `inst`. + std::unique_ptr BuildCholeskyThunk(const HloInstruction* inst); + + // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. + std::unique_ptr BuildTriangularSolveThunk(const HloInstruction* inst); + + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs + // to make sure `inst` outlives the lifetime of the returned Thunk object. + std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + + // Returns an InfeedThunk that performs a host-to-device memcpy to implement + // `inst`. + std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + + // Returns an OutfeedThunk that performs a device-to-host memcpy to implement + // `inst`. + std::unique_ptr BuildOutfeedThunk(const HloInstruction* inst); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 83894f17445..8d9ddb97d9e 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -22,6 +22,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_live_range.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -29,199 +31,6 @@ namespace xla { using absl::flat_hash_map; using absl::flat_hash_set; -namespace { -// FlattenSchedule walks through the instruction, and recurse into each called -// computations. As it walks it also tracks down the ordinal number of each -// instruction in the schedule and store it in the `instruction_schedule`. The -// end of each computation is tracked in `computation_schedule`. -int64 FlattenSchedule( - const HloComputation& computation, - const HloInstructionSequence& instruction_sequence, - const HloSchedule* schedule, int64 start_time, - absl::flat_hash_map* instruction_schedule, - absl::flat_hash_map* computation_schedule) { - int64 time = start_time; - for (const HloInstruction* instruction : - instruction_sequence.instructions()) { - if (schedule != nullptr) { - // Recurse into sub computations if we have a module-scoped schedule. - if (instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kConditional) { - for (const HloComputation* called_computation : - instruction->called_computations()) { - const HloInstructionSequence& called_sequence = - schedule->sequence(called_computation); - time = - FlattenSchedule(*called_computation, called_sequence, schedule, - time, instruction_schedule, computation_schedule); - computation_schedule->insert({called_computation, time}); - } - } - if (instruction->opcode() == HloOpcode::kWhile) { - const HloInstructionSequence& condition_sequence = - schedule->sequence(instruction->while_condition()); - time = FlattenSchedule(*instruction->while_condition(), - condition_sequence, schedule, time, - instruction_schedule, computation_schedule); - computation_schedule->insert({instruction->while_condition(), time}); - const HloInstructionSequence& body_sequence = - schedule->sequence(instruction->while_body()); - time = - FlattenSchedule(*instruction->while_body(), body_sequence, schedule, - time, instruction_schedule, computation_schedule); - } - } - if (instruction_schedule->count(instruction) != 0) { - continue; - } - instruction_schedule->insert({instruction, time++}); - } - computation_schedule->insert({&computation, time}); - return time; -} - -// The aliased buffers could have overlapping live ranges. -// NormalizeAliasedBuffers normalizes the buffer such that each alias buffer has -// disjoint live range while keeping the live range union the same. This avoid -// double counting aliased buffer sizes. -// -// Before(buffer1 and 2 are aliased): -// -// +----+ live range of buffer1 -// +------------------+ live range of buffer2 -// -// After: -// -// +----------+ live range of buffer1 -// +------+ live range of buffer2 -// -// Before(buffer1 and 2 are aliased): -// -// +----------+ live range of buffer1 -// +------------+ live range of buffer2 -// -// After: -// -// +----------+ live range of buffer1 -// +------+ live range of buffer2 -// -// Before(buffer1 and 2 are aliased): -// -// +----------+ live range of buffer1 -// +---+ live range of buffer2 -// -// After(unchanged): -// -// +----------+ live range of buffer1 -// +---+ live range of buffer2 -// -// As another example, imagine we have the following code sequence with live -// ranges of each while-aliased buffers: -// -// a p1 p2 e b -// a = ... + -// | -// { | -// p1 = param | + -// ROOT true | | -// } | + -// { // body | -// p2 = param + + -// c = p2 + 1 + -// d = c + 1 -// ROOT e = d + 1 + -// } | -// | -// b = while (a) + + -// | -// f = b + 1 + -// -// After normalization it becomes: -// -// a p1 p2 e b -// a = ... + -// | -// { + -// p1 = param + -// ROOT true | -// } + -// { // body -// p2 = param + -// c = p2 + 1 + -// d = c + 1 -// ROOT e = d + 1 + -// } | -// | -// b = while (a) + -// + -// f = b + 1 + -// -// Note there is no overlap of live ranges after normalization. -void NormalizeAliasedBuffers( - absl::flat_hash_map* buffer_start_map, - absl::flat_hash_map* buffer_end_map, - const std::vector& values_to_assign, - const HloAliasAnalysis& alias_analysis) { - absl::flat_hash_set values_to_assign_set( - values_to_assign.begin(), values_to_assign.end()); - for (const HloBuffer& hlo_buffer : alias_analysis.buffers()) { - std::vector aliased_buffers; - for (const HloValue* hlo_value : hlo_buffer.values()) { - if (values_to_assign_set.count(hlo_value) != 0) { - aliased_buffers.push_back(hlo_value); - CHECK_NE(buffer_start_map->count(hlo_value), 0); - CHECK_NE(buffer_end_map->count(hlo_value), 0); - } - } - absl::c_sort( - aliased_buffers, [&](const HloValue* value1, const HloValue* value2) { - if ((*buffer_start_map)[value1] != (*buffer_start_map)[value2]) { - return (*buffer_start_map)[value1] < (*buffer_start_map)[value2]; - } - return (*buffer_end_map)[value1] < (*buffer_end_map)[value2]; - }); - - for (int64 i = 0; i < aliased_buffers.size(); ++i) { - // We can't use aliased_buffers.size() - 1 since aliased_buffers.size() is - // an unsigned integer and can be 0. - if (i + 1 == aliased_buffers.size()) { - break; - } - - const HloValue* value1 = aliased_buffers[i]; - const HloValue* value2 = aliased_buffers[i + 1]; - if ((*buffer_start_map)[value1] == (*buffer_start_map)[value2]) { - // If value1 has the same start time as value2, make value1 disappear by - // setting the end time same as start time: - // - // Before: - // +----+ value1 - // +----------+ value2 - // - // After: - // + value1 - // +----------+ value2 - // - // Note that only when heap simulator runs before copy insertion can - // this happen where one instruction defines multiple aliased buffers -- - // This is illegle to execute and can be fixed by copy insertion later. - (*buffer_end_map)[value1] = (*buffer_start_map)[value1]; - continue; - } - - if ((*buffer_end_map)[value1] < (*buffer_start_map)[value2]) { - continue; - } - - if ((*buffer_end_map)[value1] > (*buffer_end_map)[value2]) { - (*buffer_end_map)[value2] = (*buffer_end_map)[value1]; - } - (*buffer_end_map)[value1] = (*buffer_start_map)[value2] - 1; - } - } -} -} // namespace - /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( const HloSchedule& schedule, @@ -283,8 +92,12 @@ StatusOr HeapSimulator::Run( const HloComputation* entry_computation = module.entry_computation(); const HloInstructionSequence& instruction_sequence = schedule.sequence(entry_computation); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_live_range, + HloLiveRange::Run(schedule, alias_analysis, entry_computation)); TF_RETURN_IF_ERROR(heap.RunComputation(*entry_computation, - instruction_sequence, alias_analysis)); + instruction_sequence, alias_analysis, + hlo_live_range.get())); return heap.Finish(); } @@ -298,8 +111,13 @@ StatusOr HeapSimulator::Run( memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*schedule=*/nullptr, memory_by_computation); - TF_RETURN_IF_ERROR( - heap.RunComputation(computation, instruction_sequence, alias_analysis)); + HloSchedule schedule(computation.parent()); + schedule.set_sequence(&computation, instruction_sequence); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, + HloLiveRange::Run(schedule, alias_analysis, &computation, + /*module_scoped_analysis=*/false)); + TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + alias_analysis, hlo_live_range.get())); return heap.Finish(); } @@ -312,8 +130,11 @@ StatusOr HeapSimulator::Run( const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*schedule=*/schedule, nullptr); - TF_RETURN_IF_ERROR( - heap.RunComputation(computation, instruction_sequence, alias_analysis)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_live_range, + HloLiveRange::Run(*schedule, alias_analysis, &computation)); + TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, + alias_analysis, hlo_live_range.get())); return heap.Finish(); } @@ -322,36 +143,24 @@ StatusOr HeapSimulator::Run( Status HeapSimulator::RunComputation( const HloComputation& computation, const HloInstructionSequence& instruction_sequence, - const HloAliasAnalysis& alias_analysis) { + const HloAliasAnalysis& alias_analysis, HloLiveRange* hlo_live_range) { XLA_VLOG_LINES(1, computation.parent()->ToString()); XLA_VLOG_LINES(2, computation.ToString()); + VLOG(1) << hlo_live_range->ToString(); + HloDataflowAnalysis& dataflow_analysis = alias_analysis.dataflow_analysis(); - // instruction_schedule and computation_schedule are the maps that track each - // instruction/computation and their ordinal in the schedule. - absl::flat_hash_map instruction_schedule; - absl::flat_hash_map computation_schedule; - - // program_end_time is the time of the last instruction scheduled. It is equal - // to the number of instructions in a computation. - int64 program_end_time = - FlattenSchedule(computation, instruction_sequence, schedule_, 0, - &instruction_schedule, &computation_schedule); - - VLOG(1) << "Program end time: " << program_end_time; - - // We track the definition and free events for each buffer, then we go through - // each step and reply those events in program order. - absl::flat_hash_map buffer_start_map; - absl::flat_hash_map buffer_end_map; + algorithm_->SetSchedules(&hlo_live_range->flattened_instruction_sequence(), + &hlo_live_range->instruction_schedule()); // Record the buffer define/free event for each time step. We free all // remaining buffers (entry parameter, etc) after the program has finished // running, so we set the size of to program_end_time + 1. - std::vector> buffers_defined(program_end_time + - 1); - std::vector> buffers_freed(program_end_time + 1); + std::vector> buffers_defined( + hlo_live_range->schedule_end_time() + 1); + std::vector> buffers_freed( + hlo_live_range->schedule_end_time() + 1); // values_to_assign tracks the HloValues that we need to assign a buffer to. // Note that we only need to assign a buffer to a value when both of the @@ -364,106 +173,49 @@ Status HeapSimulator::RunComputation( // - If the instruction is in a nested call of the current computation, only // assign a buffer if we are doing global heap simulation. std::vector values_to_assign; + values_to_assign.reserve(dataflow_analysis.values().size()); - // Keeps track of buffer start time and buffer end time. for (const HloValue* value : dataflow_analysis.values()) { - // Ignore buffers that are not defined. - if (instruction_schedule.count(value->defining_instruction()) == 0) { + // Ignore buffers that are not tracked. + if (hlo_live_range->instruction_schedule().count( + value->defining_instruction()) == 0) { continue; } if (IgnoreBuffer(value)) { continue; } values_to_assign.push_back(value); - int64 buffer_start_time = instruction_schedule[value->instruction()]; - - int64 buffer_end_time = -1; - // A buffer's live range ends when the last user finishes executing. - for (const HloUse& use : value->uses()) { - const HloInstruction* used = use.instruction; - // As an optimization, we deem a while's init value's live range ends as - // soon as the loop body starts. This optimization is only applicable to - // the whole module simulation. - if (schedule_ != nullptr && used->opcode() == HloOpcode::kWhile) { - // The current live range is at the end of the while, move it to the - // beginning of the body. - used = used->while_body()->parameter_instruction(0); - VLOG(1) << "Moved value " << value->ToShortString() - << " to while param: " << used->ToString(); - } - if (instruction_schedule.count(used) == 0) { - // We didn't track the instruction `used`. This happens when we do - // computation scope (versus module scope) heap simulation and when the - // used instruction is outside of the computation being simulated. - continue; - } - buffer_end_time = std::max(buffer_end_time, instruction_schedule[used]); - } - - if (buffer_end_time == -1) { - buffer_end_time = buffer_start_time; - } - - for (const HloPosition& position : value->positions()) { - const HloComputation* position_comp = position.instruction->parent(); - // If this instruction lives out, the live range of the instruction should - // be extended to the end of the computation. - if (position.instruction == position_comp->root_instruction()) { - if (schedule_ == nullptr && &computation != position_comp) { - continue; - } - if (computation_schedule.count(position_comp) == 0) { - continue; - } - buffer_end_time = - std::max(buffer_end_time, computation_schedule[position_comp]); - } - } - - // Entry parameters live across whole computation. - if (value->instruction()->opcode() == HloOpcode::kParameter && - value->instruction()->parent() == - computation.parent()->entry_computation()) { - buffer_end_time = program_end_time; - } - - CHECK(buffer_start_time <= buffer_end_time); - - buffer_start_map[value] = buffer_start_time; - buffer_end_map[value] = buffer_end_time; } - NormalizeAliasedBuffers(&buffer_start_map, &buffer_end_map, values_to_assign, - alias_analysis); + auto& buffer_live_ranges = hlo_live_range->buffer_live_ranges(); absl::c_sort(values_to_assign, [&](const HloValue* value1, const HloValue* value2) { - if (buffer_start_map[value1] != buffer_start_map[value2]) { - return buffer_start_map[value1] < buffer_start_map[value2]; - } - - if (buffer_end_map[value1] != buffer_end_map[value2]) { - return buffer_end_map[value1] < buffer_end_map[value2]; - } - return value1->id() < value2->id(); + const auto& live_range1 = buffer_live_ranges.at(value1); + const auto& live_range2 = buffer_live_ranges.at(value2); + return std::forward_as_tuple(live_range1.start, + live_range1.end, value1->id()) < + std::forward_as_tuple(live_range2.start, + live_range2.end, value2->id()); }); // For each value that we need to assign a buffer to, add the define and free // events. for (const HloValue* value : values_to_assign) { - buffers_defined[buffer_start_map[value]].push_back(value); - buffers_freed[buffer_end_map[value]].push_back(value); + auto live_range = buffer_live_ranges.at(value); + buffers_defined[live_range.start].push_back(value); + buffers_freed[live_range.end].push_back(value); } // All HloValues in a hlo buffer should be allocated to the same address. This // map tracks the first value that got allocated in a buffer. absl::flat_hash_map first_allocated_value; - VLOG(1) << "Program time" << program_end_time; + VLOG(1) << "Program time" << hlo_live_range->schedule_end_time(); // Go through each step in the program and replay each buffer define and free // events. - for (int64 i = 0; i < program_end_time + 1; ++i) { + for (int64 i = 0; i < hlo_live_range->schedule_end_time() + 1; ++i) { VLOG(1) << "Time step: " << i; for (const HloValue* value : buffers_defined[i]) { @@ -495,11 +247,21 @@ Status HeapSimulator::RunComputation( if (operand_buffer->values().size() > 1) { continue; } - if (buffer_end_map.count(operand_value) == 0) { + auto it = buffer_live_ranges.find(operand_value); + if (it == buffer_live_ranges.end()) { continue; } + + auto& operand_live_range = it->second; + + auto& user_live_range = buffer_live_ranges[value]; + // Can only share buffers that are about to be freed. - if (buffer_end_map[operand_value] != i) { + if (operand_live_range.end != i) { + continue; + } + + if (IgnoreBuffer(operand_value)) { continue; } @@ -522,7 +284,7 @@ Status HeapSimulator::RunComputation( ShareBuffer(value, operand_value, value->instruction()); // The live range of the operand buffer is now extended to the end // of the current instruction. - buffer_end_map[operand_value] = buffer_end_map[value]; + operand_live_range.end = user_live_range.end; VLOG(1) << "Sharing " << value->ToShortString() << " with " << operand_value->ToShortString() << ", size:" << size_fn_(*value); @@ -866,29 +628,27 @@ GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { // start of the first buffer and the end of the last co-located // buffer. There could be "holes" in the live ranges of each co-located // buffers, but in this heuristics we think they are contiguous. - absl::c_sort(sorted_buffer_intervals, - [&](const BufferInterval& x, const BufferInterval& y) { - int64 x_end = x.end; - for (auto colocation : GetTransitiveColocations(x)) { - x_end = - std::max(x_end, buffer_intervals_.at(colocation).end); - } + absl::c_sort(sorted_buffer_intervals, [&](const BufferInterval& x, + const BufferInterval& y) { + int64 x_end = x.end; + for (auto colocation : GetTransitiveColocations(x)) { + x_end = std::max(x_end, buffer_intervals_.at(colocation).end); + } - int64 y_end = y.end; - for (auto colocation : GetTransitiveColocations(y)) { - y_end = - std::max(y_end, buffer_intervals_.at(colocation).end); - } + int64 y_end = y.end; + for (auto colocation : GetTransitiveColocations(y)) { + y_end = std::max(y_end, buffer_intervals_.at(colocation).end); + } - if (x_end - x.start != y_end - y.start) { - return x_end - x.start > y_end - y.start; - } + if (x_end - x.start != y_end - y.start) { + return x_end - x.start > y_end - y.start; + } - if (x.size != y.size) { - return x.size > y.size; - } - return x.buffer->id() < y.buffer->id(); - }); + if (x.size != y.size) { + return x.size > y.size; + } + return x.buffer->id() < y.buffer->id(); + }); } else { // Sort by spatial size. We don't look at co-locates as they should have the // same size. @@ -910,8 +670,8 @@ GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { GlobalDecreasingSizeBestFitHeap::ChunkCandidate GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval) - const { + const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval, + int64 preferred_offset) const { VLOG(1) << "Finding chunks for buffer: " << buffer_interval.buffer->ToString(); VLOG(1) << "Size " << buffer_interval.size << ", start " @@ -960,7 +720,16 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( return; } - if (free_size < min_fit_chunk.size) { + // If a preferred offset is provided, pick that offset. + if (free_offset <= preferred_offset && + free_offset + free_size >= preferred_offset + buffer_interval.size) { + min_fit_chunk = {preferred_offset, buffer_interval.size}; + } + + // Pick the min-fit chunk only if we didn't have a preferred offset or a + // chunk at the preferred offset hasn't been found. + if ((preferred_offset < 0 || min_fit_chunk.offset != preferred_offset) && + free_size < min_fit_chunk.size) { min_fit_chunk = {free_offset, free_size}; } }; @@ -973,6 +742,12 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( offset = std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_)); } use_free_chunk_if_smaller(offset, result_.heap_size - offset); + // When preferred offset is provided and the preferred offset is larger than + // the current heap size, simply use the preferred offset provided. + if (result_.heap_size <= preferred_offset) { + chunk_candidate.heap_size = preferred_offset + buffer_interval.size; + min_fit_chunk = {preferred_offset, buffer_interval.size}; + } if (min_fit_chunk.offset == -1) { // Increase the heap size to fit in the last free chunk. @@ -993,16 +768,18 @@ void GlobalDecreasingSizeBestFitHeap::CommitChunk( interval_tree_.Add(buffer_interval.start, buffer_interval.end, chunk_candidate.chunk); for (auto colocation : GetTransitiveColocations(buffer_interval)) { - const auto emplace_result = - result_.chunk_map.emplace(colocation, chunk_candidate.chunk); - DCHECK(emplace_result.second); + AddToChunkMap(colocation, chunk_candidate.chunk); auto colocation_interval = buffer_intervals_[colocation]; interval_tree_.Add(colocation_interval.start, colocation_interval.end, chunk_candidate.chunk); } - const auto emplace_result = - result_.chunk_map.emplace(buffer_interval.buffer, chunk_candidate.chunk); + AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk); +} + +void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer, + Chunk chunk) { + const auto emplace_result = result_.chunk_map.emplace(buffer, chunk); DCHECK(emplace_result.second); } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 4d6de377813..00a748fc1e1 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_live_range.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -165,7 +166,8 @@ class HeapSimulator { Status RunComputation(const HloComputation& computation, const HloInstructionSequence& instruction_sequence, - const HloAliasAnalysis& alias_analysis); + const HloAliasAnalysis& alias_analysis, + HloLiveRange* live_range); bool IgnoreBuffer(const HloValue* buffer) const; void Alloc(const HloValue* buffer, const HloInstruction* instruction); @@ -255,6 +257,22 @@ class HeapAlgorithm { // Finish collects the buffer offset assignment results. Free may only be // called once, after the Alloc and Free calls. virtual Result Finish() = 0; + + // Heap algorithms can optionally make use of the instruction/computation + // schedule. These data structures are guaranteed to be valid while Finish() + // is being called. + virtual void SetSchedules( + const HloInstructionSequence* flattened_instruction_sequence, + const absl::flat_hash_map* + instruction_schedule) { + flattened_instruction_sequence_ = flattened_instruction_sequence; + instruction_schedule_ = instruction_schedule; + } + + protected: + const HloInstructionSequence* flattened_instruction_sequence_; + const absl::flat_hash_map* + instruction_schedule_; }; // NoFragmentationStatsHeap computes the heap size assuming no fragmentation; @@ -370,19 +388,24 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // These two methods below are exposed to other heap algorithms that inherit // from this class. The Finish() method tries to find a candidate chunk for - // each BufferInterval, after calling GetSortedBufferIntervals. The - // ChunkCandidate returns the chunk and the final heap size if it chunk is to - // be committed. The Finish() method can then call CommitChunk to associate - // the chunk with the BufferInterval, if the final heap size is within the - // limits. - ChunkCandidate FindChunkCandidate( - const BufferInterval& buffer_interval) const; + // each BufferInterval, after calling GetSortedBufferIntervals. If a + // non-negative preferred_offset is provided, FindChunkCandidate attempts + // finding a chunk at this offset. The ChunkCandidate returns the chunk and + // the final heap size if it chunk is to be committed. The Finish() method can + // then call CommitChunk to associate the chunk with the BufferInterval, if + // the final heap size is within the limits. + ChunkCandidate FindChunkCandidate(const BufferInterval& buffer_interval, + int64 preferred_offset = -1) const; void CommitChunk(const BufferInterval& buffer_interval, ChunkCandidate chunk_candidate); + // Adds the buffer and the chunk to the result chunk map. + virtual void AddToChunkMap(const HloValue* buffer, Chunk chunk); + + absl::flat_hash_map buffer_intervals_; + Result result_; private: int64 alignment_; - Result result_; Type type_; // The current time represented as an integer. It increments by 1 at each @@ -396,7 +419,6 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // returns all three of them. absl::flat_hash_set GetTransitiveColocations( const BufferInterval& interval) const; - absl::flat_hash_map buffer_intervals_; }; // A heap algorithm that chooses the best results from other algorithms added to diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 4f7daa84782..80a047142b4 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -442,8 +442,8 @@ TEST_F(HeapSimulatorTest, MultiplyAdd) { tracker.ExpectCallSequence({ {kAlloc, tracker.BufferAt(paramA, {})}, {kAlloc, tracker.BufferAt(paramX, {})}, - {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, {kFree, tracker.BufferAt(mul, {})}, {kShare, tracker.BufferAt(add, {})}, // All params and outputs are freed at the end. @@ -516,8 +516,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { tracker.ExpectCallSequence({ {kAlloc, tracker.BufferAt(paramA, {})}, {kAlloc, tracker.BufferAt(paramX, {})}, - {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(dot, {})}, // All params and outputs are freed at the end. {kFree, tracker.BufferAt(mul, {})}, @@ -554,8 +554,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { tracker.ExpectCallSequence({ {kAlloc, tracker.BufferAt(paramA, {})}, {kAlloc, tracker.BufferAt(paramX, {})}, - {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(dot, {})}, {kFree, tracker.BufferAt(mul, {})}, {kFree, tracker.BufferAt(dot, {})}, @@ -596,8 +596,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { tracker.ExpectCallSequence({ {kAlloc, tracker.BufferAt(paramA, {})}, {kAlloc, tracker.BufferAt(paramX, {})}, - {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(dot0, {})}, {kFree, tracker.BufferAt(mul, {})}, // mul no longer used {kAlloc, tracker.BufferAt(dot1, {})}, @@ -640,8 +640,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { tracker.ExpectCallSequence({ {kAlloc, tracker.BufferAt(paramA, {})}, {kAlloc, tracker.BufferAt(paramX, {})}, - {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(paramY, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, {kAlloc, tracker.BufferAt(dot0, {})}, {kFree, tracker.BufferAt(mul, {})}, // mul no longer used {kAlloc, tracker.BufferAt(dot1, {})}, diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 331bbcb7836..61e562c7eda 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 67 +// Next ID: 69 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -230,6 +230,13 @@ message HloInstructionProto { // The delta value for kRngGetAndUpdateState. int64 delta = 66; + + // Specifies if the gather/scatter indices are guaranteed to be sorted by the + // caller. + bool indices_are_sorted = 67; + + // Frontend attributes to pass to the XLA backend. + xla.FrontendAttributes frontend_attributes = 68; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 0c020daec30..1ef007cc817 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -1008,8 +1008,8 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - scalar_shape_, HloOpcode::kBitcast, constant)); + auto bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(scalar_shape_, constant)); module_->AddEntryComputation(builder.Build()); SCOPED_TRACE(module_->ToString()); @@ -1076,8 +1076,8 @@ TEST_F(HloAliasAnalysisTest, BitcastInterference) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - scalar_shape_, HloOpcode::kBitcast, constant)); + auto bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(scalar_shape_, constant)); builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); module_->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 639e853ada7..cbdada0b46b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -532,11 +532,12 @@ string HloComputation::ToString( if (options.print_percent()) { s << "%"; } - s << name() << " "; + s << PrintName(name(), options.print_ids()) << " "; } if (options.print_program_shape()) { - s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; + s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids())) + << " "; } s << "{\n"; { @@ -753,12 +754,13 @@ StatusOr HloComputation::DeepCopyInstructionWithCustomCopier( return DeepCopyHelper(instruction, &index, copy_leaf); } -ProgramShape HloComputation::ComputeProgramShape() const { +ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const { ProgramShape program_shape; for (auto* param_instruction : param_instructions_) { *program_shape.add_parameters() = param_instruction->shape(); - *program_shape.add_parameter_names() = param_instruction->name(); + *program_shape.add_parameter_names() = + PrintName(param_instruction->name(), include_ids); } *program_shape.mutable_result() = root_instruction_->shape(); @@ -835,6 +837,18 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } + if (new_instruction->frontend_attributes().map().empty()) { + new_instruction->set_frontend_attributes( + old_instruction->frontend_attributes()); + } + + // Like the metadata above, if the user didn't specify any sharding + // information on the new instruction we should copy the old sharding + // information (if any). + if (!new_instruction->has_sharding()) { + new_instruction->set_sharding(old_instruction->sharding_ptr()); + } + TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction)); return RemoveInstructionAndUnusedOperands(old_instruction); } @@ -856,25 +870,6 @@ std::vector HloComputation::CollectUnreachableRoots() const { return unreachable_roots; } -template -Status HloComputation::Accept( - DfsHloVisitorBase* visitor) const { - // Visit unreachable roots. Beware that the visitor might delete the currently - // visited root, which would invalidate iterators if the unreachable roots - // weren't computed ahead of time. - for (HloInstruction* root : CollectUnreachableRoots()) { - VLOG(3) << "Traversing unreachable root: " << root->ToString(); - // Call FinishVisit only at the end. - TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); - } - // Visit the computation root instruction last. - return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); -} - -// Explicit instantiations. -template Status HloComputation::Accept(DfsHloVisitor* visitor) const; -template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; - Status HloComputation::AcceptWithOperandOrder( DfsHloVisitor* visitor, const HloInstruction::CompareFunction& operand_order) const { @@ -891,42 +886,6 @@ Status HloComputation::AcceptWithOperandOrder( /*call_finish_visit=*/true); } -template -Status HloComputation::AcceptOrdered( - DfsHloVisitorBase* visitor, - absl::Span order) const { - VLOG(3) << "Accepting visitor with order."; - for (HloInstruction* root : CollectUnreachableRoots()) { - TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); - } - TF_RET_CHECK(order.size() == instruction_count()); - absl::flat_hash_set visited; - for (const HloInstruction* instruction : order) { - VLOG(3) << "Visiting ordered: " << instruction->ToString(); - TF_RET_CHECK(instruction_iterators_.contains(instruction)) - << "Instruction " << instruction->name() << " is not in computation " - << name(); - TF_RET_CHECK(!visited.contains(instruction)) - << "Instruction " << instruction->name() - << " appears more than once in order"; - HloInstruction* mutable_instruction = - const_cast(instruction); - TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction)); - TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor)); - visitor->SetVisited(*mutable_instruction); - TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction)); - visited.insert(instruction); - } - TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction())); - return Status::OK(); -} - -// Explicit instantiations. -template Status HloComputation::AcceptOrdered( - DfsHloVisitor*, absl::Span) const; -template Status HloComputation::AcceptOrdered( - ConstDfsHloVisitor*, absl::Span) const; - std::unique_ptr HloComputation::Clone( const string& suffix, HloCloneContext* context) { return CloneWithReplacements( diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 111b28a8610..34ff957c876 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -288,7 +288,7 @@ class HloComputation { // Computes and returns the ProgramShape of this computation (shape of // parameters and result with layout). - ProgramShape ComputeProgramShape() const; + ProgramShape ComputeProgramShape(bool include_ids = true) const; // Return whether `*this` and `other` are functionally equivalent. bool Equal(const HloComputation& other, bool is_layout_sensitive) const; @@ -314,6 +314,8 @@ class HloComputation { // Replace old instruction with new instruction. Updates uses and root // instruction. Removes old instruction from computation. Precondition: // old_instruction and new_instruction must have the compatible shapes. + // If |new_instruction| doesn't have any sharding information it will + // recieve the sharding information of |old_instruction|. Status ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction); @@ -511,6 +513,61 @@ class HloComputation { TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); }; +template +Status HloComputation::Accept( + DfsHloVisitorBase* visitor) const { + // Visit unreachable roots. Beware that the visitor might delete the currently + // visited root, which would invalidate iterators if the unreachable roots + // weren't computed ahead of time. + for (HloInstruction* root : CollectUnreachableRoots()) { + VLOG(3) << "Traversing unreachable root: " << root->ToString(); + // Call FinishVisit only at the end. + TF_RETURN_IF_ERROR(root->Accept(visitor, /*call_finish_visit=*/false)); + } + // Visit the computation root instruction last. + return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); +} + +// Explicit instantiations. +template Status HloComputation::Accept(DfsHloVisitor* visitor) const; +template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; + +template +Status HloComputation::AcceptOrdered( + DfsHloVisitorBase* visitor, + absl::Span order) const { + VLOG(3) << "Accepting visitor with order."; + for (HloInstruction* root : CollectUnreachableRoots()) { + TF_RET_CHECK(absl::c_linear_search(order, root)) << root->ToString(); + } + TF_RET_CHECK(order.size() == instruction_count()); + absl::flat_hash_set visited; + for (const HloInstruction* instruction : order) { + VLOG(3) << "Visiting ordered: " << instruction->ToString(); + TF_RET_CHECK(instruction_iterators_.contains(instruction)) + << "Instruction " << instruction->name() << " is not in computation " + << name(); + TF_RET_CHECK(!visited.contains(instruction)) + << "Instruction " << instruction->name() + << " appears more than once in order"; + HloInstruction* mutable_instruction = + const_cast(instruction); + TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction)); + TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor)); + visitor->SetVisited(*mutable_instruction); + TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction)); + visited.insert(instruction); + } + TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction())); + return Status::OK(); +} + +// Explicit instantiations. +template Status HloComputation::AcceptOrdered( + DfsHloVisitor*, absl::Span) const; +template Status HloComputation::AcceptOrdered( + ConstDfsHloVisitor*, absl::Span) const; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 311b8a15504..90af8b1f487 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -154,6 +154,12 @@ int64 HloCostAnalysis::FusionParameterReadBytes( size += hlo == user->operand(0) ? GetShapeSize(user->shape()) : GetShapeSize(hlo->shape()); break; + case HloOpcode::kDynamicUpdateSlice: + // Uses the same shape as 'update' which is operand 1. + size += hlo == user->operand(0) + ? GetShapeSize(user->operand(1)->shape()) + : GetShapeSize(hlo->shape()); + break; case HloOpcode::kBroadcast: case HloOpcode::kReshape: size += GetShapeSize(hlo->shape()); @@ -699,7 +705,7 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { if (fusion->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { current_properties_[kBytesAccessedKey] += GetShapeSize( - fusion->fused_expression_root()->operand(0)->shape()); + fusion->fused_expression_root()->operand(1)->shape()); return; } } else if (shape_index.size() == 1) { @@ -710,7 +716,7 @@ Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { current_properties_[kBytesAccessedKey] += GetShapeSize(fusion->fused_expression_root() ->operand(shape_index[0]) - ->operand(0) + ->operand(1) ->shape()); return; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 407dfe796d8..ed4bac22a9f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1105,8 +1105,8 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - scalar_shape_, HloOpcode::kBitcast, constant)); + auto bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(scalar_shape_, constant)); module_->AddEntryComputation(builder.Build()); SCOPED_TRACE(module_->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index a7e1d3a80d7..9a9898fdeee 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1543,8 +1543,9 @@ class OutputBatchIndexToInputIndex { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - TF_ASSIGN_OR_RETURN(index_vector_[i], - start_indices_.GetIntegralAsS64(index_vector_index_)); + // TODO(george): OK what should happen here? + // seems OK to crash though. + index_vector_[i] = *start_indices_.GetIntegralAsS64(index_vector_index_); } return Status::OK(); } @@ -2295,12 +2296,10 @@ static StatusOr GenerateReduceOutputElement( } if (use_fast_add) { - TF_ASSIGN_OR_RETURN(double computed_result, - init_values[0]->GetAsDouble({})); + double computed_result = *init_values[0]->GetAsDouble({}); auto reduction_step = [&](absl::Span input_index) -> StatusOr { - TF_ASSIGN_OR_RETURN(double argument, - input_args[0]->GetAsDouble(input_index)); + double argument = *input_args[0]->GetAsDouble(input_index); computed_result += argument; return true; }; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 9fcc6274866..9487d955f31 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -2035,8 +2035,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - TF_ASSIGN_OR_RETURN(index_vector_[i], scatter_indices_.GetIntegralAsS64( - index_vector_index_)); + index_vector_[i] = + *scatter_indices_.GetIntegralAsS64(index_vector_index_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index ad58bdb11b5..1c5b166a801 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -103,9 +103,13 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( return result; } +const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); } + string HloInputOutputAliasConfig::ToString() const { std::vector pieces; pieces.push_back("HloInputOutputAliasConfig"); + pieces.push_back( + absl::StrFormat(" Output shape: %s", alias_.shape().ToString())); ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM"; diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index e80567abe0a..6bd34f8a127 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -117,6 +117,9 @@ class HloInputOutputAliasConfig { Status ForEachAliasWithStatus(AliasFnWithStatus fn) const; + // Returns the shape of the output of the alias config. + const Shape& shape() const; + string ToString() const; private: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc old mode 100644 new mode 100755 index ddfcdcfd293..dabd7ab2836 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -550,7 +550,8 @@ StatusOr> HloInstruction::CreateFromProto( gather_slice_sizes.push_back(bound); } instruction = CreateGather(shape, operands(0), operands(1), - *gather_dimension_numbers, gather_slice_sizes); + *gather_dimension_numbers, gather_slice_sizes, + proto.indices_are_sorted()); break; } case HloOpcode::kScatter: { @@ -563,7 +564,8 @@ StatusOr> HloInstruction::CreateFromProto( absl::make_unique( proto.scatter_dimension_numbers()); instruction = CreateScatter(shape, operands(0), operands(1), operands(2), - computations(0), *scatter_dimension_numbers); + computations(0), *scatter_dimension_numbers, + proto.indices_are_sorted()); break; } case HloOpcode::kIota: @@ -672,6 +674,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } + if (proto.has_frontend_attributes()) { + instruction->set_frontend_attributes(proto.frontend_attributes()); + } + return std::move(instruction); } @@ -1192,6 +1198,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } + broadcast->set_frontend_attributes(operand->frontend_attributes()); return broadcast; } // Do explicit broadcast for degenerate broadcast. @@ -1217,6 +1224,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { reshaped_operand->set_sharding(operand->sharding()); } + reshaped_operand->set_frontend_attributes(operand->frontend_attributes()); // Broadcast 'reshape' up to the larger size. auto broadcast = HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -1224,6 +1232,7 @@ HloInstruction::CreateBroadcastSequence( if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } + broadcast->set_frontend_attributes(operand->frontend_attributes()); return broadcast; } @@ -1294,6 +1303,7 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); + derived_instruction->set_frontend_attributes(frontend_attributes_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1372,19 +1382,21 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr HloInstruction::CreateGather( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - absl::Span slice_sizes) { + absl::Span slice_sizes, bool indices_are_sorted) { return absl::make_unique( - shape, operand, start_indices, gather_dim_numbers, slice_sizes); + shape, operand, start_indices, gather_dim_numbers, slice_sizes, + indices_are_sorted); } /* static */ std::unique_ptr HloInstruction::CreateScatter( const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, - const ScatterDimensionNumbers& scatter_dim_numbers) { + const ScatterDimensionNumbers& scatter_dim_numbers, + bool indices_are_sorted) { return absl::make_unique( shape, operand, scatter_indices, updates, update_computation, - scatter_dim_numbers); + scatter_dim_numbers, indices_are_sorted); } /* static */ std::unique_ptr HloInstruction::CreateDomain( @@ -2179,10 +2191,20 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } +string PrintName(const string& name, bool print_ids) { + if (print_ids) { + return name; + } else { + auto dot_position = name.find_first_of("."); + return name.substr(0, dot_position); + } +} + namespace { -string PrintName(const string& name, const HloPrintOptions& options) { - return StrCat(options.print_percent() ? "%" : "", name); +string PrintNameInternal(const string& name, const HloPrintOptions& options) { + return StrCat(options.print_percent() ? "%" : "", + PrintName(name, options.print_ids())); } } // namespace @@ -2277,11 +2299,12 @@ string HloInstruction::ToStringWithCanonicalNameMap( // If we are canonicalizing instruction names and this is a top-level // HloInstruction::ToString() call, don't print an instruction name. StrAppend(&result, - PrintName(canonical_name_map->LookupOrInsert(name()), options), + PrintNameInternal(canonical_name_map->LookupOrInsert(name()), + options), " = "); } } else { - StrAppend(&result, PrintName(name(), options), " = "); + StrAppend(&result, PrintNameInternal(name(), options), " = "); } // Print shape. @@ -2347,10 +2370,10 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( // part of the canonical string. if (options.canonicalize_instruction_names() && options.is_in_nested_computation()) { - str.push_back(PrintName( + str.push_back(PrintNameInternal( canonical_name_map->LookupOrInsert(operand->name()), options)); } else if (options.print_operand_names()) { - str.push_back(PrintName(operand->name(), options)); + str.push_back(PrintNameInternal(operand->name(), options)); } StrAppend(out, StrJoin(str, " ")); }); @@ -2368,27 +2391,30 @@ std::vector HloInstruction::ExtraAttributesToString( if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { + extra.push_back(StrCat( + "condition=", PrintNameInternal(while_condition()->name(), options))); extra.push_back( - StrCat("condition=", PrintName(while_condition()->name(), options))); - extra.push_back( - StrCat("body=", PrintName(while_body()->name(), options))); + StrCat("body=", PrintNameInternal(while_body()->name(), options))); } else if (opcode() == HloOpcode::kSelectAndScatter) { - extra.push_back(StrCat("select=", PrintName(select()->name(), options))); extra.push_back( - StrCat("scatter=", PrintName(scatter()->name(), options))); + StrCat("select=", PrintNameInternal(select()->name(), options))); + extra.push_back( + StrCat("scatter=", PrintNameInternal(scatter()->name(), options))); } else if (opcode() == HloOpcode::kConditional) { if (operand(0)->shape().element_type() == PRED) { - extra.push_back(StrCat("true_computation=", - PrintName(true_computation()->name(), options))); + extra.push_back( + StrCat("true_computation=", + PrintNameInternal(true_computation()->name(), options))); extra.push_back( StrCat("false_computation=", - PrintName(false_computation()->name(), options))); + PrintNameInternal(false_computation()->name(), options))); } else { extra.push_back(StrCat( "branch_computations={", StrJoin(branch_computations(), ", ", [&](string* out, const HloComputation* computation) { - StrAppend(out, PrintName(computation->name(), options)); + StrAppend( + out, PrintNameInternal(computation->name(), options)); }), "}")); } @@ -2399,13 +2425,14 @@ std::vector HloInstruction::ExtraAttributesToString( opcode() == HloOpcode::kScatter || opcode() == HloOpcode::kSort) { extra.push_back( - StrCat("to_apply=", PrintName(to_apply()->name(), options))); + StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options))); } else if (!called_computations().empty()) { extra.push_back(StrCat( "calls=", StrJoin(called_computations(), ", ", [&](string* out, const HloComputation* computation) { - StrAppend(out, PrintName(computation->name(), options)); + StrAppend(out, + PrintNameInternal(computation->name(), options)); }))); } } else if (options.print_subcomputation_mode() == @@ -2464,6 +2491,10 @@ std::vector HloInstruction::ExtraAttributesToString( if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } + if (!frontend_attributes_.map().empty()) { + extra.push_back(StrCat("frontend_attributes=", + FrontendAttributesToString(frontend_attributes_))); + } if (!outer_dimension_partitions_.empty()) { extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}", StrJoin(outer_dimension_partitions_, ","))); @@ -2473,8 +2504,8 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("control-predecessors={", StrJoin(control_predecessors_, ", ", [&](string* out, HloInstruction* pre) { - StrAppend(out, - PrintName(pre->name(), options)); + StrAppend(out, PrintNameInternal( + pre->name(), options)); }), "}")); } @@ -2524,6 +2555,8 @@ HloInstructionProto HloInstruction::ToProto() const { } } + *proto.mutable_frontend_attributes() = frontend_attributes_; + return proto; } @@ -2573,6 +2606,9 @@ bool HloInstruction::IsFusible() const { switch (opcode_) { case HloOpcode::kDomain: case HloOpcode::kParameter: + case HloOpcode::kWhile: + case HloOpcode::kConditional: + case HloOpcode::kCall: return false; // Side effecting instrutions cannot be fused. default: @@ -3175,6 +3211,15 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name); } +string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes) { + std::vector> sorted_attributes( + frontend_attributes.map().begin(), frontend_attributes.map().end()); + absl::c_sort(sorted_attributes); + return absl::StrFormat( + "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("="))); +} + string PaddingConfigToString(const PaddingConfig& padding) { bool has_interior_padding = absl::c_any_of(padding.dimensions(), @@ -3652,6 +3697,9 @@ int64 HloInstruction::feature_group_count() const { } void HloInstruction::set_feature_group_count(int64 feature_group_count) { + if (auto convolution = DynCast(this)) { + return convolution->set_feature_group_count(feature_group_count); + } Cast(this)->set_feature_group_count( feature_group_count); } @@ -3664,6 +3712,9 @@ int64 HloInstruction::batch_group_count() const { } void HloInstruction::set_batch_group_count(int64 batch_group_count) { + if (auto convolution = DynCast(this)) { + return convolution->set_batch_group_count(batch_group_count); + } Cast(this)->set_batch_group_count( batch_group_count); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index fbaeb5d5f66..3119b52e377 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -63,6 +63,8 @@ namespace xla { class HloComputation; class HloModule; +string PrintName(const string& name, bool print_ids); + // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: @@ -88,7 +90,8 @@ class HloPrintOptions { print_control_dependencies_(true), canonicalize_instruction_names_(false), indent_amount_(0), - is_in_nested_computation_(false) {} + is_in_nested_computation_(false), + print_ids_(true) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() @@ -118,6 +121,22 @@ class HloPrintOptions { .set_canonicalize_instruction_names(true); } + // Options to produce a fingerprint of an HLO. + static HloPrintOptions Fingerprint() { + return HloPrintOptions() + .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) + .set_print_metadata(false) + .set_print_backend_config(false) + .set_compact_operands(true) + .set_print_operand_names(false) + .set_print_operand_shape(true) + .set_print_program_shape(false) + .set_print_percent(false) + .set_print_control_dependencies(false) + .set_canonicalize_instruction_names(true) + .set_print_ids(false); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; @@ -154,6 +173,12 @@ class HloPrintOptions { return *this; } + // If true, all printed names include unique identifiers. + HloPrintOptions& set_print_ids(bool value) { + print_ids_ = value; + return *this; + } + // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; @@ -216,6 +241,7 @@ class HloPrintOptions { bool include_layout_in_shapes() const { return include_layout_in_shapes_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_operand_names() const { return print_operand_names_; } + bool print_ids() const { return print_ids_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } bool print_control_dependencies() const { @@ -242,6 +268,7 @@ class HloPrintOptions { bool canonicalize_instruction_names_; int indent_amount_; bool is_in_nested_computation_; + bool print_ids_; }; // For canonical string output, we need to have a canonical way to rename @@ -767,13 +794,14 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - absl::Span slice_sizes); + absl::Span slice_sizes, bool indices_are_sorted); static std::unique_ptr CreateScatter( const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, - const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dim_numbers, + bool indices_are_sorted); // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. @@ -1357,6 +1385,14 @@ class HloInstruction { } Status set_backend_config(const tensorflow::protobuf::Message& proto); + void set_frontend_attributes(FrontendAttributes frontend_attributes) { + frontend_attributes_ = std::move(frontend_attributes); + } + + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + // Getter/setter for raw JSON-encoded backend config. Prefer the // functions above that deal in proto Messages where possible. const string& raw_backend_config_string() const { return backend_config_; } @@ -1851,6 +1887,18 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // Attributes passed from the frontend to give hints to the backend about + // how to compile this HLO. + // HLO -> HLO transforms are expected to preserve these attributes on a + // "best effort" basis only. + // For example: + // x = const(10, frontend_attributes={x} + // y = const(10, frontend_attributes={y} + // z = add(x,y), frontend_attributes={y} + // Could be simplified to: + // z' = const(20), frontend_attributes={?} + FrontendAttributes frontend_attributes_; + // This field is assigned to true when backend_config_ is assigned to // a default configuration. bool is_default_config_ = false; @@ -1881,6 +1929,8 @@ StatusOr StringToFusionKind( // Custom (de)stringification functions for protos that live inside // HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); +string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); string PrecisionToString(const PrecisionConfig::Precision& precision); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 80de1d5e0bc..0a50ed04af7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -1440,7 +1442,8 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/4), - /*slice_sizes=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26}, + /*indices_are_sorted=*/false)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -1475,7 +1478,8 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*collapsed_slice_dims=*/{}, /*start_index_map=*/{0, 1, 2, 3, 4}, /*index_vector_dim=*/2), - /*slice_sizes=*/{30, 29, 28, 27, 26})); + /*slice_sizes=*/{30, 29, 28, 27, 26}, + /*indices_are_sorted=*/false)); auto module = CreateNewVerifiedModule(); module->AddEntryComputation(builder.Build()); @@ -1524,7 +1528,8 @@ TEST_F(HloInstructionTest, StringifyScatter) { /*update_window_dims=*/{4, 5, 6, 7, 8}, /*inserted_window_dims=*/{}, /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, - /*index_vector_dim=*/2))); + /*index_vector_dim=*/2), + /*indices_are_sorted=*/false)); module->AddEntryComputation(builder.Build()); EXPECT_EQ( @@ -1956,5 +1961,26 @@ TEST_F(HloInstructionTest, GatherDoesNotReuseElements) { EXPECT_FALSE(root->ReusesOperandElements(1)); } +TEST_F(HloInstructionTest, BackendConfigCanContainNonFiniteFloats) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = b.AddInstruction(HloInstruction::CreateDot( + shape, p0, p0, dot_dnums, DefaultPrecisionConfig(2))); + + gpu::GemmBackendConfig orig_config; + orig_config.set_alpha_real(std::numeric_limits::infinity()); + orig_config.set_alpha_imag(std::numeric_limits::quiet_NaN()); + TF_ASSERT_OK(dot->set_backend_config(orig_config)); + + TF_ASSERT_OK_AND_ASSIGN(auto new_config, + dot->backend_config()); + EXPECT_GT(new_config.alpha_real(), std::numeric_limits::max()); + EXPECT_NE(new_config.alpha_imag(), new_config.alpha_imag()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 52d8c7a43ce..183967941bf 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1706,7 +1706,7 @@ bool HloRngInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const { - return false; + return true; } std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( @@ -1737,7 +1737,7 @@ HloInstructionProto HloParameterInstruction::ToProto() const { } std::vector HloParameterInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& /*options*/) const { + const HloPrintOptions& options) const { std::vector result; if (!parameter_replicated_at_leaf_buffers_) { return result; @@ -1746,8 +1746,10 @@ std::vector HloParameterInstruction::ExtraAttributesToStringImpl( for (bool replicated : *parameter_replicated_at_leaf_buffers_) { buffers_replicated_strs.push_back(replicated ? "true" : "false"); } - result.push_back(StrCat("parameter_replication={", - StrJoin(buffers_replicated_strs, ","), "}")); + if (options.print_ids()) { + result.push_back(StrCat("parameter_replication={", + StrJoin(buffers_replicated_strs, ","), "}")); + } return result; } @@ -2397,8 +2399,9 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( HloGatherInstruction::HloGatherInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - absl::Span slice_sizes) - : HloInstruction(HloOpcode::kGather, shape) { + absl::Span slice_sizes, bool indices_are_sorted) + : HloInstruction(HloOpcode::kGather, shape), + indices_are_sorted_(indices_are_sorted) { AppendOperand(operand); AppendOperand(start_indices); gather_dimension_numbers_ = @@ -2450,13 +2453,19 @@ HloInstructionProto HloGatherInstruction::ToProto() const { for (int64 bound : gather_slice_sizes()) { proto.add_gather_slice_sizes(bound); } + proto.set_indices_are_sorted(indices_are_sorted()); return proto; } std::vector HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {GatherDimensionNumbersToString(gather_dimension_numbers()), - StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; + std::vector attrs{ + GatherDimensionNumbersToString(gather_dimension_numbers()), + StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; + if (indices_are_sorted()) { + attrs.push_back("indices_are_sorted=true"); + } + return attrs; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2467,7 +2476,8 @@ bool HloGatherInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( gather_dimension_numbers(), casted_other.gather_dimension_numbers()) && - gather_slice_sizes() == casted_other.gather_slice_sizes(); + gather_slice_sizes() == casted_other.gather_slice_sizes() && + indices_are_sorted() == casted_other.indices_are_sorted(); } std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( @@ -2476,15 +2486,16 @@ std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 2); return absl::make_unique( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), - gather_slice_sizes()); + gather_slice_sizes(), indices_are_sorted()); } HloScatterInstruction::HloScatterInstruction( const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, - const ScatterDimensionNumbers& scatter_dim_numbers) - : HloInstruction(HloOpcode::kScatter, shape) { + const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted) + : HloInstruction(HloOpcode::kScatter, shape), + indices_are_sorted_(indices_are_sorted) { AppendOperand(operand); AppendOperand(scatter_indices); AppendOperand(updates); @@ -2538,12 +2549,18 @@ HloScatterInstruction::MakeScatterDimNumbers( HloInstructionProto HloScatterInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + proto.set_indices_are_sorted(indices_are_sorted()); return proto; } std::vector HloScatterInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {ScatterDimensionNumbersToString(scatter_dimension_numbers())}; + std::vector attrs{ + ScatterDimensionNumbersToString(scatter_dimension_numbers())}; + if (indices_are_sorted()) { + attrs.push_back("indices_are_sorted=true"); + } + return attrs; } bool HloScatterInstruction::IdenticalSlowPath( @@ -2554,7 +2571,8 @@ bool HloScatterInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals( scatter_dimension_numbers(), casted_other.scatter_dimension_numbers()) && - eq_computations(to_apply(), casted_other.to_apply()); + eq_computations(to_apply(), casted_other.to_apply()) && + indices_are_sorted() == casted_other.indices_are_sorted(); } std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( @@ -2563,7 +2581,7 @@ std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( CHECK_EQ(new_operands.size(), 3); return absl::make_unique( shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), - scatter_dimension_numbers()); + scatter_dimension_numbers(), indices_are_sorted()); } HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h old mode 100644 new mode 100755 index 8e6f024e5d2..0de050108b7 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1077,10 +1077,15 @@ class HloConvolutionInstruction : public HloInstruction { // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count() const { return feature_group_count_; } - + void set_feature_group_count(int64 num_feature_groups) { + feature_group_count_ = num_feature_groups; + } // The number of feature groups. Must be a divisor of the input batch // dimension. int64 batch_group_count() const { return batch_group_count_; } + void set_batch_group_count(int64 num_batch_groups) { + batch_group_count_ = num_batch_groups; + } // Returns the information used to tell the implementation information about // what sort of precision is requested. The meaning of the field is backend @@ -1401,7 +1406,7 @@ class HloGatherInstruction : public HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, - absl::Span slice_sizes); + absl::Span slice_sizes, bool indices_are_sorted); const GatherDimensionNumbers& gather_dimension_numbers() const { CHECK(gather_dimension_numbers_ != nullptr); return *gather_dimension_numbers_; @@ -1409,6 +1414,10 @@ class HloGatherInstruction : public HloInstruction { absl::Span gather_slice_sizes() const { return gather_slice_sizes_; } + bool indices_are_sorted() const { return indices_are_sorted_; } + void set_indices_are_sorted(bool indices_are_sorted) { + indices_are_sorted_ = indices_are_sorted; + } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1434,6 +1443,7 @@ class HloGatherInstruction : public HloInstruction { std::unique_ptr gather_dimension_numbers_; std::vector gather_slice_sizes_; + bool indices_are_sorted_; }; class HloScatterInstruction : public HloInstruction { @@ -1442,11 +1452,16 @@ class HloScatterInstruction : public HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, - const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dim_numbers, + bool indices_are_sorted); const ScatterDimensionNumbers& scatter_dimension_numbers() const { CHECK(scatter_dimension_numbers_ != nullptr); return *scatter_dimension_numbers_; } + bool indices_are_sorted() const { return indices_are_sorted_; } + void set_indices_are_sorted(bool indices_are_sorted) { + indices_are_sorted_ = indices_are_sorted; + } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -1473,6 +1488,7 @@ class HloScatterInstruction : public HloInstruction { HloCloneContext* context) const override; std::unique_ptr scatter_dimension_numbers_; + bool indices_are_sorted_; }; class HloIotaInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_live_range.cc b/tensorflow/compiler/xla/service/hlo_live_range.cc new file mode 100644 index 00000000000..8ec437ec250 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_live_range.cc @@ -0,0 +1,235 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_live_range.h" + +#include "absl/strings/str_format.h" + +namespace xla { +/*static*/ +StatusOr> HloLiveRange::Run( + const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, + const HloComputation* computation, bool module_scoped_analysis) { + std::unique_ptr hlo_live_range( + new HloLiveRange(schedule, alias_analysis, module_scoped_analysis)); + hlo_live_range->schedule_end_time_ = + hlo_live_range->FlattenSchedule(*computation, 0); + hlo_live_range->CalculateBufferStartEndMap(); + hlo_live_range->NormalizeAliasedBuffers(); + return std::move(hlo_live_range); +} + +void HloLiveRange::NormalizeAliasedBuffers() { + for (const HloBuffer& hlo_buffer : alias_analysis_.buffers()) { + std::vector aliased_buffers; + for (const HloValue* hlo_value : hlo_buffer.values()) { + if (buffer_live_ranges_.contains(hlo_value)) { + aliased_buffers.push_back(hlo_value); + } + } + absl::c_sort( + aliased_buffers, [&](const HloValue* value1, const HloValue* value2) { + const TimeBound& live_range1 = buffer_live_ranges_.at(value1); + const TimeBound& live_range2 = buffer_live_ranges_.at(value2); + + return std::forward_as_tuple(live_range1.start, live_range1.end) < + std::forward_as_tuple(live_range2.start, live_range2.end); + }); + + for (int64 i = 0; i + 1 < aliased_buffers.size(); ++i) { + const HloValue* value1 = aliased_buffers[i]; + const HloValue* value2 = aliased_buffers[i + 1]; + TimeBound& live_range1 = buffer_live_ranges_[value1]; + TimeBound& live_range2 = buffer_live_ranges_[value2]; + if (live_range1.start == live_range2.start) { + // If value1 has the same start time as value2, make value1 disappear + // by setting the end time same as start time: + // + // Before: + // +----+ value1 + // +----------+ value2 + // + // After: + // + value1 + // +----------+ value2 + // + // Note that only when heap simulator runs before copy insertion can + // this happen where one instruction defines multiple aliased buffers + // -- This is illegle to execute and can be fixed by copy insertion + // later. + live_range1.end = live_range2.end; + continue; + } + + if (live_range1.end < live_range2.start) { + continue; + } + + if (live_range1.end > live_range2.end) { + live_range2.end = live_range1.end; + } + live_range1.end = live_range2.start - 1; + } + } +} + +// FlattenSchedule walks through the computation and tracks down the ordinal +// number of each instruction in the schedule. +int64 HloLiveRange::FlattenSchedule(const HloComputation& computation, + int64 start_time) { + if (!schedule_.is_computation_scheduled(&computation)) { + total_order_scheduled_ = false; + return start_time; + } + + const HloInstructionSequence& instruction_sequence = + schedule_.sequence(&computation); + int64 time = start_time; + for (HloInstruction* instruction : instruction_sequence.instructions()) { + if (module_scoped_analysis_) { + // Recurse into sub computations if running with module scoped analysis + // mode. + if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional) { + for (const HloComputation* called_computation : + instruction->called_computations()) { + time = FlattenSchedule(*called_computation, time); + } + } + if (instruction->opcode() == HloOpcode::kWhile) { + time = FlattenSchedule(*instruction->while_condition(), time); + time++; + time = FlattenSchedule(*instruction->while_body(), time); + } + } + if (instruction_schedule_.count(instruction) != 0) { + continue; + } + instruction_schedule_.insert({instruction, time++}); + flattened_instruction_sequence_.push_back(instruction); + } + computation_span_times_.try_emplace(&computation, + TimeBound{start_time, time}); + DCHECK_EQ(instruction_schedule_.size(), + flattened_instruction_sequence_.size()); + DCHECK_LE(instruction_schedule_.size(), time); + return time; +} + +void HloLiveRange::CalculateBufferStartEndMap() { + for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) { + // Ignore buffers that are not defined. + if (instruction_schedule_.count(value->defining_instruction()) == 0) { + continue; + } + + int64 buffer_start_time = instruction_schedule_[value->instruction()]; + + int64 buffer_end_time = -1; + for (const HloUse& use : value->uses()) { + const HloInstruction* used = use.instruction; + // As an optimization, we deem a while's init value's live range ends as + // soon as the loop body starts. This optimization is only applicable in + // module scoped mode. + if (module_scoped_analysis_ && used->opcode() == HloOpcode::kWhile) { + // The current live range is at the end of the while, move it to the + // beginning of the body. + used = used->while_body()->parameter_instruction(0); + VLOG(1) << "Moved value " << value->ToShortString() + << " to while param: " << used->ToString(); + } + if (instruction_schedule_.count(used) == 0) { + // We didn't track the instruction `used`. This happens when we do + // computation scope (versus module scope) heap simulation and when + // the used instruction is outside of the computation being simulated. + continue; + } + buffer_end_time = std::max(buffer_end_time, instruction_schedule_[used]); + } + + // Parameters are defined at the beginning of the computation. This prevents + // any instruction that's scheduled before the parameter clobbers the + // parameter's buffer. + if (value->instruction()->opcode() == HloOpcode::kParameter) { + const HloComputation* computation = value->instruction()->parent(); + auto it = computation_span_times_.find(computation); + if (it != computation_span_times_.end()) { + buffer_start_time = std::min(buffer_start_time, it->second.start); + } + } + + if (buffer_end_time == -1) { + buffer_end_time = buffer_start_time; + } + + for (const HloPosition& position : value->positions()) { + const HloComputation* position_comp = position.instruction->parent(); + // If this instruction lives out, the live range of the instruction + // should be extended to the end of the computation. + if (position.instruction == position_comp->root_instruction()) { + auto it = computation_span_times_.find(position_comp); + if (it == computation_span_times_.end()) { + continue; + } + buffer_end_time = std::max(buffer_end_time, it->second.end); + } + } + + const HloModule* module = value->instruction()->parent()->parent(); + + // Readonly entry parameters (parameters that don't alias) live across whole + // computation. + if (value->instruction()->opcode() == HloOpcode::kParameter && + value->instruction()->parent() == module->entry_computation() && + !module->input_output_alias_config().ParameterHasAlias( + value->instruction()->parameter_number(), value->index())) { + buffer_end_time = schedule_end_time_; + } + + CHECK(buffer_start_time <= buffer_end_time) + << buffer_start_time << ", " << buffer_end_time + << value->instruction()->ToString(); + + auto& live_range = buffer_live_ranges_[value]; + live_range.start = buffer_start_time; + live_range.end = buffer_end_time; + } +} + +std::string HloLiveRange::ToString() const { + std::string output; + absl::StrAppendFormat(&output, "HloLiveRange (max %d):\n", + schedule_end_time_); + absl::StrAppendFormat(&output, " InstructionSequence:\n"); + auto& instructions = flattened_instruction_sequence().instructions(); + for (int64 i = 0; i < instructions.size(); ++i) { + absl::StrAppendFormat(&output, " %d:%s\n", i, instructions[i]->name()); + } + + absl::StrAppendFormat(&output, " BufferLiveRange:\n"); + + for (const HloValue* value : alias_analysis_.dataflow_analysis().values()) { + auto it = buffer_live_ranges_.find(value); + if (it != buffer_live_ranges_.end()) { + absl::StrAppendFormat( + &output, " %s%s:%d-%d\n", value->instruction()->name(), + value->index().ToString(), it->second.start, it->second.end); + } + } + + return output; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_live_range.h b/tensorflow/compiler/xla/service/hlo_live_range.h new file mode 100644 index 00000000000..cc0445acd1e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_live_range.h @@ -0,0 +1,206 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations under +the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Class which computes live range of the output buffers of HLOs and their +// interference by flattening all computations. The live range is only available +// when all global computations (while, if, call, etc) have total order +// sequential orders. +class HloLiveRange { + public: + // Constructs a hlo live range object for the given module and computation + // assuming the given HLO instruction ordering. + static StatusOr> Run( + const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, + const HloComputation* computation, bool module_scoped_analysis = true); + + // LogicalTime represents the time in a virtual clock. Each instruction has + // one monotonically increasing logical time assigned according to the + // schedule. + using LogicalTime = int64; + + struct TimeBound { + LogicalTime start; + LogicalTime end; + + bool friend operator==(const TimeBound& a, const TimeBound& b) { + return a.start == b.start && a.end == b.end; + } + bool friend operator!=(const TimeBound& a, const TimeBound& b) { + return !(a == b); + } + }; + + std::string ToString() const; + + const HloInstructionSequence& flattened_instruction_sequence() const { + return flattened_instruction_sequence_; + } + + // Returns the map from instruction to the end time of that instruction. + const absl::flat_hash_map& + instruction_schedule() const { + return instruction_schedule_; + } + + // Returns the map from a hlo value to the definition time of that hlo value. + const absl::flat_hash_map& buffer_live_ranges() + const { + return buffer_live_ranges_; + } + + absl::flat_hash_map& buffer_live_ranges() { + return buffer_live_ranges_; + } + + // Returns the time stamp of the end of the program. + LogicalTime schedule_end_time() const { return schedule_end_time_; } + + // Returns whether hlo live range is available on this entire module. Hlo live + // range is not available if the module is partially ordered. + bool total_order_scheduled() const { return total_order_scheduled_; } + + private: + explicit HloLiveRange(const HloSchedule& schedule, + const HloAliasAnalysis& alias_analysis, + bool module_scoped_analysis) + : schedule_(schedule), + alias_analysis_(alias_analysis), + module_scoped_analysis_(module_scoped_analysis) {} + + // FlattenSchedule walks through the instructions in `computation`, and + // recurse into each called computations in module_scoped_analysis mode. As it + // walks it also tracks down the ordinal number of each instruction in the + // schedule and store it in the `instruction_schedule` and + // 'flattened_instruction_sequence`. The end of each computation is tracked in + // `computation_end_time`. + int64 FlattenSchedule(const HloComputation& computation, int64 start_time); + + // Based on the flattened schedule, calculate the start and end of each + // buffer. + void CalculateBufferStartEndMap(); + + // The aliased buffers could have overlapping live ranges. + // NormalizeAliasedBuffers normalizes the buffer such that each alias buffer + // has disjoint live range while keeping the live range union the same. This + // avoid double counting aliased buffer sizes. + // + // Before(buffer1 and 2 are aliased): + // + // +----+ live range of buffer1 + // +------------------+ live range of buffer2 + // + // After: + // + // +----------+ live range of buffer1 + // +------+ live range of buffer2 + // + // Before(buffer1 and 2 are aliased): + // + // +----------+ live range of buffer1 + // +------------+ live range of buffer2 + // + // After: + // + // +----------+ live range of buffer1 + // +------+ live range of buffer2 + // + // Before(buffer1 and 2 are aliased): + // + // +----------+ live range of buffer1 + // +---+ live range of buffer2 + // + // After(unchanged): + // + // +----------+ live range of buffer1 + // +---+ live range of buffer2 + // + // As another example, imagine we have the following code sequence with live + // ranges of each while-aliased buffers: + // + // a p1 p2 e b + // a = ... + + // | + // { | + // p1 = param | + + // ROOT true | | + // } | + + // { // body | + // p2 = param + + + // c = p2 + 1 + + // d = c + 1 + // ROOT e = d + 1 + + // } | + // | + // b = while (a) + + + // | + // f = b + 1 + + // + // After normalization it becomes: + // + // a p1 p2 e b + // a = ... + + // | + // { + + // p1 = param + + // ROOT true | + // } + + // { // body + // p2 = param + + // c = p2 + 1 + + // d = c + 1 + // ROOT e = d + 1 + + // } | + // | + // b = while (a) + + // + + // f = b + 1 + + // + // Note there is no overlap of live ranges after normalization. + void NormalizeAliasedBuffers(); + + const HloSchedule& schedule_; + const HloAliasAnalysis& alias_analysis_; + bool module_scoped_analysis_; + bool total_order_scheduled_ = true; + + HloInstructionSequence flattened_instruction_sequence_; + absl::flat_hash_map instruction_schedule_; + absl::flat_hash_map computation_span_times_; + absl::flat_hash_map buffer_live_ranges_; + LogicalTime schedule_end_time_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVE_RANGE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_live_range_test.cc b/tensorflow/compiler/xla/service/hlo_live_range_test.cc new file mode 100644 index 00000000000..d524d9f0c82 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_live_range_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/hlo_live_range.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using TimeBound = HloLiveRange::TimeBound; +class HloLiveRangeTest : public HloTestBase { + protected: + HloLiveRangeTest() : module_(CreateNewVerifiedModule()) {} + ~HloLiveRangeTest() override {} + + void Analyze(const HloSchedule& schedule) { + alias_analysis_ = HloAliasAnalysis::Run(module_.get()).ValueOrDie(); + hlo_live_range_ = HloLiveRange::Run(schedule, *alias_analysis_, + module_->entry_computation()) + .ValueOrDie(); + } + + std::unique_ptr module_; + std::unique_ptr hlo_live_range_; + std::unique_ptr alias_analysis_; + // Shapes for use in the examples. + Shape f32scalar_ = ShapeUtil::MakeShape(xla::F32, {}); + Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4}); + + // Returns the buffer defined at the given instruction and index. + const HloValue* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { + return &alias_analysis_->dataflow_analysis().GetUniqueValueAt(instruction, + index); + } + + HloLiveRange::TimeBound LiveRangeAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + auto* value = BufferAt(instruction, index); + return hlo_live_range_->buffer_live_ranges().at(value); + } +}; + +TEST_F(HloLiveRangeTest, Multiply) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + module_->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module_.get()); + + schedule.set_sequence(module_->entry_computation(), {paramA, paramX, mul}); + + Analyze(schedule); + + // Parameters live from beginning to end. + EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 3})); + EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 3})); + // Mul lives after parameters are defined to the end. + EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 3})); +} + +TEST_F(HloLiveRangeTest, MultiplyAdd) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec4_, "paramY")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + module_->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module_.get()); + + schedule.set_sequence(module_->entry_computation(), + {paramA, paramX, mul, paramY, add}); + + Analyze(schedule); + + // Parameters live from beginning to end. + EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 5})); + EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5})); + EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 5})); + // Mul starts after parameter are defined (Note: all parameters are defined at + // 0, mul starts at 2 which is an arbitrary number). + EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 4})); + // Add lives after mul is defined to the end of the program. + EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5})); +} + +TEST_F(HloLiveRangeTest, LiveOutBuffers) { + // If a buffer is live out, its life range is extened to the end of + // computation. + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec4_, "paramY")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + module_->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module_.get()); + + schedule.set_sequence(module_->entry_computation(), + {paramA, paramX, mul, paramY, add, tuple}); + + Analyze(schedule); + + // Parameters live from beginning to end. + EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 6})); + EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 6})); + EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 6})); + // Mul starts after parameter are defined (Note: all parameters are defined at + // 0, mul starts at 2 which is an arbitrary number). + EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 6})); + // Add lives after mul is defined to the end of the program. + EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 6})); +} + +TEST_F(HloLiveRangeTest, InstructionScheduledAfterRoot) { + // If a buffer is live out, its life range is extened to the end of + // computation. + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec4_, "paramY")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + module_->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module_.get()); + + // Schedule another instruction after root. + schedule.set_sequence(module_->entry_computation(), + {paramA, paramX, mul, paramY, add, tuple, add2}); + + Analyze(schedule); + + // Parameters live from beginning to end. + EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 7})); + EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 7})); + EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 7})); + // Live out buffers live through the computation. + + EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 7})); + EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 7})); + EXPECT_EQ(LiveRangeAt(tuple), TimeBound({5, 7})); + EXPECT_EQ(LiveRangeAt(add2), TimeBound({6, 6})); +} + +TEST_F(HloLiveRangeTest, AliasedParameter) { + // If a parameter is non-readonly(non-aliased), its live range can end in the + // middle of the program. + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32vec4_, "paramA")); + auto paramX = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32vec4_, "paramX")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + auto paramY = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32vec4_, "paramY")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY)); + module_->AddEntryComputation(builder.Build()); + // Set up alias of the first parameter. + TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias( + {}, 0, {}, HloInputOutputAliasConfig::kUserAlias)); + + HloSchedule schedule(module_.get()); + + schedule.set_sequence(module_->entry_computation(), + {paramA, paramX, mul, paramY, add}); + + Analyze(schedule); + + // Non-readonly parameter live like other normal buffers. + EXPECT_EQ(LiveRangeAt(paramA), TimeBound({0, 2})); + + // Readonly parameters live from beginning to end. + EXPECT_EQ(LiveRangeAt(paramX), TimeBound({0, 5})); + EXPECT_EQ(LiveRangeAt(paramY), TimeBound({0, 5})); + // Mul starts after parameter are defined (Note: all parameters are defined at + // 0, mul starts at 2 which is an arbitrary number). + EXPECT_EQ(LiveRangeAt(mul), TimeBound({2, 4})); + // Add lives after mul is defined to the end of the program. + EXPECT_EQ(LiveRangeAt(add), TimeBound({4, 5})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index a75fc0bbc3f..789ec5d21a9 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -215,6 +215,8 @@ HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); +HLO_MATCHER(CopyDone); +HLO_MATCHER(CopyStart); HLO_MATCHER(AllReduce); HLO_MATCHER(CollectivePermute); HLO_MATCHER(Divide); diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fbef51c4ce6..ac74d5b0f65 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -215,7 +216,7 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name(); + s << "HloModule " << PrintName(name(), options.print_ids()); if (has_schedule()) { TF_CHECK_OK(schedule().Verify()); s << ", is_scheduled=true"; @@ -661,6 +662,12 @@ HloComputation* HloModule::GetComputationWithName(absl::string_view name) { return it == computations_in_module.end() ? nullptr : *it; } +uint64 HloModule::Hash() const { + return tensorflow::Hash64Combine( + entry_computation_layout().Hash(), + entry_computation()->root_instruction()->Hash()); +} + /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 950c7a72f45..b6a72db434a 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -146,9 +146,7 @@ class HloModule { // information on opcode, shape, operands, and typically a root instruction. // This function returns the same hash value for equivalent HLO modules, // with respect to HloInstruction::Identical() method. - uint64 Hash() const { - return entry_computation()->root_instruction()->Hash(); - } + uint64 Hash() const; // Gets the computations in this module. // @@ -300,6 +298,38 @@ class HloModule { return &fusion_config_; } + // Checks if this config has a list of entry parameters' HLO shardings for + // SPMD. + bool has_spmd_parameters_shardings() const { + return spmd_parameters_shardings_.has_value(); + } + + // Getter and setter for the list of entry parameters' HLO shardings for SPMD. + const std::vector& spmd_parameters_shardings() const { + CHECK(spmd_parameters_shardings_.has_value()); + return *spmd_parameters_shardings_; + } + void set_spmd_parameters_shardings( + const std::vector& shardings) { + spmd_parameters_shardings_ = shardings; + } + + // Checks if this config has the entry computation output's HLO sharding for + // SPMD. + bool has_spmd_output_sharding() const { + return spmd_output_sharding_.has_value(); + } + + // Getter and setter for the entry computation output's HLO shardings for + // SPMD. + const HloSharding& spmd_output_sharding() const { + CHECK(spmd_output_sharding_.has_value()); + return *spmd_output_sharding_; + } + void set_spmd_output_sharding(const HloSharding& sharding) { + spmd_output_sharding_ = sharding; + } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -342,6 +372,14 @@ class HloModule { // Fusion configuration. std::vector> fusion_config_; + + // The HLO shardings of the entry computation's parameters for + // SPMD-partitioned programs. + absl::optional> spmd_parameters_shardings_; + + // The HLO sharding of the entry computation's output (root) for + // SPMD-partitioned programs. + absl::optional spmd_output_sharding_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index d8ded5f7641..de4df445ac5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -34,6 +34,26 @@ namespace xla { // executable. class HloModuleConfig { public: + // Represents a pair of input and output of the entry computation that can be + // considered as the original and updated values of a variable maintained by + // the caller, and that can be transparently sharded by XLA as an internal + // optimization. If sharded, XLA will create separate sharding/unsharding + // programs, and the caller is responsible to call the XLA-generated + // sharding/unsharding programs before and after the sharded main program. + // + // The sharding/unsharding programs will include all the input/output pairs in + // shardable_value_update_pairs() as a flat tuple in their inputs/outputs, + // sorted by (input_parameter_number, parameter_shape_index). + // + // A typical usage pattern is to shard the variables first, then repeatedly + // invoke the main program, and finally invoke the unsharding program before + // they are used in full-shape. + struct ShardableValueUpdatePair { + int64 input_parameter_number; + ShapeIndex parameter_shape_index; + ShapeIndex output_shape_index; + }; + // A configuration can be created either with, or without an entry // ComputationLayout. The default ctor creates it without -- in this case // accessing entry_computation_layout will CHECK-fail. The ctor accepting a @@ -118,6 +138,15 @@ class HloModuleConfig { static_device_assignment_ = device_assignment; } + const std::vector shardable_value_update_pairs() + const { + return shardable_value_update_pairs_; + } + void set_shardable_value_update_pairs( + std::vector pairs) { + shardable_value_update_pairs_ = std::move(pairs); + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -137,6 +166,8 @@ class HloModuleConfig { // Compile-time known device assignment. absl::optional static_device_assignment_; + + std::vector shardable_value_update_pairs_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2589de633d0..c96bfb15187 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -88,6 +88,7 @@ class HloParser { // Stand alone parsing utils for various aggregate data types. StatusOr ParseShapeOnly(); StatusOr ParseShardingOnly(); + StatusOr ParseFrontendAttributesOnly(); StatusOr> ParseParameterReplicationOnly(); StatusOr ParseWindowOnly(); StatusOr ParseConvolutionDimensionNumbersOnly(); @@ -192,6 +193,7 @@ class HloParser { kWindow, kConvolutionDimensionNumbers, kSharding, + kFrontendAttributes, kParameterReplication, kInstructionList, kSliceRanges, @@ -271,6 +273,7 @@ class HloParser { bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); + bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseReplicaGroupsOnly(std::vector* replica_groups); @@ -677,7 +680,10 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, // Add optional attributes. std::unordered_map attrs; optional sharding; + optional frontend_attributes; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + attrs["frontend_attributes"] = { + /*required=*/false, AttrTy::kFrontendAttributes, &frontend_attributes}; optional parameter_replication; attrs["parameter_replication"] = {/*required=*/false, AttrTy::kParameterReplication, @@ -1678,6 +1684,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional> slice_sizes; attrs["slice_sizes"] = {/*required=*/true, AttrTy::kBracedInt64List, &slice_sizes}; + optional indices_are_sorted = false; + attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool, + &indices_are_sorted}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1693,7 +1702,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction = builder->AddInstruction(HloInstruction::CreateGather( shape, /*operand=*/operands[0], /*start_indices=*/operands[1], - dim_numbers, *slice_sizes)); + dim_numbers, *slice_sizes, indices_are_sorted.value())); break; } case HloOpcode::kScatter: { @@ -1714,6 +1723,9 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, optional update_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &update_computation}; + optional indices_are_sorted = false; + attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool, + &indices_are_sorted}; if (!ParseOperands(&operands, /*expected_size=*/3) || !ParseAttributes(attrs)) { @@ -1729,7 +1741,8 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, instruction = builder->AddInstruction(HloInstruction::CreateScatter( shape, /*operand=*/operands[0], /*scatter_indices=*/operands[1], - /*updates=*/operands[2], *update_computation, dim_numbers)); + /*updates=*/operands[2], *update_computation, dim_numbers, + indices_are_sorted.value())); break; } case HloOpcode::kDomain: { @@ -1838,6 +1851,36 @@ bool HloParser::ParseSharding(OpSharding* sharding) { return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } +// frontend_attributes ::= '{' attributes '}' +// attributes +// ::= /*empty*/ +// ::= attribute '=' value (',' attribute '=' value)* +bool HloParser::ParseFrontendAttributes( + FrontendAttributes* frontend_attributes) { + CHECK(frontend_attributes != nullptr); + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start frontend attributes")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + string attribute; + if (!ParseAttributeName(&attribute)) { + return false; + } + if (lexer_.GetKind() != TokKind::kIdent) { + return false; + } + (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal(); + lexer_.Lex(); + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of frontend attributes"); +} + // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? // ('devices=' ('[' dims ']')* device_list)? '}' // dims ::= int_list device_list ::= int_list @@ -2857,6 +2900,15 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(sharding); return true; } + case AttrTy::kFrontendAttributes: { + FrontendAttributes frontend_attributes; + if (!ParseFrontendAttributes(&frontend_attributes)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(frontend_attributes); + return true; + } case AttrTy::kParameterReplication: { ParameterReplication parameter_replication; if (!ParseParameterReplication(¶meter_replication)) { @@ -4113,6 +4165,19 @@ StatusOr HloParser::ParseShardingOnly() { return HloSharding::FromProto(op_sharding); } +StatusOr HloParser::ParseFrontendAttributesOnly() { + lexer_.Lex(); + FrontendAttributes attributes; + if (!ParseFrontendAttributes(&attributes)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after frontend attributes"); + } + return attributes; +} + StatusOr> HloParser::ParseParameterReplicationOnly() { lexer_.Lex(); ParameterReplication parameter_replication; @@ -4261,6 +4326,11 @@ StatusOr ParseSharding(absl::string_view str) { return parser.ParseShardingOnly(); } +StatusOr ParseFrontendAttributes(absl::string_view str) { + HloParser parser(str); + return parser.ParseFrontendAttributesOnly(); +} + StatusOr> ParseParameterReplication(absl::string_view str) { HloParser parser(str); return parser.ParseParameterReplicationOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index e4214c1e6b5..91ce79ec982 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -54,6 +54,12 @@ Status ParseHloString(absl::string_view str, HloModule* module); // "{replicated}". StatusOr ParseSharding(absl::string_view str); +// Parses frontend attributes from str. str is supposed to contain the body of +// the frontend attributes , i.e. just the rhs of the +// "frontend_attributes={...}" attribute string, e.g., +// "{attr_a=a,attr_b=b}". +StatusOr ParseFrontendAttributes(absl::string_view str); + // Parses parameter replication from str. str is supposed to contain the body of // the parameter replication, i.e. just the rhs of the // "parameter_replication={...}" attribute string, e.g., "{true, false}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index b9a017ada43..c913784cd13 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -875,7 +875,7 @@ ENTRY %sparse_f32_r1 () -> f32[9] { )" }, { -"gather", +"Gather", R"(HloModule StringifyGather ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { @@ -887,7 +887,19 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5] )" }, { -"scatter", +"SortedGather", +R"(HloModule StringifyGather + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true +} + +)" +}, +{ +"Scatter", R"(HloModule StringifyScatter %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { @@ -903,6 +915,25 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3 } +)" +}, +{ +"SortedScatter", +R"(HloModule StringifySortedScatter + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3 +} + )" }, { @@ -2327,6 +2358,13 @@ TEST_F(HloParserTest, ParseSharding) { EXPECT_EQ(sharding.ToString(), original); } +TEST_F(HloParserTest, ParseFrontendAttributes) { + const string original = "{attr_a=test_a,attr_b=b}"; + TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes, + ParseFrontendAttributes(original)); + EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original); +} + TEST_F(HloParserTest, ParseWindow) { Window original = window_util::MakeWindow({1, 2, 3}); TF_ASSERT_OK_AND_ASSIGN(Window parsed, diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 603371d830f..445a3ea97d2 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -100,6 +100,17 @@ bool CanBeRematerialized( using BufferId = int64; using BufferIdList = absl::InlinedVector; +struct RematStrategy { + enum { + // Recompute the node at a later program point. + kRecompute, + // Change the layout into a compact form and uncompress it back at a later + // program point. + kCompress, + } kind; + Shape compact_shape; +}; + // We wrap HloInstruction* with an Item that holds auxiliary // per-instruction state. struct Item { @@ -117,6 +128,10 @@ struct Item { // The buffers defined by this instruction. BufferIdList buffers_defined; + // Output buffers of this instruction. This is used to track outputs by GTE + // instructions (where the instruction doesn't define a buffer). + BufferIdList buffers_output; + // The buffers used by this instruction. BufferIdList buffers_used; @@ -251,6 +266,32 @@ class InstructionList { return InsertBefore(to_insert, min_position_item); } + void InsertAfterInstructions(Item* to_insert, + absl::Span after_instructions) { + VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name() + << " after {" + << absl::StrJoin(after_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) + << "}"; + + // Find the max position number of any instruction in + // 'after_instructions'. + CHECK(!after_instructions.empty()); + Item* max_position_item = nullptr; + for (Item* item : after_instructions) { + if (max_position_item == nullptr || + item->position > max_position_item->position) { + max_position_item = item; + } + } + // No rematerializable instruction should be inserted at the end of the + // computation. + CHECK(max_position_item->next != nullptr); + InsertBeforeInstructions(to_insert, {max_position_item->next}); + } + void Blacklist(const HloInstruction* inst) { GetItem(inst)->blacklisted = true; } @@ -327,6 +368,7 @@ class MemoryUsageTracker { MemoryUsageTracker( const HloComputation* computation, const HloRematerialization::ShapeSizeFunction& size_function, + const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, const InstructionList& instruction_list); @@ -338,6 +380,22 @@ class MemoryUsageTracker { // EndInstruction memory for dead operand(s) is freed. Status BeginInstruction(Item* item); + int64 RematerializationCost(const HloInstruction* instruction, + int64 memory_reduced, int64 memory_limit_bytes) { + // If none of the users of 'instruction' have been placed in the sequence + // (as tracked by memory_tracker), then rematerialization of 'instruction' + // is a zero-cost move of 'instruction' in the sequence. + if (!absl::c_any_of( + instruction->users(), + [this](const HloInstruction* inst) { return IsPlaced(inst); })) { + return 0; + } + + CHECK_GT(memory_reduced, 0); + // Return the inverse of the benefit of rematerialization. + return memory_limit_bytes / memory_reduced; + } + // Finishes the placement of the current instruction. This frees any dead // operands or dead result of the instruction. This must be called after // each call to BeginInstruction. @@ -347,17 +405,28 @@ class MemoryUsageTracker { // if the given instruction is rematerialized. int64 MemoryReducedIfRematerialized(Item* item) const; + // Returns the number of bytes that the current memory usage will be reduced + // if the given instruction is compact. + int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const; + // Returns the number of bytes that the current memory usage will be reduced // by if the given sequence of instructions is rematerialized. int64 MemoryReducedIfRematerialized(const absl::Span& items) const; + Status AddCompressInstructions(Item* original_item, Item* compressed_item, + Item* uncompressed_item); + // Adjusts memory usage to account for the rematerialization of // original_item for all remaining unplaced uses. The rematerialization // is remat_item. This method should be called after the HLO graph has - // been transformed (rematerialization instruction created and connected to - // uses). + // been transformed (rematerialization instruction created and connected + // to uses). Status AddRematerializedInstruction(Item* original_item, Item* remat_item); + std::pair PickRematerializationCandidate( + const InstructionList& instruction_list, int64 memory_limit_bytes, + absl::flat_hash_map* remat_able); + // Returns whether the given instruction has been placed (BeginInstruction // has been called with 'instruction' as the argument). bool IsPlaced(const HloInstruction* instruction) const { @@ -390,6 +459,9 @@ class MemoryUsageTracker { // The materialized size of the buffer in bytes. const int64 size; + // Shape of the buffer. + Shape shape; + // Whether this buffer is live-out of the computation. bool live_out; @@ -412,19 +484,21 @@ class MemoryUsageTracker { } }; + // Get the compact shape of given hlo instruction. An internal cache is used + // to avoid computing the shape multiple times. + StatusOr GetCompactShape(const HloInstruction* hlo); + // Creates a Buffer representing the given logical buffer. The buffer is added // to buffers_ and a reference is returned. Buffer& CreateBufferFromLogicalBuffer( const LogicalBuffer* logical_buffer, - const TuplePointsToAnalysis& points_to_analysis, - const HloRematerialization::ShapeSizeFunction& size_function, - bool live_out) { + const TuplePointsToAnalysis& points_to_analysis, bool live_out) { bool has_indirect_uses = false; ItemList users = GetUsers(instruction_list_, logical_buffer, points_to_analysis, &has_indirect_uses); return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()), - size_function(logical_buffer->shape()), std::move(users), - live_out, has_indirect_uses); + logical_buffer->shape(), std::move(users), live_out, + has_indirect_uses); } // Create a new buffer representing a rematerialization of given buffer for @@ -438,7 +512,7 @@ class MemoryUsageTracker { for (Item* use : rematerialized_uses) { CHECK(!use->placed) << use->instruction->name(); } - return NewBuffer(remat_item, original_buffer.size, + return NewBuffer(remat_item, original_buffer.shape, std::move(rematerialized_uses), /*live_out=*/false, /*has_indirect_uses=*/false); } @@ -449,7 +523,8 @@ class MemoryUsageTracker { // different computation. int64 AllocatedSize(BufferId buffer_id) const { const Buffer& buffer = buffers_.at(buffer_id); - HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode(); + HloInstruction* inst = buffer.defining_instruction->instruction; + HloOpcode def_opcode = inst->opcode(); if (buffer.live_out || def_opcode == HloOpcode::kParameter) { return 0; } else { @@ -473,7 +548,7 @@ class MemoryUsageTracker { return absl::c_linear_search(in_progress_uses, buffer_id); } - // Returns whether the given instruction is live at the current program + // Returns whether the given buffer is live at the current program // point. bool IsCurrentlyLive(BufferId buffer_id) const { const Buffer& buffer = buffers_[buffer_id]; @@ -481,13 +556,30 @@ class MemoryUsageTracker { buffer.unfinished_user_count > 0); } + // Returns whether the given instruction is live at the current program + // point. + bool IsInstructionCurrentlyLive(Item* instruction) const { + // If the instruction has not started yet, it is not alive. + if (!IsPlaced(instruction->instruction)) { + return false; + } + for (const HloInstruction* user : instruction->instruction->users()) { + if (!IsPlaced(user)) { + // If there is an unplaced user, consider this instruction currently + // live. + return true; + } + } + return false; + } + // Create a new buffer, add it to buffers_, and return a reference. - Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users, - bool live_out, bool has_indirect_uses) { + Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, + ItemList&& users, bool live_out, bool has_indirect_uses) { int buffer_id = buffers_.size(); - buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out, - has_indirect_uses, users, - static_cast(users.size())}); + buffers_.push_back(Buffer{ + buffer_id, defining_instruction, size_function_(shape), shape, live_out, + has_indirect_uses, users, static_cast(users.size())}); return buffers_.back(); } @@ -498,6 +590,16 @@ class MemoryUsageTracker { // (BeginInstruction/EndInstruction calls). const InstructionList& instruction_list_; + // Size function returns the bytes of a given buffer. + const HloRematerialization::ShapeSizeFunction& size_function_; + + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + const HloRematerialization::CompactShapeFunction& compact_shape_function_; + + // A map that caches existing known compact shape for each instruction. + absl::flat_hash_map compact_shape_; + // Memory usage at the currently placed instruction. int64 memory_usage_ = 0; @@ -512,9 +614,13 @@ class MemoryUsageTracker { MemoryUsageTracker::MemoryUsageTracker( const HloComputation* computation, const HloRematerialization::ShapeSizeFunction& size_function, + const HloRematerialization::CompactShapeFunction& compact_shape_function, const TuplePointsToAnalysis& points_to_analysis, const InstructionList& instruction_list) - : computation_(computation), instruction_list_(instruction_list) { + : computation_(computation), + instruction_list_(instruction_list), + size_function_(size_function), + compact_shape_function_(compact_shape_function) { PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); @@ -556,7 +662,7 @@ MemoryUsageTracker::MemoryUsageTracker( } } else { buffer = &CreateBufferFromLogicalBuffer( - logical_buffer, points_to_analysis, size_function, + logical_buffer, points_to_analysis, ContainsKey(live_out_set, logical_buffer)); item->buffers_defined.push_back(buffer->id); for (Item* user : buffer->users) { @@ -566,6 +672,14 @@ MemoryUsageTracker::MemoryUsageTracker( logical_buffer_to_buffer_id[logical_buffer] = buffer->id; } + + // Trace the output of each instruction. This is so that we can properly + // track which outputs does GTEs have. + for (const LogicalBuffer* logical_buffer : + points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) { + item->buffers_output.push_back( + logical_buffer_to_buffer_id[logical_buffer]); + } } XLA_VLOG_LINES(10, ToString()); DCHECK(Check()); @@ -611,7 +725,8 @@ Status MemoryUsageTracker::EndInstruction() { // Buffer is now dead. VLOG(3) << " " << buffer.ToString() << " is now dead."; memory_usage_ -= AllocatedSize(buffer_id); - CHECK_GE(memory_usage_, 0); + // The memory usage can become negative inside the computation as we can + // free up the parameter space and reuse it for other tensors. } } @@ -622,7 +737,8 @@ Status MemoryUsageTracker::EndInstruction() { if (buffer.unfinished_user_count == 0) { VLOG(3) << " " << buffer.ToString() << " is immediately dead."; memory_usage_ -= AllocatedSize(buffer_id); - CHECK_GE(memory_usage_, 0); + // The memory usage can become negative inside the computation as we can + // free up the parameter space and reuse it for other tensors. } } @@ -637,6 +753,30 @@ Status MemoryUsageTracker::EndInstruction() { return Status::OK(); } +int64 MemoryUsageTracker::MemoryReducedIfCompressed( + Item* item, const Shape& compact_shape) const { + CHECK_NE(in_progress_item_, nullptr); + if (!item->placed || item == in_progress_item_) { + return 0; + } + + int64 memory_reduced = 0; + + // We only compress a single piece of an output at one time. + CHECK_EQ(item->buffers_output.size(), 1); + BufferId buffer_id = item->buffers_output[0]; + if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id) && + IsInstructionCurrentlyLive(item)) { + const Buffer& buffer = buffers_.at(buffer_id); + memory_reduced += buffer.size; + + int64 compact_shape_size = size_function_(compact_shape); + // Account for buffers that are compressed after instruction. + memory_reduced -= compact_shape_size; + } + return memory_reduced; +} + int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const { CHECK_NE(in_progress_item_, nullptr); if (!item->placed || item == in_progress_item_) { @@ -736,6 +876,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized( return memory_reduced; } +Status MemoryUsageTracker::AddCompressInstructions(Item* original_item, + Item* compressed_item, + Item* uncompressed_item) { + // Original buffer is now dead. + memory_usage_ -= size_function_(original_item->instruction->shape()); + // Compressed buffer is now alive. + memory_usage_ += size_function_(compressed_item->instruction->shape()); + + ItemList placed_users; + ItemList unplaced_users; + CHECK_EQ(original_item->buffers_output.size(), 1); + BufferId original_buffer_id = original_item->buffers_output[0]; + Buffer& original_buffer = buffers_.at(original_buffer_id); + for (Item* user : original_buffer.users) { + if (user->placed) { + CHECK(IsFinished(user)) << user->instruction->name(); + placed_users.push_back(user); + } else { + unplaced_users.push_back(user); + } + } + original_buffer.users = std::move(placed_users); + original_buffer.unfinished_user_count = 0; + original_buffer.users.push_back(compressed_item); + Buffer& compressed_buffer = + NewBuffer(compressed_item, compressed_item->instruction->shape(), + {uncompressed_item}, /*live_out=*/false, + /*has_indirect_uses=*/false); + compressed_item->buffers_used = original_item->buffers_output; + compressed_item->buffers_output = {compressed_buffer.id}; + compressed_item->buffers_defined.push_back(compressed_buffer.id); + + Buffer& uncompressed_buffer = + NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(), + std::move(unplaced_users), /*live_out=*/false, + /*has_indirect_uses=*/false); + + uncompressed_item->buffers_used = {compressed_item->buffers_output[0]}; + uncompressed_item->buffers_output = {uncompressed_buffer.id}; + uncompressed_item->buffers_defined = {uncompressed_buffer.id}; + + for (Item* user : uncompressed_buffer.users) { + BufferIdList& buffers_used = user->buffers_used; + std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id, + uncompressed_buffer.id); + } + + return Status::OK(); +} + Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, Item* remat_item) { VLOG(3) << "AddRematerializedInstruction: original_instruction = " @@ -831,6 +1021,17 @@ string MemoryUsageTracker::ToString() const { return output; } +StatusOr MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) { + auto it = compact_shape_.find(hlo); + if (it != compact_shape_.end()) { + return it->second; + } + const Shape& original_shape = hlo->shape(); + TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape)); + compact_shape_[hlo] = min_shape; + return min_shape; +} + bool MemoryUsageTracker::Check() const { auto elements_are_unique = [](const BufferIdList& vec) { return vec.size() == std::set(vec.begin(), vec.end()).size(); @@ -917,12 +1118,15 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate( - const MemoryUsageTracker& memory_tracker, +std::pair +MemoryUsageTracker::PickRematerializationCandidate( const InstructionList& instruction_list, int64 memory_limit_bytes, absl::flat_hash_map* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; + RematStrategy best_strategy; + + VLOG(5) << "Picking candidate"; // TODO(b/35244891): This is currently quadratic in the number of HLO // instructions. @@ -947,44 +1151,215 @@ Item* PickRematerializationCandidate( if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; + continue; } - // If any of the candidate's control successor has been placed, we need to - // skip this candidate. Otherwise we will violate control dependency. - bool control_successor_placed = - std::any_of(candidate->control_successors().begin(), - candidate->control_successors().end(), - [&memory_tracker](const HloInstruction* inst) { - return memory_tracker.IsPlaced(inst); - }); + if (item->buffers_output.size() == 1) { + // Only consider compressing single output instruction. + const Buffer& output_buffer = buffers_.at(item->buffers_output[0]); + + if (item->placed && item != in_progress_item_ && + !output_buffer.live_out) { + const Shape& original_shape = item->instruction->shape(); + if (original_shape.IsArray()) { + Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie(); + const int64 memory_reduced = + MemoryReducedIfCompressed(item, compact_shape); + if (memory_reduced > 0) { + const int64 cost = memory_limit_bytes / memory_reduced; + if (best_item == nullptr || cost < best_cost) { + VLOG(3) << "candidate " << candidate->name() << "(" + << candidate->ToShortString() << ")" + << " now best when compressed into " + << compact_shape.ToString(true); + RematStrategy strategy; + strategy.kind = RematStrategy::kCompress; + best_strategy = strategy; + best_strategy.compact_shape = compact_shape; + best_item = item; + best_cost = cost; + } + } + } + } + } + + // If any of the candidate's control successor has been placed, we need + // to skip this candidate. Otherwise we will violate control dependency. + bool control_successor_placed = std::any_of( + candidate->control_successors().begin(), + candidate->control_successors().end(), + [this](const HloInstruction* inst) { return IsPlaced(inst); }); if (control_successor_placed) { continue; } - const int64 memory_reduced = - memory_tracker.MemoryReducedIfRematerialized(item); + const int64 memory_reduced = MemoryReducedIfRematerialized(item); - if (memory_reduced <= 0) { - VLOG(5) << "candidate " << candidate->name() - << " memory reduced = " << memory_reduced << " <= 0"; - continue; - } + if (memory_reduced > 0) { + const int cost = + RematerializationCost(candidate, memory_reduced, memory_limit_bytes); - const int cost = RematerializationCost(candidate, memory_tracker, - memory_reduced, memory_limit_bytes); + VLOG(5) << "candidate " << candidate->name() << ", memory reduced " + << memory_reduced << ", cost per byte " << cost; - VLOG(5) << "candidate " << candidate->name() << ", memory reduced " - << memory_reduced << ", cost per byte " << cost; - - if (best_item == nullptr || cost < best_cost) { - VLOG(5) << "candidate " << candidate->name() << " now best"; - best_item = item; - best_cost = cost; + if (best_item == nullptr || cost < best_cost) { + VLOG(5) << "candidate " << candidate->name() << " now best"; + best_strategy.kind = RematStrategy::kRecompute; + best_item = item; + best_cost = cost; + } } } - return best_item; + return {best_item, best_strategy}; +} + +StatusOr RematerializeInstruction( + MemoryUsageTracker* memory_tracker, Item* best_item, + absl::flat_hash_set* remat_move_instructions, + InstructionList* instruction_list) { + HloInstruction* best = best_item->instruction; + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << HumanReadableNumBytes( + memory_tracker->MemoryReducedIfRematerialized(best_item)) + << ")"; + + int64 net_instructions_added = 0; + + HloComputation* computation = best->parent(); + + HloInstruction* remat = + computation->AddInstruction(best->Clone(/*suffix=*/"remat")); + + // Add control dependencies to the new operation. + for (auto successor : best->control_successors()) { + TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); + } + for (auto predecessor : best->control_predecessors()) { + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); + } + + Item* remat_item = instruction_list->CreateItem(remat); + + // Replace each remaining use of 'best' with the rematerialization. + std::vector best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker->IsPlaced(user)) { + VLOG(2) << " Replacing use of " << best->name() << " in " << user->name() + << " with " << remat->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); + } + } + + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR( + memory_tracker->AddRematerializedInstruction(best_item, remat_item)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction *and* the earliest unplaced last use of any + // operands of remat. Unplaced uses of the remat's operands are included + // because we don't want to extend the live range of remat's operands as + // this could increase memory usage. + ItemList place_before; + for (auto user : remat->users()) { + place_before.push_back(instruction_list->GetItem(user)); + } + for (auto* operand : remat->operands()) { + for (auto* operand_user : operand->users()) { + if (operand_user != remat) { + Item* operand_user_item = instruction_list->GetItem(operand_user); + if (!operand_user_item->placed) { + place_before.push_back(operand_user_item); + } + } + } + } + // Insert rematerialized instruction before any of its successors to + // preserve ordering regarding control dependency. + for (auto successor : remat->control_successors()) { + Item* successor_item = instruction_list->GetItem(successor); + // Assert to make sure we never remat an operation with control + // successor already placed. + CHECK(!successor_item->placed) << successor_item->instruction->name(); + place_before.push_back(successor_item); + } + instruction_list->InsertBeforeInstructions(remat_item, place_before); + + // If the rematerialized instruction is dead then rematerialization is + // essentially a move. Don't delete the instruction now because we don't + // want duplicate HloInstruction* values during the course of the + // transformation because we keep maps with HloInstruction* values as + // keys. + if (best->users().empty()) { + VLOG(2) << best->name() << " is now dead"; + if (ContainsKey(*remat_move_instructions, best)) { + // Previously, 'best' was a rematerialization which killed the + // instruction it was a copying of. Now 'remat' is a rematerialization + // of 'best' and kills 'best'. Stop rematerializing this instruction + // to avoid an infinite loop. + instruction_list->Blacklist(remat); + } + remat_move_instructions->insert(remat); + + } else { + net_instructions_added++; + } + return net_instructions_added; +} + +StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, + Item* best_item, const Shape& compact_shape, + InstructionList* instruction_list) { + HloInstruction* best = best_item->instruction; + VLOG(5) << "Transposing instruction " << best->name() << " (saving " + << HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed( + best_item, compact_shape)) + << ") to" << compact_shape.ToString(true); + + HloComputation* computation = best->parent(); + + HloInstruction* compressed = computation->AddInstruction( + HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); + + HloInstruction* uncompressed = computation->AddInstruction( + HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); + + Item* compressed_item = instruction_list->CreateItem(compressed); + compressed_item->placed = true; + + Item* uncompressed_item = instruction_list->CreateItem(uncompressed); + + // Replace each remaining use of 'best' with the uncompressed. + std::vector best_users_copy = best->users(); + for (HloInstruction* user : best_users_copy) { + if (!memory_tracker->IsPlaced(user)) { + VLOG(5) << " Replacing use of " << best->name() << " in " << user->name() + << " with " << uncompressed->name(); + TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed)); + } + } + + // Account for the rematerialization in the memory tracker. + TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions( + best_item, compressed_item, uncompressed_item)); + + // Insert rematerialized instruction right before the earliest unplaced + // use of the instruction. + ItemList place_before; + for (auto user : uncompressed->users()) { + place_before.push_back(instruction_list->GetItem(user)); + } + + instruction_list->Blacklist(compressed_item->instruction); + instruction_list->Blacklist(uncompressed_item->instruction); + + instruction_list->InsertBeforeInstructions(uncompressed_item, place_before); + + instruction_list->InsertAfterInstructions(compressed_item, {best_item}); + + return 2; } } // namespace @@ -993,7 +1368,8 @@ StatusOr HloRematerialization::ComputePeakMemory( const HloComputation* computation, const HloInstructionSequence& order) const { InstructionList instruction_list(order); - MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_, + MemoryUsageTracker tracker(computation, size_function_, + compact_shape_function_, *points_to_analysis_, instruction_list); int64 peak_memory = tracker.memory_usage(); for (auto* item = instruction_list.first(); item != nullptr; @@ -1037,6 +1413,7 @@ StatusOr HloRematerialization::RematerializeComputation( InstructionList instruction_list(schedule->sequence(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, + compact_shape_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1086,8 +1463,11 @@ StatusOr HloRematerialization::RematerializeComputation( callee_usage) << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); - Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes, &remat_able); + Item* best_item; + RematStrategy best_strategy; + std::tie(best_item, best_strategy) = + memory_tracker.PickRematerializationCandidate( + instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1099,88 +1479,33 @@ StatusOr HloRematerialization::RematerializeComputation( } HloInstruction* best = best_item->instruction; - VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << HumanReadableNumBytes( - memory_tracker.MemoryReducedIfRematerialized(best_item)) - << ")"; changed = true; remat_count++; - HloInstruction* remat = - computation->AddInstruction(best->Clone(/*suffix=*/"remat")); + int64 added_instruction = 0; + if (best_strategy.kind == RematStrategy::kCompress) { + VLOG(1) << "Compressing instruction " << best->name() << " (saving " + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfCompressed( + best_item, best_strategy.compact_shape)) + << ")"; - // Add control dependencies to the new operation. - for (auto successor : best->control_successors()) { - TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor)); - } - for (auto predecessor : best->control_predecessors()) { - TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat)); - } - - Item* remat_item = instruction_list.CreateItem(remat); - - // Replace each remaining use of 'best' with the rematerialization. - std::vector best_users_copy = best->users(); - for (HloInstruction* user : best_users_copy) { - if (!memory_tracker.IsPlaced(user)) { - VLOG(2) << " Replacing use of " << best->name() << " in " - << user->name() << " with " << remat->name(); - TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat)); - } - } - - // Account for the rematerialization in the memory tracker. - TF_RETURN_IF_ERROR( - memory_tracker.AddRematerializedInstruction(best_item, remat_item)); - - // Insert rematerialized instruction right before the earliest unplaced - // use of the instruction *and* the earliest unplaced last use of any - // operands of remat. Unplaced uses of the remat's operands are included - // because we don't want to extend the live range of remat's operands as - // this could increase memory usage. - ItemList place_before; - for (auto user : remat->users()) { - place_before.push_back(instruction_list.GetItem(user)); - } - for (auto* operand : remat->operands()) { - for (auto* operand_user : operand->users()) { - if (operand_user != remat) { - Item* operand_user_item = instruction_list.GetItem(operand_user); - if (!operand_user_item->placed) { - place_before.push_back(operand_user_item); - } - } - } - } - // Insert rematerialized instruction before any of its successors to - // preserve ordering regarding control dependency. - for (auto successor : remat->control_successors()) { - Item* successor_item = instruction_list.GetItem(successor); - // Assert to make sure we never remat an operation with control - // successor already placed. - CHECK(!successor_item->placed) << successor_item->instruction->name(); - place_before.push_back(successor_item); - } - instruction_list.InsertBeforeInstructions(remat_item, place_before); - - // If the rematerialized instruction is dead then rematerialization is - // essentially a move. Don't delete the instruction now because we don't - // want duplicate HloInstruction* values during the course of the - // transformation because we keep maps with HloInstruction* values as - // keys. - if (best->users().empty()) { - VLOG(2) << best->name() << " is now dead"; - if (ContainsKey(remat_move_instructions, best)) { - // Previously, 'best' was a rematerialization which killed the - // instruction it was a copying of. Now 'remat' is a rematerialization - // of 'best' and kills 'best'. Stop rematerializing this instruction - // to avoid an infinite loop. - instruction_list.Blacklist(remat); - } - remat_move_instructions.insert(remat); + TF_ASSIGN_OR_RETURN(added_instruction, + CompressInstruction(&memory_tracker, best_item, + best_strategy.compact_shape, + &instruction_list)); } else { - net_instructions_added++; + VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfRematerialized(best_item)) + << ")"; + + TF_ASSIGN_OR_RETURN(added_instruction, + RematerializeInstruction(&memory_tracker, best_item, + &remat_move_instructions, + &instruction_list)); } + net_instructions_added += added_instruction; VLOG(1) << "memory_usage after rematerialization = " << HumanReadableNumBytes(memory_tracker.memory_usage()); @@ -1226,7 +1551,6 @@ StatusOr HloRematerialization::RematerializeComputation( } // Verify some invariants on the memory tracker. - CHECK_EQ(memory_tracker.memory_usage(), 0); for (auto* instruction : computation->instructions()) { CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name(); } @@ -1281,11 +1605,7 @@ StatusOr HloRematerialization::Run(HloModule* module) { module->result_shape(), [&module_output_size, module, this](const Shape& subshape, const ShapeIndex& output_index) { - if (!module->input_output_alias_config().OutputHasAlias(output_index)) { - // Only account for non-aliased outputs to avoid double counting a - // parameter buffer twice. - module_output_size += size_function_(subshape); - } + module_output_size += size_function_(subshape); }); const int64 adjusted_memory_limit_bytes = @@ -1361,7 +1681,7 @@ StatusOr HloRematerialization::Run(HloModule* module) { sizes_->after_bytes = current_peak_memory; } - XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); + XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes_) { LOG(WARNING) << absl::StrFormat( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 350cf0f8e8f..9ab34b4862d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -24,6 +24,8 @@ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" namespace xla { @@ -38,6 +40,8 @@ class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function; + using CompactShapeFunction = std::function(const Shape&)>; + // Helper struct that communicates the before / after sizes for the // rematerialization process. struct RematerializationSizes { @@ -45,23 +49,34 @@ class HloRematerialization : public HloModulePass { int64 after_bytes; }; + static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } + // Constructor parameters: // // size_function: Function which returns the size in bytes of the top-level // buffer of the given shape. // // memory_limit_bytes: The threshold number of bytes to reduce memory use to - // via rematerialization. + // via rematerialization. Size of aliased outputs should be subtracted + // from this. // // sizes: Pointer to data structure which records the peak memory usage of // the HLO module before/after rematerialization. Value are set during // Run(). Can be nullptr. - HloRematerialization(const ShapeSizeFunction& size_function, - int64 memory_limit_bytes, RematerializationSizes* sizes) + // + // compact_shape_function: Function which returns the compact form of a + // shape. If nullptr is provided, an default identity function is used. + explicit HloRematerialization( + const ShapeSizeFunction& size_function, int64 memory_limit_bytes, + RematerializationSizes* sizes, + CompactShapeFunction compact_shape_function = nullptr) : size_function_(size_function), memory_limit_bytes_(memory_limit_bytes), - sizes_(sizes) {} - ~HloRematerialization() {} + sizes_(sizes), + compact_shape_function_(compact_shape_function == nullptr + ? DefaultCompactShapeFunction + : std::move(compact_shape_function)) {} + ~HloRematerialization() override = default; absl::string_view name() const override { return "rematerialization"; } @@ -108,6 +123,10 @@ class HloRematerialization : public HloModulePass { // module before/after rematerialization RematerializationSizes* sizes_; + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + const CompactShapeFunction compact_shape_function_; + // Call graph of the hlo_module. std::unique_ptr call_graph_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 987177e40b8..dabd9d20f64 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -534,6 +533,142 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest, ::testing::Values(true, false)); +class CompressingRematerializationTest : public RematerializationTestBase { + protected: + // A special shape size function, which pads the most minor dimension to 64. + static int64 ShapeSizePadMinorTo64(const Shape& shape) { + if (shape.IsTuple()) { + // Size of a tuple is 4 bytes. + return 4; + } + Shape descending_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape); + int64 size = + ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type()); + for (int64 i = 0; i < descending_shape.rank(); ++i) { + int64 dim = shape.dimensions(i); + if (i == descending_shape.rank() - 1) { + dim = RoundUpToNearest(dim, 64); + } + size *= dim; + } + return size; + } + + // Swap the two most-minor dimensions if the second-minor dimension is bigger + // than the most-minor dimension. + static StatusOr ChooseCompactLayoutForShape(const Shape& shape) { + Shape result = shape; + Layout layout = result.layout(); + int64 most_minor_index = layout.minor_to_major()[0]; + int64 second_minor_index = layout.minor_to_major()[1]; + int64 most_minor = result.dimensions(most_minor_index); + int64 second_minor = result.dimensions(second_minor_index); + if (most_minor < second_minor) { + result.set_dimensions(most_minor_index, second_minor); + result.set_dimensions(second_minor_index, most_minor); + } + return result; + } + + StatusOr RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module) { + TF_EXPECT_OK(verifier().Run(module).status()); + HloRematerialization remat(ShapeSizePadMinorTo64, memory_limit_bytes, + /*sizes=*/nullptr, ChooseCompactLayoutForShape); + return remat.Run(module); + } +}; + +// Test rematerialization of a single instruction. +TEST_F(CompressingRematerializationTest, SingleRemat) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %constant = f32[] constant(0) + %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={} + %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0) + %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/30 * 1024, module.get())); + EXPECT_TRUE(changed); + HloInstruction* broadcast = + module->entry_computation()->GetInstructionWithName("broadcast.0"); + HloInstruction* reduce = + module->entry_computation()->GetInstructionWithName("reduce.1"); + EXPECT_THAT(reduce, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); +} + +TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) { + const string& hlo_string = R"( +HloModule fusion, is_scheduled=true + +%add_float { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param.0 = f32[] parameter(0) + %constant = f32[] constant(0) + %broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={} + %negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0) + %reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.1 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %reduce.2 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add = f32[] add(f32[] %reduce.0, f32[] %reduce.1) + %reduce.3 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float + %add.2 = f32[] add(f32[] %reduce.2, f32[] %reduce.3) + ROOT %tuple = (f32[], f32[]) tuple (f32[] add, f32[] add.2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloRematerialization( + /*memory_limit_bytes=*/30 * 1024, module.get())); + EXPECT_TRUE(changed); + + HloInstruction* broadcast = + module->entry_computation()->GetInstructionWithName("broadcast.0"); + + // Both reduces reuse the same copy instruction. + HloInstruction* reduce_2 = + module->entry_computation()->GetInstructionWithName("reduce.2"); + + HloInstruction* reduce_3 = + module->entry_computation()->GetInstructionWithName("reduce.3"); + + EXPECT_THAT(reduce_2, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); + + EXPECT_THAT(reduce_3, + op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 154cf7fc44f..daeb5943fda 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -208,13 +208,13 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( ServiceExecutableRunOptions service_run_options = GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, nullptr, RunId()); + service_run_options.mutable_run_options()->set_execution_profile(profile); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, CreateExecutable(std::move(module), run_hlo_passes)); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer retval, - executable->ExecuteOnStreamWrapper(&service_run_options, - /*profile=*/profile, arguments)); + executable->ExecuteOnStreamWrapper(&service_run_options, arguments)); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); return std::move(retval); } @@ -244,11 +244,11 @@ StatusOr HloRunner::ExecuteWithDeviceBuffers( ServiceExecutableRunOptions service_run_options = GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, nullptr, RunId()); + service_run_options.mutable_run_options()->set_execution_profile(profile); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer retval, - executable->ExecuteOnStreamWrapper(&service_run_options, - /*profile=*/profile, arguments)); + executable->ExecuteOnStreamWrapper(&service_run_options, arguments)); TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); return std::move(retval); } diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index ae7ccadbf97..1551870f734 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -1,5 +1,5 @@ load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "if_static", ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 80a3ebccff1..85768225892 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -102,13 +102,6 @@ StatusOr> InterpreterCompiler::RunHloPasses( return std::move(hlo_module); } -Status InterpreterCompiler::RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("Module group compilation not supported on Interpreter"); -} - StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* /*device_allocator*/) { @@ -133,15 +126,6 @@ StatusOr> InterpreterCompiler::RunBackend( return std::move(executable); } -StatusOr>> -InterpreterCompiler::RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented( - "Module group compilation is not supported on Interpreter."); -} - StatusOr>> InterpreterCompiler::Compile( std::unique_ptr module_group, std::vector> stream_exec, diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index dc83295b527..824594dfd84 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -46,19 +46,9 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - Status RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) override; - StatusOr> RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - StatusOr>> RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_exec, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 167a013408b..0dab86d986c 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -45,7 +45,7 @@ InterpreterExecutable::InterpreterExecutable( InterpreterExecutable::~InterpreterExecutable() {} -StatusOr InterpreterExecutable::ExecuteOnStream( +StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { @@ -113,22 +113,15 @@ StatusOr InterpreterExecutable::ExecuteOnStream( uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - { - tensorflow::mutex_lock lock(mutex_); + ExecutionProfile* profile = run_options->run_options().execution_profile(); + if (profile) { const double nanoseconds = (end_micros - start_micros) * 1000.0; - execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); + profile->set_compute_time_ns(std::max(nanoseconds, 1.0)); } return std::move(result); } -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) { - return tensorflow::errors::Unimplemented( - "ExecuteAsyncOnStream is not yet supported on Interpreter."); -} - /*static*/ int64 InterpreterExecutable::ShapeSizeBytes(const Shape& shape) { if (shape.IsOpaque()) { return sizeof(void*); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index bda13d37636..ba010de76bd 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -46,16 +46,12 @@ class InterpreterExecutable : public Executable { std::unique_ptr evaluator); ~InterpreterExecutable() override; - StatusOr ExecuteOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) override LOCKS_EXCLUDED(evaluator_lock_); - StatusOr ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - absl::Span arguments) override; - static int64 ShapeSizeBytes(const Shape& shape); protected: diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 6d337688a94..43493b6e154 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -58,14 +58,14 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { return port::Status::OK(); } - bool GetKernel(const MultiKernelLoaderSpec &spec, - KernelBase *kernel) override { - return false; + port::Status GetKernel(const MultiKernelLoaderSpec &spec, + KernelBase *kernel) override { + return port::UnimplementedError("Not Implemented"); } - bool Launch(Stream *stream, const ThreadDim &thread_dims, - const BlockDim &block_dims, const KernelBase &kernel, - const KernelArgsArrayBase &args) override { - return false; + port::Status Launch(Stream *stream, const ThreadDim &thread_dims, + const BlockDim &block_dims, const KernelBase &kernel, + const KernelArgsArrayBase &args) override { + return port::UnimplementedError("Not Implemented"); } void *Allocate(uint64 size) override; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 72ffcd26a72..bf1df58f0b8 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -619,8 +619,9 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RET_CHECK(instruction->branch_computation(j)->num_parameters() == 1); ComputationLayout& branch_computation_layout = FindOrDie(computation_layouts_, instruction->branch_computation(k)); - if (branch_computation_layout.result_layout() != - best_branch_computation_layout.result_layout()) { + if (!branch_computation_layout.result_layout().MatchesLayoutInShape( + best_branch_computation_layout.result_layout().shape(), + /*minor_to_major_only=*/true)) { computation_layouts_.erase(instruction->branch_computation(k)); InsertOrDie(&conditional_mismatch_, instruction->branch_computation(k), @@ -715,8 +716,10 @@ Status CheckConditionalLayout( absl::Span branch_computation_layouts) { for (int j = 0; j < instruction->branch_count(); ++j) { const HloInstruction* branch_operand = instruction->operand(j + 1); - TF_RET_CHECK(branch_computation_layouts[0].result_layout() == - branch_computation_layouts[j].result_layout()); + TF_RET_CHECK( + branch_computation_layouts[0].result_layout().MatchesLayoutInShape( + branch_computation_layouts[j].result_layout().shape(), + /*minor_to_major_only=*/true)); TF_RET_CHECK( branch_computation_layouts[j].result_layout().MatchesLayoutInShape( instruction->shape(), /*minor_to_major_only=*/true)); @@ -853,6 +856,30 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( VLOG(4) << "Operand " << operand->ToString() << " layout does not match " << operand_layout.ToString() << " in " << instruction->ToString(); + // If the operand is only used by a conditional, do the copy inside the branch + // to avoid overhead for other branches. + if (instruction->opcode() == HloOpcode::kConditional && operand_no > 0 && + instruction->operand(operand_no)->user_count() == 1) { + auto branch_comp = instruction->branch_computation(operand_no - 1); + auto param = branch_comp->parameter_instruction(0); + *param->mutable_shape() = operand->shape(); + auto param_users = param->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * param_copy, + CreateCopyWithNewLayout(operand_layout.shape(), param)); + for (auto user : param_users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWithDifferentShape(user, param_copy)); + } + VLOG(4) << "New copy of " << operand->ToString() << " is " + << param_copy->ToString(); + if (param == branch_comp->root_instruction()) { + branch_comp->set_root_instruction(param_copy, + /*accept_different_shape=*/true); + } + *FindOrDie(computation_layouts_, branch_comp).mutable_parameter_layout(0) = + ShapeLayout(operand->shape()); + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, CreateCopyWithNewLayout(operand_layout.shape(), operand)); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 046ffde7616..7d5a3b6623f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -819,8 +819,8 @@ TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { auto constant0 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); - builder.AddInstruction(HloInstruction::CreateUnary( - constant0->shape(), HloOpcode::kBitcast, constant0)); + builder.AddInstruction( + HloInstruction::CreateBitcast(constant0->shape(), constant0)); auto m = CreateNewVerifiedModule(); m->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index 82e955c818e..aa759b26226 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -21,23 +21,6 @@ limitations under the License. #endif namespace xla { -Status LLVMCompiler::RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented( - "Model partitioning not implemented for the CPU/GPU compilers!"); -} - -StatusOr>> -LLVMCompiler::RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented( - "Model partitioning not implemented for the CPU/GPU compilers!"); -} - StatusOr>> LLVMCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index 888815bea3d..bddda50d3e1 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -69,16 +69,6 @@ class LLVMCompiler : public Compiler { using Compiler::RunBackend; using Compiler::RunHloPasses; - Status RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index ffb2df99e9c..9ffb120bb2d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -151,10 +151,9 @@ Status FusedIrEmitter::HandleGetTupleElement( Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { indexed_generators_[parameter] = [=](const IrArray::Index& index) -> llvm::Value* { - if (tiled_parameter_info_) { - if (llvm::Value* param_tile_buffer = - tiled_parameter_info_->GetBufferForParameter( - parameter->parameter_number())) { + int64 param_num = parameter->parameter_number(); + if (param_shmem_buffers_.size() > param_num) { + if (llvm::Value* param_tile_buffer = param_shmem_buffers_[param_num]) { // TODO(jlebar): Add AA metadata to this load. Tile buffers are global // variables, so LLVM's points-to analysis doesn't help us much. And we // want the AA info to be present before address spaces are inferred @@ -162,13 +161,12 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { // address-space-based AA in LLVM, it wouldn't help us much here. return b_->CreateLoad( b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0), - tiled_parameter_info_->x(), - tiled_parameter_info_->y()}), + tile_param_x_, tile_param_y_}), "tiled_buffer"); } } - return GetIrArrayForFusedParameter(parameter->parameter_number()) - .EmitReadArrayElement(index, b_); + return GetIrArrayForFusedParameter(param_num).EmitReadArrayElement(index, + b_); }; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index b1aa6d59634..9b027144cd8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -60,10 +60,16 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { std::function()>; FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator, - ElementalIrEmitter* elemental_emitter) + ElementalIrEmitter* elemental_emitter, + llvm::Value* tile_param_x = nullptr, + llvm::Value* tile_param_y = nullptr, + absl::Span param_shmem_buffers = {}) : operand_arrays_(), operand_arrays_generator_(std::move(operand_arrays_generator)), - tiled_parameter_info_(nullptr), + tile_param_x_(tile_param_x), + tile_param_y_(tile_param_y), + param_shmem_buffers_(param_shmem_buffers.begin(), + param_shmem_buffers.end()), elemental_emitter_(elemental_emitter), b_(elemental_emitter->b()), module_(elemental_emitter->module()) {} @@ -87,10 +93,6 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Returns the generator function for the given instruction. IndexedGenerator GetGenerator(const HloInstruction* instruction) const; - void SetTiledParameterInfo(const llvm_ir::TiledParameterInfo* info) { - tiled_parameter_info_ = info; - } - // Evaluates whether fusing 'producer' into 'consumer' might cause exponential // behavior in FusedIrEmitter. We currently can have exponential time/memory // requirements for emitting certain fusion kernels, in which case we don't @@ -118,7 +120,15 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { absl::optional> operand_arrays_; GeneratorForOperandIrArrays operand_arrays_generator_; - const llvm_ir::TiledParameterInfo* tiled_parameter_info_; + // The x coordinate within a tile. + llvm::Value* tile_param_x_; + + // The y coordinate within a tile. + llvm::Value* tile_param_y_; + + // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr + // if the parameter is not tiled. + std::vector param_shmem_buffers_; ElementalIrEmitter* elemental_emitter_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h index 02c719502ee..5014aa9c8ae 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -249,11 +249,26 @@ class IrBuilderMixin { return mixin_builder()->CreateFCmpOEQ(std::forward(args)...); } + template + llvm::Value* FCmpOGT(Args&&... args) { + return mixin_builder()->CreateFCmpOGT(std::forward(args)...); + } + + template + llvm::Value* FCmpOGE(Args&&... args) { + return mixin_builder()->CreateFCmpOGE(std::forward(args)...); + } + template llvm::Value* FCmpOLT(Args&&... args) { return mixin_builder()->CreateFCmpOLT(std::forward(args)...); } + template + llvm::Value* FCmpULT(Args&&... args) { + return mixin_builder()->CreateFCmpULT(std::forward(args)...); + } + template llvm::Value* FCmpOLE(Args&&... args) { return mixin_builder()->CreateFCmpOLE(std::forward(args)...); diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 2ef844ffa62..f586ee4bd4b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -54,6 +54,15 @@ Shape MergeDimensions(absl::Span segs, const Shape& shape) { dimensions); } +std::array ElementWiseCeilOfRatio(std::array dividends, + std::array divisors) { + std::array out; + for (int i = 0; i < 3; i++) { + out[i] = CeilOfRatio(dividends.at(i), divisors.at(i)); + } + return out; +} + } // namespace absl::optional > FindTranspose021(const Shape& a, @@ -94,35 +103,36 @@ absl::optional > FindTranspose021(const Shape& a, return absl::nullopt; } -KernelMappingScheme::KernelMappingScheme( - absl::Span dims_in_elems, int64 tile_size_y, int64 tile_size_x, - absl::Span req_block_sizes, int64 num_threads_y, - int64 num_threads_x, llvm::IRBuilder<>* b) +KernelMappingScheme::KernelMappingScheme(absl::Span dims_in_elems, + int64 tile_size_y, int64 tile_size_x, + int64 block_size_z, + int64 num_threads_y, + int64 num_threads_x, bool is_dilated_x, + llvm::IRBuilder<>* b) : b_(b), - dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()), + dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]}, tile_sizes_{1, tile_size_y, tile_size_x}, + dims_in_tiles_{dims_in_elems[0], + CeilOfRatio(dims_in_elems[1], tile_size_y), + CeilOfRatio(dims_in_elems[2], tile_size_x)}, + block_sizes_{block_size_z, 1, 1}, + dims_in_blocks_{CeilOfRatio(dims_in_elems[0], block_sizes_[0]), + dims_in_tiles_[1], dims_in_tiles_[2]}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), - dilated_x_(true) { - DCHECK_EQ(dims_in_elems_.size(), 3); - DCHECK_EQ(req_block_sizes.size(), 3); - + dilated_x_(is_dilated_x) { DCHECK_EQ(tile_size_y % num_threads_y_, 0); DCHECK_EQ(tile_size_x % num_threads_x_, 0); - - dims_in_tiles_ = ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_); - block_sizes_.reserve(req_block_sizes.size()); - absl::c_transform(req_block_sizes, dims_in_tiles_, - std::back_inserter(block_sizes_), - [](const int64 requested_size, const int64 max_size) { - return std::min(requested_size, max_size); - }); - dims_in_blocks_ = ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_); - + CHECK_EQ((dims_in_elems[0] % block_size_z), 0); VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]"; VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]"; VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",") << "]"; + if (!dilated_x_) { + // dilated_x_=false is for the purpose of vectorization, which requires + // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. + CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); + } } IrArray::Index KernelMappingScheme::GetUnnormalizedIndex( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index f802cc27d51..46561dd3252 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -90,23 +90,24 @@ class KernelMappingScheme { enum { DimZ = 0, DimY, DimX, DimTot }; public: - KernelMappingScheme() {} // dims_in_elems: the normalized tensor dimensions. - // req_block_sizes: the requested block size in number of tiles for each - // dimension. The actual block size is set to min(req_block_size, - // dims_in_number_of_blocks). KernelMappingScheme(absl::Span dims_in_elems, int64 tile_size_y, - int64 tile_size_x, - absl::Span req_block_sizes, + int64 tile_size_x, int64 block_size_z, int64 num_threads_y, int64 num_threads_x, - llvm::IRBuilder<>* b); + bool is_dilated_x, llvm::IRBuilder<>* b); + // Number of elements in each dimension (Z/Y/X respectively). absl::Span GetDimensionsInElements() const { return dims_in_elems_; } + + // Ratio of elements in each dimension over tile sizes for Z/Y/X + // respectively. absl::Span GetDimensionsInTiles() const { return dims_in_tiles_; } + + // Ratio of dimensions per tile over block sizes. absl::Span GetDimensionsInBlocks() const { return dims_in_blocks_; } @@ -125,10 +126,7 @@ class KernelMappingScheme { return absl::c_accumulate(dims_in_blocks_, 1, std::multiplies()); } - int64 GetTileSizeForDimension(int d) const { - DCHECK(d >= DimZ && d <= DimX); - return tile_sizes_[d]; - } + int64 GetTileSizeForDimension(int d) const { return tile_sizes_.at(d); } int64 GetTileSizeForDimensionX() const { return GetTileSizeForDimension(DimX); } @@ -138,8 +136,7 @@ class KernelMappingScheme { absl::Span GetBlockSizes() const { return block_sizes_; } int64 GetTileBlockSizeForDimension(int d) const { - DCHECK(d >= DimZ && d <= DimX); - return dims_in_blocks_[d]; + return dims_in_blocks_.at(d); } int64 GetNumberOfThreadsForDimensionX() const { return num_threads_x_; } @@ -151,14 +148,6 @@ class KernelMappingScheme { } bool DilatedX() const { return dilated_x_; } - void SetDilatedX(bool v) { - dilated_x_ = v; - if (!dilated_x_) { - // dilated_x_=false is for the purpose of vectorization, which requires - // GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_. - CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0); - } - } IrArray::Index EmitBlockIndex(llvm::Type* index_ty); // Returns the index for the first tile in the block with the given block @@ -181,19 +170,19 @@ class KernelMappingScheme { private: llvm::IRBuilder<>* b_; // The number of elements in each dimension. - std::vector dims_in_elems_; + std::array dims_in_elems_; // The number of elements for each dimension of a tile. - std::vector tile_sizes_; + std::array tile_sizes_; // The number of tiles in each dimension. It is computed from dims_in_elem_ // and tile_sizes_. - std::vector dims_in_tiles_; + std::array dims_in_tiles_; // The number of tiles for each dimension of a tile block. - std::vector block_sizes_; + std::array block_sizes_; // The number of blocks in each dimension of a tile block. It is computed from // dims_in_tile_ and block_sizes_. - std::vector dims_in_blocks_; + std::array dims_in_blocks_; // Number of threads used to process elements in the X direction of a tile. int64 num_threads_x_; @@ -208,34 +197,6 @@ class KernelMappingScheme { bool dilated_x_; }; -// A class to represent information for tiled parameters to support IR emission -// for 021 transpose. -class TiledParameterInfo { - public: - TiledParameterInfo(absl::Span param_buffers, - llvm::Value* y, llvm::Value* x) - : param_buffers_(param_buffers), y_(y), x_(x) {} - - llvm::Value* x() const { return x_; } - llvm::Value* y() const { return y_; } - - void set_x(llvm::Value* x) { x_ = x; } - void set_y(llvm::Value* y) { y_ = y; } - - llvm::Value* GetBufferForParameter(int64 index) const { - return param_buffers_[index]; - } - - private: - // Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr - // if the parameter is not tiled. - absl::Span param_buffers_; - // The y coordinate within a tile. - llvm::Value* y_; - // The x coordinate within a tile. - llvm::Value* x_; -}; - } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index aa07bed443a..c9d86f059b4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -513,6 +513,7 @@ llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config) { flags.setNoNaNs(!options.xla_cpu_fast_math_honor_nans()); flags.setNoInfs(!options.xla_cpu_fast_math_honor_infs()); flags.setAllowReciprocal(!options.xla_cpu_fast_math_honor_division()); + flags.setApproxFunc(!options.xla_cpu_fast_math_honor_functions()); return flags; } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc new file mode 100644 index 00000000000..7dd6686bcea --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -0,0 +1,719 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment.h" + +namespace xla { + +namespace { +// Define a dummy chunk for chunks that will be allocated in the default memory +// space and for keeping track of number of asynchronous copies. +const HeapSimulator::Chunk kDummyChunk{-1, -1}; +} // namespace + +std::vector +AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + std::vector colocated_intervals; + std::vector worklist = {&interval}; + while (!worklist.empty()) { + const BufferInterval* item = worklist.back(); + worklist.pop_back(); + colocated_intervals.push_back(item); + for (const HloValue* buffer_colocated : item->colocations) { + worklist.push_back(&buffer_intervals_.at(buffer_colocated)); + } + } + + absl::c_sort(colocated_intervals, [&](const BufferInterval* x, + const BufferInterval* y) { + return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end); + }); + return colocated_intervals; +} + +HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { + std::vector sorted_buffer_intervals = + GetSortedBufferIntervals(); + + VLOG(1) << "Assigning buffers to alternate memory. Max heap size = " + << max_size_in_bytes_ + << ", min prefetch interval = " << min_prefetch_interval_ + << ", max prefetch interval = " << max_prefetch_interval_; + + for (auto& interval : sorted_buffer_intervals) { + if (!interval.need_allocation) { + continue; + } + + // Skip if we have already allocated for this buffer. + const HloBuffer& buffer = + alias_analysis_.GetBufferContainingValue(*interval.buffer); + if (allocation_map_->contains(&buffer)) { + continue; + } + + // If the buffer is a tuple, don't use this algorithm for now. The buffers + // that are pointed to by the tuple will still use this algorithm. + // TODO(berkin): Because tuples are cheap to place in the alternate memory + // (they are just pointers) we don't need to use prefetch/evict logic. + if (buffer.values()[0]->shape().IsTuple()) { + VLOG(4) << "Keeping buffer " << buffer.ToString() + << " in default mem because it is a tuple."; + continue; + } + + auto colocated_intervals = GetSortedColocatedIntervals(interval); + bool keep_in_default_memory = false; + for (const BufferInterval* colocated_interval : colocated_intervals) { + const HloValue* value = colocated_interval->buffer; + // If any of the colocated values are phi buffers, we keep them in the + // default memory for now. + if (value->is_phi()) { + keep_in_default_memory = true; + VLOG(4) << "Keeping value " << value->ToShortString() + << " because it contains a phi node."; + break; + } + } + + MemorySpaceAssignment::AllocationSequence* allocation_sequence = + &(*allocation_map_)[&buffer]; + + // At this point, none of the colocated buffers contain any phi buffers. + for (const BufferInterval* colocated_interval : colocated_intervals) { + if (keep_in_default_memory) { + break; + } + const HloValue* value = colocated_interval->buffer; + int64 definition_time = + instruction_schedule_->at(value->defining_instruction()); + // Sort the uses by the use time. + std::vector uses = value->uses(); + absl::c_sort(uses, [&](HloUse use1, HloUse use2) { + return instruction_schedule_->at(use1.instruction) < + instruction_schedule_->at(use2.instruction); + }); + // Iterate over the uses. + for (HloUse use : uses) { + int64 use_time = instruction_schedule_->at(use.instruction); + + // Bitcasts don't define buffers and don't directly consume buffers. + // Skip allocating buffers for bitcast uses. The uses that feed from + // bitcasts will be handled specially. + if (use.instruction->opcode() != HloOpcode::kBitcast) { + if (!FindAllocation(definition_time, use_time, + value->defining_position(), use, value, + colocated_interval->size, allocation_sequence)) { + // If the allocation finding failed (e.g., due to running out of + // asynchronous copies), then fall back to allocating the buffer + // entirely in the default memory. + pending_chunks_.clear(); + pending_async_copies_.clear(); + allocation_sequence->clear(); + keep_in_default_memory = true; + break; + } + + // If there are multiple uses, they can try using the memory + // allocation already at the alternate memory. + definition_time = use_time; + } + } + } + + CommitPendingChunks(); + } + + if (VLOG_IS_ON(3)) { + for (const auto& alloc_pair : *allocation_map_) { + VLOG(3) << "Allocation for " << alloc_pair.first->ToString(); + for (const auto& alloc : alloc_pair.second) { + std::string addr_str = ": default"; + if (alloc->memory_space() == MemorySpace::kAlternate) { + addr_str = absl::StrCat(": alt ", alloc->chunk().offset); + } + + VLOG(3) << " " << alloc->start_time() << "-" << alloc->end_time() + << addr_str << ", " << alloc->uses().size() << " uses"; + } + } + } + + return result_; +} + +HloInstruction* AlternateMemoryBestFitHeap::GetInstructionAt(int64 time) const { + return flattened_instruction_sequence_->instructions()[time]; +} + +void AlternateMemoryBestFitHeap::CommitPendingChunks() { + for (auto interval_and_chunk : pending_chunks_) { + VLOG(3) << "Committing chunk: " << interval_and_chunk.first.start << "-" + << interval_and_chunk.first.end << " : [" + << interval_and_chunk.second.chunk.offset << ", " + << interval_and_chunk.second.chunk.size << "]"; + CommitChunk(interval_and_chunk.first, interval_and_chunk.second); + } + pending_chunks_.clear(); + // Also add the pending async copies to the interval tree. + if (max_outstanding_async_copies_ >= 0) { + for (auto interval : pending_async_copies_) { + async_copy_interval_tree_.Add(interval.first, interval.second, + kDummyChunk); + } + } + pending_async_copies_.clear(); +} + +void AlternateMemoryBestFitHeap::AddToPendingChunks( + const BufferInterval& buffer_interval, + const ChunkCandidate& chunk_candidate) { + pending_chunks_.emplace_back(buffer_interval, chunk_candidate); +} + +bool AlternateMemoryBestFitHeap::FindAllocation( + int64 start_time, int64 end_time, HloPosition defining_position, HloUse use, + const HloValue* buffer, int64 size, + MemorySpaceAssignment::AllocationSequence* allocations) { + HloInstruction* operand = + use.instruction->mutable_operand(use.operand_number); + // If the operand is a bitcast, we look at bitcast's operand until we find a + // non-bitcast operand. + HloInstruction* non_bitcast_operand = operand; + while (non_bitcast_operand->opcode() == HloOpcode::kBitcast) { + non_bitcast_operand = non_bitcast_operand->mutable_operand(0); + } + // Create an alternate memory interval that starts at the earliest + // possible position, given by max_prefetch_interval. + BufferInterval alternate_mem_interval; + alternate_mem_interval.buffer = buffer; + alternate_mem_interval.size = size; + alternate_mem_interval.start = + std::max(start_time, end_time - max_prefetch_interval_); + alternate_mem_interval.end = end_time; + + VLOG(2) << "Finding allocation for " << buffer->ToShortString() << " (" + << start_time << ", " << end_time << "). Size = " << size + << ", def pos = " << defining_position.ToString() + << ", operand = " << operand->ToString() + << (non_bitcast_operand != operand + ? ", non_bitcast_operand = " + non_bitcast_operand->ToString() + : ""); + CHECK_LT(start_time, end_time); + + // First try keeping the allocation entirely in the alternate memory. + if (TryAllocatingInAlternateMemoryNoCopy( + start_time, end_time, defining_position, use, alternate_mem_interval, + non_bitcast_operand, allocations)) { + return true; + } + + MemorySpaceAssignment::Allocation* prev_allocation = nullptr; + if (!allocations->empty()) { + prev_allocation = allocations->back().get(); + } + + // Since copies couldn't be removed, create an allocation in the default + // memory space. + if (prev_allocation != nullptr && + prev_allocation->memory_space() == MemorySpace::kAlternate && + prev_allocation->instruction() == non_bitcast_operand) { + // If there was an allocation for this HloValue that was in the alternate + // memory space, we also need to perform an eviction. + // TODO(berkin): For now evictions happen relative to the most recent + // allocation in the alternate memory. We can potentially start evictions + // earlier and end later. + VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " (" + << prev_allocation->start_time() << ", " + << prev_allocation->end_time() << ")"; + + // See if this interval would violate the asynchronous copy limit. + if (!ViolatesMaximumOutstandingAsyncCopies(prev_allocation->start_time(), + prev_allocation->end_time())) { + AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, + prev_allocation->start_time(), prev_allocation->end_time(), + allocations); + + } else { + VLOG(3) << "This violates the maximum async copies."; + // If the original interval violated the limit, try sub-intervals within + // this interval. + bool eviction_scheduled = false; + for (int64 time = prev_allocation->start_time(); + time <= prev_allocation->end_time(); ++time) { + VLOG(3) << "Try evicting (" << time << ", " << time << ")"; + if (!ViolatesMaximumOutstandingAsyncCopies(time, time)) { + VLOG(3) << "Eviction successful."; + AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, kDummyChunk, + time, time, allocations); + eviction_scheduled = true; + break; + } + } + + if (!eviction_scheduled) { + // If the eviction couldn't be scheduled, then fail. This buffer will be + // kept in the default memory. + VLOG(3) << "Bailing: Could not evict " << use.ToString() + << " because we hit the limit of maximum asynchronous copies " + << "between " + << GetInstructionAt(prev_allocation->start_time())->ToString() + << " and " + << GetInstructionAt(prev_allocation->end_time())->ToString(); + return false; + } + } + } else if (prev_allocation != nullptr && + prev_allocation->memory_space() == MemorySpace::kDefault && + prev_allocation->instruction() == non_bitcast_operand) { + // If the previous allocation was in the default memory space and was + // defined by the same instruction, extend that. Otherwise, create a new + // allocation. + prev_allocation->Extend(end_time); + } else { + allocations->push_back(absl::make_unique( + non_bitcast_operand, defining_position, MemorySpace::kDefault, + kDummyChunk, start_time, end_time)); + } + + // Try partially placing the buffer in the alternate space. The time that is + // overlapped will be used to asynchronously copy the buffer from the + // default memory to the alternate memory. + // + // start end + // time time + // X---------------------X + // Alternate: +------+ + // Default: +---------------------+ + // ^ ^ + // Copy Copy + // Start Done + for (alternate_mem_interval.start = + std::max(start_time, end_time - max_prefetch_interval_); + alternate_mem_interval.end - alternate_mem_interval.start > + min_prefetch_interval_; + ++alternate_mem_interval.start) { + VLOG(4) << "Trying alternate memory allocation (" + << alternate_mem_interval.start << ", " + << alternate_mem_interval.end << ")"; + // If this additional asynchronous copy would violate the limit, try a + // different interval. + if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, + alternate_mem_interval.end)) { + VLOG(4) << "This would violate the outstanding async copy limit."; + continue; + } + ChunkCandidate chunk_candidate = FindChunkCandidate(alternate_mem_interval); + // Check if the new heap size fits within limits. + if (chunk_candidate.heap_size < max_size_in_bytes_) { + VLOG(3) << "Move the buffer to alternate memory at " + << alternate_mem_interval.start + << ". Offset = " << chunk_candidate.chunk.offset + << ", size = " << chunk_candidate.chunk.size + << ", heap_size = " << chunk_candidate.heap_size; + AddToPendingChunks(alternate_mem_interval, chunk_candidate); + + AddAsyncCopy(*allocations->back().get(), MemorySpace::kAlternate, + chunk_candidate.chunk, alternate_mem_interval.start, + end_time, allocations); + + allocations->back()->AddUse(use); + return true; + } + } + + // If a copy wasn't inserted, then add this use to the latest allocation. + allocations->back()->AddUse(use); + return true; +} + +void AlternateMemoryBestFitHeap::AddAsyncCopy( + const MemorySpaceAssignment::Allocation& prev_allocation, + MemorySpace memory_space, Chunk chunk, int64 start_time, int64 end_time, + MemorySpaceAssignment::AllocationSequence* allocations) { + HloInstruction* earliest_instruction = GetInstructionAt(start_time); + HloInstruction* latest_instruction = GetInstructionAt(end_time); + + VLOG(3) << "Copy to " + << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault + ? "default" + : "alternate") + << " memory between instructions " << earliest_instruction->ToString() + << " - " << latest_instruction->ToString(); + + allocations->push_back( + absl::make_unique( + prev_allocation, memory_space, chunk, start_time, end_time, + earliest_instruction, latest_instruction)); + + // Register the additional async copy with the interval tree to keep track of + // the limit at any given time. + pending_async_copies_.emplace_back(start_time, end_time); +} + +bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( + int64 start_time, int64 end_time) const { + if (max_outstanding_async_copies_ < 0) { + return false; + } + + // Count both the asynchronous copies in the interval tree as well as the + // pending asynchronous copies belonging to this buffer. + int64 num_async_copies = + async_copy_interval_tree_.ChunksOverlappingInTime(start_time, end_time) + .size(); + + for (auto interval : pending_async_copies_) { + if (interval.second > start_time && interval.first < end_time) { + num_async_copies++; + } + } + // Add one because we are checking if adding an additional asynchronous copy + // would violate the limit. + return num_async_copies + 1 > max_outstanding_async_copies_; +} + +bool AlternateMemoryBestFitHeap::TryAllocatingInAlternateMemoryNoCopy( + int64 start_time, int64 end_time, HloPosition defining_position, HloUse use, + BufferInterval alternate_mem_interval, HloInstruction* non_bitcast_operand, + MemorySpaceAssignment::AllocationSequence* allocations) { + MemorySpaceAssignment::Allocation* prev_allocation = nullptr; + bool can_eliminate_copy = false; + if (allocations->empty()) { + // There hasn't been any allocations for this interval so far. We can + // eliminate copy if the value can be placed in the alternate memory. + can_eliminate_copy = + is_allowed_in_alternate_mem_(*alternate_mem_interval.buffer); + } else { + // If there has been a previous allocation, we can eliminate the copy if the + // previous allocation was also in the alternate memory. + prev_allocation = allocations->back().get(); + can_eliminate_copy = + (prev_allocation->memory_space() == MemorySpace::kAlternate); + } + + if (!can_eliminate_copy) { + return false; + } + + if (alternate_mem_interval.start != start_time) { + return false; + } + + // Prefer the offset that was previously used for the previous allocation. + int64 preferred_offset = -1; + if (prev_allocation != nullptr) { + preferred_offset = prev_allocation->chunk().offset; + // If there is a previous allocation, set the start time one after the end + // of the previous allocation's end. + alternate_mem_interval.start = prev_allocation->end_time() + 1; + } + + VLOG(4) << "We can eliminate copy to alternate memory. Preferred offset = " + << preferred_offset; + ChunkCandidate chunk_candidate = + FindChunkCandidate(alternate_mem_interval, preferred_offset); + // Check if the new heap size fits within limits. Also ensure if a + // preferred offset was provided, that offset was used. + if (chunk_candidate.heap_size < max_size_in_bytes_ && + (preferred_offset == -1 || + preferred_offset == chunk_candidate.chunk.offset)) { + VLOG(3) << "Keep the buffer in alternate memory. Offset = " + << chunk_candidate.chunk.offset + << ", size = " << chunk_candidate.chunk.size + << ", heap_size = " << chunk_candidate.heap_size; + AddToPendingChunks(alternate_mem_interval, chunk_candidate); + + // If there was a previous allocation, the buffer location is the + // same as the previous. Otherwise, it is the operand. + if (prev_allocation != nullptr && + prev_allocation->instruction() == non_bitcast_operand) { + prev_allocation->Extend(end_time); + } else { + allocations->push_back( + absl::make_unique( + non_bitcast_operand, defining_position, MemorySpace::kAlternate, + chunk_candidate.chunk, start_time, end_time)); + } + allocations->back()->AddUse(use); + return true; + } + return false; +} + +/*static*/ int64 MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies( + const HloModule& module) { + int64 max_copies = 0; + int64 current_copies = 0; + for (HloInstruction* instruction : + module.schedule().sequence(module.entry_computation()).instructions()) { + if (instruction->opcode() == HloOpcode::kCopyStart) { + current_copies++; + } else if (instruction->opcode() == HloOpcode::kCopyDone) { + current_copies--; + } + max_copies = std::max(max_copies, current_copies); + } + return max_copies; +} + +/*static*/ StatusOr> +MemorySpaceAssignment::Run( + HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes, + int64 min_prefetch_interval, int64 max_prefetch_interval, + int64 alternate_memory_space_alignment_in_bytes, + BufferValue::SizeFunction size_fn, + AlternateMemoryBestFitHeap::IsAllowedInAlternateMemoryFunction + is_allowed_in_alternate_mem, + int64 max_outstanding_async_copies) { + CHECK(module->has_schedule()); + VLOG(4) << "Module before memory space assignment: "; + XLA_VLOG_LINES(4, module->ToString()); + VLOG(4) << "Schedule: " << module->schedule().ToString(); + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module)); + + MemorySpaceAssignment memory_space_assignment(module, alternate_memory_space); + // TODO(berkin): Explore heap algorithms other than kSpatial. + auto algorithm = absl::make_unique( + &memory_space_assignment.allocation_map_, max_size_in_bytes, + min_prefetch_interval, max_prefetch_interval, *alias_analysis, + alternate_memory_space_alignment_in_bytes, + GlobalDecreasingSizeBestFitHeap::Type::kSpatial, + is_allowed_in_alternate_mem, max_outstanding_async_copies); + + TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, + module->schedule(), + *alias_analysis.get(), size_fn) + .status()); + + TF_RETURN_IF_ERROR(memory_space_assignment.Process()); + TF_RETURN_IF_ERROR(memory_space_assignment.FixSchedule()); + + VLOG(4) << "Module after memory space assignment: "; + XLA_VLOG_LINES(4, module->ToString()); + TF_CHECK_OK(module->schedule().Verify()); + VLOG(1) << "Maximum number of outstanding async copies: " + << CountMaximumOutstandingAsyncCopies(*module); + + return std::move(memory_space_assignment.preset_assignments_); +} + +void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { + HloInstruction* operand = + use.instruction->mutable_operand(use.operand_number); + // When the operand of a use is a bitcast, we place the bitcast in a separate + // data structure. + if (operand->opcode() == HloOpcode::kBitcast) { + bitcasts_.push_back(operand); + } else { + uses_.push_back(use); + } +} + +Status MemorySpaceAssignment::Allocation::PropagateMemorySpaceToBitcasts( + const MemorySpaceAssignment& memory_space_assignment) { + for (HloInstruction* bitcast : bitcasts_) { + if (memory_space_ == MemorySpace::kAlternate) { + Layout* bitcast_layout = bitcast->mutable_shape()->mutable_layout(); + bitcast_layout->set_memory_space( + memory_space_assignment.alternate_memory_space_); + } + } + return Status::OK(); +} + +Status MemorySpaceAssignment::Allocation::Process( + MemorySpaceAssignment* memory_space_assignment) { + // For non-copy allocations, all we need to do is to update the output memory + // space if placed in the alternate memory. + if (memory_space_ == MemorySpace::kAlternate) { + Layout* layout = instruction_->mutable_shape()->mutable_layout(); + layout->set_memory_space(memory_space_assignment->alternate_memory_space_); + } + TF_RETURN_IF_ERROR(PropagateMemorySpaceToBitcasts(*memory_space_assignment)); + return Status::OK(); +} + +Status MemorySpaceAssignment::CopyAllocation::Process( + MemorySpaceAssignment* memory_space_assignment) { + // Copy allocations need to insert asynchronous copy nodes. + HloInstruction* producing_instruction = instruction(); + CHECK_NE(producing_instruction, nullptr); + + Shape shape = producing_instruction->shape(); + HloComputation* computation = producing_instruction->parent(); + + // Set the layout to include the memory space. + Layout* layout = shape.mutable_layout(); + if (memory_space_ == MemorySpace::kAlternate) { + layout->set_memory_space(memory_space_assignment->alternate_memory_space_); + } else { + layout->set_memory_space(0); + } + + HloInstruction* copy_start = + computation->AddInstruction(HloInstruction::CreateUnary( + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + HloOpcode::kCopyStart, producing_instruction)); + HloInstruction* copy_done = computation->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start)); + // Update the allocation with the copy done instruction so that if there + // are further copies from it, it can find the correct instruction. + instruction_ = copy_done; + // Also update the defining position. Note that the output of CopyDone is + // actually defined in the item {0} of CopyStart. + defining_position_ = HloPosition{copy_start, {0}}; + + // Replace all the uses with the new copy instruction. + for (HloUse use : uses_) { + TF_RETURN_IF_ERROR( + use.instruction->ReplaceOperandWith(use.operand_number, copy_done)); + } + + // Replace all the bitcasts with the new copy instruction. Note that if there + // is a chain of bitcasts, their operands will be replaced with copy done. + // For example: + // + // a = Foo() + // b = Bitcast(a) + // c = Bitcast(b) + // + // If a is moved to the alternate memory asynchronously, the graph will be + // changed into: + // + // a = Foo() + // cs = CopyStart(a) + // cd = CopyDone(cs) + // b = Bitcast(cd) + // c = Bitcast(cd) + // + // Because of the potential shape change in the operand (b -> cd), we use + // ReplaceOperandWithDifferentShape. + for (HloInstruction* bitcast : bitcasts_) { + TF_RETURN_IF_ERROR(bitcast->ReplaceOperandWithDifferentShape( + /*operand_num=*/0, instruction_)); + } + + // Propagate the memory space to all bitcasts. + TF_RETURN_IF_ERROR(PropagateMemorySpaceToBitcasts(*memory_space_assignment)); + + // Insert the new instructions at the appropriate places in the schedule. + // FixSchedule will process the maps to actually insert them. + memory_space_assignment->ScheduleAsynchronousCopy( + copy_start, copy_start_schedule_after_, copy_done, + copy_done_schedule_before_); + return Status::OK(); +} + +Status MemorySpaceAssignment::Process() { + // Insert CopyStart/CopyDone pairs. + int64 alternate_memory_size = 0; + for (auto& buffer_and_sequence : allocation_map_) { + for (auto& allocation : buffer_and_sequence.second) { + TF_RETURN_IF_ERROR(allocation->Process(this)); + // Add the offset and size of the allocation in the alternate memory to + // the output map. Special case for bitcast: since bitcast doesn't define + // its own buffer, that shouldn't be exported as a preset chunk. + if (allocation->memory_space() == MemorySpace::kAlternate && + allocation->instruction()->opcode() != HloOpcode::kBitcast) { + preset_assignments_->add_chunk(allocation->defining_position(), + allocation->chunk()); + alternate_memory_size = + std::max(alternate_memory_size, allocation->chunk().chunk_end()); + } + } + } + + if (!preset_assignments_->chunks().empty()) { + preset_assignments_->add_size(alternate_memory_space_, + alternate_memory_size); + } + + if (VLOG_IS_ON(3)) { + VLOG(3) << "Exported alternate memory allocations:"; + for (auto& pair : preset_assignments_->chunks()) { + VLOG(3) << " [" << pair.second.offset << ", " << pair.second.size + << "] : " << pair.first.ToString(); + } + VLOG(3) << "Exported alternate memory sizes:"; + for (auto& pair : preset_assignments_->sizes()) { + VLOG(3) << " space: " << pair.first << ", size: " << pair.second; + } + } + return Status::OK(); +} + +void MemorySpaceAssignment::ScheduleAsynchronousCopy( + HloInstruction* copy_start, HloInstruction* copy_start_schedule_after, + HloInstruction* copy_done, HloInstruction* copy_done_schedule_before) { + schedule_after_[copy_start_schedule_after].push_back(copy_start); + schedule_before_[copy_done_schedule_before].push_back(copy_done); +} + +void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted( + HloInstruction* new_instruction, HloInstructionSequence* new_sequence, + absl::flat_hash_set* inserted_instructions) const { + if (inserted_instructions->contains(new_instruction)) { + return; + } + for (HloInstruction* operand : new_instruction->operands()) { + EnsureInstructionAndOperandsInserted(operand, new_sequence, + inserted_instructions); + } + VLOG(4) << "inserting: " << new_instruction->ToString(); + new_sequence->push_back(new_instruction); + inserted_instructions->insert(new_instruction); +} + +Status MemorySpaceAssignment::FixSchedule() { + CHECK(module_->has_schedule()); + HloSchedule& schedule = module_->schedule(); + for (const HloComputation* computation : + module_->MakeNonfusionComputations()) { + CHECK(schedule.is_computation_scheduled(computation)); + const HloInstructionSequence& sequence = schedule.sequence(computation); + HloInstructionSequence new_sequence; + + absl::flat_hash_set inserted_instructions; + + for (HloInstruction* instruction : sequence.instructions()) { + auto insts_before_iter = schedule_before_.find(instruction); + if (insts_before_iter != schedule_before_.end()) { + for (HloInstruction* new_instruction : insts_before_iter->second) { + EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, + &inserted_instructions); + } + } + // Insert only if not previously inserted. + if (!inserted_instructions.contains(instruction)) { + EnsureInstructionAndOperandsInserted(instruction, &new_sequence, + &inserted_instructions); + } + auto insts_after_iter = schedule_after_.find(instruction); + if (insts_after_iter != schedule_after_.end()) { + for (HloInstruction* new_instruction : insts_after_iter->second) { + EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence, + &inserted_instructions); + } + } + } + schedule.set_sequence(computation, new_sequence); + } + + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h new file mode 100644 index 00000000000..71ed39ded04 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -0,0 +1,367 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This class contains pre-set assignments determined by memory space +// assignment. It contains two data structures: (1) a chunks vector that maps a +// defining HloPosition to a Chunk (offset and size), and (2) a sizes vector +// that maps the memory space to its size. If there is only one alternate memory +// space like there is currently, there will be one entry in sizes. +class PresetAssignments { + public: + PresetAssignments() = default; + + void add_chunk(const HloPosition& position, + const HeapSimulator::Chunk& chunk) { + chunks_.emplace_back(position, chunk); + } + + void add_size(int64 memory_space, int64 size) { + sizes_.emplace_back(memory_space, size); + } + + absl::Span> + chunks() const { + return chunks_; + } + + absl::Span> sizes() const { return sizes_; } + + private: + std::vector> chunks_; + std::vector> sizes_; +}; + +// MemorySpaceAssignment assigns memory spaces (default or alternate) to each +// instruction in the module. It will greedily try placing as as many values in +// the alternate memory space as possible. It uses the heap simulator to +// determine the actual allocation offsets of values in the alternate memory +// space to account for fragmentation. The default memory space is assumed to be +// large enough to hold the values that could not be placed in the alternate +// memory space. +class MemorySpaceAssignment { + public: + using Chunk = HeapSimulator::Chunk; + + // MemorySpaceAssignment uses a notion of a slow and large default memory + // space and a fast and small alternate memory space. + enum class MemorySpace { kDefault, kAlternate }; + + // This class represents an allocation that might either be in the default or + // alternate memory. An HloValue might live in multiple different allocations + // over its lifetime. The lifetimes of the allocations are defined using + // start_time and end_time, which corresponds to the instruction indexes in + // the flattened schedule. Each of these allocations might partially overlap + // with each other. CopyAllocation defined below represents asynchronous + // copies between Allocations. + // + // Consider an instruction Foo, and its users Bar and Baz, and the times given + // in terms of the flattened schedule of the entire module: + // + // Foo:10 + // / \ + // Bar:14 \ + // Baz:25 + // + // A valid memory space assignment could be like the following: + // + // Time: 10 ... 14 ... 25 + // Foo Bar Baz + // Alternate +-------+ +-----+ + // Default +---------------------+ + // ^ ^ ^ ^ + // | | | | + // evict evict prefetch prefetch + // start end start end + // + // This would be represented with: + // - Allocation(memory_space=kAlternate, start_time=10, end_time=14) + // - CopyAllocation(memory_space=kDefault, start_time=12, end_time=25) + // - CopyAllocation(memory_space=kAlternate, start_time=22, end_time=25) + class Allocation { + public: + Allocation(HloInstruction* instruction, HloPosition defining_position, + MemorySpace memory_space, Chunk chunk, int64 start_time, + int64 end_time) + : instruction_(instruction), + defining_position_(defining_position), + memory_space_(memory_space), + chunk_(chunk), + start_time_(start_time), + end_time_(end_time) {} + virtual ~Allocation() = default; + + // Adds a use to this allocation. + void AddUse(HloUse use); + + // Extends the end time of this allocation. + void Extend(int64 end_time) { end_time_ = end_time; } + + // After all of the time ranges for the allocations have been assigned, + // Process morphs the instructions affected to assign the memory spaces and + // insert asynchronous copy instructions if necessary. + virtual Status Process(MemorySpaceAssignment* memory_space_assignment); + + // Returns the instruction that produces this allocation. It might be + // different than the instruction in defining_position (e.g., a + // GetTupleElement instruction does not define the buffer). + virtual HloInstruction* instruction() const { return instruction_; } + + // Returns the defining position for this allocation. + HloPosition defining_position() const { return defining_position_; } + + const std::vector& uses() const { return uses_; } + MemorySpace memory_space() const { return memory_space_; } + Chunk chunk() const { return chunk_; } + int64 start_time() const { return start_time_; } + int64 end_time() const { return end_time_; } + + protected: + // Bitcasts are treated specially because they do not define buffers. This + // method propagates the memory space for the bitcasts of this allocation. + Status PropagateMemorySpaceToBitcasts( + const MemorySpaceAssignment& memory_space_assignment); + + HloInstruction* instruction_; + HloPosition defining_position_; + std::vector uses_; + std::vector bitcasts_; + MemorySpace memory_space_; + Chunk chunk_; + int64 start_time_; + int64 end_time_; + }; + + // This class represents an allocation as a result of an asynchronous copy. + class CopyAllocation : public Allocation { + public: + CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, + Chunk chunk, int64 start_time, int64 end_time, + HloInstruction* copy_start_schedule_after, + HloInstruction* copy_done_schedule_before) + : Allocation(/*instruction=*/nullptr, + /*defining_position=*/{nullptr, {}}, memory_space, chunk, + start_time, end_time), + prev_allocation_(prev_allocation), + copy_start_schedule_after_(copy_start_schedule_after), + copy_done_schedule_before_(copy_done_schedule_before) {} + + Status Process(MemorySpaceAssignment* memory_space_assignment) override; + + HloInstruction* instruction() const override { + // Unless explicitly set, the instruction of a copy allocation in + // retrieved from the previous allocation. + if (instruction_ != nullptr) { + return instruction_; + } else { + return prev_allocation_.instruction(); + } + } + + private: + const Allocation& prev_allocation_; + // These variables define the scheduling boundaries where CopyStart and + // CopyDone can be scheduled. The earliest CopyStart can be scheduled is + // after copy_start_schedule_after_ and the latest CopyDone can be scheduled + // is before copy_done_schedule_before_. + HloInstruction* copy_start_schedule_after_; + HloInstruction* copy_done_schedule_before_; + }; + + using AllocationSequence = std::list>; + using AllocationMap = + absl::flat_hash_map; + + // Runs the MemorySpaceAssignment pass. alternate_memory_space is the + // architecture-specific integer value that describes the alternate memory. + // max_size_in_bytes is the maximum size of the alternate memory. + // min/max_prefetch_interval define min/max number of independent instructions + // that can be overlapped while prefetching to decide how early can prefetch + // begin. alternate_memory_space_alignment_in_bytes is the alignment required + // in the alternate memory space, size_fn is the size function for buffer + // values, and is_allowed_in_alternate_mem can be used to prevent certain + // HloValues (e.g., based on the opcode) to be placed on the alternate memory. + // max_outstanding_async_copies specifies the upper bound for number of + // outstanding asynchronous copies, -1 for unlimited. + // TODO(berkin): Use the cost model instead of using number of instructions to + // decide how early to prefetch. + static StatusOr> Run( + HloModule* module, int64 alternate_memory_space, int64 max_size_in_bytes, + int64 min_prefetch_interval, int64 max_prefetch_interval, + int64 alternate_memory_space_alignment_in_bytes, + BufferValue::SizeFunction size_fn, + std::function is_allowed_in_alternate_mem, + int64 max_outstanding_async_copies = -1); + + // Returns the maximum number of outstanding asynchronous copies in the + // module. + static int64 CountMaximumOutstandingAsyncCopies(const HloModule& module); + + private: + MemorySpaceAssignment(HloModule* module, int64 alternate_memory_space) + : module_(module), + alternate_memory_space_(alternate_memory_space), + preset_assignments_(absl::make_unique()) {} + + // Process calls Process methods of the allocations after the allocations have + // been finalized. + Status Process(); + + // FixSchedule inserts asynchronous copies in the schedule. + Status FixSchedule(); + + // Insert an instruction to the schedule, and make sure its dependencies + // (operands) are already in the schedule. If not, insert these operands + // before the instruction. + void EnsureInstructionAndOperandsInserted( + HloInstruction* new_instruction, HloInstructionSequence* new_sequence, + absl::flat_hash_set* inserted_instructions) const; + + // Schedules a pair of asynchronous copy instructions (copy_start and + // copy_done) where copy_start will be scheduled after the instruction in + // copy_start_schedule_after and copy_done will be scheduled before the + // instruction in copy_done_schedule_before. + void ScheduleAsynchronousCopy(HloInstruction* copy_start, + HloInstruction* copy_start_schedule_after, + HloInstruction* copy_done, + HloInstruction* copy_done_schedule_before); + + HloModule* module_; + int64 alternate_memory_space_; + AllocationMap allocation_map_; + std::unique_ptr preset_assignments_; + + // These maps hold vectors of new instructions that need to be scheduled after + // (or before) the instruction in the key. FixSchedule uses these maps to + // modify and fix the schedule. + absl::flat_hash_map> + schedule_after_; + absl::flat_hash_map> + schedule_before_; +}; + +// This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of +// maximum size. +class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { + public: + using IsAllowedInAlternateMemoryFunction = + std::function; + using MemorySpace = MemorySpaceAssignment::MemorySpace; + + AlternateMemoryBestFitHeap( + MemorySpaceAssignment::AllocationMap* allocation_map, + int64 max_size_in_bytes, int64 min_prefetch_interval, + int64 max_prefetch_interval, const HloAliasAnalysis& alias_analysis, + int64 alignment, GlobalDecreasingSizeBestFitHeap::Type type, + IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem, + int64 max_outstanding_async_copies) + : GlobalDecreasingSizeBestFitHeap(alignment, type), + allocation_map_(allocation_map), + max_size_in_bytes_(max_size_in_bytes), + min_prefetch_interval_(min_prefetch_interval), + max_prefetch_interval_(max_prefetch_interval), + alias_analysis_(alias_analysis), + is_allowed_in_alternate_mem_(is_allowed_in_alternate_mem), + max_outstanding_async_copies_(max_outstanding_async_copies) {} + + HeapSimulator::Result Finish() override; + + private: + // Finds an allocation for the given interval. Internally, it will attempt to + // find a suitable chunk candidate within the heap size and prefetch interval + // limits, and append the new allocation(s) to allocations. The new + // allocations can be in default or alternate memory spaces, or can be + // prefetches or evictions. Returns true if successful. + bool FindAllocation(int64 start_time, int64 end_time, + HloPosition defining_position, HloUse use, + const HloValue* buffer, int64 size, + MemorySpaceAssignment::AllocationSequence* allocations); + + // Try allocating in alternate memory without any copies. Returns true if + // successful. + bool TryAllocatingInAlternateMemoryNoCopy( + int64 start_time, int64 end_time, HloPosition defining_position, + HloUse use, BufferInterval alternate_mem_interval, + HloInstruction* non_bitcast_operand, + MemorySpaceAssignment::AllocationSequence* allocations); + + // Returns the instruction at a particular time in the flattened instruction + // schedule. + HloInstruction* GetInstructionAt(int64 time) const; + + // Given a buffer interval, returns the colocated intervals. Unlike the + // similar GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations, it + // returns the colocated intervals sorted by scheduled time. + std::vector GetSortedColocatedIntervals( + const BufferInterval& interval) const; + + // Since the allocations are recorded to the AllocationMap, we don't maintain + // result_ in GlobalDecreasingSizeBestFitHeap. Override AddToChunkMap to avoid + // unnecessarily adding the chunk to the chunk map. + void AddToChunkMap(const HloValue* buffer, Chunk chunk) override {} + + // Returns true if the addition of an asynchronous copy in the given time + // interval would violate the maximum number of asynchronous copies. + bool ViolatesMaximumOutstandingAsyncCopies(int64 start_time, + int64 end_time) const; + + // Adds an asynchronous copy to the allocations. + void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, + MemorySpace memory_space, Chunk chunk, int64 start_time, + int64 end_time, + MemorySpaceAssignment::AllocationSequence* allocations); + + // These methods are used for delaying committing the chunk candidate until + // the entire live range of the buffer has been considered. + void AddToPendingChunks(const BufferInterval& buffer_interval, + const ChunkCandidate& chunk_candidate); + void CommitPendingChunks(); + + MemorySpaceAssignment::AllocationMap* allocation_map_; + int64 max_size_in_bytes_; + // The min and max prefetch intervals decribe the number of independent HLOs + // overlapped while a value is being prefetched into the alternate memory + // (between CopyStart and CopyDone HLO instructions). max_prefetch_interval + // attempts to prevent bringing tensors into the alternate memory too eagerly + // and hence occupying the space for other tensors which might use it. + // min_prefetch_interval attempts to prevent cases where tensors are + // prefetched into the alternate memory without sufficient time for the copy + // to take place. In those cases, it's just better to keep the tensor in the + // default memory instead of hurting the critical path with this copy that + // likely won't finish in time. + // TODO(berkin): Explore heuristics that take into account the cost of copying + // tensors between alternate and default memories. + int64 min_prefetch_interval_; + int64 max_prefetch_interval_; + const HloAliasAnalysis& alias_analysis_; + IsAllowedInAlternateMemoryFunction is_allowed_in_alternate_mem_; + // We use a interval tree to keep track of the number of outstanding + // asynchronous copies. + BufferIntervalTree async_copy_interval_tree_; + int64 max_outstanding_async_copies_; + std::vector> pending_chunks_; + std::vector> pending_async_copies_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc new file mode 100644 index 00000000000..99ce46c0799 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -0,0 +1,583 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class MemorySpaceAssignmentTest : public HloTestBase { + protected: + // We use the following two memory space values to describe the default (slow + // and large) and alternate (fast and small) memory spaces. + const int64 kDefaultMemorySpace = 0; + const int64 kAlternateMemorySpace = 1; + + std::unique_ptr AssignMemorySpace( + HloModule* module, int64 max_outstanding_async_copies = -1) { + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + auto is_allowed_in_alternate_mem = [](const HloValue& value) { + // Check if the value belongs to the entry computation. + HloInstruction* instruction = value.instruction(); + HloComputation* computation = instruction->parent(); + bool in_entry_computation = + (computation == computation->parent()->entry_computation()); + if (in_entry_computation && + instruction->opcode() == HloOpcode::kParameter) { + return false; + } + return true; + }; + + std::unique_ptr preset_assignments = + MemorySpaceAssignment::Run( + module, kAlternateMemorySpace, + /*max_size_in_bytes=*/128, + /*min_prefetch_interval=*/2, + /*max_prefetch_interval=*/10, + /*alternate_memory_space_alignment_in_bytes=*/8, size_fn, + is_allowed_in_alternate_mem, max_outstanding_async_copies) + .ValueOrDie(); + CheckPresetAssignments(preset_assignments.get()); + return preset_assignments; + } + + void CheckPresetAssignments(const PresetAssignments* preset_assignments) { + // Ensure that the exported preset assignments point to layouts in the + // alternate memory. Also ensure that the positions are unique. Note that + // we're using a std::set instead of absl::flat_hash_set because we can make + // use of HloPosition's comparator logic instead of providing a hasher. + std::set positions_in_preset_assignments; + for (auto& position_and_chunk : preset_assignments->chunks()) { + HloPosition position = position_and_chunk.first; + EXPECT_EQ(positions_in_preset_assignments.find(position), + positions_in_preset_assignments.end()); + positions_in_preset_assignments.insert(position); + const Shape& subshape = + ShapeUtil::GetSubshape(position.instruction->shape(), position.index); + EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace) + << "Exported position is not in alternate mem: " + << position.ToString(); + } + } + + std::unique_ptr CreateEvictAndPrefetchModule() { + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* tanh = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0)); + // tanh should be placed in the alternate memory since there isn't much + // contention in the beginning. However, tanh has another consumer at the + // end. So it should be kicked out to default memory and prefetched back in. + // The graph below is meant to increase the contention to force + // eviction/prefetch behavior. + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh)); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1)); + HloInstruction* d = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* e = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b)); + HloInstruction* f = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c)); + HloInstruction* g = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d)); + HloInstruction* h = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c)); + HloInstruction* i = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d)); + HloInstruction* j = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d)); + HloInstruction* k = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f)); + HloInstruction* l = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h)); + HloInstruction* m = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j)); + HloInstruction* n = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l)); + HloInstruction* o = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m)); + // tanh is being used at the root instruction, and this should be + // prefetched. + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i, + j, k, l, m, n, o, add}); + TF_CHECK_OK(module->set_schedule(schedule)); + return module; + } +}; + +TEST_F(MemorySpaceAssignmentTest, ParameterOnly) { + // A module consisting of a single parameter. Inputs/outputs are currently + // excluded from memory space assignment. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + EXPECT_THAT(p0, op::ShapeWithLayout(shape)); +} + +TEST_F(MemorySpaceAssignmentTest, Simple) { + // A simple module with a few simple instructions. Expect this to be + // transformed with CopyStart and CopyDone instructions inserted after inputs + // and before outputs. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1)); + HloInstruction* sub = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1)); + HloInstruction* mul = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0, p1, add, sub, mul}); + TF_CHECK_OK(module->set_schedule(schedule)); + + auto preset_assignments = AssignMemorySpace(module.get()); + + // Inputs and outputs are currently placed in the default memory. Everything + // else should be in the alternate memory. + Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( + F32, {2, 3}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + EXPECT_THAT(p0, op::ShapeWithLayout(shape)); + EXPECT_THAT(p1, op::ShapeWithLayout(shape)); + EXPECT_THAT(mul, op::ShapeWithLayout(shape)); + EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem)); + + // Make sure the preset assignments is sane. + EXPECT_EQ(preset_assignments->chunks().size(), 2); + EXPECT_EQ(preset_assignments->sizes().size(), 1); + // Ensure the offset assigned to add and sub are different. + EXPECT_NE(preset_assignments->chunks()[0].second.offset, + preset_assignments->chunks()[1].second.offset); +} + +TEST_F(MemorySpaceAssignmentTest, NegateChain) { + // The negate chain is long enough for asynchronous copy to be inserted + // between p1 and add. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* negate5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); + HloInstruction* negate6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2, + negate3, negate4, negate5, negate6, add}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, + op::Parameter(1)))); + // Parameters are in the default memory space. + EXPECT_THAT(p0, op::ShapeWithLayout(shape)); + EXPECT_THAT(p1, op::ShapeWithLayout(shape)); + // Negate instructions are in the alternate memory space (1). + Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( + F32, {2, 3}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem)); + EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem)); + // Ensure the CopyStart/CopyDone schedules. + const HloInstructionSequence& sequence = + module->schedule().sequence(computation); + EXPECT_THAT(sequence.instructions()[0], op::Parameter(0)); + EXPECT_THAT(sequence.instructions()[1], op::Parameter(1)); + EXPECT_THAT(sequence.instructions()[2], op::CopyStart()); + EXPECT_THAT(sequence.instructions()[10], op::CopyDone()); +} + +TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetch) { + std::unique_ptr module = CreateEvictAndPrefetchModule(); + + AssignMemorySpace(module.get()); + + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Add(op::Add(), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::AsyncCopy(kDefaultMemorySpace, + kAlternateMemorySpace, op::Tanh())))); + + EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), + 2); +} + +TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) { + std::unique_ptr module = CreateEvictAndPrefetchModule(); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0); + + EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), + 0); +} + +TEST_F(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) { + std::unique_ptr module = CreateEvictAndPrefetchModule(); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1); + + EXPECT_EQ(MemorySpaceAssignment::CountMaximumOutstandingAsyncCopies(*module), + 1); +} + +TEST_F(MemorySpaceAssignmentTest, While) { + auto module = CreateNewVerifiedModule(); + Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3}); + Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + HloInstruction* cond_limit = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(50.f))); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter, + cond_limit, ComparisonDirection::kLt)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloInstruction* body_iter = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1)); + HloInstruction* body_data = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, body_param, 0)); + HloInstruction* body_iter_increment = body_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.f))); + HloInstruction* body_iter_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment)); + HloInstruction* body_data_increment = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}}))); + HloInstruction* body_data_mul = + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kMultiply, body_data, body_data)); + HloInstruction* body_data_add = + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, body_data, body_data_increment)); + HloInstruction* body_data_next = + body_builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, body_data_add, body_data_mul)); + HloInstruction* body_out = body_builder.AddInstruction( + HloInstruction::CreateTuple({body_data_next, body_iter_next})); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param_iter")); + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({data, iter})); + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_limit, cond_lt}); + schedule.set_sequence(body_computation, + {body_param, body_iter, body_data, body_iter_increment, + body_iter_next, body_data_increment, body_data_mul, + body_data_add, body_data_next, body_out}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + // Ensure the tuple value and buffers used in the while instruction are + // exempted from using the alternate memory. However, body_data_mul is + // independent and can be safely be placed in the alternate memory. + EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape)); + EXPECT_THAT(data, op::ShapeWithLayout(shape)); + EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape)); + EXPECT_THAT(body_data, op::ShapeWithLayout(shape)); + EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape)); + EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape)); + Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout( + F32, {2, 3}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem)); +} + +TEST_F(MemorySpaceAssignmentTest, Tuple) { + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape}); + Shape tuple_shape = + ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape}); + HloInstruction* p = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p")); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p, 0)); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* negate5 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4)); + HloInstruction* negate6 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5)); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p, 1)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1)); + HloInstruction* p2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2)); + HloInstruction* p2_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p2, 0)); + HloInstruction* mul = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence( + computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5, + negate6, p1, add, p2, p2_0, mul}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + EXPECT_THAT( + mul, + op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, + op::GetTupleElement())), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::GetTupleElement(op::GetTupleElement())))); +} + +TEST_F(MemorySpaceAssignmentTest, Bitcast) { + // Bitcasts can cause the position in the alternate memory to appear multiple + // times in the preset assignments. This test ensure the preset assignments + // refer to unique positions. + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); + HloInstruction* bitcast = + builder.AddInstruction(HloInstruction::CreateBitcast(shape, negate)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, p1)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0, p1, negate, bitcast, add}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, Bitcast2) { + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); + Shape param_shape = ShapeUtil::MakeShape(F32, {6}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "p1")); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3)); + HloInstruction* bitcast = + builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2, + negate3, negate4, bitcast, add}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace); +} + +TEST_F(MemorySpaceAssignmentTest, Bitcast3) { + HloComputation::Builder builder(TestName()); + Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3}); + Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); + Shape shape3 = ShapeUtil::MakeShape(F32, {1, 6}); + Shape param_shape = ShapeUtil::MakeShape(F32, {6}); + HloInstruction* p0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "p1")); + HloInstruction* negate0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p0)); + HloInstruction* negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0)); + HloInstruction* negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1)); + HloInstruction* negate3 = builder.AddInstruction( + HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2)); + HloInstruction* negate4 = builder.AddInstruction( + HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3)); + HloInstruction* bitcast1 = + builder.AddInstruction(HloInstruction::CreateBitcast(shape1, p1)); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, bitcast1, negate4)); + HloInstruction* bitcast2 = + builder.AddInstruction(HloInstruction::CreateBitcast(shape3, p1)); + HloInstruction* bitcast3 = + builder.AddInstruction(HloInstruction::CreateBitcast(shape2, bitcast2)); + HloInstruction* bitcast4 = + builder.AddInstruction(HloInstruction::CreateBitcast(shape2, add)); + HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( + shape2, HloOpcode::kMultiply, bitcast3, bitcast4)); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {p0, p1, negate0, negate1, negate2, negate3, negate4, + bitcast1, add, bitcast2, bitcast3, bitcast4, mul}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + // We expect one bitcast on the LHS of multiply since bitcast(bitcast(foo)) is + // converted to bitcast(foo). + EXPECT_THAT( + mul, + op::Multiply( + op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::Parameter(1))), + op::Bitcast(op::Add( + op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, + kDefaultMemorySpace, op::Parameter(1))), + op::Negate())))); + EXPECT_EQ(bitcast1->shape().layout().memory_space(), kAlternateMemorySpace); + EXPECT_EQ(add->shape().layout().memory_space(), kAlternateMemorySpace); + // bitcast2 will no longer have a consumer and should get DCE'd, so we don't + // care about its memory space. + EXPECT_EQ(bitcast3->shape().layout().memory_space(), kAlternateMemorySpace); + EXPECT_EQ(bitcast4->shape().layout().memory_space(), kAlternateMemorySpace); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 72ca402427e..5a26ea1be22 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -38,14 +38,59 @@ cc_library( hdrs = ["mlir_compiler.h"], deps = [ ":failover_compiler", + ":lhlo_dialect_emitter", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:dump", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/gpu:gpu_constants", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/service/gpu:gpu_hlo_schedule", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl", + "//tensorflow/compiler/xla/service/gpu:stream_assignment", "//tensorflow/compiler/xla/service/gpu:target_constants", "//tensorflow/core:lib", + "//tensorflow/stream_executor:stream_executor_headers", "@local_config_mlir//:IR", "@local_config_mlir//:LLVMDialect", ], alwayslink = True, # Contains compiler registration ) + +cc_library( + name = "lhlo_dialect_emitter", + srcs = ["lhlo_dialect_emitter.cc"], + hdrs = ["lhlo_dialect_emitter.h"], + deps = [ + "//tensorflow/compiler/mlir/xla:lhlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:thunk", + "//tensorflow/compiler/xla/service/gpu:thunk_emitter", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:stream_executor_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@local_config_mlir//:IR", + "@local_config_mlir//:LLVMDialect", + "@local_config_mlir//:StandardOps", + ], +) + +cc_library( + name = "mlir_irgen_test_base", + testonly = True, + srcs = ["mlir_irgen_test_base.cc"], + hdrs = ["mlir_irgen_test_base.h"], + deps = [ + ":failover_compiler", + ":mlir_compiler", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:codegen_test_base", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + "@llvm//:support", + "@local_config_mlir//:IR", + ], +) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc index f225e92bd30..4107d92da7e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.cc @@ -50,25 +50,6 @@ StatusOr> FailoverCompiler::RunBackend( return result; } -Status FailoverCompiler::RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) { - // This is not supported by GPU compiler anyway. - return Unimplemented( - "Model partitioning not implemented for the failover compiler!"); -} - -StatusOr>> -FailoverCompiler::RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - // This is not supported by GPU compiler anyway. - return Unimplemented( - "Model partitioning not implemented for the failover compiler!"); -} - StatusOr>> FailoverCompiler::Compile( std::unique_ptr module_group, std::vector> stream_execs, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h index cfa542f2e38..05badaa98e1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h @@ -57,16 +57,6 @@ class FailoverCompiler final : public Compiler { std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - Status RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, @@ -78,6 +68,9 @@ class FailoverCompiler final : public Compiler { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; + Compiler* GetPrimary() const { return primary_.get(); } + Compiler* GetSecondary() const { return secondary_.get(); } + private: std::unique_ptr primary_; std::unique_ptr secondary_; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc new file mode 100644 index 00000000000..1f8241aeda3 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -0,0 +1,223 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir +#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Identifier.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace mlir_gpu { +namespace { + +using ::mlir::ArrayRef; +using ::mlir::Attribute; +using ::mlir::Builder; +using ::mlir::FuncOp; +using ::mlir::Identifier; +using ::mlir::Location; +using ::mlir::ModuleOp; +using ::mlir::NamedAttribute; +using ::mlir::OpBuilder; +using ::mlir::Type; +using ::mlir::Value; +using ::mlir::LLVM::LLVMDialect; +using ::xla::gpu::Thunk; +using ::xla::gpu::ThunkEmitter; +using ::xla::gpu::ThunkSequence; + +Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, + ArrayRef rets, ArrayRef args, + ArrayRef> attrs) { + switch (opcode) { + case HloOpcode::kAdd: + func_builder.create<::mlir::xla_lhlo::AddOp>(loc, rets, args, attrs); + break; + case HloOpcode::kMultiply: + func_builder.create<::mlir::xla_lhlo::MulOp>(loc, rets, args, attrs); + break; + case HloOpcode::kSubtract: + func_builder.create<::mlir::xla_lhlo::SubOp>(loc, rets, args, attrs); + break; + case HloOpcode::kDivide: + func_builder.create<::mlir::xla_lhlo::DivOp>(loc, rets, args, attrs); + break; + case HloOpcode::kAnd: + func_builder.create<::mlir::xla_lhlo::AndOp>(loc, rets, args, attrs); + break; + case HloOpcode::kMinimum: + func_builder.create<::mlir::xla_lhlo::MinOp>(loc, rets, args, attrs); + break; + case HloOpcode::kMaximum: + func_builder.create<::mlir::xla_lhlo::MaxOp>(loc, rets, args, attrs); + break; + default: + return tensorflow::errors::Internal(absl::StrCat( + "Opcode ", HloOpcodeString(opcode), " is not supported.")); + } + return Status::OK(); +} + +StatusOr<::mlir::MemRefType> ConvertTensorType(const Shape& shape, + Builder builder) { + llvm::SmallVector array; + array.reserve(shape.dimensions_size()); + for (const auto dim : shape.dimensions()) { + array.push_back(dim); + } + switch (shape.element_type()) { + case PrimitiveType::PRED: + return builder.getMemRefType(array, builder.getI1Type()); + case PrimitiveType::F16: + return builder.getMemRefType(array, builder.getF16Type()); + case PrimitiveType::F32: + return builder.getMemRefType(array, builder.getF32Type()); + case PrimitiveType::F64: + return builder.getMemRefType(array, builder.getF64Type()); + case PrimitiveType::S8: + return builder.getMemRefType(array, builder.getIntegerType(8)); + case PrimitiveType::S16: + return builder.getMemRefType(array, builder.getIntegerType(16)); + case PrimitiveType::S32: + return builder.getMemRefType(array, builder.getIntegerType(32)); + case PrimitiveType::S64: + return builder.getMemRefType(array, builder.getIntegerType(64)); + default: + return tensorflow::errors::Internal(absl::StrCat( + "Unsupported type: ", PrimitiveType_Name(shape.element_type()))); + } +} + +StatusOr ConvertType(const Shape& shape, Builder builder) { + if (shape.IsTuple()) { + Type mlir_type; + llvm::SmallVector contents; + contents.reserve(shape.tuple_shapes_size()); + for (const auto& subtype : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(auto mlir_subtype, ConvertType(subtype, builder)); + contents.push_back(mlir_subtype); + } + return builder.getTupleType(contents); + } + return ConvertTensorType(shape, builder); +} + +StatusOr> GetInstructionArgTypes( + const HloInstruction& instruction, Builder builder) { + llvm::SmallVector arg_types; + for (auto operand : instruction.operands()) { + TF_ASSIGN_OR_RETURN(auto operand_type, + ConvertType(operand->shape(), builder)); + arg_types.push_back(operand_type); + } + TF_ASSIGN_OR_RETURN(auto operand_type, + ConvertType(instruction.shape(), builder)); + arg_types.push_back(operand_type); + return arg_types; +} + +} // namespace + +LhloDialectEmitter::LhloDialectEmitter(const HloModule& hlo_module, + const BufferAssignment& assignment, + const se::Platform* platform, + ModuleOp mlir_module) + : mlir_module_(mlir_module), + builder_(mlir_module_.getContext()), + buffer_assignment_(assignment), + platform_(platform), + thunk_sequence_(new ThunkSequence()) { + LLVMDialect* llvmDialect = + mlir_module.getContext()->getRegisteredDialect(); + pointer_size_ = llvmDialect->getLLVMModule().getDataLayout().getPointerSize(); +} + +void LhloDialectEmitter::AddThunkToThunkSequence(std::unique_ptr thunk) { + thunk_sequence_->push_back(std::move(thunk)); +} + +StatusOr LhloDialectEmitter::MaybeGetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index) const { + return buffer_assignment_.GetUniqueSlice(&hlo, index); +} + +int64 LhloDialectEmitter::ByteSizeOf(const Shape& shape) const { + return ShapeUtil::ByteSizeOf(shape, pointer_size_); +} + +const se::Platform* LhloDialectEmitter::platform() const { return platform_; } + +Status LhloDialectEmitter::EmitComputation(const HloComputation& computation) { + return computation.root_instruction()->Accept(this); +} + +StatusOr LhloDialectEmitter::CreateFunction( + const HloInstruction& instr) { + TF_ASSIGN_OR_RETURN(auto args, GetInstructionArgTypes(instr, builder_)); + auto function_type = builder_.getFunctionType(args, {}); + auto function = + FuncOp::create(builder_.getUnknownLoc(), instr.name(), function_type); + mlir_module_.push_back(function); + function.addEntryBlock(); + instruction_to_mlir_func_[&instr] = function; + return function; +} + +Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { + TF_ASSIGN_OR_RETURN(auto function, CreateFunction(*instr)); + OpBuilder func_builder(function.getBody()); + llvm::SmallVector arg_values{function.args_begin(), + function.args_end()}; + llvm::SmallVector attributes{ + builder_.getNamedAttr("name", builder_.getStringAttr(instr->name()))}; + TF_RETURN_IF_ERROR(InsertMlirOp(instr->opcode(), func_builder, + builder_.getUnknownLoc(), ArrayRef{}, + arg_values, attributes)); + return Status::OK(); +} + +Status LhloDialectEmitter::HandleFusion(HloInstruction* fusion) { + LOG(FATAL) << "Not implemented yet."; +} + +Status LhloDialectEmitter::HandleCustomCall(HloInstruction* custom_call) { + return ThunkEmitter(this).HandleCustomCall(custom_call); +} + +Status LhloDialectEmitter::HandleParameter(HloInstruction* parameter) { + return Status::OK(); +} + +Status LhloDialectEmitter::FinishVisit(HloInstruction* root) { + return Status::OK(); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h new file mode 100644 index 00000000000..7d0c818068a --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -0,0 +1,91 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_ + +#include "absl/container/flat_hash_map.h" +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace mlir_gpu { + +// Implementation for the translation of HLO instructions to a ThunkSequence +// via MLIR using the LHLO dialect. +// Implements the DfsHloVisitor interface, emits LHLO computations as MLIR IR +// functions and transforms them into gpu::Thunk. +class LhloDialectEmitter : public DfsHloVisitorWithDefault, + private gpu::ThunkEmitter::EmissionContext { + public: + LhloDialectEmitter(const HloModule& hlo_module, + const BufferAssignment& assignment, + const se::Platform* platform, + ::mlir::ModuleOp mlir_module); + ~LhloDialectEmitter() override = default; + + Status EmitComputation(const HloComputation& computation); + + // The following methods implement the DfsHloVisitor interface. + // + // Default action which emits code for most operations. Operations which are + // special in some way are handled explicitly in HandleFoo methods. + Status DefaultAction(HloInstruction* instr) override; + + Status HandleFusion(HloInstruction* fusion) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleParameter(HloInstruction* parameter) override; + + Status FinishVisit(HloInstruction* root) override; + + // Transfers the ownship of thunk_sequence_ out. + std::unique_ptr ConsumeThunkSequence() { + return std::move(thunk_sequence_); + } + + private: + StatusOr<::mlir::FuncOp> CreateFunction(const HloInstruction& instr); + // Interface required by ThunkEmitter + void AddThunkToThunkSequence(std::unique_ptr thunk) override; + StatusOr MaybeGetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index) const override; + int64 ByteSizeOf(const Shape& shape) const override; + const se::Platform* platform() const override; + + ::mlir::ModuleOp mlir_module_; + ::mlir::Builder builder_; + absl::flat_hash_map + instruction_to_mlir_func_; + const BufferAssignment& buffer_assignment_; + const se::Platform* platform_; + // Cached pointer size extracted from the mlir module. + unsigned pointer_size_; + // The thunk sequence this IrEmitter generates for the input computation. + std::unique_ptr thunk_sequence_; + + TF_DISALLOW_COPY_AND_ASSIGN(LhloDialectEmitter); +}; + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_LHLO_DIALECT_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index 5421a3ae093..d240003b039 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -15,21 +15,41 @@ limitations under the License. #include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" -#include "mlir/LLVMIR/LLVMDialect.h" // TF:local_config_mlir +#include + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/nvptx_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" +#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { -namespace mlir { +namespace mlir_gpu { +namespace { using ::mlir::MLIRContext; +using ::mlir::ModuleOp; +using ::mlir::OwningModuleRef; +using ::mlir::UnknownLoc; using ::mlir::LLVM::LLVMDialect; +using ::xla::gpu::GpuExecutable; +using ::xla::gpu::GpuHloSchedule; +using ::xla::gpu::GpuVersion; +using ::xla::gpu::StreamAssignment; +using ::xla::gpu::ThunkSchedule; -namespace { int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { LLVMDialect* dialect = context->getRegisteredDialect(); llvm::Module& module = dialect->getLLVMModule(); @@ -37,6 +57,7 @@ int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { module.setDataLayout(gpu::nvptx::kDataLayout); return module.getDataLayout().getPointerSize(); } + } // namespace MlirCompiler::MlirCompiler() @@ -51,34 +72,109 @@ StatusOr> MlirCompiler::RunHloPasses( se::DeviceMemoryAllocator* device_allocator) { // Until we find a reason to do something different, run the same passes // that the normal GPU backend runs. - TF_RETURN_IF_ERROR(xla::gpu::impl::OptimizeHloModule( - module.get(), stream_exec, device_allocator)); - - TF_RETURN_IF_ERROR( - xla::gpu::impl::PrepareHloModuleForIrEmitting(module.get())); + gpu::NVPTXCompiler xla_compiler; + TF_RETURN_IF_ERROR(xla_compiler.OptimizeHloModule(module.get(), stream_exec, + device_allocator)); + TF_RETURN_IF_ERROR(xla_compiler.PrepareHloModuleForIrEmitting(module.get())); return std::move(module); } +namespace { + +// TODO(b/137624192): Move this to custom call handling and share. +absl::optional CanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index) { + if (user->opcode() == HloOpcode::kCustomCall) { + // Share the bias buffer with the parent instruction. + if (user->custom_call_target() == xla::gpu::kGemmCallTarget) { + if (user->operand_count() == 3 && user->operand(2) == operand) { + return true; + } + } + // The operand of cholesky can be shared with the first output. + if (user->custom_call_target() == xla::gpu::kCusolverCholeskyCallTarget) { + return user_index.size() == 1 && user_index[0] == 0; + } + } + return absl::nullopt; +} + +// TODO(b/137624192): Share this with nvptx backend. +GpuVersion GetGpuVersion(se::StreamExecutor* stream_exec) { + int cc_major, cc_minor; + const auto& device_description = stream_exec->GetDeviceDescription(); + if (!device_description.cuda_compute_capability(&cc_major, &cc_minor)) { + LOG(WARNING) + << "Couldn't get compute capability for device; assuming sm_20."; + cc_major = 2; + cc_minor = 0; + } + return std::make_pair(cc_major, cc_minor); +} + +} // namespace + StatusOr> MlirCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("Not yet implemented in MLIR compiler"); -} + // Determine the HLO schedule, which is an ordering of HLO instructions. This + // is used by buffer assignment to enable buffer reuse, and the same ordering + // must also be used to determine the thunk launch schedule. + std::unique_ptr stream_assignment = + xla::gpu::AssignStreams(*module); + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); -Status MlirCompiler::RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("Not yet implemented in MLIR compiler"); -} + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_assignment, + BufferAssigner::Run( + module.get(), hlo_schedule->ConsumeHloOrdering(), + BufferSizeBytesFunction(), + /*color_alignment=*/ + [](LogicalBuffer::Color) { + return xla::gpu::kXlaAllocatedBufferAlignBytes; + }, + /*allocate_buffers_for_constants=*/true, + /*colorer=*/BufferAssigner::DefaultColorer(), + /*must_not_live_out=*/{}, &CanShareBufferHint)); + DumpHloModuleIfEnabled(*module, *buffer_assignment, "after_optimizations"); -StatusOr>> -MlirCompiler::RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) { - return Unimplemented("Not yet implemented in MLIR compiler"); + MLIRContext mlir_context; + OwningModuleRef mlir_module = + ModuleOp::create(UnknownLoc::get(&mlir_context)); + LhloDialectEmitter lhlo_emitter(*module, *buffer_assignment, + stream_exec->platform(), *mlir_module); + + TF_RETURN_IF_ERROR( + lhlo_emitter.EmitComputation(*module->entry_computation())); + + if (module_hook_.callback && !module_hook_.apply_on_lowered) { + module_hook_.callback(*mlir_module); + } + + // TODO(b/137624192): Emit function per hlo and turn into ptx string and blob. + std::string ptx; + std::vector cubin; + + auto thunk_schedule = absl::make_unique( + lhlo_emitter.ConsumeThunkSequence(), std::move(stream_assignment), + hlo_schedule->ThunkLaunchOrder()); + + if (DumpingEnabledForHloModule(*module)) { + DumpToFileInDirOrStdout(*module, "thunk_schedule", + thunk_schedule->ToString()); + } + + // TODO(b/137624192): Add profiling support. + + return static_cast>( + absl::make_unique( + ptx, cubin, GetGpuVersion(stream_exec), std::move(thunk_schedule), + std::move(module), std::move(buffer_assignment), nullptr, nullptr)); } StatusOr>> MlirCompiler::Compile( @@ -94,14 +190,20 @@ MlirCompiler::CompileAheadOfTime(std::unique_ptr module_group, return Unimplemented("Not yet implemented in MLIR compiler"); } -} // namespace mlir +void MlirCompiler::SetModuleHook(IRHook module_hook) { + module_hook_ = module_hook; +} + +void MlirCompiler::RemoveModuleHook() { module_hook_ = {nullptr, false}; } + +} // namespace mlir_gpu } // namespace xla static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, []() { return absl::make_unique( - absl::make_unique(), + absl::make_unique(), absl::make_unique()); }); return true; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index f02164c4d24..fdc71903a06 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -17,10 +17,11 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir #include "tensorflow/compiler/xla/service/compiler.h" namespace xla { -namespace mlir { +namespace mlir_gpu { // A Compiler implementation that converts XLAs IR to a matching MLIR dialect, // performs all lowering on the MLIR IR and finally converts MLIR to LLVMIR for @@ -39,16 +40,6 @@ class MlirCompiler : public Compiler { std::unique_ptr module, se::StreamExecutor* stream_exec, se::DeviceMemoryAllocator* device_allocator) override; - Status RunHloPassesOnModuleGroup( - HloModuleGroup* module_group, - absl::Span executors, - se::DeviceMemoryAllocator* device_allocator) override; - - StatusOr>> RunBackendOnModuleGroup( - std::unique_ptr module_group, - std::vector> stream_exec, - se::DeviceMemoryAllocator* device_allocator) override; - StatusOr>> Compile( std::unique_ptr module_group, std::vector> stream_execs, @@ -65,12 +56,21 @@ class MlirCompiler : public Compiler { }; } + struct IRHook { + std::function callback; + bool apply_on_lowered; + }; + + void SetModuleHook(IRHook module_hook); + void RemoveModuleHook(); + private: ::mlir::MLIRContext context_; int64 pointer_size_; + IRHook module_hook_; }; -} // namespace mlir +} // namespace mlir_gpu } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc new file mode 100644 index 00000000000..4b6a03270c7 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.cc @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h" + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/failover_compiler.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace mlir_gpu { + +void MlirIrGenTestBase::CompileAndVerifyIr( + std::unique_ptr hlo_module, const string& pattern, + bool match_lowered_ir) { + MlirCompiler* compiler = GetMLIRCompiler(); + string ir; + compiler->SetModuleHook({[&ir](mlir::ModuleOp module) -> Status { + std::string buffer_string; + llvm::raw_string_ostream ostream(buffer_string); + module.print(ostream); + ostream.flush(); + ir = buffer_string; + return Status::OK(); + }, + match_lowered_ir}); + Status status = CompileToExecutable(std::move(hlo_module)).status(); + compiler->RemoveModuleHook(); + TF_ASSERT_OK(status); + + StatusOr filecheck_result = RunFileCheck(ir, pattern); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(filecheck_result.ValueOrDie()); +} + +void MlirIrGenTestBase::CompileAndVerifyIr(const string& hlo_text, + const string& expected_llvm_ir, + bool match_lowered_ir) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule(hlo_text, config)); + CompileAndVerifyIr(std::move(module), expected_llvm_ir, match_lowered_ir); +} + +MlirCompiler* MlirIrGenTestBase::GetMLIRCompiler() { + // TODO(b/137624192): Remove failover once no longer in place. + FailoverCompiler* failover = + static_cast(backend().compiler()); + return static_cast(failover->GetPrimary()); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h new file mode 100644 index 00000000000..613ddc27bf6 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_ + +#include + +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h" +#include "tensorflow/compiler/xla/tests/codegen_test_base.h" + +namespace xla { +namespace mlir_gpu { + +// Tests that verify IR emitted by the CPU/GPU backend is as expected. +class MlirIrGenTestBase : public CodegenTestBase { + protected: + // Compiles the given HLO module to MLIR IR and verifies the IR matches the + // given pattern. `pattern` is in the FileCheck pattern matching syntax + // (http://llvm.org/docs/CommandGuide/FileCheck.html). + // + // This function invokes the JIT compiler. + // + // If `match_lowered_ir` is true, match the version of the IR after lowering + // steps to LLVM IR are applied; otherwise, the IR before lowering is + // matched. + void CompileAndVerifyIr(std::unique_ptr hlo_module, + const string& pattern, bool match_lowered_ir = false); + + // A thin wrapper around CompileAndVerifyIr that parses `hlo_text` to create + // an HLO module. + void CompileAndVerifyIr(const string& hlo_text, + const string& expected_llvm_ir, + bool match_lowered_ir = false); + + // Compiles and returns module with optimizations from a given HLO. + StatusOr> GetOptimizedModule( + absl::string_view hlo); + + private: + MlirCompiler* GetMLIRCompiler(); +}; + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_IRGEN_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD new file mode 100644 index 00000000000..2e799381c48 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/BUILD @@ -0,0 +1,42 @@ +# TODO(herhut): describe this package. + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core/platform:default/build_config_root.bzl", + "tf_cuda_tests_tags", +) + +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +tf_cc_test( + name = "mlir_gpu_lhlo_gen_test", + srcs = ["mlir_gpu_lhlo_gen_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:mlir_gpu_plugin", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/service/mlir_gpu:mlir_irgen_test_base", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc new file mode 100644 index 00000000000..5e9413c1b5e --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/mlir_gpu_lhlo_gen_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/mlir_irgen_test_base.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace mlir_gpu { + +class LhloGenTest : public MlirIrGenTestBase {}; + +TEST_F(LhloGenTest, Add) { + CompileAndVerifyIr(R"( +HloModule Add + +ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { + %x = f32[2,2]{1,0} parameter(0) + %y = f32[2,2]{1,0} parameter(1) + ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) +})", + R"( +;CHECK: module { +;CHECK: func @add(%{{.*}}: memref<2x2xf32>, %{{.*}}: memref<2x2xf32>, %{{.*}}: memref<2x2xf32>) { +;CHECK: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %{{.*}}) {name = "add"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () +;CHECK: } +;CHECK: } + )"); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index 582e59349e8..6c31f6bdc86 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -123,7 +123,6 @@ HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, if (fused->IsMultiOutputFusion()) { std::swap(remaining, fused); } - if (fused->opcode() == HloOpcode::kFusion) { remaining->MergeFusionInstructionIntoMultiOutput(fused); } else { @@ -249,14 +248,12 @@ bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, multioutput_user_is_not_gte(instr2)) { return false; } - if (is_connected(instr1, instr2)) { return false; } if (!ShapesCompatibleForFusion(instr1, instr2)) { return false; } - return true; } @@ -339,4 +336,5 @@ bool MultiOutputFusion::Perform() { } bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion() { return false; } + } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 3d129c4ec50..9000370f6f3 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -40,8 +40,8 @@ namespace xla { // fused and their fusion profit scores. // // Function Perform() applies the optimization. It picks up the most profitable -// pair in the worklist_, check if it's legal to fuse and fuse the pair. -// After fusion, it updates the associated structure such as reachability_, +// pair in the worklist_, checks if it's legal to fuse and fuses the pair. +// After fusion, it updates the associated structures such as reachability_, // candidates_ and worklist_. // Note that the reachability map is updated based on the original computation. // This works because the reachability is monotonically increasing with @@ -105,13 +105,6 @@ class MultiOutputFusion : public HloModulePass { virtual bool DoProducerConsumerMultiOutputFusion(); private: - // Update the internal data structures after instr1 and instr2 are fused into - // one fusion instruction. - void Update(HloInstruction* instr1, HloInstruction* instr2); - - // Computation for the pass. - HloComputation* computation_; - // An internal data structure for each instruction in current computation. // When an instruction is removed, member 'hlo' is set to nullptr. struct FusionCandidate { @@ -119,16 +112,6 @@ class MultiOutputFusion : public HloModulePass { std::list> fusibles; explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} }; - std::vector candidates_; - - // A map that maps an instruction to the index_. - absl::flat_hash_map candidates_index_; - - // The reachability map of current computation. - std::unique_ptr reachability_; - - // This stores all the candidate instructions in current computation. - std::vector all_fusion_candidates_; // The pair of candidates to be fused and the profit score. struct ToBeFused { @@ -139,7 +122,10 @@ class MultiOutputFusion : public HloModulePass { : instr1(instr1), instr2(instr2), score(score) {} bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } }; - std::priority_queue worklist_; + + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); int64 get_candidate_id(HloInstruction* instr) { return FindOrDie(candidates_index_, instr); @@ -156,6 +142,21 @@ class MultiOutputFusion : public HloModulePass { bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { return reachability_->IsConnected(instr1, instr2); } + + std::vector candidates_; + std::priority_queue worklist_; + + // A map that maps an instruction to the index_. + absl::flat_hash_map candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // Computation for the pass. + HloComputation* computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index db2cd28d0c5..32e4c636327 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -455,9 +455,9 @@ class LayoutPattern { template auto AppendImpl(NewImpl new_impl) const -> LayoutPattern(std::declval(), - std::move(new_impl)))> { - auto new_allof = AllOf(impl_, std::move(new_impl)); + decltype(AllOf<::xla::Layout>(std::declval(), + std::move(new_impl)))> { + auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl)); return LayoutPattern(std::move(new_allof), matched_layout_); } @@ -869,7 +869,7 @@ class ShapePatternLayoutImpl { layout_.Match(&shape->layout(), option); } - bool Match(Shape* shape, MatchOption option) const { + bool Match(::xla::Shape* shape, MatchOption option) const { if (!LayoutUtil::HasLayout(*shape)) { EXPLAIN << "Shape does not have a layout"; return false; @@ -946,9 +946,10 @@ class ShapePattern { private: template auto AppendImpl(NewImpl new_impl) const - -> ShapePattern(std::declval(), - std::move(new_impl)))> { - auto new_all_of = AllOf(impl_, std::move(new_impl)); + -> ShapePattern(std::declval(), + std::move(new_impl)))> { + auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl)); return ShapePattern(std::move(new_all_of), matched_shape_); } @@ -1077,7 +1078,7 @@ class ShapePattern { } ShapePattern& op2) : op1_(op1), op2_(op2) {} - bool Match(HloInstruction* inst, MatchOption option) const { + bool Match(::xla::HloInstruction* inst, MatchOption option) const { return MatchImpl(inst, option); } - bool Match(const HloInstruction* inst, MatchOption option) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { return MatchImpl(inst, option); } @@ -1663,7 +1664,7 @@ class HloInstructionPatternOneUseOrUserImpl { class HloInstructionPatternOneUseImpl : public HloInstructionPatternOneUseOrUserImpl { public: - bool Match(const HloInstruction* inst, MatchOption option) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { if (!MatchOneUser(inst, option)) { return false; } @@ -1688,7 +1689,7 @@ class HloInstructionPatternOneUseImpl class HloInstructionPatternOneUserImpl : public HloInstructionPatternOneUseOrUserImpl { public: - bool Match(const HloInstruction* inst, MatchOption option) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { return MatchOneUser(inst, option); } @@ -1779,30 +1780,19 @@ class HloConstantScalarImpl { return true; } - // Check that literal == static_cast(val) and - // val == static_cast(literal). This is sufficient to ensure that - // the two constant scalars are actually "equal". - auto val_literal = LiteralUtil::CreateR0(*val_); - auto literal_r0_or = const_inst->literal().Reshape({}); - auto val_as_literal_ty_or = - val_literal.Convert(const_inst->shape().element_type()); - if (!literal_r0_or.ok() || !val_as_literal_ty_or.ok()) { - EXPLAIN << "could not construct relevant Literals (how did this happen?)"; + auto const_inst_scalar_or = const_inst->literal().Reshape({}); + if (!const_inst_scalar_or.ok()) { + EXPLAIN << "could not convert matched literal to effective scalar"; return false; } - auto literal_r0 = std::move(literal_r0_or).ValueOrDie(); - auto val_as_literal_ty = std::move(val_as_literal_ty_or).ValueOrDie(); - auto literal_r0_as_val_ty_or = - literal_r0.Convert(val_literal.shape().element_type()); - bool rv = literal_r0_as_val_ty_or.ok() && // - literal_r0_as_val_ty_or.ValueOrDie() == val_literal && - literal_r0 == val_as_literal_ty; - if (!rv) { + Literal const_inst_scalar = std::move(const_inst_scalar_or).ValueOrDie(); + if (!const_inst_scalar.IsEqualAt({}, *val_)) { EXPLAIN << "HloInstruction's constant value " - << literal_r0.ToStringWithoutShape() + << const_inst_scalar.ToStringWithoutShape() << " did not match expected value " << *val_; + return false; } - return rv; + return true; } absl::optional val_; @@ -1815,9 +1805,9 @@ class HloInstructionPattern { private: template auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< - HloInstructionType, decltype(AllOf( + HloInstructionType, decltype(AllOf<::xla::HloInstruction>( std::declval(), std::move(new_impl)))> { - auto new_allof = AllOf(impl_, std::move(new_impl)); + auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl)); return HloInstructionPattern( std::move(new_allof), matched_inst_); } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5ec45eb491a..e3a7efff0b1 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -166,8 +166,9 @@ Service::Service(const ServiceOptions& options, << "Requested more replicas than there are devices."; } LOG(INFO) << StrFormat( - "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name()); + "XLA service %p initialized for platform %s (this does not guarantee " + "that XLA will be used). Devices:", + this, execute_backend_->platform()->Name()); auto stream_executors = execute_backend_->stream_executors(); for (int i = 0; i < execute_backend_->device_count(); ++i) { se::StreamExecutor* executor = stream_executors.at(i); @@ -351,11 +352,11 @@ StatusOr>> Service::BuildExecutables( VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. - std::vector> hlo_snapshots; + std::vector> hlo_protos; for (int64 i = 0; i < module_protos.size(); ++i) { - auto hlo_snapshot = absl::make_unique(); - *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; - hlo_snapshots.push_back(std::move(hlo_snapshot)); + auto hlo_proto = absl::make_unique(); + *hlo_proto->mutable_hlo_module() = *module_protos[i]; + hlo_protos.push_back(std::move(hlo_proto)); } VLOG(1) << "Computations:"; @@ -383,7 +384,7 @@ StatusOr>> Service::BuildExecutables( const auto& debug_opts = module_configs[i]->debug_options(); if (DumpingEnabledForHloModule(module_protos[i]->name(), debug_opts) && debug_opts.xla_dump_hlo_snapshots()) { - executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i])); + executables[i]->set_hlo_proto(std::move(hlo_protos[i])); } } @@ -451,13 +452,19 @@ Service::ExecuteParallelAndRegisterResult( options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); + // Use run-time profile information from execution_profile on the 0th + // device. + if (i == 0) { + options.set_execution_profile(profile); + } ServiceExecutableRunOptions run_options(options, backend->StreamBorrower()); // Asynchronously launch the computation. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, executables[i]->ExecuteAsyncOnStream( - &run_options, arguments[i][replica])); + &run_options, arguments[i][replica], + /*hlo_execution_profile=*/nullptr)); if (replica == 0 && profile != nullptr) { streams.back()->ThenStopTimer(timers.back().get()); @@ -490,10 +497,6 @@ Service::ExecuteParallelAndRegisterResult( uint64 nanoseconds = *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end()); - // Merge in run-time profile information from execution_profile on the - // zeroth device. - profile->MergeFrom(executables[0]->execution_profile()); - // Overall execution time (in nanoseconds) from the executor timer. profile->set_compute_and_transfer_time_ns(nanoseconds); @@ -546,13 +549,13 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); + options.set_execution_profile(profile); run_options.emplace_back(options, backend->StreamBorrower()); } if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN( - auto result, executable->ExecuteOnStreamWrapper(&run_options[0], - profile, arguments[0])); + TF_ASSIGN_OR_RETURN(auto result, executable->ExecuteOnStreamWrapper( + &run_options[0], arguments[0])); return allocation_tracker_.Register(std::move(result), result_tag); } @@ -692,14 +695,17 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, executable_ptrs.push_back(executable.get()); } + std::vector snapshots; + snapshots.resize(executable_ptrs.size()); for (int i = 0; i < executable_ptrs.size(); i++) { if (executable_ptrs[i]->dumping_snapshot()) { + *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto(); TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( all_executors[i][0]->device_ordinal())); TF_RETURN_IF_ERROR(RecordArguments(all_arguments[i].front(), stream.get(), execute_backend_->transfer_manager(), - executable_ptrs[i]->hlo_snapshot())); + &snapshots[i])); } } @@ -746,9 +752,8 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, execute_backend_->BorrowStream(all_executors[i][0])); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), - executable->hlo_snapshot())); - DumpHloSnapshotIfEnabled(executable->module(), - *executable->hlo_snapshot()); + &snapshots[i])); + DumpHloSnapshotIfEnabled(executable->module(), snapshots[i]); } } @@ -803,9 +808,9 @@ StatusOr> Service::BuildExecutable( const auto& debug_opts = module_config->debug_options(); if (DumpingEnabledForHloModule(module_proto.name(), debug_opts) && debug_opts.xla_dump_hlo_snapshots()) { - auto hlo_snapshot = absl::make_unique(); - *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; - executable->set_hlo_snapshot(std::move(hlo_snapshot)); + auto hlo_proto = absl::make_unique(); + *hlo_proto->mutable_hlo_module() = module_proto; + executable->set_hlo_proto(std::move(hlo_proto)); } return std::move(executable); @@ -891,12 +896,13 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream( execute_backend_->default_stream_executor())); + HloSnapshot snapshot; if (executable->dumping_snapshot()) { - executable->hlo_snapshot()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), stream.get(), - execute_backend_->transfer_manager(), executable->hlo_snapshot())); + *snapshot.mutable_hlo() = *executable->hlo_proto(); + snapshot.set_execution_platform(execute_backend_->platform()->Name()); + TF_RETURN_IF_ERROR( + RecordArguments(replicated_arguments.front(), stream.get(), + execute_backend_->transfer_manager(), &snapshot)); } TF_ASSIGN_OR_RETURN( @@ -913,8 +919,8 @@ Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { allocation_tracker_.ResolveForReplica(result->output(), 0)); TF_RETURN_IF_ERROR(RecordResult(*result_buffer, stream.get(), execute_backend_->transfer_manager(), - executable->hlo_snapshot())); - DumpHloSnapshotIfEnabled(executable->module(), *executable->hlo_snapshot()); + &snapshot)); + DumpHloSnapshotIfEnabled(executable->module(), snapshot); } VLOG(1) << "successfully completed 'execute' request"; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3510e4913f4..30f6faada43 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1711,15 +1711,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 kernel_output_features = rhs.dimensions(dnums.kernel_output_feature_dimension()); - if (batch_group_count > 1 && input_batch % kernel_output_features != 0) { + if (batch_group_count > 1 && kernel_output_features != batch_group_count) { return InvalidArgument( - "Expected input batch (value %d) to be divisible by output feature " - "dimension size (value %d) for batch group count %d; " - "got (%s, %s)\n" + "Expected output feature dimension size (value %d) to be equal to " + "batch group count %d; got (%s, %s)\n" "Dimension numbers: {%s}.", - input_batch, kernel_output_features, batch_group_count, - ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), - dnums.DebugString()); + kernel_output_features, batch_group_count, ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), dnums.DebugString()); } if (input_features % feature_group_count != 0 || @@ -2119,10 +2117,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferReduceWindowShape( const Shape& operand_shape, const Shape& init_value_shape, const Window& window, const ProgramShape& to_apply_shape) { - TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, {&init_value_shape}, {operand_shape.element_type()}, /*inputs=*/1)); + return InferReduceWindowShape(operand_shape, init_value_shape, window); +} + +/* static */ StatusOr ShapeInference::InferReduceWindowShape( + const Shape& operand_shape, const Shape& init_value_shape, + const Window& window) { + TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window")); return InferWindowOutputShape(operand_shape, window, init_value_shape.element_type(), /*allow_negative_padding=*/false); @@ -2207,6 +2211,60 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::MakeShape(U32, {}); } +/* static */ StatusOr ShapeInference::InferWindowFromDimensions( + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation) { + const auto verify_size = [&](const size_t x, const char* x_name) { + if (x == 0 || x == window_dimensions.size()) { + return Status::OK(); + } else { + return InvalidArgument( + "%s", absl::StrCat( + "Window has different number of window dimensions than of ", + x_name, + "\nNumber of window dimensions: ", window_dimensions.size(), + "\nNumber of ", x_name, ": ", x, "\n")); + } + }; + TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides")); + TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries")); + TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors")); + TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors")); + + Window window; + for (size_t i = 0; i < window_dimensions.size(); i++) { + auto dim = window.add_dimensions(); + dim->set_size(window_dimensions[i]); + if (!window_strides.empty()) { + dim->set_stride(window_strides[i]); + } else { + dim->set_stride(1); + } + if (!padding.empty()) { + dim->set_padding_low(padding[i].first); + dim->set_padding_high(padding[i].second); + } else { + dim->set_padding_low(0); + dim->set_padding_high(0); + } + if (!lhs_dilation.empty()) { + dim->set_base_dilation(lhs_dilation[i]); + } else { + dim->set_base_dilation(1); + } + if (!rhs_dilation.empty()) { + dim->set_window_dilation(rhs_dilation[i]); + } else { + dim->set_window_dilation(1); + } + dim->set_window_reversal(false); + } + return window; +} + /* static */ StatusOr ShapeInference::InferSliceShape( const Shape& arg, absl::Span starts, absl::Span limits, absl::Span strides) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 590a664224e..393b45e5ac3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -159,6 +159,10 @@ class ShapeInference { const Shape& operand_shape, const Shape& init_value, const Window& window, const ProgramShape& to_apply_shape); + static StatusOr InferReduceWindowShape(const Shape& operand_shape, + const Shape& init_value, + const Window& window); + // Infers the shape produced by scattering the given source shape to the // selected indices of each window on the operand shape. static StatusOr InferSelectAndScatterShape( @@ -295,6 +299,15 @@ class ShapeInference { static StatusOr InferGetDimensionSizeShape(const Shape& shape, int64 dimension); + // Helper function for creating a Window proto from user-supplied data. + // Returns error if the user-supplied data was invalid. + static StatusOr InferWindowFromDimensions( + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 3bfa971f857..c241a4ac2ce 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -573,6 +573,43 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { HasSubstr("each dimension exactly once")); } +TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) { + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.add_kernel_spatial_dimensions(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + dnums.add_output_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(3); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4}); + Window window; + auto dim0 = window.add_dimensions(); + auto dim1 = window.add_dimensions(); + dim0->set_size(4); + dim1->set_size(4); + dim0->set_padding_low(0); + dim0->set_padding_high(2); + dim1->set_padding_low(2); + dim1->set_padding_high(1); + dim0->set_stride(1); + dim1->set_stride(1); + dim0->set_window_dilation(3); + dim1->set_window_dilation(2); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6, + window, dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("to be equal to batch group count")); +} + namespace fft { static const char* unsupported_rank = "only supports ranks 1-3"; diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.cc b/tensorflow/compiler/xla/service/slow_operation_alarm.cc new file mode 100644 index 00000000000..3a0bd830d30 --- /dev/null +++ b/tensorflow/compiler/xla/service/slow_operation_alarm.cc @@ -0,0 +1,136 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/slow_operation_alarm.h" + +#include +#include // NOLINT (for std::call_once, not std::mutex) + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "tensorflow/core/platform/env.h" + +namespace xla { +namespace { + +absl::Mutex mu(absl::kConstInit); +absl::CondVar* ready; +std::once_flag init_flag; +std::list* outstanding_alarms ABSL_PT_GUARDED_BY(mu) = + nullptr; + +void AlarmLoop() { + while (true) { + absl::MutexLock lock(&mu); + + // Fire any alarms which are ready. + absl::Time now = absl::Now(); + for (auto it = outstanding_alarms->begin(); + it != outstanding_alarms->end();) { + auto next = std::next(it); + auto* alarm = *it; + // Fire the alarm if applicable. + if (alarm->deadline() <= now) { + outstanding_alarms->erase(it); + int64 count = + alarm->counter() == nullptr ? 0 : alarm->counter()->fetch_add(1); + // If the alarm has a counter, only fire if the count is a power of 2. + if (count == 0 || (count & (count - 1)) == 0) { + // We fire alarms with LOG(ERROR) because otherwise it might not show + // up without --logtostderr. + LOG(ERROR) << alarm->msg(); + } + } + it = next; + } + + if (outstanding_alarms->empty()) { + ready->Wait(&mu); + continue; + } + + SlowOperationAlarm* next_alarm = *absl::c_min_element( + *outstanding_alarms, + [](const SlowOperationAlarm* a, const SlowOperationAlarm* b) { + return a->deadline() < b->deadline(); + }); + ready->WaitWithDeadline(&mu, next_alarm->deadline()); + } +} + +void ScheduleAlarm(SlowOperationAlarm* alarm) { + std::call_once(init_flag, [] { + ready = new absl::CondVar(); + outstanding_alarms = new std::list(); + (void)tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "SlowOperationAlarm", [] { AlarmLoop(); }); + }); + + absl::MutexLock lock(&mu); + outstanding_alarms->push_back(alarm); + ready->Signal(); +} + +void UnscheduleAlarm(const SlowOperationAlarm* alarm) { + absl::MutexLock lock(&mu); + CHECK(outstanding_alarms != nullptr); + auto it = absl::c_find(*outstanding_alarms, alarm); + if (it != outstanding_alarms->end()) { + outstanding_alarms->erase(it); + } +} + +} // namespace + +SlowOperationAlarm::SlowOperationAlarm(absl::Duration timeout, string msg, + std::atomic* counter /*=nullptr*/) + : deadline_(absl::Now() + timeout), + msg_(std::move(msg)), + counter_(counter) { + ScheduleAlarm(this); +} + +SlowOperationAlarm::~SlowOperationAlarm() { UnscheduleAlarm(this); } + +std::unique_ptr SlowCompilationAlarm() { + // Pass a counter to these alarms so they only log once every power-of-two + // occurrences. + static auto* counter = new std::atomic(0); + + const char* separator = "\n********************************"; +#if NDEBUG + return absl::make_unique( + absl::Duration(absl::Minutes(2)), + absl::StrCat( + separator, + "\nVery slow compile? If you want to file a bug, run with envvar " + "XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.", + separator), + counter); +#else + return absl::make_unique( + absl::Duration(absl::Seconds(10)), + absl::StrCat( + separator, + "\nSlow compile? XLA was built without compiler optimizations, " + "which can be slow. Try rebuilding with -c opt.", + separator), + counter); +#endif +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/slow_operation_alarm.h b/tensorflow/compiler/xla/service/slow_operation_alarm.h new file mode 100644 index 00000000000..014fc7709f8 --- /dev/null +++ b/tensorflow/compiler/xla/service/slow_operation_alarm.h @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SLOW_OPERATION_ALARM_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SLOW_OPERATION_ALARM_H_ + +#include +#include +#include +#include + +#include "absl/time/time.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// This RAII object asynchronously prints a warning if it's alive for more than +// a certain amount of time. +class SlowOperationAlarm { + public: + // If `counter` is not null, this alarm will throttle itself to logging + // once-every-power-of-two occurrences. The counter must outlive this object. + SlowOperationAlarm(absl::Duration timeout, std::string msg, + std::atomic* counter = nullptr); + ~SlowOperationAlarm(); + + // Not copyable or movable, because the constructor stores a pointer to `this` + // into a global variable. + SlowOperationAlarm(const SlowOperationAlarm&) = delete; + SlowOperationAlarm(const SlowOperationAlarm&&) = delete; + SlowOperationAlarm& operator=(const SlowOperationAlarm&) = delete; + SlowOperationAlarm& operator=(const SlowOperationAlarm&&) = delete; + + absl::Time deadline() const { return deadline_; } + absl::string_view msg() const { return msg_; } + std::atomic* counter() { return counter_; } + + private: + absl::Time deadline_; + std::string msg_; + // counter_ may be null. If it's not, this alarm prints something only once + // every power of two occurrences. + std::atomic* counter_; +}; + +// Returns an object which prints a warning about slow compilation after a +// certain amount of time. +// +// In debug builds, recommends building with -c opt. +// +// In opt builds, recommends filing a bug. +// +// This is throttled to once-every-power-of-two occurrences, globally. +ABSL_MUST_USE_RESULT std::unique_ptr SlowCompilationAlarm(); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SLOW_OPERATION_ALARM_H_ diff --git a/tensorflow/compiler/xla/service/tree_reduction_rewriter.cc b/tensorflow/compiler/xla/service/tree_reduction_rewriter.cc new file mode 100644 index 00000000000..69af16ef428 --- /dev/null +++ b/tensorflow/compiler/xla/service/tree_reduction_rewriter.cc @@ -0,0 +1,110 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { + +class ReductionRewriterVisitor : public DfsHloRewriteVisitor { + public: + explicit ReductionRewriterVisitor(int64 reduce_window_size) + : reduce_window_size_(reduce_window_size) {} + + Status HandleReduce(HloInstruction *hlo) override { + HloInstruction *reduced_op = hlo->mutable_operand(0); + HloInstruction *initial_value = hlo->mutable_operand(1); + const Shape &input_shape = reduced_op->shape(); + const Shape &reduce_shape = hlo->shape(); + if (!reduce_shape.IsArray()) { + return Status::OK(); + } + auto reduced_dimensions = hlo->dimensions(); + std::vector window_dimensions; + std::vector window_strides; + for (int64 dim = 0; dim < input_shape.rank(); dim++) { + if (!absl::c_linear_search(hlo->dimensions(), dim)) { + window_dimensions.push_back(1); + window_strides.push_back(1); + continue; + } + // One of the reduced dimensions is smaller than the window size, + // do not perform the rewrite. + if (input_shape.dimensions(dim) < reduce_window_size_) { + return Status::OK(); + } + + window_dimensions.push_back(reduce_window_size_); + window_strides.push_back(reduce_window_size_); + } + + std::vector> padding = + MakePadding(AsInt64Slice(input_shape.dimensions()), window_dimensions, + window_strides, Padding::kSame); + + TF_ASSIGN_OR_RETURN( + Window window, ShapeInference::InferWindowFromDimensions( + window_dimensions, window_strides, padding, {}, {})); + + TF_ASSIGN_OR_RETURN(Shape intermediate_shape, + ShapeInference::InferReduceWindowShape( + input_shape, initial_value->shape(), window)); + + HloInstruction *reduce_window = + hlo->parent()->AddInstruction(HloInstruction::CreateReduceWindow( + intermediate_shape, reduced_op, initial_value, window, + hlo->to_apply())); + + std::unique_ptr new_output = + HloInstruction::CreateReduce(reduce_shape, reduce_window, initial_value, + hlo->dimensions(), hlo->to_apply()); + + return ReplaceWithNewInstruction(hlo, std::move(new_output)); + } + + private: + int64 reduce_window_size_; +}; + +StatusOr TreeReductionRewriter::Run(HloModule *module) { + ReductionRewriterVisitor visitor(reduce_window_size_); + bool changed = false; + for (const auto &computation : module->MakeNonfusionComputations()) { + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + changed |= visitor.changed(); + } + + return changed; +} + +} // end namespace xla diff --git a/tensorflow/compiler/xla/service/tree_reduction_rewriter.h b/tensorflow/compiler/xla/service/tree_reduction_rewriter.h new file mode 100644 index 00000000000..a9852d88a6e --- /dev/null +++ b/tensorflow/compiler/xla/service/tree_reduction_rewriter.h @@ -0,0 +1,58 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Increase precision for the reduction operation by applying the reduce-window +// first. +// +// E.g. suppose we want to reduce f32[1024] to a scalar. This pass first applies +// a reduce-window (with kSame padding) of size `reduce_window_size`, and then +// reduces the resulting array f32[32]. The rewrite is not applied if any of the +// reduced dimensions is smaller than the `reduce_window_size`. +// +// Applying this pass until a fixed point performs a variant of pairwise +// summation (https://en.wikipedia.org/wiki/Pairwise_summation), which is +// guaranteed to have an assymptotically smaller error bound provided that +// intermediate roundoff errors are random and have random sign. +// +// If this pass lowers the performance too much, the window size can always be +// increased to a larger value. +class TreeReductionRewriter : public HloModulePass { + public: + explicit TreeReductionRewriter(int64 reduce_window_size = 32) + : reduce_window_size_(reduce_window_size) {} + ~TreeReductionRewriter() override = default; + absl::string_view name() const override { return "tree_reduction_rewriter"; } + + StatusOr Run(HloModule* module) override; + + private: + int64 reduce_window_size_; +}; + +} // end namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index 57efee700be..0a8e2c3849f 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -266,8 +266,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, int64 m_dim = (left_side) ? -1 : -2; int64 m = ShapeUtil::GetDimension(b_shape, m_dim); + std::vector update_ops; + int bdims = b_shape.rank(); + int64 block_dim = (left_side) ? bdims - 2 : bdims - 1; + // Initialize the solution - auto x = ZerosLike(b); + XlaOp x; // This loop is unrolled for performance reasons, but it could be expressed // rolled as well since the matrices are of the same size each iteration @@ -278,7 +282,8 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, // can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i] // Decide whether we go from first block to last or vice versa - auto j = (left_side ^ lower ^ transpose_a) ? num_blocks - 1 - i : i; + bool backward = left_side ^ lower ^ transpose_a; + auto j = backward ? num_blocks - 1 - i : i; // Get the size of the inverse blocks (the last one might be smaller) int64 block = (n % block_size != 0 && j + 1 == num_blocks) @@ -304,9 +309,17 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, if (i == 0) { remainder = b_row; } else { - // This matrix multiply involves a lot of multiplying with zero (namely, - // X[i * block_size:] = 0), but this is faster than slicing... - end = {k, n}; + // This matrix multiply get rid of a lot of multiplying with zero + // (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i] + if (backward) { + start = {j * block_size, + std::max(0LL, (num_blocks - i) * block_size)}; + end = {k, n}; + } else { + start = {j * block_size, 0}; + end = {k, std::min(i * block_size, n)}; + } + if (!left_side) { std::swap(end[0], end[1]); } @@ -335,7 +348,16 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks, BatchDot(remainder, false, inv_block, transpose_a, precision); std::swap(update_starts[0], update_starts[1]); } - x = DynamicUpdateSliceInMinorDims(x, x_update, /*starts=*/update_starts); + + if (i == 0) { + x = x_update; + } else { + if (backward) { + x = ConcatInDim(builder, {x_update, x}, block_dim); + } else { + x = ConcatInDim(builder, {x, x_update}, block_dim); + } + } } return x; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index d0515fb5825..be7ad99aac4 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -564,8 +564,8 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto constant2 = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); - auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( - constant2->shape(), HloOpcode::kBitcast, constant2)); + auto bitcast = builder.AddInstruction( + HloInstruction::CreateBitcast(constant2->shape(), constant2)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast})); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index ebb56746518..e2d74627c60 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -43,6 +43,8 @@ limitations under the License. namespace xla { +class ShapeIndexView; + // An index for specifying a particular nested subshape within a shape. Used in // ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data // structures (trees) and ShapeIndex defines a path through the tree where each @@ -69,6 +71,8 @@ class ShapeIndex { template ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} + explicit ShapeIndex(ShapeIndexView v); + bool empty() const { return indices_.empty(); } size_t size() const { return indices_.size(); } void push_back(int64 value) { indices_.push_back(value); } @@ -137,6 +141,10 @@ class ShapeIndexView { CHECK(!empty()); return indices_.front(); } + int64 back() const { + CHECK(!empty()); + return indices_.back(); + } ShapeIndexView ConsumeFront() const { ShapeIndexView result = *this; result.indices_.remove_prefix(1); @@ -161,6 +169,9 @@ class ShapeIndexView { absl::Span indices_; }; +inline ShapeIndex::ShapeIndex(ShapeIndexView v) + : ShapeIndex(v.begin(), v.end()) {} + std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); diff --git a/tensorflow/compiler/xla/test.h b/tensorflow/compiler/xla/test.h index a657554dc2f..c20c1341541 100644 --- a/tensorflow/compiler/xla/test.h +++ b/tensorflow/compiler/xla/test.h @@ -41,6 +41,7 @@ limitations under the License. #else #include #include +#include #endif #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index f67050863d3..ae0d70610be 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -3,7 +3,7 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") @@ -296,9 +296,12 @@ xla_test( xla_test( name = "conv_depthwise_test", timeout = "long", - srcs = ["conv_depthwise_test.cc"], + srcs = [ + "conv_depthwise_test.cc", + ], shard_count = 50, deps = [ + ":conv_depthwise_common", ":test_macros_header", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", @@ -709,9 +712,151 @@ cc_library( ], ) +cc_library( + name = "conv_depthwise_common", + testonly = True, + srcs = ["conv_depthwise_common.cc"], + hdrs = ["conv_depthwise_common.h"], + deps = [ + ":test_macros_header", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:bfloat16_normalization", + "//tensorflow/compiler/xla/service:despecializer", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:optional", + ], +) + xla_test( - name = "exhaustive_unary_test", + name = "exhaustive_unary_test_f32_or_smaller", srcs = ["exhaustive_unary_test.cc"], + copts = ["-DUNARY_TEST_TARGET_F32_OR_SMALLER"], + real_hardware_only = True, # Very slow on the interpreter. + shard_count = 48, + tags = [ + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", + ], + deps = [ + ":exhaustive_op_test_utils", + ], +) + +xla_test( + name = "exhaustive_unary_test_f64", + srcs = ["exhaustive_unary_test.cc"], + backends = [ + "gpu", + "cpu", + ], + copts = ["-DUNARY_TEST_TARGET_F64"], + real_hardware_only = True, # Very slow on the interpreter. + shard_count = 48, + tags = [ + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", + ], + deps = [ + ":exhaustive_op_test_utils", + ], +) + +xla_test( + name = "exhaustive_unary_test_complex", + srcs = ["exhaustive_unary_test.cc"], + backends = [ + "gpu", + "cpu", + ], + copts = ["-DUNARY_TEST_TARGET_COMPLEX"], + real_hardware_only = True, # Very slow on the interpreter. + shard_count = 48, + tags = [ + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", + ], + deps = [ + ":exhaustive_op_test_utils", + ], +) + +xla_test( + name = "exhaustive_binary_test_f16", + srcs = ["exhaustive_binary_test.cc"], + backends = [ + "gpu", + "cpu", + ], + copts = ["-DBINARY_TEST_TARGET_F16"], + real_hardware_only = True, # Very slow on the interpreter. + shard_count = 48, + tags = [ + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", + ], + deps = [ + ":exhaustive_op_test_utils", + ], +) + +xla_test( + name = "exhaustive_binary_test_bf16", + srcs = ["exhaustive_binary_test.cc"], + backends = [ + "gpu", + "cpu", + ], + copts = ["-DBINARY_TEST_TARGET_BF16"], + real_hardware_only = True, # Very slow on the interpreter. + shard_count = 48, + tags = [ + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", + ], + deps = [ + ":exhaustive_op_test_utils", + ], +) + +xla_test( + name = "exhaustive_binary_test_f32", + srcs = ["exhaustive_binary_test.cc"], + backends = [ + "gpu", + "cpu", + ], + copts = ["-DBINARY_TEST_TARGET_F32"], + real_hardware_only = True, # Very slow on the interpreter. + shard_count = 48, + tags = [ + "optonly", + # This is a big test that we skip for capacity reasons in OSS testing. + "no_oss", + ], + deps = [ + ":exhaustive_op_test_utils", + ], +) + +xla_test( + name = "exhaustive_binary_test_f64", + srcs = ["exhaustive_binary_test.cc"], + backends = [ + "gpu", + "cpu", + ], + copts = ["-DBINARY_TEST_TARGET_F64"], real_hardware_only = True, # Very slow on the interpreter. shard_count = 48, tags = [ @@ -1505,6 +1650,7 @@ xla_test( name = "fmax_fmin_test", srcs = ["fmax_fmin_test.cc"], deps = [ + ":test_macros_header", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", @@ -1744,6 +1890,7 @@ xla_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service/gpu:nccl_all_reduce_thunk", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1954,7 +2101,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", - "//tensorflow/compiler/xla/service/gpu:nvptx_compiler", "//tensorflow/compiler/xla/service/gpu:nvptx_compiler_impl", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -1987,8 +2133,8 @@ xla_test( ) xla_test( - name = "fusion_test", - srcs = ["fusion_test.cc"], + name = "cpu_gpu_fusion_test", + srcs = ["cpu_gpu_fusion_test.cc"], deps = [ ":test_macros_header", "//tensorflow/compiler/xla:array2d", diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 48719c6c47c..7153ace8789 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -4,7 +4,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index b8439ee0fdd..efa7448f191 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_replace.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_common.cc b/tensorflow/compiler/xla/tests/conv_depthwise_common.cc new file mode 100644 index 00000000000..e11ec33e730 --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_common.cc @@ -0,0 +1,135 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/conv_depthwise_common.h" + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +string GetFloatDataType(bool use_bfloat16) { + return use_bfloat16 ? "bf16" : "f32"; +} + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& spec = ::testing::get<0>(data.param); + const string data_type = GetFloatDataType(::testing::get<1>(data.param)); + string str = absl::StrCat( + "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), + "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), + "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", + absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", + absl::StrJoin(spec.output_dims, "x"), "_output_layout_", + absl::StrJoin(spec.output_layout, "_"), data_type); + // -1 indicates non-existence. + if (spec.stride != -1) { + absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); + } + + // Test names are not allowed to contain the '-' character. + absl::c_replace(str, '-', 'n'); + return str; +} + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16, + bool is_scheduled) { + const string data_type = GetFloatDataType(use_bfloat16); + const string sched_tag = is_scheduled ? ", is_scheduled=true " : ""; + if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv %s + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + sched_tag, data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.window, spec.window, spec.window, spec.output_feature); + + } else if (spec.stride == -1) { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv %s + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d}, dim_labels=b01f_01io->b01f, + feature_group_count=%d + } + )", + sched_tag, data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.output_feature); + } else { + return absl::StrFormat( + R"( + HloModule TensorFlowDepthwiseConv %s + + ENTRY main { + activation = %s[%s]{%s} parameter(0) + kernel = %s[%s]{%s} parameter(1) + ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), + window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, + dim_labels=b01f_01io->b01f, feature_group_count=%d + } + )", + sched_tag, data_type, absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), data_type, + absl::StrJoin(spec.output_dims, ","), + absl::StrJoin(spec.output_layout, ","), data_type, + absl::StrJoin(spec.activation_dims, ","), + absl::StrJoin(spec.activation_layout, ","), data_type, + absl::StrJoin(spec.kernel_dims, ","), + absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, + spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature); + } +} +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_common.h b/tensorflow/compiler/xla/tests/conv_depthwise_common.h new file mode 100644 index 00000000000..0c00f8d0abe --- /dev/null +++ b/tensorflow/compiler/xla/tests/conv_depthwise_common.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CONV_DEPTHWISE_COMMON_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_CONV_DEPTHWISE_COMMON_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/despecializer.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +string GetFloatDataType(bool use_bfloat16); + +struct DepthwiseConvolution2DSpec { + int64 output_feature, window, stride, pad, lhs_dilate; + std::vector activation_dims; + std::vector activation_layout; + std::vector kernel_dims; + std::vector kernel_layout; + std::vector output_dims; + std::vector output_layout; +}; + +string DepthwiseConvolution2DTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data); + +string BuildHloTextDepthwiseConvolution2D( + const DepthwiseConvolution2DSpec& spec, bool use_bfloat16, + bool is_scheduled = false); + +} // namespace xla +#endif // TENSORFLOW_COMPILER_XLA_TESTS_CONV_DEPTHWISE_COMMON_H_ diff --git a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc index fe958242329..98f6b5bc6d7 100644 --- a/tensorflow/compiler/xla/tests/conv_depthwise_test.cc +++ b/tensorflow/compiler/xla/tests/conv_depthwise_test.cc @@ -22,26 +22,13 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/conv_depthwise_common.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" namespace xla { namespace { -string GetFloatDataType(bool use_bfloat16) { - return use_bfloat16 ? "bf16" : "f32"; -} - -struct DepthwiseConvolution2DSpec { - int64 output_feature, window, stride, pad, lhs_dilate; - std::vector activation_dims; - std::vector activation_layout; - std::vector kernel_dims; - std::vector kernel_layout; - std::vector output_dims; - std::vector output_layout; -}; - class DepthwiseConvolution2DTest : public HloTestBase, public ::testing::WithParamInterface< @@ -70,6 +57,7 @@ static std::vector GetConv2DTestCases() { config.kernel_dims = {kernel_size, kernel_size, 1, feature}; config.kernel_layout = {3, 2, 1, 0}; + config.output_layout = {3, 0, 2, 1}; if (activation_size == 1 && kernel_size == 2) { // Test for outer dim. @@ -87,127 +75,12 @@ static std::vector GetConv2DTestCases() { config.output_dims = {batch, activation_size - kernel_size + 1, activation_size - kernel_size + 1, feature}; } - - // Try this layout for all kernel shapes. - config.output_layout = {3, 0, 2, 1}; config_set.push_back(config); - - // Try other layouts only for certain kernel shapes. - if (kernel_size % 2 == 0) { - config.activation_layout = {0, 3, 2, 1}; - config_set.push_back(config); - - config.output_layout = {0, 3, 2, 1}; - config_set.push_back(config); - - config.activation_layout = {3, 0, 2, 1}; - config_set.push_back(config); - } } return config_set; } -string DepthwiseConvolution2DTestDataToString( - const ::testing::TestParamInfo< - ::testing::tuple>& data) { - const auto& spec = ::testing::get<0>(data.param); - const string data_type = GetFloatDataType(::testing::get<1>(data.param)); - string str = absl::StrCat( - "activation_dims_", absl::StrJoin(spec.activation_dims, "x"), - "_activation_layout_", absl::StrJoin(spec.activation_layout, "_"), - "_kernel_dims_", absl::StrJoin(spec.kernel_dims, "x"), "_kernel_layout_", - absl::StrJoin(spec.kernel_layout, "_"), "_output_dims_", - absl::StrJoin(spec.output_dims, "x"), "_output_layout_", - absl::StrJoin(spec.output_layout, "_"), data_type); - // -1 indicates non-existence. - if (spec.stride != -1) { - absl::StrAppend(&str, "_lhs_dilation_", spec.lhs_dilate, "x1"); - } - - // Test names are not allowed to contain the '-' character. - absl::c_replace(str, '-', 'n'); - return str; -} - -string BuildHloTextDepthwiseConvolution2D( - const DepthwiseConvolution2DSpec& spec, bool use_bfloat16) { - const string data_type = GetFloatDataType(use_bfloat16); - if (spec.activation_dims[1] == 1 && spec.kernel_dims[1] == 2) { - return absl::StrFormat( - R"( - HloModule TensorFlowDepthwiseConv - - ENTRY main { - activation = %s[%s]{%s} parameter(0) - kernel = %s[%s]{%s} parameter(1) - ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), - window={size=%dx%d pad=1_1x%d_%d rhs_dilate=1x%d}, dim_labels=b01f_01io->b01f, - feature_group_count=%d - } - )", - data_type, absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), data_type, - absl::StrJoin(spec.output_dims, ","), - absl::StrJoin(spec.output_layout, ","), data_type, - absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, - spec.window, spec.window, spec.window, spec.output_feature); - - } else if (spec.stride == -1) { - return absl::StrFormat( - R"( - HloModule TensorFlowDepthwiseConv - - ENTRY main { - activation = %s[%s]{%s} parameter(0) - kernel = %s[%s]{%s} parameter(1) - ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), - window={size=%dx%d}, dim_labels=b01f_01io->b01f, - feature_group_count=%d - } - )", - data_type, absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), data_type, - absl::StrJoin(spec.output_dims, ","), - absl::StrJoin(spec.output_layout, ","), data_type, - absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, - spec.output_feature); - } else { - return absl::StrFormat( - R"( - HloModule TensorFlowDepthwiseConv - - ENTRY main { - activation = %s[%s]{%s} parameter(0) - kernel = %s[%s]{%s} parameter(1) - ROOT conv = %s[%s]{%s} convolution(%s[%s]{%s} activation, %s[%s]{%s} kernel), - window={size=%dx%d stride=%dx1 pad=%d_%dx0_0 lhs_dilate=%dx1}, - dim_labels=b01f_01io->b01f, feature_group_count=%d - } - )", - data_type, absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), data_type, - absl::StrJoin(spec.output_dims, ","), - absl::StrJoin(spec.output_layout, ","), data_type, - absl::StrJoin(spec.activation_dims, ","), - absl::StrJoin(spec.activation_layout, ","), data_type, - absl::StrJoin(spec.kernel_dims, ","), - absl::StrJoin(spec.kernel_layout, ","), spec.window, spec.window, - spec.stride, 0, 0, spec.lhs_dilate, spec.output_feature); - } -} XLA_TEST_P(DepthwiseConvolution2DTest, DoIt) { const DepthwiseConvolution2DSpec& spec = ::testing::get<0>(GetParam()); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0ab765aefa0..e656951a968 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -1842,15 +1842,11 @@ INSTANTIATE_TEST_CASE_P( Convolve1DTestParam{130, 1, 1, 1, 3}, Convolve1DTestParam{64, 1, 1, 1, 1}, Convolve1DTestParam{128, 1, 1, 1, 1}, -// TODO(b/72566306): The following five tests failed on CPU with unreasonable -// relative errors. Last ran on 2018-02-22. -#if XLA_TEST_BACKEND_GPU Convolve1DTestParam{139, 1, 1, 128, 1}, Convolve1DTestParam{640, 3, 3, 128, 1}, Convolve1DTestParam{900, 1, 1, 10, 1}, Convolve1DTestParam{1, 10, 10, 1, 10}, Convolve1DTestParam{1, 10, 130, 1, 1}, -#endif Convolve1DTestParam{1, 10, 130, 1, 2}, Convolve1DTestParam{1, 64, 64, 1, 10}, Convolve1DTestParam{1, 65, 65, 1, 1}, @@ -1946,7 +1942,8 @@ XLA_TEST_F(ConvolutionTest, ConvolveF32BackwardInputGroupedConvolution) { class ConvolutionHloTest : public HloTestBase {}; -XLA_TEST_F(ConvolutionHloTest, ConvolveF64Forward) { +// double datatype is not yet supported in ROCm +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64Forward)) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1970,7 +1967,9 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardFilter) { +// double datatype is not yet supported in ROCm +XLA_TEST_F(ConvolutionHloTest, + DISABLED_ON_GPU_ROCM(ConvolveF64BackwardFilter)) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1982,7 +1981,8 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } -XLA_TEST_F(ConvolutionHloTest, ConvolveF64BackwardInput) { +// double datatype is not yet supported in ROCm +XLA_TEST_F(ConvolutionHloTest, DISABLED_ON_GPU_ROCM(ConvolveF64BackwardInput)) { constexpr char kHlo[] = R"( HloModule TestModule @@ -1995,5 +1995,18 @@ ENTRY Test { EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.001})); } +XLA_TEST_F(ConvolutionHloTest, ConvolveBackwardInput) { + constexpr char kHlo[] = R"( +HloModule TestModule + +ENTRY Test { + %output = f32[3,3,64,64] parameter(0) + %kernel = f32[672,7,7,64] parameter(1) + %reverse = f32[672,7,7,64]{3,2,1,0} reverse(f32[672,7,7,64]{3,2,1,0} %kernel), dimensions={1,2} + ROOT %convolution = f32[672,9,9,64]{3,2,1,0} convolution(f32[3,3,64,64]{3,2,1,0} %output, f32[672,7,7,64]{3,2,1,0} %reverse), window={size=7x7 pad=6_6x6_6}, dim_labels=01bf_o01i->f01b +})"; + EXPECT_TRUE(RunAndCompare(kHlo, ErrorSpec{0.01, 0.01})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc similarity index 94% rename from tensorflow/compiler/xla/tests/fusion_test.cc rename to tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc index 2d0805cdb0e..7719e89f9e8 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/cpu_gpu_fusion_test.cc @@ -60,7 +60,7 @@ const float test_float_vals[3][test_width][test_height] = { // Test whether fusion operations are emitted with no errors and compute // accurate outputs. -class FusionTest : public HloTestBase { +class CpuGpuFusionTest : public HloTestBase { protected: template void TestElementwise2D( @@ -148,8 +148,8 @@ class FusionTest : public HloTestBase { } }; -float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode, - absl::Span xs) { +float CpuGpuFusionTest::ComputeElementwiseAnswerFloat( + HloOpcode opcode, absl::Span xs) { switch (opcode) { case HloOpcode::kAdd: return xs[0] + xs[1]; @@ -172,8 +172,8 @@ float FusionTest::ComputeElementwiseAnswerFloat(HloOpcode opcode, } } -bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction, - absl::Span xs) { +bool CpuGpuFusionTest::ComputeElementwiseAnswerCompare( + ComparisonDirection direction, absl::Span xs) { switch (direction) { case ComparisonDirection::kEq: return xs[0] == xs[1]; @@ -190,7 +190,7 @@ bool FusionTest::ComputeElementwiseAnswerCompare(ComparisonDirection direction, } } -XLA_TEST_F(FusionTest, Test) { +XLA_TEST_F(CpuGpuFusionTest, Test) { // test expression: // slice(select({{T, F, T}, {F, T, F}}, // concat(transpose({{1.0}, {2.0}, {3.0}} + @@ -243,7 +243,7 @@ XLA_TEST_F(FusionTest, Test) { } // Test whether we emit appropriate code for parameters of fusion instructions. -XLA_TEST_F(FusionTest, Parameter) { +XLA_TEST_F(CpuGpuFusionTest, Parameter) { // Build a computation and fuse part of it so the fusion instruction has an // operand parameter. auto builder = HloComputation::Builder(TestName()); @@ -268,7 +268,7 @@ XLA_TEST_F(FusionTest, Parameter) { ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } -XLA_TEST_F(FusionTest, RandomizedParallelPartition) { +XLA_TEST_F(CpuGpuFusionTest, RandomizedParallelPartition) { // Tests parallel partitioning of a fusion instruction. // Create shape with random outer dimension size to generate random parallel // partition counts for each test run. @@ -304,7 +304,7 @@ XLA_TEST_F(FusionTest, RandomizedParallelPartition) { } } -XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { +XLA_TEST_F(CpuGpuFusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( @@ -328,7 +328,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } -XLA_TEST_F(FusionTest, ReshapeToScalar) { +XLA_TEST_F(CpuGpuFusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto single_element_array = builder.AddInstruction( @@ -343,7 +343,7 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { +XLA_TEST_F(CpuGpuFusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -358,7 +358,7 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { +XLA_TEST_F(CpuGpuFusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -373,7 +373,7 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reshape_1by1by1_) { +XLA_TEST_F(CpuGpuFusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -388,7 +388,7 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reshape__1by1by1) { +XLA_TEST_F(CpuGpuFusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -403,7 +403,7 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reshape__) { +XLA_TEST_F(CpuGpuFusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -418,7 +418,7 @@ XLA_TEST_F(FusionTest, Reshape__) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { +XLA_TEST_F(CpuGpuFusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -433,7 +433,7 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Transpose_2by3) { +XLA_TEST_F(CpuGpuFusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -448,7 +448,7 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Transpose_3by3) { +XLA_TEST_F(CpuGpuFusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -463,7 +463,7 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Reverse) { +XLA_TEST_F(CpuGpuFusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -479,7 +479,7 @@ XLA_TEST_F(FusionTest, Reverse) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, ReverseNegate) { +XLA_TEST_F(CpuGpuFusionTest, ReverseNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -497,7 +497,7 @@ XLA_TEST_F(FusionTest, ReverseNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, BroadcastNegate) { +XLA_TEST_F(CpuGpuFusionTest, BroadcastNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction( @@ -515,7 +515,7 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, SliceNegate) { +XLA_TEST_F(CpuGpuFusionTest, SliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -533,7 +533,7 @@ XLA_TEST_F(FusionTest, SliceNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, DynamicSliceNegate) { +XLA_TEST_F(CpuGpuFusionTest, DynamicSliceNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -555,7 +555,7 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, ReshapeNegate) { +XLA_TEST_F(CpuGpuFusionTest, ReshapeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -573,7 +573,7 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, TransposeNegate) { +XLA_TEST_F(CpuGpuFusionTest, TransposeNegate) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -602,7 +602,7 @@ std::unique_ptr MakeReduceTestComputation() { return builder.Build(); } -XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { +XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( @@ -621,7 +621,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) { +XLA_TEST_F(CpuGpuFusionTest, ReduceImplicitBroadcast) { auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -643,7 +643,7 @@ XLA_TEST_F(FusionTest, ReduceImplicitBroadcast) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { +XLA_TEST_F(CpuGpuFusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewVerifiedModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( @@ -696,7 +696,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { // When a constant (or other op) which has multiple users is imported // into a fusion, it should remain shared, rather than being duplicated // within the fusion. -XLA_TEST_F(FusionTest, SharedConstant) { +XLA_TEST_F(CpuGpuFusionTest, SharedConstant) { auto hlo_module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); @@ -729,57 +729,59 @@ XLA_TEST_F(FusionTest, SharedConstant) { ExecuteAndTransfer(std::move(hlo_module), {}))); } -XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } +XLA_TEST_F(CpuGpuFusionTest, Add2D) { + TestElementwise2D(HloOpcode::kAdd); +} -XLA_TEST_F(FusionTest, Subtract2D) { +XLA_TEST_F(CpuGpuFusionTest, Subtract2D) { TestElementwise2D(HloOpcode::kSubtract); } -XLA_TEST_F(FusionTest, Multiply2D) { +XLA_TEST_F(CpuGpuFusionTest, Multiply2D) { TestElementwise2D(HloOpcode::kMultiply); } -XLA_TEST_F(FusionTest, Divide2D) { +XLA_TEST_F(CpuGpuFusionTest, Divide2D) { TestElementwise2D(HloOpcode::kDivide); } -XLA_TEST_F(FusionTest, Power2D) { +XLA_TEST_F(CpuGpuFusionTest, Power2D) { TestElementwise2D(HloOpcode::kPower); } -XLA_TEST_F(FusionTest, Minimum2D) { +XLA_TEST_F(CpuGpuFusionTest, Minimum2D) { TestElementwise2D(HloOpcode::kMinimum); } -XLA_TEST_F(FusionTest, Maximum2D) { +XLA_TEST_F(CpuGpuFusionTest, Maximum2D) { TestElementwise2D(HloOpcode::kMaximum); } -XLA_TEST_F(FusionTest, Equal2D) { +XLA_TEST_F(CpuGpuFusionTest, Equal2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kEq); } -XLA_TEST_F(FusionTest, Inequal2D) { +XLA_TEST_F(CpuGpuFusionTest, Inequal2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kNe); } -XLA_TEST_F(FusionTest, Greater2D) { +XLA_TEST_F(CpuGpuFusionTest, Greater2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGt); } -XLA_TEST_F(FusionTest, Lesser2D) { +XLA_TEST_F(CpuGpuFusionTest, Lesser2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLt); } -XLA_TEST_F(FusionTest, GreaterOrEqual2D) { +XLA_TEST_F(CpuGpuFusionTest, GreaterOrEqual2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kGe); } -XLA_TEST_F(FusionTest, LesserOrEqual2D) { +XLA_TEST_F(CpuGpuFusionTest, LesserOrEqual2D) { TestElementwise2D(HloOpcode::kCompare, ComparisonDirection::kLe); } -XLA_TEST_F(FusionTest, Clamp2D) { +XLA_TEST_F(CpuGpuFusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 25e82842b05..ff2fd7e2297 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1409,6 +1409,54 @@ ENTRY MatrixVectorComplex { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } +// Regression test for b/138155357, where we were incorrectly creating a dot-add +// fusion where the dot had a batch dimension. This isn't supported on the CPU +// backend. +XLA_TEST_F(DotOperationTextTest, FusedBatchDotRegressionTest) { + absl::string_view module_string = R"( +HloModule jaxpr_computation__5.33 + +jaxpr_computation__6.8 { + tuple.9 = () tuple() + parameter.14 = () parameter(4) + parameter.13 = (f32[2]{0}) parameter(3) + get-tuple-element.15 = f32[2]{0} get-tuple-element(parameter.13), index=0 + reshape.16 = f32[1,2]{1,0} reshape(get-tuple-element.15) + parameter.10 = f32[2,2]{1,0} parameter(0) + reshape.17 = f32[2,1]{1,0} reshape(get-tuple-element.15) + dot.18 = f32[2,1]{1,0} dot(parameter.10, reshape.17), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.19 = f32[2]{0} reshape(dot.18) + reshape.20 = f32[2,1]{1,0} reshape(reshape.19) + dot.21 = f32[1,1]{1,0} dot(reshape.16, reshape.20), lhs_contracting_dims={1}, rhs_contracting_dims={0} + reshape.22 = f32[] reshape(dot.21) + parameter.11 = f32[2,1,2]{2,1,0} parameter(1) + broadcast.23 = f32[2,2,1]{2,1,0} broadcast(reshape.20), dimensions={1,2} + dot.24 = f32[2,1,1]{2,1,0} dot(parameter.11, broadcast.23), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + broadcast.25 = f32[2,1,2]{2,1,0} broadcast(reshape.16), dimensions={1,2} + parameter.12 = f32[2,2,1]{2,1,0} parameter(2) + dot.26 = f32[2,1,1]{2,1,0} dot(broadcast.25, parameter.12), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + add.27 = f32[2,1,1]{2,1,0} add(dot.24, dot.26) + reshape.28 = f32[2]{0} reshape(add.27) + ROOT tuple.29 = (f32[], f32[2]{0}) tuple(reshape.22, reshape.28) +} + +ENTRY jaxpr_computation__5.33 { + constant.2 = f32[] constant(1) + broadcast.3 = f32[2,2]{1,0} broadcast(constant.2), dimensions={} + constant.5 = f32[2,1,2]{2,1,0} constant({ { { 1, 0 } }, { { 0, 1 } } }) + constant.4 = f32[2,2,1]{2,1,0} constant({ { {1}, {1} }, { {1}, {1} } }) + parameter.6 = f32[2]{0} parameter(0) + tuple.7 = (f32[2]{0}) tuple(parameter.6) + tuple.1 = () tuple() + call.30 = (f32[], f32[2]{0}) call(broadcast.3, constant.5, constant.4, tuple.7, tuple.1), to_apply=jaxpr_computation__6.8 + get-tuple-element.31 = f32[] get-tuple-element(call.30), index=0 + ROOT get-tuple-element.32 = f32[2]{0} get-tuple-element(call.30), index=1 +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + EXPECT_TRUE(RunAndCompare(std::move(module), /*error=*/absl::nullopt)); +} + XLA_TEST_F(DotOperationTest, ReorderContractingDimsConstLHS_RL) { Array3D input_arr(2, 3, 2); Array2D const_arr(2, 6); diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc new file mode 100644 index 00000000000..c0f8a0dc626 --- /dev/null +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_test.cc @@ -0,0 +1,392 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h" + +#ifdef __FAST_MATH__ +#error("Can't be compiled with fast math on"); +#endif + +namespace xla { +namespace { + +template +using ExhaustiveBinaryTest = ExhaustiveOpTestBase; + +// Exhaustive test for binary operations for 16 bit floating point types, +// including float16 and bfloat. +// +// Test parameter is a pair of (begin, end) for range under test. +template < + PrimitiveType T, + typename std::enable_if< + std::is_same::type, + half>::value || + std::is_same::type, + bfloat16>::value>::type* = nullptr> +class Exhaustive16BitBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + public: + int64 GetInputSize() override { + int64 begin, end; + std::tie(begin, end) = GetParam(); + return end - begin; + } + + // Given a range of uint64 representation, uses bits 0..15 and bits 16..31 for + // the values of src0 and src1 for a 16 bit binary operation being tested, + // and generates the cartesian product of the two sets as the two inputs for + // the test. + void FillInput(std::array* input_literals) override { + int64 input_size = GetInputSize(); + CHECK_EQ(input_size, (*input_literals)[0].element_count()); + CHECK_EQ(input_size, (*input_literals)[1].element_count()); + + int64 begin, end; + std::tie(begin, end) = GetParam(); + VLOG(2) << "Checking range [" << begin << ", " << end << "]"; + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + for (int64 i = 0; i < input_size; i++) { + uint32 input_val = i + begin; + // Convert the lower 16 bits to the NativeT and replaced known incorrect + // input values with 0. + input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); + input_arr_1[i] = + ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + } + } + + protected: + using typename ExhaustiveBinaryTest::NativeT; + using ExhaustiveBinaryTest::ConvertAndReplaceKnownIncorrectValueWith; +}; + +using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; +using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; + +// Returns a wrapper of the given build method, which build an HLO operation +// with an empty broadcast dimension. +inline std::function AddEmptyBroadcastDimension( + std::function)> build_method) { + return [&](XlaOp src0, XlaOp src1) -> XlaOp { + return build_method(src0, src1, {}); + }; +} + +#define XLA_TEST_16BIT(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ + __VA_ARGS__ \ + XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ + __VA_ARGS__ + +XLA_TEST_16BIT(Add, { + auto host_add = [](float x, float y) { return x + y; }; + Run(AddEmptyBroadcastDimension(Add), host_add); +}) + +XLA_TEST_16BIT(Sub, { + auto host_sub = [](float x, float y) { return x - y; }; + Run(AddEmptyBroadcastDimension(Sub), host_sub); +}) + +// TODO(bixia): Mul fails with bfloat16 on CPU. +XLA_TEST_16BIT(DISABLED_ON_CPU(Mul), { + auto host_mul = [](float x, float y) { return x * y; }; + Run(AddEmptyBroadcastDimension(Mul), host_mul); +}) + +// TODO(bixia): Div fails with bfloat16 on CPU. +XLA_TEST_16BIT(DISABLED_ON_CPU(Div), { + auto host_div = [](float x, float y) { return x / y; }; + Run(AddEmptyBroadcastDimension(Div), host_div); +}) + +template ::value || + std::is_same::value>::type* = nullptr> +T ReferenceMax(T x, T y) { + // We need to propagate NAN here becasue std::max may not propagate NAN. + if (std::fpclassify(x) == FP_NAN) { + return x; + } + if (std::fpclassify(y) == FP_NAN) { + return y; + } + + return std::max(x, y); +} + +template ::value || + std::is_same::value>::type* = nullptr> +T ReferenceMin(T x, T y) { + // We need to propagate NAN here becasue std::max may not propagate NAN. + if (std::fpclassify(x) == FP_NAN) { + return x; + } + if (std::fpclassify(y) == FP_NAN) { + return y; + } + + return std::min(x, y); +} + +XLA_TEST_16BIT(Max, + { Run(AddEmptyBroadcastDimension(Max), ReferenceMax); }) + +XLA_TEST_16BIT(Min, + { Run(AddEmptyBroadcastDimension(Min), ReferenceMin); }) + +// TODO(bixia): Pow fails with bfloat16 on CPU. +XLA_TEST_16BIT(DISABLED_ON_CPU(Pow), + { Run(AddEmptyBroadcastDimension(Pow), std::powf); }) + +// TODO(bixia): Atan2 fails with bfloat16 on CPU. +XLA_TEST_16BIT(DISABLED_ON_CPU(Atan2), + { Run(AddEmptyBroadcastDimension(Atan2), std::atan2f); }) + +#if defined(BINARY_TEST_TARGET_F16) +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, + ::testing::ValuesIn(CreateExhaustiveF32Ranges())); +#endif +#endif + +#if defined(BINARY_TEST_TARGET_BF16) +#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, + ::testing::ValuesIn(CreateExhaustiveF32Ranges())); +#endif +#endif + +// Exhaustive test for binary operations for float and double. +// +// Test parameter is a tuple of (FpValues, FpValues) describing the possible +// values for each operand. The inputs for the test are the Cartesian product +// of the possible values for the two operands. +template +class Exhaustive32BitOrMoreBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + protected: + using typename ExhaustiveBinaryTest::NativeT; + using ExhaustiveBinaryTest::ConvertAndReplaceKnownIncorrectValueWith; + + private: + int64 GetInputSize() override { + FpValues values_0; + FpValues values_1; + std::tie(values_0, values_1) = GetParam(); + return values_0.GetTotalNumValues() * values_1.GetTotalNumValues(); + } + + void FillInput(std::array* input_literals) override { + int64 input_size = GetInputSize(); + FpValues values_0; + FpValues values_1; + std::tie(values_0, values_1) = GetParam(); + + VLOG(2) << " testing " << values_0.ToString() << " " << values_1.ToString() + << "total values " << input_size; + CHECK(input_size == (*input_literals)[0].element_count() && + input_size == (*input_literals)[1].element_count()); + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + + uint64 i = 0; + for (auto src0 : values_0) { + for (auto src1 : values_1) { + input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(src0, 1); + input_arr_1[i] = ConvertAndReplaceKnownIncorrectValueWith(src1, 1); + ++i; + } + } + CHECK_EQ(i, input_size); + } +}; + +using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; +using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; + +XLA_TEST_P(ExhaustiveF32BinaryTest, Add) { + auto host_add = [](float x, float y) { return x + y; }; + Run(AddEmptyBroadcastDimension(Add), host_add); +} + +XLA_TEST_P(ExhaustiveF32BinaryTest, Sub) { + auto host_sub = [](float x, float y) { return x - y; }; + Run(AddEmptyBroadcastDimension(Sub), host_sub); +} + +// TODO(bixia): Need to investigate the failure on CPU and file bugs. +XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Mul)) { + auto host_mul = [](float x, float y) { return x * y; }; + Run(AddEmptyBroadcastDimension(Mul), host_mul); +} + +// TODO(bixia): Need to investigate the failure on CPU and file bugs. +XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(Div)) { + auto host_div = [](float x, float y) { return x / y; }; + Run(AddEmptyBroadcastDimension(Div), host_div); +} + +XLA_TEST_P(ExhaustiveF32BinaryTest, Max) { + Run(AddEmptyBroadcastDimension(Max), ReferenceMax); +} + +XLA_TEST_P(ExhaustiveF32BinaryTest, Min) { + Run(AddEmptyBroadcastDimension(Min), ReferenceMin); +} + +// It is more convenient to implement Abs(complex) as a binary op than a unary +// op, as the operations we currently support all have the same data type for +// the source operands and the results. +// TODO(bixia): May want to move this test to unary test if we will be able to +// implement Abs(complex) as unary conveniently. +// +// TODO(bixia): Need to investigate the failure on CPU and file bugs. +XLA_TEST_P(ExhaustiveF32BinaryTest, DISABLED_ON_CPU(AbsComplex)) { + auto host_abs_complex = [](float x, float y) { + return std::abs(std::complex(x, y)); + }; + auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; + + Run(device_abs_complex, host_abs_complex); +} + +#if defined(BINARY_TEST_TARGET_F32) + +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(2000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(2000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine(::testing::Values(GetNormals(2000)), + ::testing::Values(GetNormals(2000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. +// Comparing with the unary tests, the binary tests use a smaller set of inputs +// for each sub-test to avoid timeout because the implementation of ExpectNear +// more than 2x slower for binary test. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnituedNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, + 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); + +#endif + +XLA_TEST_P(ExhaustiveF64BinaryTest, Add) { + auto host_add = [](double x, double y) { return x + y; }; + Run(AddEmptyBroadcastDimension(Add), host_add); +} + +XLA_TEST_P(ExhaustiveF64BinaryTest, Sub) { + auto host_sub = [](double x, double y) { return x - y; }; + Run(AddEmptyBroadcastDimension(Sub), host_sub); +} + +// TODO(bixia): Need to investigate the failure on CPU and file bugs. +XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Mul)) { + auto host_mul = [](double x, double y) { return x * y; }; + Run(AddEmptyBroadcastDimension(Mul), host_mul); +} + +// TODO(bixia): Need to investigate the failure on CPU and file bugs. +XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(Div)) { + auto host_div = [](double x, double y) { return x / y; }; + Run(AddEmptyBroadcastDimension(Div), host_div); +} + +XLA_TEST_P(ExhaustiveF64BinaryTest, Max) { + Run(AddEmptyBroadcastDimension(Max), ReferenceMax); +} + +XLA_TEST_P(ExhaustiveF64BinaryTest, Min) { + Run(AddEmptyBroadcastDimension(Min), ReferenceMin); +} + +// TODO(bixia): Need to investigate the failure on CPU and file bugs. +XLA_TEST_P(ExhaustiveF64BinaryTest, DISABLED_ON_CPU(AbsComplex)) { + auto host_abs_complex = [](double x, double y) { + return std::abs(std::complex(x, y)); + }; + auto device_abs_complex = [](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; + + Run(device_abs_complex, host_abs_complex); +} + +#if defined(BINARY_TEST_TARGET_F64) + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(1000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(1000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine(::testing::Values(GetNormals(1000)), + ::testing::Values(GetNormals(1000)))); + +// Tests a total of 40000 ^ 2 inputs, with 1000 ^ 2 inputs in each sub-test. +// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each +// for each sub-test comparing with the unary test to avoid timeout. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnituedNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +#endif + +#endif +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc index 465da47faeb..1d3248fe04c 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.cc @@ -17,8 +17,8 @@ limitations under the License. namespace xla { -// For f32, f16, and bf16, we need 9, 5, and 4 decimal places of precision to be -// guaranteed that we're printing the full number. +// For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of +// precision to be guaranteed that we're printing the full number. // // (The general formula is, given a floating-point number with S significand // bits, the number of decimal digits needed to print it to full precision is @@ -26,71 +26,237 @@ namespace xla { // ceil(1 + S * log_10(2)) ~= ceil(1 + S * 0.30103). // // See https://people.eecs.berkeley.edu/~wkahan/Math128/BinDecBin.pdf.) -/*static*/ -string ExhaustiveOpTestBase::StringifyNum(float x) { - return absl::StrFormat("%0.9g (0x%08x)", x, BitCast(x)); -} +namespace { +template +struct ComponentStringifyFormat {}; + +template <> +struct ComponentStringifyFormat { + static constexpr absl::string_view value = "%0.17g (0x%16x)"; +}; + +template <> +struct ComponentStringifyFormat { + static constexpr absl::string_view value = "%0.8g (0x%08x)"; +}; + +template <> +struct ComponentStringifyFormat { + static constexpr absl::string_view value = "%0.5g (0x%04x)"; +}; + +template <> +struct ComponentStringifyFormat { + static constexpr absl::string_view value = "%0.4g (0x%04x)"; +}; +} // namespace /*static*/ -string ExhaustiveOpTestBase::StringifyNum(half x) { - return absl::StrFormat("%0.5g (0x%04x)", static_cast(x), - BitCast(x)); +template +string ExhaustiveOpTestBase::StringifyNum( + typename ExhaustiveOpTestBase::ComponentNativeT x) { + typedef typename ExhaustiveOpTestBase::ComponentNativeT ComponentType; + typedef typename ExhaustiveOpTestBase::ComponentIntegralNativeT + IntegralType; + return absl::StrFormat(ComponentStringifyFormat::value, + static_cast(x), BitCast(x)); } -/*static*/ -string ExhaustiveOpTestBase::StringifyNum(bfloat16 x) { - return absl::StrFormat("%0.4g (0x%04x)", static_cast(x), - BitCast(x)); -} - -/*static*/ -std::vector> -ExhaustiveOpTestBase::CreateExhaustiveF32Ranges() { - // We break up the 2^32-element space into small'ish chunks to keep peak - // memory usage low. - std::vector> result; - const int64 step = 1 << 25; - for (int64 i = 0; i < (1l << 32); i += step) { - result.push_back({i, i + step}); +template +void ExhaustiveOpTestBase::ExpectNear(const InputLiterals& input_literals, + const Literal& result_literal, + EvaluateOp evaluate_op, + ErrorSpecGen error_spec_gen) { + // Cache for when all components are subnormal testing values. + std::vector pure_subnormal_cache; + pure_subnormal_cache.reserve(GetMaxCacheSize()); + for (int i = 0; i < GetMaxCacheSize(); ++i) { + pure_subnormal_cache.push_back( + CallOperation(evaluate_op, FromCacheLocation(i))); } - return result; + + NativeInputsList inputs_arr; + for (int i = 0; i < N; ++i) { + const Literal& literal = input_literals[i]; + inputs_arr[i] = literal.data(); + } + + absl::Span result_arr = result_literal.data(); + + int64 mismatches = 0; + + for (int64 i = 0; i < result_arr.size(); ++i) { + NativeInputs inputs; + NativeRefInputs inputs_ref_ty; + + for (int j = 0; j < N; ++j) { + inputs[j] = inputs_arr[j][i]; + inputs_ref_ty[j] = static_cast(inputs[j]); + } + + NativeT actual = result_arr[i]; + NativeT expected = + static_cast(CallOperation(evaluate_op, inputs_ref_ty)); + ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs); + + if (IsClose(static_cast(expected), + static_cast(actual), error_spec)) { + continue; + } + + std::vector subnormal_test_inputs = + GetTestValuesWithSubnormalSubstitutions(inputs_ref_ty); + + // Easy case: If `input` is not subnormal and !IsClose(expected, actual, + // error_spec), print an error. + if (subnormal_test_inputs.size() == 1) { + PrintMismatch(&mismatches, [&] { + return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", + StringifyNum(inputs), StringifyNum(expected), + StringifyNum(actual)); + }); + continue; + } + + // Otherwise, we need to test the additional subnormal test values. + std::vector subnormal_test_results; + subnormal_test_results.reserve(subnormal_test_inputs.size()); + bool passed_subnormal_test = false; + + for (NativeRefInputs test_value : subnormal_test_inputs) { + NativeRefT result; + int cache_loc = GetCacheLocation(test_value); + if (cache_loc == kInvalidCacheIndex) { + result = CallOperation(evaluate_op, test_value); + } else { + result = pure_subnormal_cache[cache_loc]; + } + + if (IsClose(result, static_cast(actual), error_spec)) { + passed_subnormal_test = true; + break; + } + subnormal_test_results.push_back(std::move(result)); + } + + if (passed_subnormal_test) { + continue; + } + + std::string mismatch = absl::StrFormat( + "Mismatch on subnormal value %s. Expected one of:\n" + " %10s (evaluated at full-precision value)\n", + StringifyNum(inputs), StringifyNum(expected)); + + CHECK_EQ(subnormal_test_inputs.size(), subnormal_test_results.size()); + for (int i = 0; i < subnormal_test_inputs.size(); ++i) { + absl::StrAppend( + &mismatch, + absl::StrFormat(" %10s (evaluated at %s)\n", + StringifyNum(subnormal_test_results[i]), + GetSubnormalDescription(subnormal_test_inputs[i], + inputs_ref_ty))); + } + absl::StrAppend(&mismatch, + absl::StrFormat("but got %s", StringifyNum(actual))); + + PrintMismatch(&mismatches, [mismatch] { return mismatch; }); + } + EXPECT_EQ(mismatches, 0); } namespace { -ExhaustiveOpTestBase::ErrorSpec DefaultF64SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; +template +inline typename ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + typename ExhaustiveOpTestBase::NativeT) { + LOG(FATAL) << "Unhandled Type"; } -ExhaustiveOpTestBase::ErrorSpec DefaultF32SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; +template +inline typename ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + typename ExhaustiveOpTestBase::NativeT, + typename ExhaustiveOpTestBase::NativeT) { + LOG(FATAL) << "Unhandled Type"; } -ExhaustiveOpTestBase::ErrorSpec DefaultF16SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + complex128) { + return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; } -ExhaustiveOpTestBase::ErrorSpec DefaultBF16SpecGenerator(float) { - return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02}; +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + complex64) { + return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + double) { + return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + float) { + return ExhaustiveOpTestBase::ErrorSpec{0.0001, 0.0001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + Eigen::half) { + return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + bfloat16) { + return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + double, double) { + return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + float, float) { + return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + Eigen::half, Eigen::half) { + return ExhaustiveOpTestBase::ErrorSpec{0.001, 0.001}; +} + +template <> +inline ExhaustiveOpTestBase::ErrorSpec DefaultSpecGenerator( + bfloat16, bfloat16) { + return ExhaustiveOpTestBase::ErrorSpec{0.002, 0.02}; } } // namespace /*static*/ -std::function -ExhaustiveOpTestBase::GetDefaultSpecGenerator(PrimitiveType ty) { - switch (ty) { - case C128: - case F64: - return DefaultF64SpecGenerator; - case C64: - case F32: - return DefaultF32SpecGenerator; - case F16: - return DefaultF16SpecGenerator; - case BF16: - return DefaultBF16SpecGenerator; - default: - LOG(FATAL) << "Unhandled Type"; - } +template +typename ExhaustiveOpTestBase::ErrorSpecGen +ExhaustiveOpTestBase::GetDefaultSpecGenerator() { + return DefaultSpecGenerator; } +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; + +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h index 3df4de295e3..3d77b44b53a 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h +++ b/tensorflow/compiler/xla/tests/exhaustive_op_test_utils.h @@ -28,8 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" namespace xla { -using Eigen::half; +// T: The primitive type being tested. +// N: The number of operands that the function being tested takes. +template class ExhaustiveOpTestBase : public ClientLibraryTestBase { public: struct ErrorSpec { @@ -41,11 +43,186 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // spec; this only covers the case when both `expected` and `actual` are // equal to 0. bool strict_signed_zeros = false; + + ErrorSpec(float a, float r) : abs_err(a), rel_err(r) {} }; - // `ty` is the primitive type being tested. - explicit ExhaustiveOpTestBase(PrimitiveType ty) - : ty_(ty), platform_(client_->platform()->Name()) {} + // Definitions depending on the primitive type T. + + static constexpr bool kIsComplex = (T == C128 || T == C64); + + // The primitive type used to compute the reference output. + struct RefT { + static constexpr PrimitiveType value = (T == F16 || T == BF16) ? F32 : T; + }; + + // The primitive type of the component of T. If T is not complex, then + // ComponentT = T. + struct ComponentT { + static constexpr PrimitiveType value = + !kIsComplex ? T + : T == C128 ? F64 : T == C64 ? F32 : PRIMITIVE_TYPE_INVALID; + }; + + // Same as ComponentT, but for the RefT primitive type. + struct ComponentRefT { + static constexpr PrimitiveType value = + !kIsComplex ? RefT::value + : RefT::value == C128 + ? F64 + : RefT::value == C64 ? F32 : PRIMITIVE_TYPE_INVALID; + }; + + // The primitive type of an unsigned integer that can be bitcasted to and from + // ComponentT. + struct ComponentIntegralT { + static constexpr PrimitiveType value = + (T == C128 || T == F64) + ? U64 + : (T == C64 || T == F32) + ? U32 + : (T == F16 || T == BF16) ? U16 : PRIMITIVE_TYPE_INVALID; + }; + + // Native types that correspond to the primtive types above. + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + using NativeRefT = + typename primitive_util::PrimitiveTypeToNative::type; + using ComponentNativeT = + typename primitive_util::PrimitiveTypeToNative::type; + using ComponentNativeRefT = typename primitive_util::PrimitiveTypeToNative< + ComponentRefT::value>::type; + using ComponentIntegralNativeT = + typename primitive_util::PrimitiveTypeToNative< + ComponentIntegralT::value>::type; + + using InputLiterals = std::array; + + private: + // N spans corresponding to the list of literal data values. + using NativeInputsList = std::array, N>; + + // N data items representing a single input to an XLA function. + using NativeInputs = std::array; + + // N data items representing a single input to an interpreter backend + // function. + using NativeRefInputs = std::array; + + // N data items representing a single input to an XLA function. + using XlaInputs = std::array; + + // Representations of the reference function passed in by the user. + template + struct EvaluateOpWrapper {}; + template <> + struct EvaluateOpWrapper<1> { + using type = NativeRefT (*)(NativeRefT); + }; + template <> + struct EvaluateOpWrapper<2> { + using type = NativeRefT (*)(NativeRefT, NativeRefT); + }; + + // Representations of the reference function passed in by the user. + template + struct EnqueueOpWrapper {}; + template <> + struct EnqueueOpWrapper<1> { + using type = std::function; + static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { + return ty(inputs[0]); + } + }; + template <> + struct EnqueueOpWrapper<2> { + using type = std::function; + static XlaOp BuildFromInputs(XlaInputs inputs, type ty) { + return ty(inputs[0], inputs[1]); + } + }; + + // Representations of the ErrorSpecGen function passed in by the user. + template + struct ErrorSpecGenWrapper {}; + template <> + struct ErrorSpecGenWrapper<1> { + using type = ErrorSpec (*)(NativeT); + }; + template <> + struct ErrorSpecGenWrapper<2> { + using type = ErrorSpec (*)(NativeT, NativeT); + }; + + public: + using ErrorSpecGen = typename ErrorSpecGenWrapper::type; + using EvaluateOp = typename EvaluateOpWrapper::type; + using EnqueueOp = typename EnqueueOpWrapper::type; + + explicit ExhaustiveOpTestBase() + : ty_(T), platform_(client_->platform()->Name()) { + SetFastMathDisabled(true); + + // Run all HLO passes. In particular, constant folding is disabled by + // default for tests, but we need to run it in order to tickle some bugs. + mutable_debug_options()->clear_xla_disable_hlo_passes(); + } + + void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op) { + Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator()); + } + + // A helper for implementing the Run method for exhaustive op tests. It + // constructs the HLO module, compiles and runs the module and checks the + // result. + // + // We use a function pointer for evaluate_op for performance because it is + // called each time an output element is compared inside a loop in routine + // ExpectNear. + void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, + ErrorSpecGen error_spec_gen) { + InputLiterals input_literals = CreateInputLiterals(); + FillInput(&input_literals); + + XlaBuilder builder(TestName()); + XlaInputs xla_inputs; + for (int i = 0; i < N; ++i) { + xla_inputs[i] = + Parameter(&builder, i, input_literals[i].shape(), "input"); + } + EnqueueOpWrapper::BuildFromInputs(xla_inputs, enqueue_op); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, + RunComputationHelper(comp, input_literals)); + ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen); + } + + StatusOr RunComputationHelper(const XlaComputation& comp, + const Literal& literal) { + return RunComputation(comp, {&literal}); + } + + StatusOr RunComputationHelper( + const XlaComputation& comp, const std::array& literals) { + std::array lit_ptrs; + for (int i = 0; i < N; ++i) { + lit_ptrs[i] = &literals[i]; + } + return RunComputation(comp, lit_ptrs); + } + + // We essentially reimplement LiteralTestUtil::Near here because + // a) this streamlined implementation is much faster, and + // b) we can print out better error messages (namely, we can print out + // which floating-point value input failed, while LiteralTestUtil::Near + // can only print out the input index that failed). + // c) we need special handling of certain inputs. For example, we say that + // a denormal input has multiple correct outputs (namely, f(x) and f(0)) + // and just needs to be close to one of them. + void ExpectNear(const InputLiterals& input_literals, + const Literal& result_literal, EvaluateOp evaluate_op, + ErrorSpecGen error_spec_gen); // Builds and runs the computation using the LocalClient API, rather than the // plain Client API, which is used by ClientLibraryTestBase. This is because @@ -94,30 +271,395 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { return std::move(result_literal); } + const string& Platform() { return platform_; } + // Returns the number of elements in each input literal. virtual int64 GetInputSize() = 0; - Literal CreateInputLiteral() { - return LiteralUtil::CreateFromDimensions(ty_, {GetInputSize()}); + // Fills the literals with values to test for. + virtual void FillInput(InputLiterals* literals) = 0; + + // Replace infinites with max value to help compute errors. + static ComponentNativeRefT ReplaceInfWithMax(ComponentNativeRefT value) { + if (std::isinf(value)) { + return std::copysign(std::numeric_limits::max(), + value); + } + return value; } - // `T` is the type of the value being compared, which is float if ty_ is of 32 - // bits or less, and double otherwise. - template - bool IsClose(T expected, T actual, ErrorSpec spec) { - static_assert( - std::is_same::value || std::is_same::value, - "Only supports float and double."); - T abs_err = std::abs(expected - actual); - T rel_err = abs_err / std::abs(expected); - if (spec.strict_signed_zeros && actual == T{0} && expected == T{0}) { - // Check sign of zero. - return std::signbit(actual) == std::signbit(expected); + // Returns true if both components are 0, but their sign bits differ. + static bool CheckSignedZeroError(ComponentNativeRefT expected, + ComponentNativeRefT actual) { + return expected == 0 && actual == 0 && + std::signbit(expected) != std::signbit(actual); + } + + // Sets the components to 0 if both are NaNs. + static void RemoveCorrespondingNaNs(ComponentNativeRefT* expected, + ComponentNativeRefT* actual) { + if (std::isnan(*expected) && std::isnan(*actual)) { + *expected = 0; + *actual = 0; } - return abs_err <= spec.abs_err || rel_err <= spec.rel_err || - (std::isnan(expected) && std::isnan(actual)) || - (std::isinf(expected) && std::isinf(actual) && - (expected > 0) == (actual > 0)); + } + + // The Implementation of the functions above, except for complex inputs. + + static std::complex ReplaceInfWithMax( + std::complex value) { + value.real(ReplaceInfWithMax(value.real())); + value.imag(ReplaceInfWithMax(value.imag())); + return value; + } + + static bool CheckSignedZeroError(std::complex expected, + std::complex actual) { + return CheckSignedZeroError(expected.real(), actual.real()) || + CheckSignedZeroError(expected.imag(), actual.imag()); + } + + static void RemoveCorrespondingNaNs( + std::complex* expected, + std::complex* actual) { + ComponentNativeRefT expected_real = expected->real(); + ComponentNativeRefT expected_imag = expected->imag(); + ComponentNativeRefT actual_real = actual->real(); + ComponentNativeRefT actual_imag = actual->imag(); + RemoveCorrespondingNaNs(&expected_real, &actual_real); + RemoveCorrespondingNaNs(&expected_imag, &actual_imag); + expected->real(expected_real); + expected->imag(expected_imag); + actual->real(actual_real); + actual->imag(actual_imag); + } + + // Returns a list of inputs that should be tested for closeness given some + // original input values. + // + // For denormal component inputs, we accept answers that are close to any of: + // + // - evaluate_op(input) + // - evaluate_op(+/-0), where the sign of 0 equal to the sign of + // `input`, + // - evaluate_op(+/-min_normal_float), where the sign of + // min_normal_float matches `input`. + // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of + // 0 is the opposite of `input`. + // + // (In particular, the XLA:CPU implementation of log flushes positive + // denormals to min-normal-float. This seems kind of reasonable if our + // goal is to avoid infinities because they cause nans?) + std::vector GetTestValuesWithSubnormalSubstitutions( + ComponentNativeRefT value) { + std::vector test_values; + if (std::fpclassify(value) == FP_SUBNORMAL) { + test_values.reserve(relaxed_denormal_signs_ ? 3 : 2); + test_values.push_back(std::copysign(0, value)); + test_values.push_back(std::copysign( + std::numeric_limits::min(), value)); + if (relaxed_denormal_signs_) { + test_values.push_back(std::copysign(0, -value)); + } + } else { + test_values.push_back(value); + } + return test_values; + } + + // Similar to complex numbers, we only need to test the components that are + // subnormal. We can find the subnormal testing values for each component, + // then take the Cartesian product of each set of component values. + std::vector> + GetTestValuesWithSubnormalSubstitutions( + std::complex value) { + using complex = std::complex; + + auto real_values = GetTestValuesWithSubnormalSubstitutions(value.real()); + auto imag_values = GetTestValuesWithSubnormalSubstitutions(value.imag()); + + std::vector test_values; + test_values.reserve(real_values.size() * imag_values.size()); + for (auto real : real_values) { + for (auto imag : imag_values) { + test_values.push_back(complex(real, imag)); + } + } + + return test_values; + } + + // The test values for an XLA function with N operands are the Cartesian + // product of the test values for each of the N operands. + std::vector> + GetTestValuesWithSubnormalSubstitutions( + const std::array& value) { + std::vector> test_values; + + std::array, N> component_test_values; + int total = 1; + for (int i = 0; i < N; ++i) { + component_test_values[i] = + GetTestValuesWithSubnormalSubstitutions(value[i]); + if (!component_test_values.empty()) { + total *= component_test_values[i].size(); + } + } + + // If total == 1, then value has no subnormal components, so we can just + // return a vector with value in it. + if (total == 1) { + test_values.push_back(value); + return test_values; + } + + test_values.reserve(total); + + // Perform a Cartesian product of the vectors in component_test_values. + // We can calculate this by uniquely mapping each integer from 0 to + // (total - 1) to a list of component indices. The function that maps an + // integer z to the index of component j is: + // component_index(j) = (i / NumValues(0, j-1)) % NumValues(j, j) + // and NumIndices(x, y) is the number of values in the Cartesian product of + // component_test_values[x], component_test_values[x+1], ... + // component_test_values[y]. + for (int i = 0; i < total; ++i) { + int accumulated_num_values = 1; + std::array test_value; + for (int j = 0; j < N; ++j) { + int num_indices = component_test_values[j].size(); + int component_index = (i / accumulated_num_values) % num_indices; + test_value[j] = component_test_values[j][component_index]; + accumulated_num_values *= num_indices; + } + test_values.push_back(std::move(test_value)); + } + return test_values; + } + + // The number of values that can be substituted for subnormal inputs. + static constexpr int kNumSubnormalSubstitutionValues = 4; + + // Encodings used to determine where subnormal test values are cached. + static constexpr int kPositiveMin = 0; + static constexpr int kNegativeMin = 1; + static constexpr int kPositiveZero = 2; + static constexpr int kNegativeZero = 3; + static constexpr int kNonSubnormal = -1; + static constexpr int kInvalidCacheIndex = -1; + + // Since we take the cross product of all possible test values, and each + // component has kNumSubnormalSubstitutionValues possible test values, then + // the total number of different cache locations are + // kNumSubnormalSubstitutionValues raised to the num_components. + // num_components = N for the reals, and 2*N for the complex. + static constexpr int GetMaxCacheSize() { + return pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1)); + } + + // When we are testing a value such that all of its components are subnormal, + // we also need to test inputs made up of the Cartesian product of values + // replaced for each subnormal component. These additional test inputs are + // common enough where it will be efficient to just cache the results of these + // Cartesian products. In order to cache these values, we need a one to one + // mapping between these Cartesian products and cache locations. + // + // Our mapping works by assigning each component an integer in + // [0, kNumSubnormalSubstitutionValues) based on its test value. By lining + // these integers up with the n'th component corresponding to the n'th digit, + // then for each Cartesian product element we essentially create a unique base + // kNumSubnormalSubstitutionValues number. This number represents our cache + // index. + // + // In the event that there a component is not a subnormal, the value should + // not be cached, so we return a kNonSubnormal value. + + static int GetCacheLocation(ComponentNativeRefT value) { + bool positive = !std::signbit(value); + if (std::abs(value) == std::numeric_limits::min()) { + if (positive) { + return kPositiveMin; + } else { + return kNegativeMin; + } + } else if (value != 0) { + CHECK(std::fpclassify(value) != FP_SUBNORMAL); + return kNonSubnormal; + } else if (positive) { + return kPositiveZero; + } else { + return kNegativeZero; + } + } + + static int GetCacheLocation(std::complex value) { + int real_loc = GetCacheLocation(value.real()); + int imag_loc = GetCacheLocation(value.imag()); + if (real_loc == kNonSubnormal || imag_loc == kNonSubnormal) { + return kNonSubnormal; + } else { + return real_loc * kNumSubnormalSubstitutionValues + imag_loc; + } + } + + static int GetCacheLocation(const NativeRefInputs& input) { + int location = 0; + int cache_size_per_element = + (kIsComplex + ? kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues + : kNumSubnormalSubstitutionValues); + for (int i = 0; i < N; ++i) { + int comp_loc = GetCacheLocation(input[i]); + if (i == kNonSubnormal) { + return kNonSubnormal; + } + location *= cache_size_per_element; + location += comp_loc; + } + return location; + } + + // The inverse function of GetCacheLocation. + + template + static RetT FromCacheLocationComponent(int cache_loc) { + LOG(FATAL) << "Not implemented."; + } + + template <> + static ComponentNativeRefT + FromCacheLocationComponent(int cache_loc) { + switch (cache_loc) { + case kPositiveMin: + return std::numeric_limits::min(); + case kNegativeMin: + return -std::numeric_limits::min(); + case kPositiveZero: + return static_cast(0.0); + case kNegativeZero: + return static_cast(-0.0); + default: + LOG(FATAL) << "Invalid cache_loc value of " << cache_loc; + } + } + + template <> + static std::complex + FromCacheLocationComponent>( + int cache_loc) { + CHECK_LT(cache_loc, + kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues); + CHECK_GE(cache_loc, 0); + + std::complex value; + value.real(FromCacheLocationComponent( + cache_loc / kNumSubnormalSubstitutionValues)); + value.imag(FromCacheLocationComponent( + cache_loc % kNumSubnormalSubstitutionValues)); + return std::move(value); + } + + static NativeRefInputs FromCacheLocation(int cache_loc) { + NativeRefInputs input; + int cache_size_per_element = + (kIsComplex + ? kNumSubnormalSubstitutionValues * kNumSubnormalSubstitutionValues + : kNumSubnormalSubstitutionValues); + for (int i = N - 1; i >= 0; --i) { + input[i] = FromCacheLocationComponent( + cache_loc % cache_size_per_element); + cache_loc /= cache_size_per_element; + } + + return input; + } + + // Returns a string that describes the test value for the actual value. + std::string GetSubnormalDescription(ComponentNativeRefT test_val, + ComponentNativeRefT actual_val) { + const string sp_min_normal = "sign-preserving min-normal-float"; + const string sp_zero = "sign-preserving zero"; + const string nsp_zero = "non-sign-preserving zero"; + + switch (GetCacheLocation(test_val)) { + case kNegativeMin: + case kPositiveMin: + return sp_min_normal; + case kNegativeZero: + case kPositiveZero: + return (std::signbit(test_val) == std::signbit(actual_val)) ? sp_zero + : nsp_zero; + default: + return ""; + } + } + + std::string GetSubnormalDescription( + std::complex test_val, + std::complex actual_val) { + std::string real = + GetSubnormalDescription(test_val.real(), actual_val.real()); + std::string imag = + GetSubnormalDescription(test_val.imag(), actual_val.imag()); + + if (real.empty()) { + if (imag.empty()) { + return ""; + } + real = "real"; + } else if (imag.empty()) { + imag = "imag"; + } + + return absl::StrCat("(", real, ", ", imag, ")"); + } + + std::string GetSubnormalDescription(std::array test_vals, + std::array actual_vals) { + if (N == 1) { + return GetSubnormalDescription(test_vals[0], actual_vals[0]); + } + + std::array str_vals; + for (int i = 0; i < N; ++i) { + str_vals[i] = GetSubnormalDescription(test_vals[i], actual_vals[i]); + if (str_vals[i].empty()) { + str_vals[i] = "original"; + } + } + + return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")"); + } + + InputLiterals CreateInputLiterals() { + InputLiterals literals; + for (int i = 0; i < N; ++i) { + literals[i] = LiteralUtil::CreateFromDimensions(T, {GetInputSize()}); + } + return std::move(literals); + } + + // Determines if two output values are sufficiently close to each other based + // on an error spec. + bool IsClose(NativeRefT expected, NativeRefT actual, ErrorSpec spec) { + // When two corresponding values are a NaN, they can be considered to have + // the same value, so the values are just set to 0. + RemoveCorrespondingNaNs(&expected, &actual); + + if (spec.strict_signed_zeros) { + if (CheckSignedZeroError(expected, actual)) { + return false; + } + } + + // Replace Inf with Max when calculating absolute or relative errors. This + // allows the test to pass when another value are close to Inf and the + // specified absolute or relative errors are not zero. + double abs_err = + std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual)); + double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected)); + + return abs_err <= spec.abs_err || rel_err <= spec.rel_err; } template @@ -140,24 +682,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { } } - template - struct IntegralTypeWithByteWidth {}; - - template <> - struct IntegralTypeWithByteWidth<2> { - using type = uint16; - }; - - template <> - struct IntegralTypeWithByteWidth<4> { - using type = uint32; - }; - - template <> - struct IntegralTypeWithByteWidth<8> { - using type = uint64; - }; - // Converts part or all bits in an uint64 to the value of the floating point // data type being tested. // @@ -166,47 +690,57 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // bit patterns for T. This bit pattern is zero extended and stored as uint64. // This function is used to convert such a bit pattern stored as uint64 to // the input value for T. - // - // T is the type of the floating value represented by the `bits`. - template - T ConvertValue(uint64 bits) { - using I = typename IntegralTypeWithByteWidth::type; + static ComponentNativeT ConvertValue(uint64 bits) { + using I = ComponentIntegralNativeT; I used_bits = static_cast(bits); - return BitCast(used_bits); + return BitCast(used_bits); } - template - T ConvertAndReplaceKnownIncorrectValueWith(uint64 bits, - int replacement_value = 0) { + ComponentNativeT ConvertAndReplaceKnownIncorrectValueWith( + uint64 bits, int replacement_value = 0) { if (known_incorrect_fn_ && known_incorrect_fn_(bits)) { - return static_cast(replacement_value); + return static_cast(replacement_value); } - return ConvertValue(bits); + return ConvertValue(bits); } - static string StringifyNum(float x); + static string StringifyNum(ComponentNativeT x); - static string StringifyNum(half x); - - static string StringifyNum(bfloat16 x); - - template - static string StringifyNum(std::complex x) { - return absl::StrCat(StringifyNum(x.real()), " ", StringifyNum(x.imag())); + static string StringifyNum(std::complex x) { + return absl::StrCat("(", StringifyNum(x.real()), ", ", + StringifyNum(x.imag()), ")"); } - template - static void AppendStringifyNum(std::string* s, T x) { + // We also stringify the NativeRefT, so we need to generate an additional + // version of this function when NativeRefT != NativeT. + template < + typename T1 = NativeRefT, + class = typename std::enable_if::value>::type> + static string StringifyNum(NativeRefT x) { + return ExhaustiveOpTestBase::StringifyNum(x); + } + + static string StringifyNum(const NativeInputs& inputs) { + if (N == 1) { + return StringifyNum(inputs[0]); + } + + std::array str_vals; + for (int i = 0; i < N; ++i) { + str_vals[i] = StringifyNum(inputs[i]); + } + + return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")"); + } + + static void AppendStringifyNum(std::string* s, NativeT x) { absl::StrAppend(s, StringifyNum(x)); } - static std::function GetDefaultSpecGenerator( - PrimitiveType ty); - - static std::vector> CreateExhaustiveF32Ranges(); + static ErrorSpecGen GetDefaultSpecGenerator(); protected: - // The primitive type under test. + // The primitive type being tested. const PrimitiveType ty_; // The platform under test. @@ -225,7 +759,448 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // // XLA:GPU preserves denormal signs, but other backends don't. bool relaxed_denormal_signs_ = platform_ != "CUDA"; + + private: + using EvaluateOpInternal = NativeRefT (*)(NativeRefInputs); + using ErrorSpecGenInternal = ErrorSpec (*)(NativeInputs); + + template + ErrorSpec CallErrorSpec(FuncPtr* func, const std::array& in) { + return func(in[0]); + } + + template + ErrorSpec CallErrorSpec(FuncPtr* func, const std::array& in) { + return func(in[0], in[1]); + } + + template + Type CallOperation(FuncPtr* func, const std::array& in) { + return func(in[0]); + } + + template + Type CallOperation(FuncPtr* func, const std::array& in) { + return func(in[0], in[1]); + } }; +// Represents a set of 64 bit chunks by representing the starting bit chunk, +// the last bit chunk, and the spacing between two adjacent bit chunks, without +// actually storing all the bit chunks being generated. The bit chunk iterator +// is provided to retrieve all the bit chunks. +// +// This data structure is used to generate the bit representation to test +// operations that requires more than 64 bit input data. In this case, +// truly exhaustive testing is not possible and we want to test a value every +// n values, where n == spacing_. +// +// Currently, the iterator of BitChunks adds the `spacing_` to a bit chunk to +// compute the next bit chunk. We can change this to use values generated +// by a random number generator that can achieve the average spacing +// statistically, if we will find this is necessary. +class BitChunks { + public: + class iterator + : public std::iterator { + public: + iterator() {} + + explicit iterator(const BitChunks* bit_chunks) + : bit_chunks_(bit_chunks), next_bit_chunk_(bit_chunks->start_) {} + + iterator& operator++() { + Next(); + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + Next(); + return retval; + } + + bool operator==(iterator other) const { + return bit_chunks_ == other.bit_chunks_ && + next_bit_chunk_ == other.next_bit_chunk_; + } + + bool operator!=(iterator other) const { return !(*this == other); } + + iterator MoveToEnd() { + MoveNextBitChunkToOnePassEnd(); + return *this; + } + + reference operator*() const { + CHECK(*this != this->bit_chunks_->end()); + return next_bit_chunk_; + } + + const BitChunks* GetBitChunks() const { return bit_chunks_; } + + void Reset() { next_bit_chunk_ = bit_chunks_->start_; } + + void Next() { + CHECK(*this != this->bit_chunks_->end()); + if (next_bit_chunk_ == bit_chunks_->end_) { + MoveNextBitChunkToOnePassEnd(); + } else { + next_bit_chunk_ += bit_chunks_->spacing_; + if (next_bit_chunk_ > bit_chunks_->end_) { + next_bit_chunk_ = bit_chunks_->end_; + } + } + } + + std::string ToString() const { + return absl::StrFormat("0x%08x", next_bit_chunk_); + } + + private: + // Move next_bit_chunk_ to 1 pass the bit_chunks_->end, to mark that the + // iterator has reached the end. When spacing_ is not one, or if we will + // change to use a random value instead of spacing_ in function Next(), + // normalizing the representation of the iterator ending this way can + // can simplify the checking for iterator ending. + void MoveNextBitChunkToOnePassEnd() { + next_bit_chunk_ = bit_chunks_->end_ + 1; + } + + const BitChunks* bit_chunks_; + uint64 next_bit_chunk_; + }; + + iterator begin() const { return iterator(this); } + iterator end() const { + iterator end(this); + return end.MoveToEnd(); + } + + explicit BitChunks(uint64 start = 0, uint64 end = 0, uint64 spacing = 1) + : start_(start), end_(end), spacing_(spacing) { + CHECK_GE(end_, start_); + CHECK_NE(spacing, 0) << ToString(); + } + + int64 GetTotalBitChunks() const { + if (start_ == end_) { + return 1; + } + + return 1 + (end_ - start_ + spacing_ - 1) / spacing_; + } + + std::string ToString() const { + return absl::StrFormat("(0x%08x, 0x%08x, 0x%08x)", start_, end_, spacing_); + } + + uint64 start_; + uint64 end_; + uint64 spacing_; +}; + +inline string StringifyNum(BitChunks c) { return c.ToString(); } + +inline string StringifyNum(BitChunks::iterator c) { return c.ToString(); } + +template +void AppendStringifyNum(std::string* s, T x) { + absl::StrAppend(s, StringifyNum(x)); +} + +// Represents a set of floating point values through the possible values for +// the three components: mantissa, exponent, and sign. Also implements an +// iterator for retrieving all the represented floating point values. +class FpValues { + public: + static constexpr uint kTotalBitChunks = 3; + + class iterator + : public std::iterator { + public: + explicit iterator(const FpValues* fp_values) : fp_values_(fp_values) { + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + iters_[i] = BitChunks::iterator(&fp_values->GetBitChunks(i)); + } + } + + iterator& operator++() { + Next(); + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + Next(); + return retval; + } + + bool operator==(iterator other) const { + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + if (iters_[i] != other.GetBitChunksIter(i)) { + return false; + } + } + return true; + } + + bool operator!=(iterator other) const { return !(*this == other); } + + iterator MoveToEnd() { + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + iters_[i].MoveToEnd(); + } + return *this; + } + + uint64 operator*() const { + uint64 value = 0; + for (int i = 0; i < FpValues::kTotalBitChunks; ++i) { + value = value | (*iters_[i]) << fp_values_->offsets_[i]; + } + return value; + } + + const BitChunks::iterator& GetBitChunksIter(int i) { return iters_[i]; } + + std::string ToString() const { + return absl::StrJoin(iters_, ",", + AppendStringifyNum); + } + + private: + // Moves the iterator for the ith BitChunks to the next value, and + // returns true if the new state is not the end of the iterator. + bool Next(int i = 0) { + iters_[i].Next(); + if (iters_[i] == iters_[i].GetBitChunks()->end()) { + if (i == FpValues::kTotalBitChunks - 1) { + return false; + } + if (Next(i + 1)) { + iters_[i].Reset(); + return true; + } + return false; + } + return true; + } + + std::array iters_; + const FpValues* fp_values_; + }; + + FpValues() : bit_chunks_(), offsets_() {} + FpValues(absl::Span chunks, absl::Span offsets) { + CHECK_EQ(chunks.size(), offsets.size() - 1); + CHECK_EQ(chunks.size(), kTotalBitChunks); + std::copy_n(chunks.begin(), kTotalBitChunks, bit_chunks_.begin()); + std::copy_n(offsets.begin(), kTotalBitChunks, offsets_.begin()); + + // The last value in `offsets` is the total number of bits. + offsets_[kTotalBitChunks] = offsets[kTotalBitChunks]; + // Validate the input values. + for (int i = 0; i < kTotalBitChunks; ++i) { + int total_bits = offsets[i + 1] - offsets[i]; + if (total_bits < 64) { + uint64 bound = 1ull << total_bits; + CHECK_LT(chunks[i].start_, bound); + CHECK_LT(chunks[i].end_, bound); + } else { + CHECK_EQ(total_bits, 64); + } + } + } + + iterator begin() const { return iterator(this); } + + iterator end() const { + iterator end(this); + return end.MoveToEnd(); + } + + int64 GetTotalNumValues() const { + int64 total = 1; + absl::c_for_each(bit_chunks_, [&](const BitChunks& chunks) { + total *= chunks.GetTotalBitChunks(); + }); + return total; + } + + const BitChunks& GetBitChunks(int i) const { return bit_chunks_[i]; } + + std::string ToString() const { + return absl::StrCat( + "[", absl::StrJoin(bit_chunks_, ",", AppendStringifyNum), + "]"); + } + + std::array bit_chunks_; + std::array offsets_; +}; + +template ::value || + std::is_same::value>::type* = nullptr> +int GetMantissaTotalBits() { + return std::numeric_limits::digits - 1; +} + +template +int GetFpTotalBits() { + return sizeof(T) * 8; +} + +template +int GetExponentTotalBits() { + return GetFpTotalBits() - GetMantissaTotalBits() - 1; +} + +template +uint64 GetAllOneMantissa() { + return (1ull << GetMantissaTotalBits()) - 1ull; +} + +template +uint64 GetAllOneExponent() { + return (1ull << GetExponentTotalBits()) - 1ull; +} + +template ::value || + std::is_same::value>::type* = nullptr> +FpValues GetFpValues(BitChunks mantissa, BitChunks exponent, BitChunks sign) { + int total_bits = GetFpTotalBits(); + return FpValues({mantissa, exponent, sign}, + {0, GetMantissaTotalBits(), total_bits - 1, total_bits}); +} + +template +FpValues GetZeros() { + return GetFpValues(BitChunks(0, 0, 1), BitChunks(0, 0, 1), + BitChunks(0, 1, 1)); +} + +template +FpValues GetSubnormals(int approx_num_values) { + int mantissa = GetMantissaTotalBits(); + uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); + return GetFpValues( + BitChunks(0x1, GetAllOneMantissa(), mantissa_spacing), + BitChunks(0, 0, 1), BitChunks(0, 1, 1)); +} + +template +FpValues GetInfinites() { + uint64 all_one_exp = GetAllOneExponent(); + return GetFpValues(BitChunks(0, 0, 1), + BitChunks(all_one_exp, all_one_exp, 1), + BitChunks(0, 1, 1)); +} + +template +FpValues GetNans(int approx_num_values) { + int mantissa = GetMantissaTotalBits(); + uint64 mantissa_spacing = (1ull << mantissa) / (approx_num_values * 2); + uint64 all_one_exp = GetAllOneExponent(); + return GetFpValues( + BitChunks(0x1, GetAllOneMantissa(), mantissa_spacing), + BitChunks(all_one_exp, all_one_exp, 1), BitChunks(0, 1, 1)); +} + +template +FpValues GetNormals(int approx_num_values) { + float component_total = std::sqrt(static_cast(approx_num_values)); + return GetFpValues( + BitChunks(0x1, GetAllOneMantissa(), + (1ull << (GetMantissaTotalBits() + 1)) / component_total), + BitChunks(0x1, GetAllOneExponent() - 1, + (1ull << (GetExponentTotalBits() + 1)) / component_total), + BitChunks(0, 1, 1)); +} + +// Returns a vector of FpValues, which together represent about +// `approx_num_values` floating point values of type `T`, with each FpValues +// represents about `num_values_per_group` floating point values. +template +std::vector GetFpValuesWithExponents(uint64 first_exponent, + uint64 exponent_spacing, + uint64 num_exponents, + uint64 approx_num_values, + uint64 num_values_per_group) { + const uint64 num_signs = 2; + uint64 approx_num_mantissa = approx_num_values / (num_exponents * num_signs); + uint64 num_mantissa_per_group = + num_values_per_group / (num_exponents * num_signs); + CHECK_GT(approx_num_mantissa, 0); + CHECK_GT(num_mantissa_per_group, 0); + + CHECK_LT(first_exponent + num_exponents - 1ull, GetAllOneExponent()); + int mantissa = GetMantissaTotalBits(); + uint64 mantissa_spacing = (1ull << mantissa) / approx_num_mantissa; + + std::vector result; + for (uint64 group_start = 0; group_start < GetAllOneMantissa(); + group_start += mantissa_spacing * num_mantissa_per_group) { + uint64 group_end = + group_start + (num_mantissa_per_group - 1) * mantissa_spacing; + if (group_end > GetAllOneMantissa()) { + group_end = GetAllOneMantissa(); + } + result.push_back(GetFpValues( + BitChunks(group_start, group_end, mantissa_spacing), + BitChunks(first_exponent, first_exponent + num_exponents - 1, 1), + BitChunks(0, 1, 1))); + } + return result; +} + +// Returns a vector of FpValues together represent about `approx_num_values` +// "very large" floating point values and `approx_num_values` "very small" +// floating point values of type `T`, which each FpValues represent about +// `num_values_per_group` floating point values. Because we use FpValues as +// a parameter for parameterized testing, the number of floating values +// represented by each FpValues affects the input size for each sub-test and +// the hence the peak memory usage of the test. +template +std::vector GetFpValuesForMagnitudeExtremeNormals( + uint64 approx_num_values = 40000, uint64 num_values_per_group = 4000) { + std::vector large = + GetFpValuesWithExponents(GetAllOneExponent() - 5, 1, 5, + approx_num_values / 2, num_values_per_group); + std::vector small = GetFpValuesWithExponents( + 1, 1, 5, approx_num_values / 2, num_values_per_group); + large.insert(large.end(), small.begin(), small.end()); + return large; +} + +template +std::vector CreateFpValuesForBoundaryTest() { + return {GetZeros(), GetSubnormals(1000), GetInfinites(), + GetNans(1000)}; +} + +inline std::vector> CreateExhaustiveF32Ranges() { + // We break up the 2^32-element space into small'ish chunks to keep peak + // memory usage low. + std::vector> result; + const int64 step = 1 << 25; + for (int64 i = 0; i < (1l << 32); i += step) { + result.push_back({i, i + step}); + } + return result; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_EXHAUSTIVE_OP_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc index 0186d7d668d..3a14bb2d4cc 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_unary_test.cc @@ -155,154 +155,8 @@ float HostDigamma(float x) { return result - reflection; } -class ExhaustiveRealUnaryTestBase : public ExhaustiveOpTestBase { - public: - explicit ExhaustiveRealUnaryTestBase(PrimitiveType ty) - : ExhaustiveOpTestBase(ty) {} - - // A helper for implementing the Run method for unary op test. It constructs - // the HLO module, compiles and runs the module and checks the result. - // - // T: is the input and output data type. - // RefT: is the type used for the host function to get the reference result. - // RefT is different from T when T is of less than 32 bits, that is half and - // bfloat16. - // - // We use a function pointer for evaluate_op for performance because it is - // called each time an output element is compared inside a loop in routine - // ExpectNear. - template - void RunImpl(std::function enqueue_op, - RefT (*evaluate_op)(RefT), const Literal& input_literal, - std::function error_spec_gen) { - XlaBuilder builder(TestName()); - XlaOp input = Parameter(&builder, 0, input_literal.shape(), "input"); - enqueue_op(input); - TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, - RunComputation(comp, {&input_literal})); - ExpectNear(input_literal, result_literal, evaluate_op, - error_spec_gen); - } - - // We essentially reimplement LiteralTestUtil::Near here because - // a) this streamlined implementation is much faster, and - // b) we can print out better error messages (namely, we can print out - // which floating-point value input failed, while LiteralTestUtil::Near - // can only print out the input index that failed). - // c) we need special handling of certain inputs. For example, we say that - // a denormal input has multiple correct outputs (namely, f(x) and f(0)) - // and just needs to be close to one of them. - template - void ExpectNear(const Literal& input_literal, const Literal& result_literal, - RefT (*evaluate_op)(RefT), - std::function error_spec_gen) { - absl::Span input_arr = input_literal.data(); - absl::Span result_arr = result_literal.data(); - ASSERT_EQ(result_arr.size(), input_arr.size()); - int64 mismatches = 0; - // Hoisting these out of the loop is a nice speedup on shards that have many - // denormals. - const T expected_at_pos_zero = static_cast(evaluate_op(0)); - const T expected_at_neg_zero = static_cast(evaluate_op(-0.0)); - const T expected_at_pos_min_normal_float = - static_cast(evaluate_op(std::numeric_limits::min())); - const T expected_at_neg_min_normal_float = - static_cast(evaluate_op(-std::numeric_limits::min())); - - for (int64 i = 0; i < input_arr.size(); ++i) { - T input = input_arr[i]; - RefT input_ref_ty = static_cast(input); - T actual = result_arr[i]; - T expected = static_cast(evaluate_op(input_ref_ty)); - - ErrorSpec error_spec = error_spec_gen(input_ref_ty); - - // We only implement fpclassify for float and double, so we call - // IsClose for half and bfloat16. - if (IsClose(static_cast(expected), static_cast(actual), - error_spec)) { - continue; - } - - // Easy case: If `input` is not denormal and !IsClose(expected, actual, - // error_spec), print an error. - if (std::fpclassify(input_ref_ty) != FP_SUBNORMAL) { - PrintMismatch(&mismatches, [&] { - return absl::StrFormat("Mismatch on %s. Expected %s, but got %s.", - StringifyNum(input), StringifyNum(expected), - StringifyNum(actual)); - }); - continue; - } - - // Otherwise, `input` is denormal. For denormal inputs, we accept answers - // that are close to any of: - // - // - evaluate_op(input) - // - evaluate_op(+/-0), where the sign of 0 equal to the sign of - // `input`, - // - evaluate_op(+/-min_normal_float), where the sign of - // min_normal_float matches `input`. - // - if relaxed_denormal_signs_, evaluate_op(-/+0), where the sign of - // 0 is the opposite of `input`. - // - // (In particular, the XLA:CPU implementation of log flushes positive - // denormals to min-normal-float. This seems kind of reasonable if our - // goal is to avoid infinities because they cause nans?) - T sign_preserving_ftz_expected = std::signbit(input_ref_ty) - ? expected_at_neg_zero - : expected_at_pos_zero; - T flush_to_normal_expected = std::signbit(input_ref_ty) - ? expected_at_neg_min_normal_float - : expected_at_pos_min_normal_float; - T sign_nonpreserving_ftz_expected = std::signbit(input_ref_ty) - ? expected_at_pos_zero - : expected_at_neg_zero; - if (IsClose(static_cast(sign_preserving_ftz_expected), - static_cast(actual), error_spec) || - IsClose(static_cast(flush_to_normal_expected), - static_cast(actual), error_spec) || - (relaxed_denormal_signs_ && - IsClose(static_cast(sign_nonpreserving_ftz_expected), - static_cast(actual), error_spec))) { - continue; - } - - if (relaxed_denormal_signs_) { - PrintMismatch(&mismatches, [&] { - return absl::StrFormat( - "Mismatch on denormal value %s. Expected one of:\n" - " %10s (evaluated at full-precision value)\n" - " %10s (evaluated at sign-preserving min-normal-float)\n" - " %10s (evaluated after flushing to sign-preserving zero)\n" - " %10s (evaluated after flushing to non-sign-preserving " - "zero)\n" - "but got %s.", - StringifyNum(input), // - StringifyNum(expected), StringifyNum(flush_to_normal_expected), - StringifyNum(sign_preserving_ftz_expected), - StringifyNum(sign_nonpreserving_ftz_expected), - StringifyNum(actual)); - }); - } else { - PrintMismatch(&mismatches, [&] { - return absl::StrFormat( - "Mismatch on denormal value %s. Expected one of:\n" - " %10s (evaluated at full-precision value)\n" - " %10s (evaluated at sign-preserving min-normal-float)\n" - " %10s (evaluated after flushing to sign-preserving zero)\n" - "but got %s.", - StringifyNum(input), // - StringifyNum(expected), StringifyNum(flush_to_normal_expected), - StringifyNum(sign_preserving_ftz_expected), // - StringifyNum(actual)); - }); - } - } - EXPECT_EQ(mismatches, 0); - } -}; +template +using ExhaustiveUnaryTest = ExhaustiveOpTestBase; // Exhaustive test for unary operations for <= 32bit floating point types. // @@ -310,53 +164,21 @@ class ExhaustiveRealUnaryTestBase : public ExhaustiveOpTestBase { // - primitive type under test, // - (begin, end) range under test, as zero-extended int64s bitcast to the // primtive type under test. +template class Exhaustive32BitOrLessUnaryTest - : public ExhaustiveRealUnaryTestBase, - public ::testing::WithParamInterface< - std::tuple>> { + : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface> { public: - typedef float (*F32EvaluateOp)(float); - - Exhaustive32BitOrLessUnaryTest() - : ExhaustiveRealUnaryTestBase(std::get<0>(GetParam())) {} - - void Run(std::function enqueue_op, F32EvaluateOp evaluate_op) { - return Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(ty_)); - } - - void Run(std::function enqueue_op, F32EvaluateOp evaluate_op, - std::function error_spec_gen) { - SetFastMathDisabled(true); - - // Run all HLO passes. In particular, constant folding is disabled by - // default for tests, but we need to run it in order to tickle some bugs. - mutable_debug_options()->clear_xla_disable_hlo_passes(); - Literal input_literal = CreateInputLiteral(); - switch (ty_) { - case F32: - FillInput(&input_literal); - return RunImpl(enqueue_op, evaluate_op, input_literal, - error_spec_gen); - case F16: - FillInput(&input_literal); - return RunImpl(enqueue_op, evaluate_op, input_literal, - error_spec_gen); - case BF16: - FillInput(&input_literal); - return RunImpl(enqueue_op, evaluate_op, input_literal, - error_spec_gen); - default: - LOG(FATAL) << "Unhandled type."; - } - } - // Sets error parameters appropriately for testing sin/cos/tan. void SetParamsForSinCosTan(); + protected: + using typename ExhaustiveUnaryTest::NativeT; + private: int64 GetInputSize() override { int64 begin, end; - std::tie(begin, end) = std::get<1>(GetParam()); + std::tie(begin, end) = GetParam(); VLOG(2) << "Checking range [" << begin << ", " << end << ")"; return end - begin; } @@ -367,54 +189,64 @@ class Exhaustive32BitOrLessUnaryTest // pattern. Each bit representation is first truncated to the integral type of // the same bit as the type being tested, if needed, and then bitcasted to the // type being tested. - template - void FillInput(Literal* input_literal) { - using IntegralT = typename IntegralTypeWithByteWidth::type; - int64 input_size = input_literal->element_count(); + void FillInput(std::array* input_literal) override { + using IntegralT = + typename ExhaustiveOpTestBase::ComponentIntegralNativeT; + int64 input_size = (*input_literal)[0].element_count(); int64 begin, end; - std::tie(begin, end) = std::get<1>(GetParam()); + std::tie(begin, end) = GetParam(); VLOG(2) << "Checking range [" << begin << ", " << end << ")"; CHECK_EQ(input_size, end - begin); - absl::Span input_arr = input_literal->data(); + absl::Span input_arr = (*input_literal)[0].data(); for (int64 i = 0; i < input_size; i++) { IntegralT input_val = i + begin; - input_arr[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); + input_arr[i] = + this->ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); } } }; -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Log) { - auto error_spec_gen = GetDefaultSpecGenerator(ty_); - if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { - error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; }; - } +typedef Exhaustive32BitOrLessUnaryTest ExhaustiveF32UnaryTest; +typedef Exhaustive32BitOrLessUnaryTest ExhaustiveF16UnaryTest; +typedef Exhaustive32BitOrLessUnaryTest ExhaustiveBF16UnaryTest; +#define XLA_TEST_FLOAT_32_BITS_OR_LESS(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ + __VA_ARGS__ \ + XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ + __VA_ARGS__ \ + XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ + __VA_ARGS__ + +XLA_TEST_FLOAT_32_BITS_OR_LESS(Log, { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { + error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; }; + } Run(Log, std::log, error_spec_gen); -} +}) -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Log1p) { - auto error_spec_gen = GetDefaultSpecGenerator(ty_); +XLA_TEST_FLOAT_32_BITS_OR_LESS(Log1p, { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { - error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; }; + error_spec_gen = +[](NativeT x) { return ErrorSpec{0.001, 0.001}; }; } - Run(Log1p, std::log1p, error_spec_gen); -} +}) -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Exp) { +XLA_TEST_FLOAT_32_BITS_OR_LESS(Exp, { // When x < -105, the true value of exp(x) is smaller than the smallest F32, // so exp(x) should return exactly 0. We want our implementation of exp to // return exactly 0 as well, as not doing so implies either that our // implementation of exp is not following the asymptotic behavior that exp(x) // approaches 0 as x approaches -inf, or that our implementation is not // approaching 0 fast enough. - auto default_spec_gen = GetDefaultSpecGenerator(ty_); - auto error_spec_gen = [default_spec_gen](float x) { - if (x < -105) { + ErrorSpecGen error_spec_gen = +[](NativeT x) { + if (x < static_cast(-105)) { return ErrorSpec{0, 0}; } - return default_spec_gen(x); + return GetDefaultSpecGenerator()(x); }; // Our CPU implementation of exp returns one incorrect value: says @@ -432,20 +264,13 @@ XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Exp) { } else { Run(Exp, std::exp, error_spec_gen); } -} +}) -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Expm1) { - auto default_spec_gen = GetDefaultSpecGenerator(ty_); - auto error_spec_gen = [default_spec_gen](float x) { - if (x < -105) { - return ErrorSpec{0, 0}; - } else if (std::abs(x) < 5e-6) { - // For points around x=0, we should make sure that the result is accurate - // within 1 ULP of the value. - return ErrorSpec{0, 1.1921e-7}; - } - return default_spec_gen(x); - }; +XLA_TEST_FLOAT_32_BITS_OR_LESS(Expm1, { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (ty_ == F32) { + error_spec_gen = +[](NativeT x) { return ErrorSpec{0, 0.00015}; }; + } // Our CPU implementation of expm1 returns one incorrect value: says // exp(88.7228394) = max-float, but the correct answer is inf. We deem this @@ -462,65 +287,73 @@ XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Expm1) { } else { Run(Expm1, std::expm1, error_spec_gen); } -} +}) // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but // this *did* find a bug, namely that some backends were assuming sqrt(x) == // pow(x, 0.5), but this is not true for x == -inf. -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, PowOneHalf) { - Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, - +[](float x) { return std::pow(x, 0.5f); }); -} +XLA_TEST_FLOAT_32_BITS_OR_LESS(PowOneHalf, { + EvaluateOp fn = +[](float x) { return std::pow(x, 0.5f); }; + // TODO(b/123837116): Enable the test for all values after fixing the bug. + if (platform_ != "Host" && platform_ != "CUDA") { + fn = +[](float x) { + if (x == -std::numeric_limits::infinity()) { + return std::nanf(""); + } + return std::pow(x, 0.5f); + }; + } + Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn); +}) -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Rsqrt) { +XLA_TEST_FLOAT_32_BITS_OR_LESS(Rsqrt, { Run( Rsqrt, +[](float x) { return 1 / std::sqrt(x); }); -} +}) -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sqrt) { - auto default_spec_gen = GetDefaultSpecGenerator(ty_); - std::function error_spec_gen; +XLA_TEST_FLOAT_32_BITS_OR_LESS(Sqrt, { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "Host" || platform_ == "CUDA") { - error_spec_gen = [default_spec_gen](float x) { - ErrorSpec spec = default_spec_gen(x); + error_spec_gen = +[](NativeT x) { + auto spec = GetDefaultSpecGenerator()(x); spec.strict_signed_zeros = true; return spec; }; - } else { - error_spec_gen = default_spec_gen; } Run(Sqrt, std::sqrt, error_spec_gen); -} +}) // TODO(jlebar): Test trig functions over complex inputs. - -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Acosh) { +XLA_TEST_P(ExhaustiveF32UnaryTest, Acosh) { // Error inherited from Log, which our implementation of Acosh uses. - std::function error_spec_gen; - if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { - error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; }; - } else { - error_spec_gen = GetDefaultSpecGenerator(ty_); + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ != "Host" && platform_ != "CUDA") { + error_spec_gen = +[](float x) { return ErrorSpec{0.001, 0.001}; }; } Run(Acosh, std::acosh, error_spec_gen); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Asinh) { - // Error inherited from Log, which our implementation of Asinh uses. - std::function error_spec_gen; - if (platform_ != "Host" && platform_ != "CUDA" && ty_ == F32) { - error_spec_gen = [](float x) { return ErrorSpec{0.001, 0.001}; }; - } else { - error_spec_gen = GetDefaultSpecGenerator(ty_); +XLA_TEST_P(ExhaustiveF16UnaryTest, Acosh) { Run(Acosh, std::acosh); } +XLA_TEST_P(ExhaustiveBF16UnaryTest, Acosh) { Run(Acosh, std::acosh); } + +// Tests for Asinh +XLA_TEST_P(ExhaustiveF32UnaryTest, Asinh) { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ != "Host" && platform_ != "CUDA") { + error_spec_gen = +[](float x) { return ErrorSpec{0.001, 0.001}; }; } + Run(Asinh, std::asinh, error_spec_gen); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Atanh) { Run(Atanh, std::atanh); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Acos) { Run(Acos, std::acos); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Asin) { Run(Asin, std::asin); } +XLA_TEST_P(ExhaustiveF16UnaryTest, Asinh) { Run(Asinh, std::asinh); } +XLA_TEST_P(ExhaustiveBF16UnaryTest, Asinh) { Run(Asinh, std::asinh); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Cosh) { +XLA_TEST_FLOAT_32_BITS_OR_LESS(Atanh, { Run(Atanh, std::atanh); }) +XLA_TEST_FLOAT_32_BITS_OR_LESS(Acos, { Run(Acos, std::acos); }) +XLA_TEST_FLOAT_32_BITS_OR_LESS(Asin, { Run(Asin, std::asin); }) + +XLA_TEST_FLOAT_32_BITS_OR_LESS(Cosh, { // Our cosh implementation incorrectly overflows to inf for +/-89.4159851. // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to // max-float, so we deem this acceptable. @@ -539,8 +372,9 @@ XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Cosh) { }; } Run(Cosh, host_cosh); -} -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sinh) { +}) + +XLA_TEST_FLOAT_32_BITS_OR_LESS(Sinh, { // Our sinh implementation incorrectly overflows to +/-inf for +/-89.4159851. // The correct answer of 3.40281961e+38 (0x7f7fffec) is very close to // max-float, so we deem this acceptable. @@ -559,76 +393,103 @@ XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sinh) { }; } Run(Sinh, host_sinh); -} -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Tanh) { Run(Tanh, std::tanh); } +}) -void Exhaustive32BitOrLessUnaryTest::SetParamsForSinCosTan() { - if (platform_ == "Host" || platform_ == "CUDA") { +XLA_TEST_FLOAT_32_BITS_OR_LESS(Tanh, { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ == "CUDA") { + error_spec_gen = +[](NativeT x) { + return x <= static_cast(-20.0) || x >= static_cast(20.0) + ? ErrorSpec{0, 0} + : GetDefaultSpecGenerator()(x); + }; + } + Run(Tanh, std::tanh, error_spec_gen); +}) + +template +void Exhaustive32BitOrLessUnaryTest::SetParamsForSinCosTan() { + if (this->platform_ == "Host" || this->platform_ == "CUDA") { return; } // Non CPU/GPU targets may have used the Cody-Waite range reduction technique // and will not provide meaningful results for sin/cos/tan if magnitudes // exceed 2**p. - if (ty_ == F32) { - known_incorrect_fn_ = [](int64 v) { + if (T == F32) { + this->known_incorrect_fn_ = [](int64 v) { float f = BitCast(static_cast(v)); return std::abs(f) > (1 << 13); }; - } else if (ty_ == BF16) { - known_incorrect_fn_ = [](int64 v) { + } else if (T == BF16) { + this->known_incorrect_fn_ = [](int64 v) { float f = static_cast(BitCast(static_cast(v))); return std::abs(f) > (1 << 13); }; } } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Cos) { +XLA_TEST_P(ExhaustiveF32UnaryTest, Cos) { SetParamsForSinCosTan(); - std::function error_spec_gen; - if (ty_ == F32) { - error_spec_gen = [](float) { return ErrorSpec{0.001, 0.001}; }; - } else { - error_spec_gen = GetDefaultSpecGenerator(ty_); - } - Run(Cos, std::cos, error_spec_gen); + Run( + Cos, std::cos, +[](NativeT) { + return ErrorSpec{0.001, 0.001}; + }); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Sin) { +XLA_TEST_P(ExhaustiveF16UnaryTest, Cos) { SetParamsForSinCosTan(); - std::function error_spec_gen; - if (ty_ == F32) { - error_spec_gen = [](float) { return ErrorSpec{0.001, 0.001}; }; - } else { - error_spec_gen = GetDefaultSpecGenerator(ty_); - } - Run(Sin, std::sin, error_spec_gen); + Run(Cos, std::cos); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Tan) { +XLA_TEST_P(ExhaustiveBF16UnaryTest, Cos) { SetParamsForSinCosTan(); - std::function error_spec_gen; - if (ty_ == F32) { - error_spec_gen = [](float) { return ErrorSpec{0.001, 0.001}; }; - } else { - error_spec_gen = GetDefaultSpecGenerator(ty_); - } - Run(Tan, std::tan, error_spec_gen); + Run(Cos, std::cos); +} + +XLA_TEST_P(ExhaustiveF32UnaryTest, Sin) { + SetParamsForSinCosTan(); + Run( + Sin, std::sin, +[](NativeT) { + return ErrorSpec{0.001, 0.001}; + }); +} +XLA_TEST_P(ExhaustiveF16UnaryTest, Sin) { + SetParamsForSinCosTan(); + Run(Sin, std::sin); +} +XLA_TEST_P(ExhaustiveBF16UnaryTest, Sin) { + SetParamsForSinCosTan(); + Run(Sin, std::sin); +} + +XLA_TEST_P(ExhaustiveF32UnaryTest, Tan) { + SetParamsForSinCosTan(); + Run( + Tan, std::tan, +[](NativeT) { + return ErrorSpec{0.001, 0.001}; + }); +} +XLA_TEST_P(ExhaustiveF16UnaryTest, Tan) { + SetParamsForSinCosTan(); + Run(Tan, std::tan); +} +XLA_TEST_P(ExhaustiveBF16UnaryTest, Tan) { + SetParamsForSinCosTan(); + Run(Tan, std::tan); } // TODO(jlebar): Enable these. -// XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Atan) { Run(Atan, std::atan); } -// XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Atan2) { Run(Atan2, std::atan2); } +// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan) { Run(Atan, std::atan); } +// XLA_TEST_FLOAT_32_BITS_OR_LESS(Atan2) { Run(Atan2, std::atan2); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Erf) { Run(Erf, std::erf); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Erfc) { Run(Erfc, std::erfc); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, ErfInv) { Run(ErfInv, HostErfInv); } -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Digamma) { - std::function error_spec_gen; +XLA_TEST_FLOAT_32_BITS_OR_LESS(Erf, { Run(Erf, std::erf); }) +XLA_TEST_FLOAT_32_BITS_OR_LESS(Erfc, { Run(Erfc, std::erfc); }) +XLA_TEST_FLOAT_32_BITS_OR_LESS(ErfInv, { Run(ErfInv, HostErfInv); }) +XLA_TEST_FLOAT_32_BITS_OR_LESS(Digamma, { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ != "Host" && platform_ != "CUDA") { // TODO(b/123956399): This is a fairly high error, significantly higher than // we see on CPU/GPU. - error_spec_gen = [](float) { return ErrorSpec{0.01, 0.01}; }; - } else { - error_spec_gen = GetDefaultSpecGenerator(ty_); + error_spec_gen = +[](NativeT) { return ErrorSpec{0.01, 0.01}; }; } if (platform_ == "CUDA") { @@ -651,27 +512,25 @@ XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Digamma) { } else { Run(Digamma, HostDigamma, error_spec_gen); } -} -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Lgamma) { +}) + +XLA_TEST_FLOAT_32_BITS_OR_LESS(Lgamma, { // Our implementation gets within 0.0001 rel error except for ~20 denormal // inputs on GPU. Anyway 0.001 rel error should be good enough for lgamma. - auto default_spec_gen = GetDefaultSpecGenerator(ty_); - std::function error_spec_gen; + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); if (platform_ == "CUDA" && (ty_ == F32 || ty_ == F16)) { - error_spec_gen = [default_spec_gen](float x) { - ErrorSpec spec = default_spec_gen(x); + error_spec_gen = +[](NativeT x) { + auto spec = GetDefaultSpecGenerator()(x); spec.rel_err = 0.001; return spec; }; - } else { - error_spec_gen = default_spec_gen; } float (*host_lgamma)(float) = std::lgamma; if (platform_ != "Host" && platform_ != "CUDA") { // TODO(b/123956399): This is a fairly high error, significantly higher than // we see on CPU/GPU. - error_spec_gen = [](float) { return ErrorSpec{0.01, 0.01}; }; + error_spec_gen = +[](NativeT) { return ErrorSpec{0.01, 0.01}; }; // Overflows to inf for input 4.08500343e+36 (0x7c44af8e). if (ty_ == F32) { @@ -684,28 +543,362 @@ XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Lgamma) { } } Run(Lgamma, host_lgamma, error_spec_gen); -} +}) -XLA_TEST_P(Exhaustive32BitOrLessUnaryTest, Round) { Run(Round, std::round); } +XLA_TEST_FLOAT_32_BITS_OR_LESS(Round, { Run(Round, std::round); }) -INSTANTIATE_TEST_SUITE_P( - F32, Exhaustive32BitOrLessUnaryTest, - ::testing::Combine(::testing::Values(F32), - ::testing::ValuesIn( - ExhaustiveOpTestBase::CreateExhaustiveF32Ranges()))); +#if defined(UNARY_TEST_TARGET_F32_OR_SMALLER) + +INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, + ::testing::ValuesIn(CreateExhaustiveF32Ranges())); #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -INSTANTIATE_TEST_SUITE_P( - F16, Exhaustive32BitOrLessUnaryTest, - ::testing::Combine(::testing::Values(F16), - ::testing::Values(std::make_pair(0, 1 << 16)))); +INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 16))); #endif #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) +INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 16))); +#endif + +#endif + +// Exhaustive test for unary operations for double. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - FpValues representing a set of double values. + +class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface { + private: + int64 GetInputSize() override { + FpValues values = GetParam(); + return values.GetTotalNumValues(); + } + + void FillInput(std::array* input_literal) override { + FpValues fp_values = GetParam(); + int64 input_size = (*input_literal)[0].element_count(); + LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", " + << input_size; + absl::Span input_arr = (*input_literal)[0].data(); + + uint64 i = 0; + for (auto bits : fp_values) { + input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1); + ++i; + } + CHECK_EQ(i, input_size); + } +}; + +XLA_TEST_P(ExhaustiveF64UnaryTest, Log) { Run(Log, std::log); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Log1p) { Run(Log1p, std::log1p); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Exp) { Run(Exp, std::exp); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Expm1) { Run(Expm1, std::expm1); } + +// TODO(b/138385863): Turn on the test for GPU after fixing the bug. +XLA_TEST_P(ExhaustiveF64UnaryTest, DISABLED_ON_GPU(PowOneHalf)) { + Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, + +[](double x) { return std::pow(x, 0.5); }); +} + +XLA_TEST_P(ExhaustiveF64UnaryTest, Rsqrt) { + Run( + Rsqrt, +[](double x) { return 1 / std::sqrt(x); }); +} + +XLA_TEST_P(ExhaustiveF64UnaryTest, Sqrt) { Run(Sqrt, std::sqrt); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Acosh) { Run(Acosh, std::acosh); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Asinh) { Run(Asinh, std::asinh); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Atanh) { Run(Atanh, std::atanh); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Acos) { Run(Acos, std::acos); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Asin) { Run(Asin, std::asin); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Cosh) { Run(Cosh, std::cosh); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Sinh) { Run(Sinh, std::sinh); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Tanh) { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ == "CUDA") { + error_spec_gen = +[](NativeT x) { + return x <= static_cast(-20.0) || x >= static_cast(20.0) + ? ErrorSpec{0, 0} + : GetDefaultSpecGenerator()(x); + }; + } + Run(Tanh, std::tanh, error_spec_gen); +} + +XLA_TEST_P(ExhaustiveF64UnaryTest, Cos) { Run(Cos, std::cos); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Sin) { Run(Sin, std::sin); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Tan) { Run(Tan, std::tan); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Round) { Run(Round, std::round); } + +XLA_TEST_P(ExhaustiveF64UnaryTest, Erf) { + Run(Erf, std::erf, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; }); +} + +XLA_TEST_P(ExhaustiveF64UnaryTest, Erfc) { + Run(Erfc, std::erfc, [](NativeT x) { return ErrorSpec{1e-20, 1e-20}; }); +} + +#if defined(UNARY_TEST_TARGET_F64) +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) INSTANTIATE_TEST_SUITE_P( - BF16, Exhaustive32BitOrLessUnaryTest, - ::testing::Combine(::testing::Values(BF16), - ::testing::Values(std::make_pair(0, 1 << 16)))); + SpecialValues, ExhaustiveF64UnaryTest, + ::testing::ValuesIn(CreateFpValuesForBoundaryTest())); + +INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest, + ::testing::Values(GetNormals(1000))); + +// Tests a total of 4000000000 inputs, with 16000000 inputs in each sub-test, to +// keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnituedNormalValues, ExhaustiveF64UnaryTest, + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( + 4000000000ull, 16000000))); +#endif +#endif + +// T is the Primitive Type of the complex number +// Test parameter is a tuple containing +// - primitive type under test, +// - two FpValues representing the values for the real and imaginary +// components. The complex numbers for the test input is the cartesian +// product of the values represented by the two FpValues. +template +class ExhaustiveComplexUnaryTestBase + : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface> { + protected: + using typename ExhaustiveUnaryTest::NativeT; + + void SetParamsForTanh() { + // TODO(b/138126045): Current libc++ implementation of the complex tanh + // function returns (NaN, NaN) when the imaginary + // component is more than half of the max value. + // TODO(b/138750327): Current libc++ implementation of the complex tanh + // function returns (1, 0) when the real component is + // negative infinity, when it should return (-1, 0). + // We only need to set the former as incorrect values for C128 because when + // testing with C64, we first cast our input to a C128 value. + this->known_incorrect_fn_ = [&](int64 v) { + double f = this->ConvertValue(v); + return (T == C128 && + std::abs(f) > std::numeric_limits::max() / 2) || + f == -std::numeric_limits::infinity(); + }; + } + + private: + // Generates the input complex literal given the FpValues representation for + // the real and imaginary components. + void FillInput(std::array* input_literal) override { + FpValues real_values = std::get<0>(GetParam()); + FpValues imag_values = std::get<1>(GetParam()); + + VLOG(2) << " testing input total " + << real_values.GetTotalNumValues() * imag_values.GetTotalNumValues() + << ", range " << real_values.ToString() << " " + << imag_values.ToString(); + + absl::Span input_arr = (*input_literal)[0].data(); + + uint64 i = 0; + for (auto real : real_values) { + for (auto imag : imag_values) { + input_arr[i] = + NativeT(this->ConvertAndReplaceKnownIncorrectValueWith(real, 1), + this->ConvertAndReplaceKnownIncorrectValueWith(imag, 1)); + + ++i; + } + } + } + + int64 GetInputSize() override { + FpValues real_values = std::get<0>(GetParam()); + FpValues imag_values = std::get<1>(GetParam()); + return real_values.GetTotalNumValues() * imag_values.GetTotalNumValues(); + } +}; + +typedef ExhaustiveComplexUnaryTestBase ExhaustiveC64UnaryTest; +typedef ExhaustiveComplexUnaryTestBase ExhaustiveC128UnaryTest; + +// TODO(b/138578594): Enable the test for the CPU backend after fixing the bug. +XLA_TEST_P(ExhaustiveC64UnaryTest, DISABLED_ON_CPU(Log)) { + Run(Log, [](complex64 x) { return std::log(x); }); +} + +XLA_TEST_P(ExhaustiveC64UnaryTest, Sqrt) { + Run(Sqrt, [](complex64 x) { + return static_cast( + std::sqrt(static_cast(x))); + }); +} + +XLA_TEST_P(ExhaustiveC64UnaryTest, Rsqrt) { + Run(Rsqrt, [](complex64 x) { + return static_cast( + complex128(1, 0) / std::sqrt(static_cast(x))); + }); +} + +// The current libc++ implementation of the complex tanh function provides +// less accurate results when the denomenator of a complex tanh is small, due +// to floating point precision loss. To avoid this issue for complex64 numbers, +// we cast it to and from a complex128 when computing tanh. +XLA_TEST_P(ExhaustiveC64UnaryTest, Tanh) { + SetParamsForTanh(); + ErrorSpecGen error_spec_gen = +[](complex64 x) { + // This implementation of Tanh becomes less accurate when the denominator + // is small. + if (std::cosh(2 * x.real()) + std::cos(2 * x.imag()) < 1e-4) { + return ErrorSpec{5e-2, 5e-2}; + } + + return GetDefaultSpecGenerator()(x); + }; + Run( + Tanh, + +[](complex64 x) { + return static_cast(std::tanh(static_cast(x))); + }, + error_spec_gen); +} + +#if defined(UNARY_TEST_TARGET_COMPLEX) +INSTANTIATE_TEST_SUITE_P( + F32SpecialValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + F32SpecialAndNormalValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(10000)))); + +INSTANTIATE_TEST_SUITE_P( + F32NormalAndSpecialValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(10000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + F32NormalAndNormalValues, ExhaustiveC64UnaryTest, + ::testing::Combine(::testing::Values(GetNormals(10000)), + ::testing::Values(GetNormals(10000)))); + +// Tests a total of 40000 ^ 2 inputs, with 4000 ^ 2 inputs in each sub-test, to +// keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + F32LargeAndSmallMagnituedNormalValues, ExhaustiveC64UnaryTest, + ::testing::Combine( + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, + 4000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 4000)))); +#endif + + +XLA_TEST_P(ExhaustiveC128UnaryTest, Log) { + // TODO(b/138578313): Enable the test for all values after fixing the bug. + known_incorrect_fn_ = [&](int64 v) { + double f = this->ConvertValue(v); + return std::fpclassify(f) == FP_NAN || std::abs(f) > 1.0e+300 || + std::abs(f) < 1.0e-300; + }; + Run(Log, [](complex128 x) { return std::log(x); }); +} + +XLA_TEST_P(ExhaustiveC128UnaryTest, Sqrt) { + // Similar to the Tanh bug. + known_incorrect_fn_ = [&](int64 v) { + double f = this->ConvertValue(v); + return std::abs(f) > std::numeric_limits::max() / 2; + }; + Run(Sqrt, [](complex128 x) { return std::sqrt(x); }); +} + +XLA_TEST_P(ExhaustiveC128UnaryTest, Rsqrt) { + ErrorSpecGen error_spec_gen = GetDefaultSpecGenerator(); + if (platform_ == "CUDA") { + // Edge case on CUDA backend where the Log of a complex number made up of + // the smallest denormals is more accurate than the interpreter backend. + error_spec_gen = [](complex128 x) { + constexpr double denorm_min = std::numeric_limits::denorm_min(); + if (std::abs(x.real()) == denorm_min && + std::abs(x.imag()) == denorm_min) { + return ErrorSpec(0.5, 0.5); + } + return GetDefaultSpecGenerator()(x); + }; + } + Run( + Rsqrt, + [](complex128 x) { return complex128(1, 0) / std::sqrt(x); }, + error_spec_gen); +} + +XLA_TEST_P(ExhaustiveC128UnaryTest, Tanh) { + SetParamsForTanh(); + Run( + Tanh, +[](complex128 x) { return std::tanh(x); }); +} + +#if defined(UNARY_TEST_TARGET_COMPLEX) +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(10000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(10000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + F32NormalAndNormalValues, ExhaustiveC128UnaryTest, + ::testing::Combine(::testing::Values(GetNormals(10000)), + ::testing::Values(GetNormals(10000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test, to +// keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnituedNormalValues, ExhaustiveC128UnaryTest, + ::testing::Combine( + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +#endif #endif } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 701dac3902b..8df4a57afcd 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -84,6 +84,29 @@ void LlvmIrGenTestBase::CompileAheadOfTimeAndVerifyIr( EXPECT_TRUE(filecheck_result.ValueOrDie()); } +void LlvmIrGenTestBase::MatchOptimizedHlo(absl::string_view hlo, + absl::string_view pattern, + bool print_operand_shape) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(hlo)); + HloPrintOptions print_opts; + print_opts.set_print_operand_shape(print_operand_shape); + StatusOr filecheck_result = + RunFileCheck(optimized_module->ToString(print_opts), pattern); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(filecheck_result.ValueOrDie()); +} + +StatusOr> LlvmIrGenTestBase::GetOptimizedModule( + absl::string_view hlo) { + HloModuleConfig config; + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo, config)); + return backend().compiler()->RunHloPasses( + std::move(module), backend().default_stream_executor(), + backend().default_stream_executor()->GetAllocator()); +} + LLVMCompiler* LlvmIrGenTestBase::GetLLVMCompiler() { return static_cast(backend().compiler()); } diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h index 018f9546afc..ff69787c273 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.h @@ -58,6 +58,21 @@ class LlvmIrGenTestBase : public CodegenTestBase { const string& pattern, bool match_optimized_ir); + // Compiles the given `hlo` with optimizations, and verifies that optimized + // HLO matches the given FileCheck pattern. + void MatchOptimizedHlo(absl::string_view hlo, absl::string_view pattern, + bool print_operand_shape = false); + + // LikeMatchOptimizedHlo, but checks operand shapes as well. + void MatchOptimizedHloWithShapes(absl::string_view hlo, + absl::string_view pattern) { + MatchOptimizedHlo(hlo, pattern, /*print_operand_shape=*/true); + } + + // Compiles and returns module with optimizations from a given HLO. + StatusOr> GetOptimizedModule( + absl::string_view hlo); + private: LLVMCompiler* GetLLVMCompiler(); diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index c5e1dbe7432..ff8adb0c460 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -142,6 +142,15 @@ XLA_TEST_P(ReduceWindowTest, Min3In5Stride2) { {}, ErrorSpec(0.00001)); } +XLA_TEST_P(ReduceWindowTest, Min3In5Stride2Same) { + const auto input = CreateConstantFromLiteral( + LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + ReduceWindowMin(input, {3}, {2}, Padding::kSame); + ComputeAndCompareLiteral(&builder_, + LiteralUtil::CreateR1({1000, 10, 1}), {}, + ErrorSpec(0.00001)); +} + XLA_TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { const auto input = CreateConstantFromLiteral( LiteralUtil::CreateR1({10000, 1000, 100, 10, 1}), &builder_); diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 9636df2ff5f..c9c2cb7630b 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -36,6 +36,7 @@ limitations under the License. #define DISABLED_ON_CPU(X) X #define DISABLED_ON_GPU(X) X +#define DISABLED_ON_GPU_ROCM(X) X #define DISABLED_ON_INTERPRETER(X) X // We need this macro instead of pasting directly to support nesting @@ -54,6 +55,12 @@ limitations under the License. #ifdef XLA_TEST_BACKEND_GPU # undef DISABLED_ON_GPU # define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X) + +#if TENSORFLOW_USE_ROCM +# undef DISABLED_ON_GPU_ROCM +# define DISABLED_ON_GPU_ROCM(X) XLA_TEST_PASTE(DISABLED_, X) +#endif // TENSORFLOW_USE_ROCM + #endif // XLA_TEST_BACKEND_GPU #ifdef XLA_TEST_BACKEND_INTERPRETER diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index c3618eb20fa..4563d7e0df2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/tests/test_utils.h" + #include #include "absl/base/casts.h" @@ -21,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/tests/test_utils.h" namespace xla { @@ -349,13 +351,14 @@ void PopulateWithRandomIntegralDataWithBounds(Literal* literal, // range [min, max]. Currently this works only for INT types. StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, std::minstd_rand0* engine, - int64 min, int64 max) { + int64 min, int64 max, + bool is_sorted) { if (shape.IsTuple()) { std::vector elements; for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN( - Literal element, - MakeFakeLiteralInternalWithBounds(element_shape, engine, min, max)); + TF_ASSIGN_OR_RETURN(Literal element, + MakeFakeLiteralInternalWithBounds( + element_shape, engine, min, max, is_sorted)); elements.push_back(std::move(element)); } return LiteralUtil::MakeTupleOwned(std::move(elements)); @@ -373,34 +376,58 @@ StatusOr MakeFakeLiteralInternalWithBounds(const Shape& shape, case S8: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case U8: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case S16: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case U16: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case S32: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case U32: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case S64: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; case U64: PopulateWithRandomIntegralDataWithBounds( &literal, engine, static_cast(min), static_cast(max)); + if (is_sorted) { + std::sort(literal.data().begin(), literal.data().end()); + } break; default: return Unimplemented( @@ -510,6 +537,7 @@ StatusOr CreateLiteralForConstrainedUses( int64 index_bound = INT64_MAX; bool no_duplicates = false; bool needs_constant = false; + bool needs_sorted_indices = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { @@ -547,6 +575,13 @@ StatusOr CreateLiteralForConstrainedUses( std::min(index_bound, operand_shape.dimensions(dim_in_operand)); } } + if (use->opcode() == HloOpcode::kScatter) { + needs_sorted_indices |= + Cast(use)->indices_are_sorted(); + } else { + needs_sorted_indices |= + Cast(use)->indices_are_sorted(); + } break; } case HloOpcode::kReduce: @@ -579,7 +614,7 @@ StatusOr CreateLiteralForConstrainedUses( } if (index_bound != INT64_MAX) { return MakeFakeLiteralInternalWithBounds(param.shape(), engine, -1, - index_bound); + index_bound, needs_sorted_indices); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index dacb5faa228..06ea42235b2 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -424,19 +424,6 @@ T CeilOfRatio(T dividend, T divisor) { return tensorflow::MathUtil::CeilOfRatio(dividend, divisor); } -template -std::vector ElementWiseCeilOfRatio(absl::Span dividends, - absl::Span divisors) { - std::vector ceil_of_ratios; - CHECK_EQ(dividends.size(), divisors.size()); - ceil_of_ratios.reserve(dividends.size()); - absl::c_transform(dividends, divisors, std::back_inserter(ceil_of_ratios), - [](const T dividend, const T divisor) { - return CeilOfRatio(dividend, divisor); - }); - return ceil_of_ratios; -} - // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index d91bc72c2f8..bfd79b537e3 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -1,15 +1,15 @@ """Wrapper around cc_proto_library used inside the XLA codebase.""" load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "cc_proto_library", ) load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "if_static", ) load( - "//tensorflow/core:platform/default/cuda_build_defs.bzl", + "//tensorflow/core/platform:default/cuda_build_defs.bzl", "if_cuda_is_configured", ) @@ -48,3 +48,6 @@ ORC_JIT_MEMORY_MAPPER_TARGETS = [] # We link the GPU plugin into the XLA Python extension if CUDA is enabled. def xla_python_default_plugins(): return if_cuda_is_configured(["//tensorflow/compiler/xla/service:gpu_plugin"]) + +def xla_py_test_deps(): + return [] diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 7a40e4096de..09c6c793a2f 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -180,6 +180,11 @@ message DebugOptions { // xla_cpu_enable_fast_math is false. bool xla_cpu_fast_math_honor_division = 126; + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to approximate calculations for functions. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_functions = 129; + // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. @@ -282,7 +287,13 @@ message DebugOptions { bool xla_gpu_force_conv_nchw = 125; - // Next id: 127 + // Paths to files with ptx code. + repeated string xla_gpu_ptx_file = 127; + + // Blacklist for cuDNN convolutions. + string xla_gpu_algorithm_blacklist_path = 128; + + // Next id: 130 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 1bd6db2662e..f5218ad4d8c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -294,6 +294,10 @@ message ExecutionProfile { // The size of the binary code in the executable. int64 executable_size_in_bytes = 6; + + // Whether this profile was drawn from a cache of profiles instead of from + // execution on the hardware. + bool profile_cache_hit = 7; } // Handle given to a user that represents an execution that the user launched @@ -579,6 +583,12 @@ message CholeskyOptions { bool lower = 1; } +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 67402c11fcc..ce614904523 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -8,7 +8,7 @@ load( ) load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library_py", ) diff --git a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc index 39c83c14f0a..d5f60ec33bb 100644 --- a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc +++ b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.cc @@ -45,7 +45,6 @@ EAGER_CLIENT_METHOD(WaitQueueDone); EAGER_CLIENT_METHOD(KeepAlive); EAGER_CLIENT_METHOD(CloseContext); EAGER_CLIENT_METHOD(RegisterFunction); -EAGER_CLIENT_METHOD(SendTensor); #undef EAGER_CLIENT_METHOD #define WORKER_CLIENT_METHOD(method) \ diff --git a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h index 2ef4efa652c..75e32e6d8f0 100644 --- a/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h +++ b/tensorflow/compiler/xrt/client/xrt_grpc_eager_client.h @@ -73,9 +73,6 @@ class XrtGrpcEagerClient { eager::RegisterFunctionResponse* response, StatusCallback done, CallOptions* call_opts = nullptr); - void SendTensorAsync(const eager::SendTensorRequest* request, - eager::SendTensorResponse* response, StatusCallback done, - CallOptions* call_opts = nullptr); // The following two methods are actually from the WorkerService API, not // EagerService, but are necessary for using remote Eager, and we include them diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.cc b/tensorflow/compiler/xrt/client/xrt_tf_client.cc index 88d0d25f84a..20206088799 100644 --- a/tensorflow/compiler/xrt/client/xrt_tf_client.cc +++ b/tensorflow/compiler/xrt/client/xrt_tf_client.cc @@ -286,15 +286,16 @@ XrtTensorHandle XrtTfContext::SendTensor( op_id = op->id; } - eager::SendTensorRequest request; + eager::EnqueueRequest request; request.set_context_id(context_id_); - request.set_op_id(op_id); - request.mutable_tensors()->AddAllocated(tensor_proto.release()); - request.set_device_name(devices_.at(rpc_device_id).name()); - auto response = std::make_shared(); + auto* send_tensor = request.add_queue()->mutable_send_tensor(); + send_tensor->set_op_id(op_id); + send_tensor->mutable_tensors()->AddAllocated(tensor_proto.release()); + send_tensor->set_device_name(devices_.at(rpc_device_id).name()); + auto response = std::make_shared(); auto context_ptr = shared_from_this(); absl::Notification done; - eager_client_->SendTensorAsync( + eager_client_->EnqueueAsync( &request, response.get(), [context_ptr, op_id, response, &done](Status status) { absl::MutexLock lock(&context_ptr->mu_); @@ -440,6 +441,7 @@ XrtTensorHandle& XrtTensorHandle::operator=(XrtTensorHandle&& other) { void XrtTensorHandle::Serialize(eager::RemoteTensorHandle* proto) const { proto->set_op_id(tensor_id_.first); proto->set_output_num(tensor_id_.second); + proto->set_device(context_->devices_.at(device_id_).name()); } AttrValue MakeAttrValue(std::string s) { diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index b791519c097..89daa98ee18 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -151,7 +151,7 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { xrt::XLAComputation computation_proto; OP_REQUIRES( ctx, - computation_proto.ParseFromString(computation_input.scalar()()), + computation_proto.ParseFromString(computation_input.scalar()()), errors::InvalidArgument( "Unable to parse computation input to XLAComputation")); @@ -191,7 +191,7 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { .ComputeProgramShape() .ToProto(); Tensor program_shape_output(DT_STRING, TensorShape({1})); - program_shape_output.vec()(0) = program_shape.SerializeAsString(); + program_shape_output.vec()(0) = program_shape.SerializeAsString(); ctx->set_output(1, program_shape_output); } diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 231387e314f..1c4e1f7e2c7 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -260,7 +260,7 @@ Status XRTExecuteOp::DoWork(OpKernelContext* context) { TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); xrt::XRTExecutionConfig config_proto; TF_RET_CHECK( - config_proto.ParseFromString(execution_config.scalar()())); + config_proto.ParseFromString(execution_config.scalar()())); int core_index_in_replica = config_proto.core_index_in_replica(); TF_RET_CHECK(core_index_in_replica == 0); @@ -343,12 +343,12 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { const Tensor& execution_plan = context->input(0); TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape())); xrt::XRTChainedExecutePlan plan; - TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); + TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); const Tensor& execution_config = context->input(1); TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); xrt::XRTChainedExecuteConfig config; - TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); + TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); XRTCompilationCache* cache; TF_RETURN_IF_ERROR(rm->Lookup( diff --git a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h index 2ffde52af06..769ec188349 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_state_ops.h +++ b/tensorflow/compiler/xrt/kernels/xrt_state_ops.h @@ -177,7 +177,7 @@ class XRTAllocateOp : public OpKernel { xrt::XLAAllocation allocation_proto; OP_REQUIRES( ctx, - allocation_proto.ParseFromString(allocation_info.scalar()()), + allocation_proto.ParseFromString(allocation_info.scalar()()), errors::InvalidArgument( "Unable to parse allocation input to XLAAllocation")); @@ -419,7 +419,7 @@ class XRTMakeTupleOp : public OpKernel { errors::Internal("tuple description input should be a string scalar")); xrt::XLATupleNode tuple_proto; OP_REQUIRES( - ctx, tuple_proto.ParseFromString(tuple_info.scalar()()), + ctx, tuple_proto.ParseFromString(tuple_info.scalar()()), errors::InvalidArgument("Unable to parse tuple input to XLATupleNode")); OpInputList arg_list; @@ -512,7 +512,7 @@ class XRTReadLiteralOp : public OpKernel { xla::LiteralProto literal_proto = literal.ToProto(); Tensor output(DT_STRING, TensorShape({})); - literal_proto.SerializeToString(&output.scalar()()); + SerializeToTString(literal_proto, &output.scalar()()); ctx->set_output(0, output); } }; @@ -627,7 +627,7 @@ class XRTWriteLiteralOp : public OpKernel { errors::Internal("literal input should be a string scalar")); xla::LiteralProto literal_proto; OP_REQUIRES(ctx, - literal_proto.ParseFromString(literal_info.scalar()()), + literal_proto.ParseFromString(literal_info.scalar()()), errors::InvalidArgument( "Unable to parse allocation input to LiteralProto")); xla::Literal literal; diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index cc6ab9a3ed4..701125f63f0 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -1,6 +1,6 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_cc_test") load( - "//tensorflow/core:platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "tf_cuda_tests_tags", ) diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index f0729251eeb..427a631f82d 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -127,7 +127,7 @@ xla::LiteralProto FloatMatrix( xla::Literal ReadOutputLiteral(const std::vector& outputs, size_t idx) { xla::LiteralProto response; - CHECK(response.ParseFromString(outputs[idx].scalar()())); + CHECK(response.ParseFromString(outputs[idx].scalar()())); return xla::Literal::CreateFromProto(response).ValueOrDie(); } @@ -316,7 +316,7 @@ TEST(RawApiTest, AllocFromTensor) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } @@ -351,7 +351,7 @@ TEST(RawApiTest, AllocUninitialized) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto read_back_literal; EXPECT_TRUE( - read_back_literal.ParseFromString(outputs[0].scalar()())); + read_back_literal.ParseFromString(outputs[0].scalar()())); Tensor read_back_tensor; TF_ASSERT_OK(LiteralToHostTensor( xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT, @@ -381,7 +381,7 @@ TEST(RawApiTest, AllocUninitialized) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(response, new_literal)); } } @@ -413,7 +413,7 @@ TEST(RawApiTest, AllocFromTensorTuple) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } @@ -439,7 +439,7 @@ TEST(RawApiTest, AllocFromTensorTupleSingle) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response)); } @@ -465,7 +465,7 @@ TEST(RawApiTest, AllocFromTensorRelayout) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); // We have sent literal's data (in array layout) with a attribute layout // {0,1}, so the expected literal read from device needs to be changed // accordingly. @@ -493,7 +493,7 @@ TEST(RawApiTest, AllocAndRewrite) { int64 allocation_handle = outputs[1].scalar()(); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); xla::LiteralProto new_literal = @@ -512,7 +512,7 @@ TEST(RawApiTest, AllocAndRewrite) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto new_response; - EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response)); Tensor release_tensor(DT_INT64, TensorShape({1})); @@ -652,7 +652,7 @@ TEST(RawApiTest, ReadAndWriteState) { session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); } @@ -673,7 +673,7 @@ TEST(RawApiTest, ReadAndWriteStateAutoFree) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response)); } @@ -707,13 +707,13 @@ TEST(RawApiTest, SubBuffer) { auto base_elements = base_literal.DecomposeTuple(); auto nested_0_elements = base_elements[0].Clone().DecomposeTuple(); xla::LiteralProto response_0; - EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0)); xla::LiteralProto response_1; - EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); + EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1)); xla::LiteralProto response_00; - EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar()())); + EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar()())); EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00)); } @@ -779,9 +779,9 @@ TEST(RawApiTest, MakeTuple) { std::vector outputs; TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs)); xla::LiteralProto response_0; - EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar()())); xla::LiteralProto response_1; - EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); + EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar()())); auto expected_0 = MakeTuple0(); EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0)); @@ -853,7 +853,7 @@ TEST(RawApiTest, ExecuteChainedOpByOp) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -973,7 +973,7 @@ TEST(RawApiTest, ExecuteChained) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({-150.0f, -36.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -1022,13 +1022,13 @@ TEST(RawApiTest, CompileAndExecute) { TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); xla::ProgramShapeProto program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -1077,13 +1077,13 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); xla::ProgramShapeProto program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -1128,7 +1128,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { {release}, &outputs)); xla::ProgramShapeProto program_shape_proto; - EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec()(0))); + EXPECT_TRUE( + program_shape_proto.ParseFromString(outputs[0].vec()(0))); xla::ProgramShape program_shape(program_shape_proto); EXPECT_EQ(program_shape.parameters_size(), 1); @@ -1196,7 +1197,7 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR2WithLayout({{18.0f}, {44.0f}}, layout); @@ -1231,7 +1232,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR0(3.0f); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -1281,7 +1282,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { TF_EXPECT_OK(session.Run({read_back}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto sum = xla::LiteralUtil::CreateR1({9.0f, 7.0f}); auto expected = xla::LiteralUtil::MakeTuple({&sum}); @@ -1343,7 +1344,7 @@ TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) { EXPECT_EQ(voutputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR0(kResults[i]); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); @@ -1514,13 +1515,13 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR0(15123899); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); xla::ProgramShapeProto program_shape; - EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType( xla::Shape(program_shape.result()), xla::S64)); @@ -1580,7 +1581,7 @@ TEST(RawApiTest, TestDeviceMemoryCompaction) { // we have on record. for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) { xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[j].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[j].scalar()())); EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response)); } } @@ -1668,7 +1669,7 @@ TEST(RawApiTest, TestDeviceMemorySwap) { EXPECT_EQ(outputs.size(), 1); xla::LiteralProto response; - EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto literal = xla::Literal::CreateFromProto(response).ValueOrDie(); EXPECT_EQ(literal, zero_literal); } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 79c0a4136e1..034ecd85fd0 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -1,7 +1,6 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. -load("//third_party/mpi:mpi.bzl", "if_mpi") load("//tensorflow:tensorflow.bzl", "if_not_windows") package( @@ -42,7 +41,6 @@ py_library( "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/feature_column:feature_column_py", "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/gan", "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/grid_rnn:grid_rnn_py", "//tensorflow/contrib/hadoop", @@ -109,7 +107,7 @@ py_library( "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ + ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], @@ -176,7 +174,7 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ + ] + select({ "//tensorflow:android": [], "//tensorflow:ios": [], "//tensorflow:linux_s390x": [], diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 0d510a16601..1611cf4f338 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -49,7 +49,6 @@ from tensorflow.contrib import estimator from tensorflow.contrib import factorization from tensorflow.contrib import feature_column from tensorflow.contrib import framework -from tensorflow.contrib import gan from tensorflow.contrib import graph_editor from tensorflow.contrib import grid_rnn from tensorflow.contrib import image diff --git a/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb b/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb index bf824e2760e..c51d2124920 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb @@ -18,18 +18,29 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "TuWj26KWz1fZ" }, "outputs": [], "source": [ - "!pip install -U -q tf-nightly" + "!pip install -U -q tf-nightly-2.0-preview" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Cp7iTarmz62Y" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "tf = tf.compat.v2\n", + "tf.enable_v2_behavior()" ] }, { @@ -41,25 +52,21 @@ "source": [ "### Fibonacci numbers\n", "\n", - "https://en.wikipedia.org/wiki/Fibonacci_number" + "https://en.wikipedia.org/wiki/Fibonacci_number\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 11, "metadata": { "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 197 + "height": 187 }, "colab_type": "code", "executionInfo": { - "elapsed": 7512, + "elapsed": 709, "status": "ok", - "timestamp": 1532101577266, + "timestamp": 1563825398552, "user": { "displayName": "", "photoUrl": "", @@ -68,7 +75,7 @@ "user_tz": 240 }, "id": "H7olFlMXqrHe", - "outputId": "472dbfe0-9449-4f93-e908-1a0785188a92" + "outputId": "25243e7b-99a7-4a6d-ad00-e97c52be7d97" }, "outputs": [ { @@ -89,25 +96,19 @@ } ], "source": [ - "import tensorflow as tf\n", - "from tensorflow.contrib import autograph as ag\n", - "\n", - "\n", + "@tf.function\n", "def fib(n):\n", " f1 = 0\n", " f2 = 1\n", - " for i in range(n):\n", + " for i in tf.range(n):\n", " tmp = f2\n", " f2 = f2 + f1\n", " f1 = tmp\n", - " print(i, ': ', f2)\n", + " tf.print(i, ': ', f2)\n", " return f2\n", "\n", "\n", - "with tf.Graph().as_default():\n", - " final_fib = ag.to_graph(fib)(tf.constant(10))\n", - " with tf.Session() as sess:\n", - " sess.run(final_fib)" + "_ = fib(tf.constant(10))" ] }, { @@ -122,68 +123,15 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 541 - }, + "colab": {}, "colab_type": "code", - "executionInfo": { - "elapsed": 103, - "status": "ok", - "timestamp": 1532101577412, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "UeWjK8rHq6Cj", - "outputId": "73ece895-12fb-489a-e52c-032945d7ed7a" + "id": "UeWjK8rHq6Cj" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "from __future__ import print_function\n", - "import tensorflow as tf\n", - "\n", - "def tf__fib(n):\n", - " try:\n", - " with tf.name_scope('fib'):\n", - " f1 = 0\n", - " f2 = 1\n", - "\n", - " def extra_test(f1_1, f2_1):\n", - " with tf.name_scope('extra_test'):\n", - " return True\n", - "\n", - " def loop_body(i, f1_1, f2_1):\n", - " with tf.name_scope('loop_body'):\n", - " tmp = f2_1\n", - " f2_1 = f2_1 + f1_1\n", - " f1_1 = tmp\n", - " with ag__.utils.control_dependency_on_returns(ag__.utils.\n", - " dynamic_print(i, ': ', f2_1)):\n", - " f2, i_1 = ag__.utils.alias_tensors(f2_1, i)\n", - " return f1_1, f2\n", - " f1, f2 = ag__.for_stmt(ag__.utils.dynamic_builtin(range, n),\n", - " extra_test, loop_body, (f1, f2))\n", - " return f2\n", - " except:\n", - " ag__.rewrite_graph_construction_error(ag_source_map__)\n", - "\n" - ] - } - ], + "outputs": [], "source": [ - "print(ag.to_code(fib))" + "print(tf.autograph.to_code(fib.python_function))" ] }, { @@ -200,20 +148,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": { "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 125 + "height": 119 }, "colab_type": "code", "executionInfo": { - "elapsed": 233, + "elapsed": 663, "status": "ok", - "timestamp": 1532101577681, + "timestamp": 1563825401385, "user": { "displayName": "", "photoUrl": "", @@ -222,7 +166,7 @@ "user_tz": 240 }, "id": "33CAheYsrEQ7", - "outputId": "82a493ee-15b5-419d-8c9c-5f4159090a05" + "outputId": "2a88b65d-4fed-4d96-8770-0c68ffece861" }, "outputs": [ { @@ -240,8 +184,9 @@ ], "source": [ "import tensorflow as tf\n", - "from tensorflow.contrib import autograph as ag\n", "\n", + "\n", + "@tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.EQUALITY_OPERATORS)\n", "def fizzbuzz(i, n):\n", " while i \u003c n:\n", " msg = ''\n", @@ -251,14 +196,11 @@ " msg += 'Buzz'\n", " if msg == '':\n", " msg = tf.as_string(i)\n", - " print(msg)\n", + " tf.print(msg)\n", " i += 1\n", " return i\n", "\n", - "with tf.Graph().as_default():\n", - " final_i = ag.to_graph(fizzbuzz)(tf.constant(10), tf.constant(16))\n", - " with tf.Session() as sess:\n", - " sess.run(final_i)" + "_ = fizzbuzz(tf.constant(10), tf.constant(16))" ] }, { @@ -273,98 +215,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 1081 - }, + "colab": {}, "colab_type": "code", - "executionInfo": { - "elapsed": 289, - "status": "ok", - "timestamp": 1532101578003, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "bBhFIIaZrxvx", - "outputId": "d076a7ea-e643-4689-f90a-57f5d086dedc" + "id": "bBhFIIaZrxvx" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "from __future__ import print_function\n", - "import tensorflow as tf\n", - "\n", - "def tf__fizzbuzz(i, n):\n", - " try:\n", - " with tf.name_scope('fizzbuzz'):\n", - "\n", - " def loop_test(i_1):\n", - " with tf.name_scope('loop_test'):\n", - " return tf.less(i_1, n)\n", - "\n", - " def loop_body(i_1):\n", - " with tf.name_scope('loop_body'):\n", - " msg = ''\n", - "\n", - " def if_true():\n", - " with tf.name_scope('if_true'):\n", - " msg_1, = msg,\n", - " msg_1 += 'Fizz'\n", - " return msg_1,\n", - "\n", - " def if_false():\n", - " with tf.name_scope('if_false'):\n", - " return msg,\n", - " msg = ag__.utils.run_cond(tf.equal(i_1 % 3, 0), if_true, if_false)\n", - "\n", - " def if_true_1():\n", - " with tf.name_scope('if_true_1'):\n", - " msg_2, = msg,\n", - " msg_2 += 'Buzz'\n", - " return msg_2,\n", - "\n", - " def if_false_1():\n", - " with tf.name_scope('if_false_1'):\n", - " return msg,\n", - " msg = ag__.utils.run_cond(tf.equal(i_1 % 5, 0), if_true_1, if_false_1\n", - " )\n", - "\n", - " def if_true_2():\n", - " with tf.name_scope('if_true_2'):\n", - " msg_3, = msg,\n", - " msg_3 = tf.as_string(i_1)\n", - " return msg_3,\n", - "\n", - " def if_false_2():\n", - " with tf.name_scope('if_false_2'):\n", - " return msg,\n", - " msg = ag__.utils.run_cond(tf.equal(msg, ''), if_true_2, if_false_2)\n", - " with ag__.utils.control_dependency_on_returns(ag__.utils.\n", - " dynamic_print(msg)):\n", - " msg_4 = ag__.utils.alias_tensors(msg)\n", - " i_1 += 1\n", - " return i_1,\n", - " i = ag__.while_stmt(loop_test, loop_body, (i,), (tf, n, ag__, i))\n", - " return i\n", - " except:\n", - " ag__.rewrite_graph_construction_error(ag_source_map__)\n", - "\n" - ] - } - ], + "outputs": [], "source": [ - "print(ag.to_code(fizzbuzz))" + "print(tf.autograph.to_code(fizzbuzz.python_function))" ] }, { @@ -393,12 +252,7 @@ "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "7moIlf8VABkl" }, @@ -414,44 +268,47 @@ "id": "QlEvfIQPAYF5" }, "source": [ - "#### Game of Life for AutoGraph" + "#### Game of Life for AutoGraph\n", + "\n", + "Note: the code may take a while to run." ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, + "colab": {}, "colab_type": "code", "id": "5pCK2qQSAAK4" }, "outputs": [], "source": [ "#@test {\"skip\": true} \n", - "NUM_STEPS = 100" + "NUM_STEPS = 75" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GPZANPdhMagD" + }, + "source": [ + "Note: This code uses a non-vectorized algorithm, which is quite slow. For 75 steps, it will take a few minutes to run. " ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": { "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 308 + "height": 309 }, "colab_type": "code", "executionInfo": { - "elapsed": 14892, + "elapsed": 147654, "status": "ok", - "timestamp": 1532101593030, + "timestamp": 1563825336196, "user": { "displayName": "", "photoUrl": "", @@ -460,15 +317,15 @@ "user_tz": 240 }, "id": "hC3qMqryPDHS", - "outputId": "8405c0e9-e518-41d6-f5bc-e78df6474169" + "outputId": "56a095a3-28a3-455d-e95e-2c4c9dcd97d2" }, "outputs": [ { "data": { "text/html": [ - "\u003cvideo width=\"432.0\" height=\"288.0\" controls autoplay loop\u003e\n", - " \u003csource type=\"video/mp4\" src=\"data:video/mp4;base64,AAAAHGZ0eXBNNFYgAAACAGlzb21pc28yYXZjMQAAAAhmcmVlAACZUm1kYXQAAAKuBgX//6rcRem9\n", - "5tlIt5Ys2CDZI+7veDI2NCAtIGNvcmUgMTQ4IHIyNzk1IGFhYTlhYTggLSBILjI2NC9NUEVHLTQg\n", + "\u003cvideo width=\"432\" height=\"288\" controls autoplay loop\u003e\n", + " \u003csource type=\"video/mp4\" src=\"data:video/mp4;base64,AAAAHGZ0eXBNNFYgAAACAGlzb21pc28yYXZjMQAAAAhmcmVlAABdAG1kYXQAAAKuBgX//6rcRem9\n", + "5tlIt5Ys2CDZI+7veDI2NCAtIGNvcmUgMTUyIHIyODU0IGU5YTU5MDMgLSBILjI2NC9NUEVHLTQg\n", "QVZDIGNvZGVjIC0gQ29weWxlZnQgMjAwMy0yMDE3IC0gaHR0cDovL3d3dy52aWRlb2xhbi5vcmcv\n", "eDI2NC5odG1sIC0gb3B0aW9uczogY2FiYWM9MSByZWY9MyBkZWJsb2NrPTE6MDowIGFuYWx5c2U9\n", "MHgzOjB4MTEzIG1lPWhleCBzdWJtZT03IHBzeT0xIHBzeV9yZD0xLjAwOjAuMDAgbWl4ZWRfcmVm\n", @@ -479,725 +336,449 @@ "bWlkPTIgYl9hZGFwdD0xIGJfYmlhcz0wIGRpcmVjdD0xIHdlaWdodGI9MSBvcGVuX2dvcD0wIHdl\n", "aWdodHA9MiBrZXlpbnQ9MjUwIGtleWludF9taW49MTAgc2NlbmVjdXQ9NDAgaW50cmFfcmVmcmVz\n", "aD0wIHJjX2xvb2thaGVhZD00MCByYz1jcmYgbWJ0cmVlPTEgY3JmPTIzLjAgcWNvbXA9MC42MCBx\n", - "cG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAPQZYiE\n", - "ABH//veIHzLLafk613IR560urR9Q7kZxXqS9/iAAAAMAFpyZZ6/h5MpYA5/oqv4s2qPbYpW3jfK6\n", - "zQ6q7WMrNj7Hy8jZzmBpfHCwAAO1W4riBNsrapcCk+5V1W0XkkFULR4Qe+H3uGA2HgNW0zFAAUgt\n", - "W4tdpXv2OEg0Vuy5W5l/xGRmEGKDyeXyrM0S6q/1EKbad0x2mcHseUqNmeOGLy1N3b376XZKZcPY\n", - "IXC5F2332tNMj8CwOQiXM9PiCLyCVfZ3rQSkKBTZErkpS5kXUyoJG3FdIqLjRFKEapbUjcW64HIo\n", - "BeIbtRyWV9FyZfcTakx2KW3eB4ZI//MDykSe8CRgN76uBEqZFXwO63wmUREhHOb5AdaLV3xyGl/I\n", - "RV70rU/3t9t1aq5mFD3hy1aLTAV2U7nG072dyX87F7NgCxZHT2kFxu44fxf6gqVzE3PEbGr5fx9x\n", - "7TKXtmY53VP8UaeCd2HJiZ/sd165SutTnfiWvaLuCnmmXGF0AGqbj9S19kgOhTubZIJBydTTqQOV\n", - "YRlxbgKn2nzvunv9+NDG0/2ikyyp73W15QClmjyt8dUeynoN8CwtEQ59DdrAPZe4ARZTwWAfsRXw\n", - "1vcZ6Gr1nCNWllQw5IyZyxQtXrfc5p4wjPvGaltciG7d3FG1SGk6HDsZy5i/PsnkjRXLUvGbzYp2\n", - "2gs7ZSGfSJbEifctcMGeSqhOOYORKy6f/9omoieCVEEkniBXwWZ/eImb3nxF7SFIaBjgG2j9w5ut\n", - "BY6zSuQ5zRCdajzJ1loNO0havI8mp5yViAeAlLKYCxeK0Lha1FskL67W1YsARZVZ5EkhqAYEeTNI\n", - "M38Og48OXmj6QBN7c1b9uDUTacYEXO88ZQ1gCIREIMnm2Fgkir8pN4gtSeQ12sfOVz5x5KX7sa95\n", - "L4LyFQPDrFZcDBr4PWLeEEv8yzk0cYHE97GmAlA6WQ0HlWsS42cnXefvTPXnx4vcq8pbEo/slAuH\n", - "IBsrJEN1+aMCc9FNxwUPVbZVaWVjwLY0qh+mNWEaiNGRmacDXrYWw0NjqMPiLiFHacY5oGELRgym\n", - "S2mSo6zhsD1wKQ3EUQtwrjKPiDYc/HCqhkVwoWKUdI8xTS60kn4f5UqB0L77Yevh/wt7AnvQKQAq\n", - "QAEEevggRl1uigbOBTtscnYRnAj0edW4QExAzdo+RwLWXTzW/l3cBWTrh3ORzZQlxJ8jQTvPLB+f\n", - "bLazJZWFQQDcWhuhQ3gYcP1ruNwIroINRIr8px0UOgAhnk6CllxMN6gA5S0YPhFVFKd3n0AAAC9f\n", - "vYgISQAAAltBmiRsQR/+tSqC8p1IAOZemTPutEfx0mzK8zG8tdIxonBsDpoLZ+NnIOp4qK6idP1s\n", - "vbGvZz/zHM86Bg3q0yx2atmtgoo/Trt3YRy3se4HTjou+tCi7oJt2d7A8vEhVDu33JNJx+WCOgP0\n", - "03nVdg9lBs15v/0w7qMc3zqqJXCOy/Whl9aRhcaeOEWcD7uK6mCV8a6MpDJ959xBRfv2i/qFOFbL\n", - "Grs58WiGJcq4MQJI+rVWuFN50oiqBgiunfUrRmdviPYpNN11V9pwcOJwssWfIE3agnor/RC7vfLY\n", - "YoXzaJjtWLEL92OOaHLZT0j555xfb4FZcoJee+RXovB9IaoDdYRusngtBXPMUvnO+g2Z5Qdo9P8q\n", - "Zb8ItBAeHT8IBZAD/Z2nEA6qbxqOBSBtQNW6ZFYLtCTIoP/bLjCDHgtZk3cf+N1CpXs15pUIYWDW\n", - "elZtlTkM4w4EJlLdjLZyQPAeaBx/qoLmKyTKAEhm0hU8EcTq00f6fwkWgz2J6GTGtL/vJXgC8u4o\n", - "nTnf+Ou7sVJGVaouXxrzx+yGVHEcp/eV4gaFA95rInngQAOZWbA3558nK61JBPZl3NjEv5B9r9pg\n", - "2+SYY3wBAUeu2fgAB2+yYGw82pkoJJKpzYWORs6i1vn3GEgUTcwlYsdJcraYC5SnGvqSZhX7KM72\n", - "uE1e9bkpvpVyG/mkACn5R4jwX3xc2utCjjZgM101rirIF/7VfDtmJsSTDes+UVhbSr3SeMSI9ixJ\n", - "+fVuFZ5bnQPoRIfPc+Erw+K99JiGN+HE98/eq4pPlMY9oCfVPSdNyOAAAAFfQZ5CeId/AUuqOi5D\n", - "jlKfxuJGZZ1+rVyomjOIykvxtsjsuCiGtElbraCSFWcn3aIYWLrF3fPovVLcOnroBkiRMsdf5yJA\n", - "F87MQuoKeTaGOrxojCCCS64RiHrqNsE+7mfRRUDuB4sAEHFQHxBorgTukPSvrdFr5QDq+BhZj/6H\n", - "KN+IutwFWKX3ZX9pO3sI8My78TgRY5AA6FEcT91WcvnMypB/OWXzK6M8fYuhVVWipAZigjVOYhcF\n", - "9i6GweQFX9AV9EUQOp2qFbkrT5jceBRFLX6j4JUQ781/UGTekv1fcpCmzlpNpp8GdSeWxRL4gasp\n", - "F5uO5KW63rlhYccBo1cFwIN8txHNnwyQNiP00XC0PWDRZfaWSxsACRWrISow71IyUfcL7JNhjTII\n", - "rwDYATS0xZ9ep8siFC3JTxg1eNaroYfeI4tbkRHok47Vk+CUOQPuagVBtFMOOcy2OUbw8AWlAAAA\n", - "ugGeYXRDfwHM79ghzBo9nMnzfQPPIuvorxBb6AC8F4fYGD/t93kNSKNSEuhUXq9FKGtxnCkxN880\n", - "BPb/uTbjLTQVyPNuYlGl/gTlyLcVA/cDoLrl5TvaR/AcSLFE7C/t3kLx0STNibmdAf4TsHWKSblH\n", - "VWB4X7oQHrrDdhwIivRgUZf7f63j2XaGB+cbp5aHCCwJoovY51YTqsZZTz70FlSnypPHQBNzif7h\n", - "uvZkXhtEzpu9rYMo3YECkgAAAXIBnmNqQ38BDchAitLfY16mYQAQlVmv7062W8KLpIS1/zhS50Ib\n", - "b3ERigmkZKZMPaCsAi+zsLcku/gHGHnVZpuCZMFs72gmyuL4JFo6VjWcr5FtBvzIgD26rBNvP73P\n", - "nJjl3JImmFHiKjNez/gG3zTuYyCACuJCEYXyuEmzCM13hdCPHKg5GZtso0Z1qk6T1k2oiqF/3RIn\n", - "kyjRWuxBlHHmJ46TXULiUY14G+RAGoXI+u/G6muNclld2bq+6Zztuy+5ynaDWNNjuN1Ag9KUIx2F\n", - "XwNdepmp52/rOvISNPbMJ0U26OvqplXi+qHTbg8MLpUSIGCY8w9FZ5woLAENgvgu9M79yGlL20e7\n", - "ypJ4RMBqHYDpEz6Z+SSjXD8LsJ7VKlwo22A5Yukp1vTp6HHA35nV+PXK09DuRWKKdQUzmXVihF51\n", - "/+bB0PEFdoNxGdbbM7WveaCJN8XI7JgQWvw2nPlHX8M5QyPGSJ2HEexumoFrABvRAAAB70GaaEmo\n", - "QWiZTAgj//61KoCPNGHq/MxnjqmxxQAEHvTwibmyMZGX3ES9Abh1tMR+/DjR+6dnqRr/VxCl6gEP\n", - "wJ/5EYCYfGaGmQYsLOeM3v2SZjdvqQBwrwKk5A/63kFm8fc3QCLe93Mldv3KWXHdFT7/mudSntDc\n", - "vJwStG4jgi5LKlWdSrVaAxOmElsF+zWNzaCIQ1dOiZqi3JKj64hOeq1XIWyGvRvh6OLKBpB4rL6W\n", - "ugf7H/IPbSQuF5jWV7zL5LhxWiTiI+kAZTUMfO2YOLzmhCUSN9GAmNzgY4D2awYB4V4QTDjI7kdQ\n", - "tL+3Pmfl1HVilu7nC9CzQSvWIosiwv4btyHTL7IPT2gusybyNfW8QO133L6KbDhhXSDWUtcIFCgn\n", - "QUm36C9hvgGjorpKYr5VnErpJX6fRJm76fFYs8/nt763alyqdcSrqaTOLaf/72Wkkmlwbq3nLOIw\n", - "ADFDkkAPwzaM811K11iK/3HaYRT3nEhjJQFk5v4WBXwIVLAZeKdtC8YoGN9K6isN142fOG3s6fm4\n", - "J1nMtOEZHIwep8In4slLmHh39qBzhGZO3igiVpgz7u+JMBeFkVHe72vduBjIy+1dqvxL/TPics3s\n", - "+alwfTMNQKave1qW+5Uj8jZQTjcLAtKvzoako9VMIOfQUQAAAQpBnoZFESw7/wC9ZU4P+UeGsidW\n", - "4n5tFkXmtxppYvKQ+WGj/x3AAdl6+9c9x7N2b/yJykTvVggfpMnFUWtxla4sr1ouwANom+Uf4IBJ\n", - "/zXPovndpGdy98nJbZxFU4rrWpr8aI4YmRX65+IGTn756CZWwXKY5DyMgKnDcCtk0HEuoHgdGhh7\n", - "1PG8+nue+pE9pBHqiBNWAjPd90qfMtABmMShLoXtUObqYbqXhJvVjjFhKdPS03IF24fu9Z0ax15V\n", - "DnkiLmgyOCvJmcdIX70L2ZEECd/hxrSq9JUVjC41OX0F/ayI6GtkPMUuZ2xWkMFo5rqOAo7v0Zlk\n", - "ke/79TjeY13FNiowqcbhMwfDuwAAATIBnqV0Q38BDXNpg2t4nJdhAA5ru/5Co2KbB/AnQt7fa959\n", - "0crOQgtTxL36jtVyKPmfuQMYuWbJ/7bYTEV8sEjceHvN6B0CSEZzVCjaPLzOQJZMQpQ4K4WKPlGc\n", - "lnEwYAC9Dsejj7Fbk2RyCFiJinyU2HOscjUR6fW2jRsAFpVq/PtZDVPvesPG3AqooVaKHp9Ex+Da\n", - "AH0OvccSugyDKsRBAEiYR8645aXxbFSzraQsELDsIIr6HRN8F3lUNVBvzNO3mxBhq4th/kgZSjjJ\n", - "JZrYmg3UfIUO/jn4xs2XQ9Pa7Uy5K3JhuIQwAOUKDmAMC0p6fgz2on4ceyEcfiCGDPZpPyL3391F\n", - "dXID0ctPQ1a+Hk7UcAc9gSDL8CZKz59YyO0ACPjfAKV3Y2dbTAKdWBsUU0EAAAFEAZ6nakN/AItk\n", - "aaqbMCcBE0iEIDnEBfRZN0neHQxaz5DPSzK0ZSL640q0AA5jkP0YAYAumNCN0MxJYpWFoQ9r43H0\n", - "i9SZLdv1UbgpG3aX6KESZW7AgdlevaBngH/w8xYsqWx5t90zzi7x9VyRYpIAD+XTrxvgBoFILNCs\n", - "gd+zDA9uvbAPlLMwG/qFltlwvLokMt344erv3a/C/ySOwZHFzpakInpJ7MQHkmKi1KHZB5KrfqwF\n", - "FnglZJwWbe7LtVojTdwQnAksziDNlEWCkMQQJwziY1KYtlXMNX8mZ3MtYR1KNf/CNin7/ys9ZQyx\n", - "4Zlk//H5KDc/8O2+JaxH20CAaAABxgSxo+yJal1LnRHYfOQ1TygNueW/rPAA37g/6fLS7mbYKz7k\n", - "dsiSiy1mAV7n/qq81UHJPShQSXK+E4Y5XKuXEWG4AAAB8UGarEmoQWyZTAgj//61KoAW7kO9JCjl\n", - "XSE6nAngAJVxWWFl/YDS0gZ32xjwUFed4hmI6rj18z16nS3Mz1iMmFblrtaE4zGXS046COODiIwH\n", - "QG5lRmcBExMKlnynQruQtA8n/NitzdP/ysLrucGyp5nKV+XyJURULfxk4kwNp0a5TFlJ1fusOOJm\n", - "y0hvsvEg+d4Jz3anvWT6M9n5A84CGucNifV+WlN9gI9gs3qSoCZdU/gglcFYM5u8YchzhQFyMKxn\n", - "kpfWK2LU7aaZHt6xLbqjuv74523K9/dtrrsFq/LySiv1P9Wk6/6d5RC72z4cyaUq6hMMn4IWWRo0\n", - "zJIM1/lSYsWxt5/M1Mkv00Rt8OZvmLxuFfd1BIVlANlpgZ39RYhqqzU6v1HwaW0EudelFBGhr5mf\n", - "GaDE05Z8ywp5rN4Qq4D4GNAGD/qgEjtaDDf4ZBAD/TAHBwxfNjm2nPAdbbbIuWSkkv8NK6EMlKqH\n", - "mOktd+CB3P6Szd1+HPnUsyQ3659r3XLnoi0cvM4usfW+BgxqT0mgHSgn/F6ajdTNM+a8xJQnT036\n", - "7195r0uF5vwi7PIviCQ2E4Vs4Wx80/8tBDEJS4qOY1YJ5aNV1OV82fB3HOimLHd2vU/d4Cv7OBh8\n", - "k3gNFcjeBGh+3lQcDCLZrG1mAAAA3kGeykUVLDv/AGVBMHxAlJYGEpFnv2bb0ADrwvVKxe7+SIJI\n", - "g0dPJdL0s9Hd2mGX7rpdIiUH9ZgtnBO+m3uPNae/YtN3u2p0kkCez2KiPNqgSoEcHM+ePgq7afkq\n", - "0HHTSZl/+QbjsyfbI/0lv1mLAJUd3u7VZPPHSdXK3vwLfAwOe3Nid72slU892DijWVvanzM1IzDQ\n", - "XfN6x6GH2qfaLrHePrJTJxXC/RSxcAol7x2JJ5OA8VjN8jXu0yKirBiYqgcdFf9odG8j4bRmE2wD\n", - "MG0SKuGrJfd91b6B7hbRUwAAAPYBnul0Q38Ahz7YAbwPIqnkAA5sEIcKo2/sVUP0LEeFOLjKjaet\n", - "5YFAjDbL5BIdGqWouG/H8ozoec2ZpUbIZu0ELtG5yXc/5opSZlnqbOpqdTQkLs6gr9dv5GbFvVjS\n", - "Os1j9FIMQsdc8pttosNtygWB8gLxr65El6umAZE5CVU9Mc8Xxg/tenmTduGK9Cd7qRDiu1sLYR2f\n", - "or3KBMo8ebz5q5EmWucvREbYSziQIIycIwJg9OG+aH+ZUEQbjbfHfaiX7yoxGJGP78aNOHP7GvC+\n", - "JwM6DxnSyowUBAqkW8ckgrhet8gYYrt8MIe1MPJQB6sv8hHuAXkAAAFWAZ7rakN/AI9XvmYGr0rf\n", - "QEvrPPTQWEAA5ru3wBCXPJiC8OaE25OBvVl2wRXqp61wQU4HxGJCAxkSOz+G3Yzvg36uCK8bPZTq\n", - "avaOG/H9WxjsuwAl/bIYJdnyD151CiUZ34aErVIixKJ53oKrLeHr3xLgxuH+y3w5uH5lQRsL0Pmp\n", - "0jQItTBkKwlPywxFk55pROuYZWi/h/N19QaFlF7WPobUElLlr+nCH+pVt1nW9/YwVGz/cO8zwmWe\n", - "Fb0OnFji7CYSsi9ScC3a50GjUP7IpaY5NAHv33V57bkO/BD6dnreymTbSmQdcj7PAJkvz610fMqn\n", - "mDGTMB31oxAIE5eWeH7mBZouSgmtxEamul7sYaTPe7mP6FqNCz0h6wLot/zAFwx9/D2+XB0x8mmS\n", - "b086o+gqkoYoHQeQm2Sb3MU1Bz0KHDGo9jCmsBmecxs3oNHV4KaIoLKAAAABrEGa8EmoQWyZTAgj\n", - "//61KoAcdmk2P6doyaR4wEHxsIcmssCD5f+3/v8PGtlbWZ+A0oGGFPTAdgmU2TFbrRxlmwUCouNe\n", - "8freV7blHDodFImzwP3saA3AZT6NUl7vDGH/tw5n9y8rP4XGnhEXBHK+6jIhoAYc6G1CDX0mqczJ\n", - "7tbei5I0YSkDjza4rJSbAF6cRoJQH3s2Q+ggBQR0BfH6N3QlPVwd9YFvP6++J+XrbNU56Pxu6Wey\n", - "51asar4AaARXHregTXL4xn/VNt8Ppk2xD3/1jXAVXdqMlS0tYGM/TtrcuTC63Lx21RQtklG6k0xA\n", - "eWm6W0oL0KTvxuyegpC2ySp5v6zpSEYvzWR4IYirfT0RYU+jLtX0t4M/L/0k8xOLTHbouoUPD6DN\n", - "dYYLYlVX5noJzjCAVCiS21OCcIKqWD/YiU/+dTZpdFFNdHEa/MPvUEq7cJD7ANJ0YUweepq2Eqdh\n", - "57SC4Tpg6jyEnFgMaHQLSz1nJNh4lxM1TPouGZ9bmQdDr9WY+nwzRBa+ZLnaqBSYKWSKEs/TNtNZ\n", - "ev7d+EnJUf9G9CAmmiSDlRAvAAAAz0GfDkUVLDv/AGU2nAwHHyQlvUxuENDSO8vXFIAPilnMlQWb\n", - "nTHwb8wkIo6JKOaIP9blrrNXcWeeQDVprB1Bn//+nbSDHls1apJcUyMHUmojA58P91gutTiF40zp\n", - "fDaF096G01gcvpH5Za4+DfUvxQpt/wH5PntJzggww1tLhP1NyH5U2TTgrnA/BevK2aCa9xCuCVgA\n", - "JJZF4uqHE//COeWbJ6LIFJPoadxAxbrAcxPQQHMzEG5G5S3Yfd+YJBLrdO35JvVrsUTYO4AfvJeC\n", - "zwAAAe8Bny10Q38Aj03WPPyvISnWAC7KM5WfLH925SBeAKcvJaYOa5WZCzX9H5nU/7qAFTCgAnl3\n", - "rAoSnKk1337XDAnLfPYAAOSIcqQwF++e4HouwNVAWCEsVyl7Y6DnBaBT2mD1H8560KoMvm3kKNNC\n", - "oxFCc4BdAIXk45JUbGFNGYAjCbBbJInMjwa41HA404yKnJG7rNXdBctnsSL/36UoXvVx3J2tGX84\n", - "+FHk7e72CsAyB49ajd62idmFQji9Jj1GaiqtCIjWs5o6Mz8s5QfrvipNYYD0YZ7gBBGm4AEz17d8\n", - "isscgsp4QI2odbuEJDq1nfJbW6+1HGcN1XfDC1Xfa5IptM5UYHm5zIT4rSPBIDE6l8/NhVxlFP21\n", - "JPQ0DZxnZFvxIBznQbqkhaGZjMafgFoRzC9Nl17x+K6e75RlplRZtXaUIbjAUFBJIQPkoIrT6/O9\n", - "NtkAmnl8qqUC1RktW/RjiJqOyRTTITHqNKvKy/0gb88xEvvGPgzcSs2KpkbHJWmCGIlSWEkuqcCE\n", - "jBn3Y8XOQxMUxEYeLPJ/9s/F2fT5NAnko+RFlv75fWLekZZP2s17yJ5ccFGhZyrkGX6u7xXK7N8G\n", - "Qlz8qfOHvgMQrlB8p4j7qtnPgBPf8mcsM295CuAZxkK+sut074W+0hM24VMAAADaAZ8vakN/AI9G\n", - "UrhSy/Rrhc/LGXguupji5cAHC2DVoxU1gWUkKeMT366GcmuxH5O8lBZJeHl8r2KNT0EaVARyW7pN\n", - "L4uNsKKl/WAzLJ1OZWTQf4NaAfodQGO9KzZS0j6oGvr/urKiQwbP44Tv//glYQyyCFeq+8nnrHBj\n", - "aACu2w1otySh0DYMX412uY6EYcx3GtQaRpNPiKQniWdVV2KH48fVxDy0uLS0SmCZEAWLVNvtWqO+\n", - "q2OwCBr1m50s0i8eRTlSP9xoKtxWC4ZqL77eAW3kYEBJOAywYUAAAAH6QZs0SahBbJlMCCP//rUq\n", - "gBY3NzYDjVIwwAKbp/vtZn3NtK6t0V/4sA0MV4ijJVoTZ+e36T0E9eQ0LOyzsqR0ULZJUDRy41oM\n", - "RdsBwM4wyEJC67daWmuDEXKhZo862uqAH8A0QJ5u5RKBPFpngChYYJdWzP3onEWImG8Yryy/SXt0\n", - "jQ5te76AagLius72bzwZ4AZfLm/04ID6oXhPwqkf1cNsu4/kIt7oCOETiL+lzwHLEnEsdPSz3DxD\n", - "uLGkH8o6jHofDxEXcB6cOS43aUxGKPYPtHCj2gw6RzcRoX5lD5mwqtoCTxk6N8TxyipSUyNnbA2b\n", - "G5NuBUVLHTce3QKY3SdkbyH/wzdOpT3YHUE+FYQwMKCF6SMyMBxp2gI9k4yUZYljUiekF2XIFkfv\n", - "TFy1RUmikOycLKkTYTreTarsMD5JfjZ2FJWrroj/YX+uNeGtKNZl9Zyt+k8u4Htq1bPYEjCrLHds\n", - "qeIuFWmvxTYEQblStjDXmWfITtxy8KvOgn9iV+KlidrnVhlE7Dz30fuHXxxFZvIzhgU9uv6sSC7T\n", - "vZuGMsKGBGTYmSe0P9hLI2VyM/8GUWwG/AITiU4a7OVDjUNRPaiIEt8jt2oImPIY8qcrJ82CVd+P\n", - "mSjoppoeHUTHmeo+koGqjhwT7ueVHNT5VZ4yuGKEDdFfEIkAAAEMQZ9SRRUsO/8AYrbCELHs5dcg\n", - "AyOPuRHZUWtdXLx9XaNQixO/8Cc4Q2MgEa/wKETsHiR8C1XOv7rI3JB0rg46JfjEArbHaTHmANKo\n", - "+czcI/sIduYNFOE3TvObMh/KtGpZSdF+qnDDtY8zD+7RQUdzmkG5zeDj3u4Vq+f3qnKCwgbU+U0R\n", - "dQR9Q60wXqL03p/iYVxkI8jJqvkECuxT7efJI+5rmzyP1yn+WKY2EsjjB7bwwVfe6RxBmzR9Ed/9\n", - "CA95ILUJxNg4HsmCO2Ko+MqZAH3wMlG18kUm2ogL3cKIkVXogjofyKhbsSpKLpFFk71DzB6NrY/3\n", - "HfknWM2yn9yeQB/joufGEf/bvMAS8QAAAN4Bn3F0Q38Ado97WJWiqN4XS53kTA5YWsnJBdebpf+9\n", - "lcN5zPySAC6fH/XzBsBKbxdm4pTiPFVrmGXyhaRiB6dxtlwj8MyI40Do8AXHq41BAunk4K4PTgzR\n", - "rFycWqaL549wB2C5jNCLXlq6Tuytik3ijlMSkx9noeIG2Lc83eWkRkQieksQSO4xI1tzzkdqaNhG\n", - "ExZARu3MauZwrBopslb/ZLdR5ZS0G6p8o9DD5cphJjxJoSV/70/0Gr+woS8Zj0JpVvvpygE5bXQp\n", - "/YBCqjmq4uOCyt9SvCzPelUEwXEAAAGyAZ9zakN/AHZ6+HiwE6fxvgA5rqP9zmI+FShvhJS43N4N\n", - "sc5a7qq0DK7DHadXkQxf+APmeqLrIGM9X5aCQgeyxdoAlcQoyNsm6ol85w5z6JV8A3YntmCae+s8\n", - "+8/Yheg1ctJWrSharoeypUyemQeq9Rm5cIkSOS9Ej0hbIHyFhPQW6K3SawgMNVKQ0s1BpJvXDQSY\n", - "x3jIEdIgEtwe7zce/DjcO3RNN3g+SlPoM7cl0qJbM44NIDG9JGXcwVrY/YKNrpChX0yegP2ZHDI1\n", - "MzOs5eWP/2l5loJrLid2mK4Qhw6EGFrIadsV8rSjzgHRNuzJ4U3JdubidEobU0ehkU0P6MYRK/XM\n", - "58mVywGbsw6LPu56h1S4w3zHGYMd1zPKOsnCUhaRfrSZTxvjerNQ22prVPqBstk4JgHdnSScrwGw\n", - "eQcqvIw7gKhonPDKM4fJtO4n2EsI5Cd0iGMjmgPw/PU3FL8ZP3QbYLMwZ81Wd7BLLBDf+ngKiFIe\n", - "it4neyhhaE/a71b8TxeM/ZrgH9+D76dlgPI1ZJW6CCVyIs6Y5gK2plkcgRYa0MwWF+1A6zPtBEgA\n", - "LOAAAAIIQZt4SahBbJlMCCP//rUqgBY9we30eRuAA2kMf/9/gX2SHKs8Uq31+W7Vx4LugxILnhMT\n", - "6icG5WQzdpL8yjIXjBq99nVaYweUdJE3LrdOpsVxNJ3kODVBkposYOoRuOMi/SNhcjrJwShp6ljG\n", - "Qs7tSeRJSYDkvm+SI2ckjbManbEesw6wo2ZffuryaLuWkU9SNALC+2QbPJD4bFy7sTmB9+6VOdMm\n", - "rnLvYN4ZyAJz7OhQG85P+JnxdgXgvSv66sWBs05p3vOE+53H+HQCMTLVgvoYmHNTIYtZ5CIln4hA\n", - "GrjLg53unVVQTiYlSzZrRE2vmtsqac+v6CrcbtgC4HktflvPTsvgqWNHri9NWa+EuXgx/AgGkZVJ\n", - "r1n6gAd3jtjLtv6YvbPiBBo2AhBUxCbYyroAjcvjwUBtRjXTdDEvdYfItmTKA7W3+KvVi/PCtod6\n", - "/3gOoaA7zRdO+8+MHlGl/c2xzQhj2O1n8eJkOu+NcsBkpmxyosDi11EOEaiQ6vfnOvH9MSM+7D/v\n", - "k91SLlwv/nF+5eDPHSLZQIoFUjHjwVoSGCdOLqmIe6tsfTERCeAhC+1bhRhe0612KIL6izjolsR2\n", - "nUgrl1o39HqnKAVqQ/HguEezLTgmGW27Df2kp4E1wRl/EQgEcsMfBPga1ndY4uHPYq84ArNCWk+c\n", - "YwxlHAPVC3PK3Zp2kQAAAWFBn5ZFFSw7/wBXFVHDEfqz5TAg6AmqzzGCl9B1ICKhB+tKz4Y9Km1L\n", - "/vZyZ1OR5rO815FlrTgGoncUDKVNjpKrVerCm+HleHb1b4FhYQG8B61zGq10uLuoQHIyL4Cv2/mm\n", - "s5Mi7ZftErBt64oWYphUyh0Hmn9dYYheGFzLdE9gvqcAEGJDyLZq+nfiK0Px8pHIgaIfsEdSUYcC\n", - "8Otyxta0EKY+Dm2m8AtQ8jjuDmkSHm/uLhgf1uCnztOKFhkR+ydRCeR9tnIlTfiv3gJbsPT8swjP\n", - "0OUm6yT8LhwwCJU0AGI9hN0/kTkz+NeSHjSPaBx26MAfS2Y5NEtva844h4B/RttjqxMsNDiDrfB4\n", - "5xn/Cl/3XrcF40eivyUSC+FHzx3M4BoLQLOKf7iz8hKiUrqRGVkGToUMxkr5192x9xCjbuvLRMd8\n", - "9Pel4WIOhSi52xuSf1eEhC5VVAp4lHpZmHCbgAAAAaABn7V0Q38AdnTaV3jxqK844c19uepGJJSA\n", - "C7DQuTz6pWfCzxcMbX5JwHItpyM9y3YT46z61a7h5Lyukp+nSKoO0zQhT0EB/u6ILUCNvVbb/89X\n", - "7TVI5UN6EFwYYfi4uoFmqb+5Cd0J/+d2405yTsK/f6WH/T+vNB1DYWrW67ctgHOgMHAWDLG9mitl\n", - "16bXmPVSi2sWzpWYg3147nlnaD00aZHqQlrMPzYTLLFwWHOLNqCoWpNLMMEevc8AnQWeykk9VNTU\n", - "NXzAXhrKDXl1tLQTxZG7GX3K9cQyeUnjfH3rMBGDD2zCLGXrMfPVl9EJ/F5M49Rjn38sXUf2JvF8\n", - "D9r9tV1APCHN27+egfFIMDg9OhrQMtjAe3WEfpYS7pl5yHh7ZZ2CedEo/Wf/ygYTAQFI72AaUTrV\n", - "n47d9OSqAdYs7lkgV0864auRyPQeTKK1Sp3ADeIFS134VGBNG1VnrfyZuznYkI2r0FVkGFrAXpUu\n", - "ZJmyKqqILhJ1OTBM8C0VBV2QXBYa2aSn2jj9t40/wJJWc9IGAVR0vj/u+wFocjwf4QAAAZYBn7dq\n", - "Q38AeUc/pR5QUuADgu7/kKjYlIf8yn+MfKKvFMJ4eRJz/DRqteBIBJsZW3T3phi3NzuSw0zOvEhr\n", - "CHz7xEUteyaR+fa6YCBeiCtangbUerW/UGoCobzV/74XB/lXH53NcEw+6x9o3/ZgwG/7l4psK3P0\n", - "EqSwtCrcKAAv8Wi0Z88mFp3Sp19shMF41mqYa8pNsyefrruQONS60LHg/1GySbrTeTWW74lCDwnt\n", - "BGXpwghp/QF087PP7hxkE8lvu8APh5F1FTiOCBSvJFm6yFC/tz24gmveLoV4Rq/qtYWRE09VDCDH\n", - "yjftToPMsyi4DoCtXsPRk5Jxr9Mn6xDxGjfz8uMmOKJ15ejPi/Sx9cR1QrBsU9dhcYifdB+c0AMF\n", - "PolB3N4pBZAASP6m7EzaTer6yZ2sIKcQdlGt9xsZ0SHtS2313gpdJkLEVrHpO5/BTcfUTTcK1+bC\n", - "PwRYX+iIyInP1m6htprdy84ySZ5IaGCpRKFxMCf5w22wXyyon+dlMPKACguyEPTCCZQ2MqEuC+sa\n", - "uB/hAAABxUGbvEmoQWyZTAgj//61KoAXgR9s4tVmwJ9HTza3s57iAAoQf/wjqzjlXnP+29f12EfR\n", - "S7B+4I2epG2qM/uoQ7VlrfXFlhjyX/aTq0n55QXAKa2xUKolKsuMfmZFFc6+GP96b13JiSidvPgt\n", - "2SSGnq9Yw4MfceFmgOaZRcwoMnpdb0UpI73YdP+DfypKyrkDqKWcBc/BGhrH8+XdnpCNDXfg5rMl\n", - "b0uFlQ11yUxnDYOfRwLbdjJA6FYddawSEVorFtY7jkSQx+OUBUgWkKC9rhKB+uV/yqQsvbuFiyYV\n", - "MviBpsZgSSN0TOC5JedQ5H38ENVBLjXnWZD9PQyueLoT4qwtI+7lodFSnBG3zboWdj6P7XDbgKT/\n", - "zKkFObUjwhstiQtohzxd5AXhBH3DQqNv6mRzuMxFDcTEo5ut/0/1HrPGOF4R3sJ/eQT+YnYseqvc\n", - "0m5njpgI3qkLmn8efBB4q3zWGpHCxBwC84HKjuugMICuXfcJHKn0aWkn65aEjT8AdxDWE09InGyo\n", - "EM1wsU0JgJ/qq/6MdHWfQW6+bt5xWlpYJ4axi9wZc3Aoz+Rixn8UVM2e/bd31+W37ucz9udquxnL\n", - "2JdNUAAAARlBn9pFFSw7/wBZVXkLa/7xg9HEtDOpc+GkSv0gCD3x6eQNkROUaCyL6QH8m/0USPLW\n", - "nllgC+uXg2X8kUpaUiErsLvwKd9y+trtKwV7xlvkAn0JqEnToCvptE1Sb8eF86DTi2ywy7WE/imn\n", - "jNBYQny1cV38ScnZp/V3phWQAYBG3kUdNNuj/FyVB7DgbQbTLK48AO5nLYv8B3LvBNBfBJ+ym1yg\n", - "YJXKwjm8kt8xUjO2UGKeggZOs7YHWr5Fj8OX4jV/B3/cMzP+f6YyrayA/80F6f9vgrbTlhWdlFQ8\n", - "QtrHKjmrl874OSSPJYH5wfQfF/1NrQd6soxjmSWYI9/FqOPoy6ujUPxQvg1fUda+wK31Cv8gD96H\n", - "LPqpgQAAAXkBn/l0Q38AeBaU9hYCjxV6lA176iBcJKIHTfhwkqkAB+a0LmdvcgdK3vyEsSkCI+8U\n", - "up3OQ4OQId/B45+Mf5P4Fc2VsfnQAACxyzNkvgEEYwZk+TyOR6/VZmeFNYMrBdqc2NNBlh56ISK/\n", - "h5V9lagvsX7yv0p9Hk6RXo3uoMgKhKOv/QgBAqhUvAKDw4DS7G31tehd/myRMmCPxIJ79bZsQe2/\n", - "iq7Nquzc/VDpPXFZHPvOmiyfyrt6Fxc2jLHZJGpvacPTIeLJiSaBxgRTEKBr/xXaKQjc5nLhlwgc\n", - "HSz1WRlyOsXOkob3rY8KoGVETaaIvHEl7sVHsV3QN7iR2rIGzf6YHv+c3l8OW1b7tAMShtcCLifl\n", - "8k1OtS8Z5o7MNTObuLXIONSPGo1fC97qRzqHFEfMZntEMqsFjjWPM6JduvRiAv8p/h0kRdcTeRox\n", - "t4PEdFJikYgCJgtFa00LDpNvd6Vv6MImiivCAgL9L7zEaNCr8p/p5ZiDugAAAO8Bn/tqQ38AfAnX\n", - "r+Rl0wYAC9kEZglKr0YEZPxbFiynbDVLyUoB5/4mwbggJCKqWcWLXkOc702XkfuMANGy7OD7QUCV\n", - "nopFHkp77AuzGvvM2JQndhYVkdbX30/kmHQDID1DcpthKQBbzUjm7wgAOqbulxKDc1OUw1plN1OA\n", - "iXs8Ju+zQDtZelKPfekDEF5iPA8IQMn3LLocZ168PVHW73hdmgfMFTsqduJxZ1oiezDuUBPUKdNQ\n", - "1lGg5KUsS5A9iNuo+n1shJKCmk20FfXGeNEywAjYeaq4bao/dd8nZn//htlIayY083IymAgdHbKW\n", - "UQAAAW1Bm/5JqEFsmUwUTBH//rUqgBbB5O6qXkABRezeefAxp9PjwxeDBuTTFSUNk2voPSz0T3Lj\n", - "1K/LmQtEI6YkskJKgxvIXHGf8LHTV/h2Mg/qV3IQ4zvBygOQs98iZyR5jgV+hQ58R6xIcus/6y5a\n", - "HrkViRrv8Sk7So3LYWmfkLzyR6vcCKhF/sCJsY8RS8BK5OOGU2Ll4Qs1n4jPQwTLDELf8SF2+07z\n", - "zB5hexERnOHmWZ9THKXS8j6NXPrj2p32k0gvmlI4b/Of9evEX9mDBp5GtQHOvTswQ/VYUajAUXz4\n", - "5w6EHuB/k+FBz9pe+B69syJ2X5MYn7Qi9rKpCl2kZv4uAWXuNo7oIaU7hr6elcFz53tdL9AEjCAb\n", - "BlT3p448134hjvo9lj95CHF5teK1w+R310Gc3NQ0eeJcsiYD2EoVrHHjVDF/m8I8JtTUFdJ3xm+G\n", - "muADOcIpcqYbeqyKWwHmgvRze+DMQbkLo4AlgQAAAR4Bnh1qQ38AfBSmnoPKZzTuFWeZOcrkeWeU\n", - "yVIALsozlefbqRZf6f7w7fkPoFSkdlxkJJsnO6qzfbc/Kotbm2yeFrIQw5yspszQL8gAAvMHKSnw\n", - "f4CTQ2vfLY55MADj1baDD7LZtn0UK1Eh1HnwXobc+mdHd/JEl/a2Tszf/EZ9+J7oMl+BYsjWKwNY\n", - "vOv5flnnPLcex/hWFIF4n+hpBybvasl5hI9mV0CeAAyAclftj8N9n7hadcpM/TOVmHbSkJ3cr/k+\n", - "StSwI8gY9k3tmbMSZc42caMpFr6YdNCCIj52zmNBccPNFxW+UT/4qCqtX1gc2j7obKDaWzC1yj1A\n", - "td8/VAjqVn+FzuuEokhhvubRT3RCdxeWnBTCG0CxwC7gAAACMkGaAknhClJlMCCP//61KoAXgkIw\n", - "VJpvAgAqN7f+5rJJcY8tkjj7p4LozjswOy2dTydK33mOBGS+NojRzBOlwt3ro+/vdQIUTIVrXKwh\n", - "2SrHPCPJXQoCjJUPkRODCmqbZeBHsv1r7iIOZPpX66HYYhWgPLvPzAb/Nqu9nQqKoyphhNy32+S5\n", - "qAFvjRKLSjPAx7GoKGUNMbYduhsBsrvVTwhrV8uWAls2mxYggJzVuRUZSL9cSt+tjl44BXjlbo1a\n", - "I7ybNHG97GCzcbSNcg0RA+iqwDsdnrZCO0zsNdWK1qVmER0PsSf0dicSrZwIcxZWy6JbkwQn5TnO\n", - "kAah3wAs6pJvW+a5ZiJHl6sVlU3yCOlrECAESqWu0YR75WfiMXgesBOuXGGNsC3icmPYNzM93us1\n", - "7GQTI6RmmFHGo+B2yAB2YJiK1YN/T0ltUuXfFAvL4UdHgEXOVIqVj+S+YpITMKy740IvYQ5zuZPD\n", - "ahdXF7HIU7xE0W12w+6qkuyZwxUMXLXdgx6svudMor1GNfDCdymcKIidhuuXh7vdQrgbivH7usVC\n", - "zjMqgjGahkW1YlmytCooEIoULx5ux9DK360iAi4u/nAomESdiosanRfQ9jQdJSpo4rurLfeCLF1Z\n", - "XsQAQRTcezHlxp1tz3A3WsYMA9urPBB8pUlDdB63MfZDCBphVx/Ddv1AMvPXFEPu18oREsV3BdKx\n", - "e3lxLWWpytzF3zXttYGgBb90j9DgRGE1uaAWyEAAAAEiQZ4gRTRMO/8AWVV6uU/hFqUNYqrP23yu\n", - "FpB+ECoAQNVnJ92i7ZF1i7u1D6K4L4gxm2RaiGsRDmf2iYWEjO8yGHAqwpcDep1/+H221WMh98AE\n", - "VV9Ferf+hy0D7Zu5rX4Hp3s1TpcNcEBIKPHVSHIzaZKKfPXkqE/ga/eepp8Bzdc39OW6g91hVVvf\n", - "WJxrnf77rapWbmivuJFfeO9u+RRykk/agdEi5E/5a475KGQprA2yl390PNrCvoamPyXbETwtbYAQ\n", - "pF9uDZkHdN/NQ1P4rz+zQLJx21eQsP9WBLswpDFYg9BjPw+3VrVEzeid2j5wJBlq+56Hw+Ex6fI6\n", - "1O0GbWSAC5/5Zg+kGX0Yx7/We9PseMWGwXWIVwqI7oHPEnK6wUkAAADgAZ5fdEN/AHk02mburIzA\n", - "1V5U+8CauxZABexQ9zxvy3GIkNn2+19EyZqnRm0DMMsXP4ZwiY8vW/qdBTlATfbmIFDxCTzt76+L\n", - "X3WaNfG+rqTfzj6gLFFHl5IJDtQmIC9KAmTgQM0Lp8TEDdYJnPYGFybq0Xdyl74+130DteV0SYTD\n", - "hgB6230zJvCx8ZW04pZHmYvtJ1LZAxF3BAWKPXcstkh7/Er8zYdPblR7K6t0r3b/sIHpME53VRBk\n", - "ggj1uN/p+iN4KwToxjP8kZ1opB7xpkyOQpicygiGnwjU7EpZpywAAAF2AZ5BakN/AIdka2Wer/IA\n", - "EJVZr+9KNmiS7zXHA/5uJU6D0CbJOrsLPWcfwAUCZZjhlCsnAlgzrrGOONmuxU3En1TfTKb/7Pu5\n", - "1R8PfIYkV/dZFitvMyRPMvzwXX1OcxtjbhM+M0LCh6zNEWJFi2Pi95t8cspIknD4iXNUblA3oEFp\n", - "VGuXt+8S3Upf64YqAxWADhb5zxXL+O/gnWiyawM9fyRrYcExecMkEiv5MHRsJs8Euzdps1vwxzNA\n", - "Zu4bu6ic2K2ueNja78qXGaHz7xLoPIVJv/T4KAuseyOhznfFtKf0Ey0eSBVK9qutGGF83lfe5Wtv\n", - "xb73lHTKLAyiyJassoDHBSQLAcUPb4nB6xWNr9G9gWtqEIp4Or9tKJzZIZ1tnIKZFZGb0ELAlV2+\n", - "pKKDz5nW+syHi871Soc3HtgomT3Y1cp83yQG1GdKkcJPkU1uJVzsVPzbXbSU7/z2Q7cikc4seN2D\n", - "ryQ1l58HjUs0ikCXV/V/CDkAAAH6QZpGSahBaJlMCCP//rUqgBbmS0XBN5gNQAaCJTjyhVwVkMwl\n", - "GF6KXnd0XUyzqjFCJEv0D2xQiJu8if6sKo6qHl+BP/MZw8ss5OKq407INzCjWOsjf2HTKyC5fNLK\n", - "wiJv+PzieOozn64ZK7RRud2QUaDe0kuhk4uCClSYQBImrxmWeEf/X9zH3+ilYhfoZigVm0IoMiuu\n", - "YX1ERVdg0Ld9E6wxbYMiQAGJU1qeeTwc8vb3w3kiJheTA2PNXtrJ98RwtpnhN6QxMe1dw+aQWI7S\n", - "j0oQ9iNx73N93RuNVRxXj/57S9VltjA0RTZBjLvYS81QDA3fBgaNHNzOBZ7dztz/rTxxOpumjTTw\n", - "x9FgnvlMsjx7FYPKUcXD5quVKd8lwTlOiGVI7X1HEv3Hh4EvpYVt6azhUBI1qGunVb3X1lyMhWJ9\n", - "p3muqcicwInEt+BuHY92HoNXaaJJbbQmNX5s3QJbI28Pg4gc2gaUF4SQRcBgM8uwcYUzxEkBS06L\n", - "0moZm8bwMsLYCLj3fgXOyFudpfg6jkYPDeVK811WbzEz8Hcd42XVL0EwE3bwDc+i2I4+NERo6J6l\n", - "d4d7nOIvqUuorZnDPtlYcfSWgBqdP0tQHvFb4Sv9QUCBvXlH2IEiNzo/daaHVtbFRNZ3cag2HOiP\n", - "lMxyt8xYJMnG7di2JiwAAAD7QZ5kRREsO/8AVwwP3fRRACC0tQoY45xe6yfL8KMHlR1wbd4HcPUC\n", - "+4PcnqOzdoNv80ufRyOopFYryJahX+qWFUVKK+nDtdvegTv/PqvENcT8ykEwwQ7z2oNUdaMITYi5\n", - "4tC5YA9FaLSBorMGx3aocAbiF8065MBqyaTkiW7FtGRHVSPubGixAl7hiQRoBoEipfCxkE/EBoII\n", - "omSCNrFRyjd8oY66cDfZt+iBI44uLDeP6eHMEpBALsV0FY7iWjBLaYO1t2PsklOb93SAExoyIX1I\n", - "TiPXiUgrCYe7dgepAF31BCnOuxiIAPWKLDHZLhGOJBLqdemk1EZoKCEAAAE5AZ6DdEN/AIteG4cJ\n", - "hGXgWAAHNd3/IaNiUh/zKhTXYgf+UKkbUvWJoLo7whMXByWkvy3MotNcPaSHeaKS5vKy/hBJIgk5\n", - "CWcdsbd5QzFHyjOIZiaEAA1AziqRPTDRRVYKhcrm181rAlAdaYmvKZAOu92pmI39/PSQjhiMouSe\n", - "XVT3pg0s+/zN7WMQCHqTmey2TTctwD0YnAH9CK4EMAw1jPCCTXgop9epuL/iXjup2S+LS3pGE3iO\n", - "oIHon+1ERGRC2Vp3b2QAstSXzK/2zI+bVnxf0PhgKqa/NeuEaF2SBGZ/TyqGPDnQfJRorCp1s+mw\n", - "tm/3aVbjKRTXeSwl+OCfF6rMqjf/Zw8/4yrjLNmiyOgD8OWqATkM50NFqOShrrTCaHdcxgVW70ss\n", - "cCXKxvzAUCe+4nK4C3zP8QAAAWMBnoVqQ38Ai2Rc7ISR6q0L0pberS7nbElvP1eAuajd6ehFPCEk\n", - "va4007gA4DkP0YAYAumNCN0kma3A2DvFPa+NTDmrilkXNhiNVTFRLzynsy8rdgQPBH6k5DFr/4eZ\n", - "jmJjfYPWB5+2eEYYc9uJ5Ni70hsVFfV+T8zp+ZkLZnd2wv7AZ7A8baF9R5O9oQlCkoVPxkDHTrmt\n", - "rElQhX8Fi0yj2+BVP5O9UNPGQU0+M3KYUTg9yTBG2cCw6Drt49/5M/86NN03F5R9JS9KGOfJjIlA\n", - "koCavGpTFqq7OYU0RM3ilfXBmxvL5QoIK28Uvs71J3h/IvKmg4v/14n3/eoSpqNUCC77ty2SgAAi\n", - "rxQNIHz2GF/lpTynlwsORrYNT1lJMVud8AAQb+/SaHWQXmhJ+8cZTt8XuMgG/t/hdF6GqyG0A/Pn\n", - "hWRq+asN+zBaeyQUWZrjl8ry0h3WPkAZksFb/gV7ABWxAAAB/0GaikmoQWyZTAgj//61KoAWw9mB\n", - "34Nmlq4DQoTYIkneVdOFHxDDrFwsv7yxZXXwNkGuLMduj7QGT/7lr2bNfzApMJfo9/ffM5g789Cz\n", - "1Mn0zxePHMHBL6IHHRVXWyqDMhVLYnQ9xFtc1jml18If/8STBCOf+AZjMnARcFmX1IwLt/ziVSoN\n", - "e4GPKKZqfZWytoW7461OuaeZ9dvtxrCL+W45zobgR5vOrVM+Opl+w/eFlupHlgpQBWgJcPy8sZC4\n", - "/O9laiYA63xx6M701UUvGFsRI+RM6anXyjKc7TVrmZ/YQKRjqB6Mejs2G1mTDkBn7T2ZURI2vZ3u\n", - "VXRNsQnGYDxRUokS3YRHs9LEF/gxKSdLEEiHDqcoIHyS2FPM+cIJRSvB7sxIA3hgfN/O4qDK6VO+\n", - "t71oi1H0Bkz1ugONnVTpQr+WeMS5AtXXNBMXU+ycO0+R9eRe9BwSk0V6tHm/HJ45oIYvyWTj3yZa\n", - "JQ6q+o4isbf26PsTbuSAcvQoMnzEXJkqElGJ8Z3rZtdkIzQW0DDnXeNRbj2wQmuUNBknMsWOw2/t\n", - "fD8BErzYLXI65PwTY+6R5c6RWYzF9HNMLBaO1c6cI4yEu1DMKtZW5FrmVuc6hg7VnWxgAgOdFKFA\n", - "QvmmcrbHsqCH4rkez1y5GoMlxeOuW5WKa/JdcefAflYgakEAAAEQQZ6oRRUsO/8AZUEtmg0dqwLy\n", - "ubLYtABfXw0ri+bvSnwBqWW9hB3/jYP94x5LyZNY560IvuBe5T4EX3/71Gbqj7BS5SJLQ7X1JK0z\n", - "I9iR6McwRU2BDEhu+2JQm1RA2fBVxnzCyNr1JVnfyyuumlkNzE8n1UgnkIbS/FMxc8DghB7zqZzK\n", - "rkagW0hHwSjNf+LJf3DnbXyvnzmB1lcv8Z9QlsnPKDef2giSgbZeTNWRMfeu91kckRy0SSKkaYVK\n", - "KUUpf450Vl2TzPLRaNhk7Du1IJzIJRf9supxssXD9v31LAVibgyznyLU/cS57Vr8KEXG+WpKysV+\n", - "6iQmQ/hCoRg82drzuniAPltxm8MMUZwVMGAAAAEzAZ7HdEN/AHUKF3WsfCAA7NAZyuGlRySXJzA8\n", - "WtPYIqCp+udF6BaVoG3w794kSqeP3syNbVlr+uFhruNMOOzTsNGrbATFZMl9DU6mhIXZ1HEAskmI\n", - "VVSgXlz4sVX35JqYrDPP8r9Bsg/O9tAp7LnTMjWlqOdgOPhHpyqf/hmokPsCwqtKfsDhxP/tmX60\n", - "fhM4KsfvpygzK8jmUmY/GDBCISRQeW6U8uaq8guf+cvy+sP09JLJ4HsULhIsm6kyYO04HBdOFUDr\n", - "/8IzlOKX3w/FCxhimlJIduY8iySAFQmALOuag1Ry1Z3p7NpGIGhZp/q5hzsMAsH2jpHXQPdtFNFH\n", - "4VkqDlRDeGqieCr6gwu3hPQQfF9yauq4qf5R+bfPha9tZ3XjpRO4eqNaj2xEQrcb5cIJOAAAAUsB\n", - "nslqQ38Aj1e+ZhXsJE07lvgA5ryx/X3Tt1hQ2T/wP93u+Km2fQtCsS47kHT/v+BMMbdxEWzwYvcd\n", - "d3NYalS7o/aUthPBRfYGmx2hUIQijLOXN4leC3SONeoCputIRor3Lgsy985K8UL4nvf1+pFmRQg0\n", - "eJgJ9ubt7jVqU4S6enDDZ82+hYwxDWOROomkxsOv8nlizRgAHHE1n42Dq5sLIu8oVYp/4M1h4rCy\n", - "m7AmDrR9dbHlpV6pqPLshIJSKr7R6XCF5H/mgt+78ttEoS2XxbrmVQj6DQtTzcYF1gqzE9DaiXTc\n", - "rKcf1aBAFclenBiNHhbAMEE20Br4FIkr51a0ynzJocMgaUhstOH+7gKJGCsTPkykOiVzQeIGOfi6\n", - "AmLkbzIds0NOnV21ExFbxIFAMu1BymG8Kjwvo1cLb7372R2f+Qt5Z8LjmGrBAAABxUGazkmoQWyZ\n", - "TAgj//61KoAWP/AeMmkxh4qDG8hcZFMZjYIY//v8PGtlbWZ+A0oGGFPTAdgmU2TFbrR0QmwUCouN\n", - "e8fq+V7LhZ4IhSGjAEZXRALCc6lvXQaVk4Hy29vGup69bTfpCSIWWGXFW7WfQjL50GRbZZRZHQ2m\n", - "pjAJ2N9/bloCCNQEfrVxCeDkKfJqKlRpIdnOUaiQpsnEysqkLqMfxaCLAtiv1vFXcLPLizzlMPs7\n", - "NIiiAuhD4+CMokPsODEut5yq6fM1zRym2P9iids6rfyvN0EtWlvUXkAIdmS8HfE5DlX5rtipWZ2i\n", - "d9rb+tQcwCfWN6erokI6tARQJu2c+ZSF/sI7qofDkfNVCHii2Msza0cnJEbLkEfdF+gBET2KrdRv\n", - "E5mgO+6ICEAI6O/h7r7DxvTQ9Wxzo3mHNo6898yojVZYUAEyiEUBn5+alz6XfA0d5GcOXFRjv906\n", - "SVSt5h/ZyjXd+HmcrubYPlDuxhjCrkqyrKcbhfJHp/Mq+DI065H9OXdNO/+uDSHvPcKkibqiAVhI\n", - "DqTA+NZM5+PbtXMsqU6iKpSzqr3AN5mBITP84n9JoTkmCR2U/+5h8eajZc3UcAAAAOdBnuxFFSw7\n", - "/wBlSP3uCsGGoV8bqfG+TF6JTvUuRSAD4pZzJUFnxrFOJYnshFJtjPOw7rAcguf7FPJIlPqbN5qs\n", - "fqCPl7TU74m2w4/OJHMnDpS1+crxo620hZORUqqaN/UeMSuSm/KKx2/MSsIgkvOy0fYS1MAD67Fk\n", - "Z5FUhBYQOPZatG+Xc3Icj+kvLjp5v9fX+nJsaNN4CCl0quEK1R//8eZO87p6DKKxlnRfV62uCNE9\n", - "o2MWYwf9qwHYbtyqG6I4xWPTngQnrsOmiw1Sy0bIvHiKKw6nsCsKdLVPqCFU/q5rppy8Ah4AAAIT\n", - "AZ8LdEN/AI9CIO0JMMhrV/0AB0HLuqwUdobO4BdVbPV1Ioua5WZC0IWTaPE/7qAFTCgAnl3rAoSn\n", - "Kk1336t4zGyyPYAAOSIcqQwF8zee7dn7XFk1tvgy6W/qOMTmkEiEdwceoRsnhNmrNp/TK9OoMIUg\n", - "ShyIuwXG8nP6tDCpAEYSuvpzo5kchXf9jICMUEGqQZjLulIdzbNUEecLTDRk1r3gpdToPPcXdXTM\n", - "AElxf3acmkXSo1kx4tBmKJrXm4kNQ2oDIaqLOc1dGZ+ccoProxsI+jQiCldj17rGF1/E4alcIa3L\n", - "dIofRLGOPkev2msNj9eN+tELiQktxoUq9fKnDsRx9Nbc5IkysRYA/KsIu02gpfPyisLPQwjLSjpr\n", - "jTxnZViCfPC6UCMSLVKUvso8AB0eV8Q+lldoHmqd+EeBeeJOkPU3vuU/GQacMWsLnKmVt/65Nw0r\n", - "y1AnL9+YKkDmvNgpqgQANfZvj5NhddHche/p4la1cXWhY3W/jmtWxMTkOC4tX16bao5sNwcVWRvt\n", - "UHjkDIOIXB+3akBV5Lzaef6YjjT1MeUeFh/FB0tOMV3Bhvdw35krP/ItZ1RF5hRCk1oYqz0ykGZW\n", - "YkciBlvCsweWM2wXwX55h7SZHtxiKM3rO4Aff+TOWGbe8hXaapPE+4wKof+j5KoQ530gP62KsQIG\n", - "BV49pf0LYkAEd7yVzO9dhYYFAAAA+QGfDWpDfwCPWoxxjdaiaFtca/OwfG9dSAC6jYuqYuZmzKSC\n", - "kzbTtnf9idy9v7frgKuFjQymibohZCHRXBQdujo9Laqcw233I4Za+//Mdf06kxHe/IBTsCsxcSfV\n", - "ksVUEdqCe9dEwWwg//4Ee8Le2gLXqz21e4jiFyBOjP5GsM1hpupcfwZtr5Mo/ou28BY4QZExXJ0H\n", - "FzCqK0jKq6c//ut1tsd+kiOyZUVGRAFVkS8bi0vvjrj3zga9Zaa6Mt7yQii43DdcrobbVIWdc0QI\n", - "3+rsc8fgmOnJ+GJGdWYzpFLd5zMjS5ofw5IMBt0GmHVcG82Z6YQkqKJHzQAAAe9BmxJJqEFsmUwI\n", - "I//+tSqAFjc3NgONUfiwAKbp/vtZn3NtK6t0V/4sA0MV4unWIJlE1N72EjQeUPmvxOpceaVXIrAK\n", - "21oMRdsBwM4wyEJDPiji6fXmMlmmsCvOtr78Aj8gA+xKnVDFjoVlH7PPNvnMo0iZJruZeFy1B4T9\n", - "/2iVnlLy1r3LZhoykeyNXqaKEANWeqYl2HjpH92g+fHSONko5D2m4SRKJwFWFllUBg2RTQ3etVYS\n", - "PdQGNCLeaZwhH8zjnIe5Vuu46VBC79Le/PF0x5A18FileZQS8Adcvcamp8leUQ9dML537b7ARaSt\n", - "9Lyu3Sdke9BouNe3+hTyxzxAi1Setn//aNMjVtdKZIT0wLvPIMCsfe3gvhpNMtez9cWJYRUO4qU0\n", - "Dlg6h/pUIog+BzidDDvn6SZ9WUgEXhGZOFeOBYowQfwTGI3ac1V8O93aTpJwa/om7scQbOrwAjjK\n", - "gaYt9yqViBt3FWYRIoJJGYqmGJkf0tLvcymA+Hyayho8kg3J33tLzi7Gkd8xVzsn0AbjvoJ9u5le\n", - "OKsB4L1kcStddnytXouu9GStBCQSRLPeb+iGeZTwQ5uYY8D5fTAcb3C6Ob+B7IWRbbytzq93Kz0y\n", - "yYvbeUq1qJCNW3/zJeXeH+8yV69x5FRyM+55j6UAAAEdQZ8wRRUsO/8AYsUcQvOGOSSADI46r94B\n", - "/W+PEO3biH5wUahFid/4E5wZcJb1S+5KPsyD0qQEL2HibG5BPsDLysut2eDJfU6ijjP6zrYmNEWR\n", - "huQfgh9NsMVuoggiphkYt9ccXxVhYHn++9K8YAnkm28Kzp0jUWHgD2VeIoDjCfJPNnBqH+CERm3s\n", - "nubUQ9LmttVf/+MNJAJgtOFW5A6IBAcBpJtd5kPS+zJ8VxzguhOiD6Pf/zfgjMDUsehmT57QUanw\n", - "gbdNgBf1mSXZw3Czfs4swXmaj+42V39PQblTRJ5hVxxBfyBMHdtD+eP+pUlQP8pBAAnf3v75+Q0T\n", - "L19oeS5dx79IIwiodA3vtFf2KOiU2gODZqY3kJGizWNAAAAA3AGfT3RDfwB2j3tYlaKo3hdLneRM\n", - "Dlhayh8NourV4B4kYRi+kgAOdUf8hAGAI5XCPTeroAwXn8G2yGEphnv3FPeZqmLNmvgLgUkPciaQ\n", - "A3x0WVLvMk+lZn6cJdklOXHEnjNKsClw6wU0RbMDBk1zQUzYb/75rZ2h0N0KqL096XGATDutyhUZ\n", - "RVkyTgfbEgHdPAmzdroStgpcOUEN4xVVZX2E+XrryGs2/tIi+iUaglsBszkGSHUeEuoEpHc8PRHH\n", - "tDc+6s5rO2oABm+Gux/PUd+4yoXEBbF4DtdMIooAAAHGAZ9RakN/AHaNgkMVTymoPnXABzXUf7nM\n", - "R8KlDfCSlxubwbY5y13VVoGV2GO0t+vExf+APmeqLrIGM9X5aCQgGSaQJX4OQoECqyNRzFZQDLhW\n", - "KA4dfYJp7oYRPF8AMOzGYqm7AO7w7FtM2J0yD1XqM3LrKYS1dGZTAzMM0YXyhFuS7+8HWwRTCnl1\n", - "B1MtLMYaA8qvJY/AATH13D2takXBcx78I1sCsI+P57X6Q2Nh62/bggQuV3uhAAN0tyrIgbNQYVBH\n", - "gFwoUmXrxaEApAv0P2E40tM9SJDDcZe8DyE7ljCyxGjQA+gKJHzTkZCCQsmlxDg5It6wsdQ6cusN\n", - "DyWnlyoq3MMo7ugMYcm1YMEY73l36Y/R5wo4wUzuNvV2tJ3rSYBCfXsVjc5o1oA8OllKUpgpBG5u\n", - "9AavXOqCqjA07sUF9WlQ9JPrhiXa9bThYRp0lNBazKKlKwsBPK9zJ1/OayuptCCUOtFLyDYWpp2k\n", - "qNXWH8r0IpnJjxnQFcNmI3LKk+rH0vqX+48vd2BUqTcJ4rwX4e+V6oU1+lJyU8fmS4Kj/iQFUx5A\n", - "ntiGKLVWwqfkoYN2YexrEPVBTpKi81wf61aU8NAxYQAAAjdBm1ZJqEFsmUwII//+tSqAFj3B7fR5\n", - "G4ADaQx//3+BfZIcqzxSrotcVc8CLm7cBBc8JifUTg3KyGbsl0UtvUGR3t77PRffuzjjVfcKeiAp\n", - "EmDpLoqmMXTQU5wmHksjapt36fasfEiGyN1dOKyOI9nT0TFFL0pzQSss7Ux5GajOaQUF29zSIoeo\n", - "7hOusjWiFyZylISVuEBU8nCgDYn9P601XpFko2u3FAuYp/svCLJOzc9W7b14FY05eVZdhfmiv0Wm\n", - "d+i5ZPIv9mhB+8Cb50V0LQeFfsyfPeAABtfp/HIPaN+amWONE9vQ2YbC1JsqKljPbi6Vrd258gHB\n", - "PNyXvESqATfkK1Gnk0AWxo7XFr5y0Ce95pJr1n6gAd91M5RV5lL/XAgE7sYG4524aA+cXAa2XPdd\n", - "1BugfbN6YGWbktwAoVIXoUq7TnrmhBrw2FHa1aE9uMJerl9x/Rs847iKP+iuBUD2VIUOVa/G9Po0\n", - "ksPo1bHVIsITIKnrhXV1NabDgHAc5kIv+PJk6IroGA19oMw2I1d4rGiaYQZE9dmK1VRARJ9VXDBJ\n", - "Vlz3aoQhCyQZvwzvxWhVA1iU1RO1TWnJsppajNeO4Vg4/b+BSviIvrSwwqmjaRr8iuCpVTgz+ZJ6\n", - "95zLiSdnoIFqQJA1Hz4YR/KIOmAfhTTnHcdDelso1m8Bx2oHlzAOiYwR4NhSSRD6EhhCU2kXf5vn\n", - "vYdShk1Y3/pp+Wd9yZwIwTneJB0AoI0bbmfrtbbWj1oAAAFQQZ90RRUsO/8AVxVRwqizyog1fzvw\n", - "w3oFk0s5kH60rPhj0qbUv+9nJnU5H1hbksC+yivmpdt3FAylOp/Re8NoooEKQr4q7MX/kjNCB5zj\n", - "aCmG5E3TxVGWGCYMCsdEF1I+HuXX2a3wLCwf1iqCfznNMRG46GE6nIgxc91oY/zfMduLLCzyb8AQ\n", - "b20W2eRODsXd4+7XC1RndLreJ7Km543AdL1iUo99hYdoASXjyWRNv6wvJrmyFngIDlQOrLluZf/9\n", - "T8Y21pcggXpfTtvdj+B+3lZv29AFHkL2xGPZvyL4UyVUgb3U1DWd/iySeGzlK1IbRNu7obP1czi4\n", - "Rchm1nI/pS+cSuamJbhlQHIreF0u2/zcrSGkuOpbObSfAY//5j6RVfcQovw5wL1RQN0tcA1GtFxu\n", - "ZpovaLthGUkeOPh8iV5bEpupJR1R79Ew1sEkTDugAAABwQGfk3RDfwB2dNpntdq7wHtHkfExb8Mi\n", - "4AOIW+6weDVD4WeLhja/JOA5FtORnuW7CfHWfWrXcPJWyNJJfpx2maEKeggtR3RVEAdA1a1truYO\n", - "N3PBvt2C5hri51AyWveiUQtRNh8OhcT8b+NVPo5dLHlfN2wr8ZipKDuUP3k1md+EiPqVCrK5TuMQ\n", - "knvfHHEV8fXqrrFiHhWYrAGbSJdOrXgrQTN4JDv0LMwXs1Nl1nmEdfSgT5BF3DohYi4r2xGfiJcJ\n", - "KMZ1oPHaRBjgxhu40ZP5HqUG5rQWHD92UCH/Terh0cf4e0554mxHgDF9CBXD2Ey6LaV8LB9Jb9nA\n", - "f7tFFMQRIVaLiP+uig+B5OoeaCY5+GdEeHuY+ZE9jNToZ4yOUwNfysZaXJBrtfqEkQosI3EYRZQA\n", - "COu9BHjZjXsKjEmWe9Jj9yWusbXq4WMANyEJEPNSeDcqy2nLsc2OqSE4CgyCqy8blbRZqycUiZt/\n", - "3NpFflI5dk/7eeQ8Uo727U5FhceNm/3Tv/0N3CZNlPGV4f+3/HHJknpIjibzMw4AkTq3Lkxy1XZ+\n", - "FA9yAR3cZ0/eN1EscyudULe5dTvs1EvlYMWBAAABtgGflWpDfwB5Rz+lHWcxYALocP/IVGxKQ/5l\n", - "P8Y+UVeKYTw8iTn+GjVV8vbhgCZ5cI/70wvHdrfJYaZZyRIawh8+61+/vwo8HAkEyAQL0QVrU8Db\n", - "Z7+ORIRATWUQyS/LIyP8q4/O5rf7OuybqgrrJ5JQm3dvb5EYgnYLHCULt4xtpfvTsT5gEynxu9HL\n", - "Km20sO4q1oqcF4MPx2dj7xETa3veUfVJqfvwop/9NWsmPrdhY/wz7rinYt2HcWm7+ulSBZtWIRv3\n", - "yMRoNM+lyCvZDr0PaN2HfwYWOYr/NgyLM3qvI6TujkJkGWBIPuiFK/SHsSPx7iAMcrZ3CQvQC1rq\n", - "psLEx1Lx0vtWsdQAcjEYe6l7VHqUFbgcjcHAYPQIIgi8NauIxLhxUOQnkJo1mXO/e5w2N9AAHA22\n", - "RlXXsFU92TGe3GmYdLlI4OC3IklyabPhxs95veQzY6n0a2BnyANXxWrQG1vVVVAYgtb88NEdo6By\n", - "gCh1aEE1VpUTP0of4shaZpNk/2gd6T34r4uIClLqdADAAdaA4/epPc357p2Ro8OkrT9okATGaQDM\n", - "AYBiPC2kAQBkyn5ImAAAAdBBm5pJqEFsmUwII//+tSqAF4In0o7iUdIU6DQAMu59v/f4eNbK2my3\n", - "LFfU4bVvmOXvurgANJp+yhdNshfKZWyf1yiq02eNo25TtXkBg+c9UZquU5KtxkSr2wTyRJb5fWbg\n", - "+NL8Fosje7XYkSxYEiB3sVwPhHSvNWh2d4v6fN1lP9qvuUnfb1Bn+TdruqmJdM2vx9efbO5Th2CP\n", - "KiH3jeuRzoCzSIUG7cY38FVzT4nUIJdz+2KjjjJ0E7ZNKQ6lROaPqjFN4utrXaZfqGFX2nWmlL+h\n", - "PxS7plcEcSC1oWpbRWphWgodqD5c2VmFV0yO9NkxWYeDoEeaPVORAB/gqWAbIHdoZVHMBBV6fLyv\n", - "D3u5FppjGB4tzB+WC5jnXJKg0Sk3SkInESay6cwWUVJt/G4Tfg6wbMdEkCvCKlRosg/RTpp5P6wR\n", - "Z2iZfctuN2EQi36vtriULh4PVI/bw9ZXWlyhMpAYPlW3C1NvZrlJMNaSqGSSnh5cJMfrxHquXcAN\n", - "CTgojRhZ3tMe14Ny/HV3UfnpEJgrqxN8KZxlRpYS28Q96uqEu6NBBsBIIz0ei/Mg1x57c0aguL4j\n", - "dVBDXATm12Zi0uXfiRBRiIror0O2CDrlUQAAAPNBn7hFFSw7/wBgSQL3wIE2Tv5B6OJXPcoXMcSb\n", - "cE8qv/1v/uy5HaAJNUQCTSWlcVovOwe/GLZOdN2BNEgb1OlzNEinzyASzg3GuZ9zFeyJHe/zvxXW\n", - "qHgQlhmuH8QdE1M1s5tXy5mwAyoAiCrzupaN60ez6jWL/yRvGdGiPt3qJJLeMG60zAMKa7QhUJFJ\n", - "FMWUFrcLW6iQXx7VTZR7Qo0gz/aCe+BxT2h34J4bdpQTH59SHjOd2X4DMr2kpW5buE3EQBEKSUD8\n", - "yEiNy7MVRtsZHXt1V4Pb6TljTGXtC9pzGwEXtgadiRP8dhtDjxgpVN3IyoEAAAFOAZ/XdEN/AHkx\n", - "u7J3fsEfo6cXtbkNOd4swcOB3voAJyKHu0c0/MGiiYXv+2wca3XUwSOEG+s8df2rHPxj/J/Armyt\n", - "j86AAAWOWZsl8AgjGF9fWv1mQf9jrWNuA4APvfeLBFbZJZm7otp6Fc0DFqB0XCbEvLTkRU5ySc7e\n", - "Y4CD3ziWyxgWkLgxNxAV0V3rzOqUGhFxcTbBCJI75knYyulzgB9+SazwgLVSR2N8nND844Y7GLCN\n", - "0aeRWZgNIAWJkPPhP1VnSRo1jOpV+axgAXL8ExpNwIvLk+O8lekZ0/1o7sI+uJ46XyI2SuA6uJHd\n", - "bwUKNMI2qDKAM6f4kKlJLSQWqzXAi8hAQzI017i25Vpi5npQJ4TsJeyOHRvmO1wY5ZnIEZHyhgB4\n", - "IoLWrdA5opbAou9XxH6m1F6osqepeJLd97Dr7+5BqWzoHoOLhOxNwAAAAQ4Bn9lqQ38Ah1fDGltb\n", - "SoFNBABy4LNe514R+dnaDTYn5E46OmsRrJgYyAm1lSXdflAXI1+CFQXE0A4eKb0poyZSLaaXfRBJ\n", - "r/tA3jW8xYt/UxFDszVrqnPHP/Ny6pw3mJ+pwWr+YYAHxNaLyZj85nxRNPFMUkOr96iCB+MslYrg\n", - "cr/vUoZCrrFka9nw08yFJlyN4Ky9KHUYJOXDrBIiz8KQQaHFalCe3rENKk9raHLB9E2PdI37xydW\n", - "9R3Ktqa3KW5rMJCOoArO2/3trkkCh+/FDlbsei4VdbDQ32DjCaAkDFjCyuqOJNsi8nSI2KDSRFCB\n", - "83l81kCObhPemVMTlMBQzSDvOtDFUtuVwHtirD8AAAFqQZvcSahBbJlMFEwR//61KoAWweTusUEY\n", - "AFR7WLigAceU/KgvW9LBBRTRioW652v1Xpv5tYMFhkRmmlUca4/8lM9NJwOZFgbdLq3dhRjr1SQ+\n", - "iitgTnIKVe77qt/yWy3INzcVxffYfGucVy2ypyvLSUZVvVzu37Ufe4d1uKQAC1EE3Wwzkx7sEK4N\n", - "QwJyCdTZZnLiyrlEXcLAMbB36CvMtmCiaP8XPpa1U2RaJxnBB9qYeP0+JCORflaC8m/hyWfMppd0\n", - "XeCFuAYTEakC9vO4HVF02QH4GZZigg7j7bXnvstEtP5QgYZViZcOoAaQGKtWm3PCHoS8mKWfCUk8\n", - "ZLC6z2a10V0U2DavVH2m02W1Lc4/2WzrwUTHr66DOaP+urnPdabeHdXruv1HJ087InGSipJtxGko\n", - "4rppNbdlP4z6g2o/ksCKcSZ76uS1diKM/39wzVYDu1tkCD1lomve9NoQwUToKqCn30PDqMAAAAEr\n", - "AZ/7akN/AIdka2XuDkeawxOj/BZhZtP+kNbRABb4RmWT8vSOMSH2HVKuz5/n3pn38gQM6YQqY5bV\n", - "v8KsLMWKt//3BpX7BUiSjA/GsXEpiGachc2o+KqjjRfujy3SLc+TvzNfgePwT9w0Jj9Y8j6ORxA7\n", - "13x9/iM5Lx1s2OQQyRluiOYKxXDE9QjNulPCcMLJFKpvAfnZmzl0pzzHw/ANcBEDhABHQ9ftCkUs\n", - "Q4pQOQF20mJ1++bXoRcUz/lR79ACwohpzpGuaQCknCVhUL3lnnyQzloB0PAIRq1VnOd+y8D18t8/\n", - "IEva3L9FTrRi90eT/2pNxjMaqrOmFzrhjd2kmSd3YBlll+A3KrjDn/HtXx8SDjztM7Km7BEd2LVO\n", - "U1pVGn0+C8gCov9gxoEAAAIMQZvgSeEKUmUwII///rUqgBet471BV4xl2QAFRvb+6Uilj9hVaCt9\n", - "oXOXB19FM5G4bNDJAOl9w7HrxMOF2dPOUf977Rp9NoBObCR9cN42Ht77Y+l36qfp5SrWPFz3DG9k\n", - "Uks1s5yfRvMME5RxPYk9+qohbe5TR7z2WNWBJjaTvhnu4485WU3BaTyIbA4BRRdj0/JwsbCXRVZy\n", - "OMmFdXnFdxhNGZ5JMCQy+ip435WTv8KevLzG3OUTxX5d8x0gaiQZdaPwNC9GVrgmtqTc0z7He5Hx\n", - "p/UnXiE+WgHU095CwXga4AbeOtQbj0tjxKUoS9sAoJ5fyTlHv9FnU0ujgUuoA3Kj0ma5qF69zgnv\n", - "MTXEIqf8zuYuInk435YB6s5Aa1W77q49/ZLR70JdKU9F42nWnuaGIFvaX8JNp0NTGvA0s1VSOWIl\n", - "YVdpY6hSPbDqLYXO/LE7X1D3sWpexh+/kcA2B6pYDzx14bD7OD1f9pMDWxIrW6BpNH75M54gOMY1\n", - "SxoTsfh6KVoyFK4Yqd6lPKCLY4O17tm0vzqLEva8zNeuM7b2yHKwMHpqK8FV5yaEer9Zd+uSgIqd\n", - "eftECExc0GDPrda1mDLPyRR8iDjZRvRS/EElnceTaWiUEonB934ThxItQqnJINdKSyNdNwx44Jgq\n", - "H9/Zh55FLA3sdVDr+1aesKMfNmYnbwaje7GN0y0AAAENQZ4eRTRMO/8AYEUc98FD5/CYkGD6VZTK\n", - "7qaMD8JeD5Yvz1s+LaCSFWcn3aLtkXWLu76WBTjEp2boTz2lISGgYIiIhTqGBdSAvn4GaApcqQ2+\n", - "sy0LjwIg9aZXDdjP9AWFTV1H8wY3dWCf+Rn8X8p7dsAFRxXZ4015PG0t6STtIq5DOqARSPJ32oCq\n", - "OenP2L2rQhT0bU7kBXZqDOvuedMFko4K8dbR3EOKtstAjt1gHGNubjQIVeNhJsdrdMtXEY7juX3P\n", - "NuPteAILXrR8S3R5mIOtuZ+vWEUdS+Inr7FnZsbQiIv9i7KDzU2m3LJLNdjmArFBBLgFXYHDvQmL\n", - "9VT51Mb8gx1TyNar/CPWDggAAADyAZ49dEN/AInJdfYNr4ilmYSAMFB4GADpypoeWWXE3q20mGL8\n", - "wfGmH6ZgcbtTXJWZn5/uB2IPeQFG/rqNYZ/bmIUcKhccFRuPa9wOgu4Qnm9oi81y+ChWQK1KoKDK\n", - "TWWDeg/SDhV8w/q9dFY0rcekgnjPKbKFgzK+IO7hoMF7vhpMoVCqvwMtBaesBfF4bzxIufyftMba\n", - "VRaJWuZpM22/FtH8FxujQ6EjGNr9PHZg3rsxXbkYHRqZvH6RGypNdfKRL4serPMKtCeuCWEKaj1Z\n", - "h+pr+ULdNvwpLLHfA3OCu3Ql8v/sLDD/O1LVB9ug+l/wHpAAAAGVAZ4/akN/AInJdjcgUcZACEqh\n", - "GvWiTtr19IbQdv8WE1dBOa+lNipi00vM+C9W8F7IDH0aaS+KKFaekfOwUNG520lVemVKNYbjnPl7\n", - "LimE+s4N2NJ5SYT5+XRMb+vTvKCkG/By5wQO/WbZo9HorEm10+Tu4CVIj+2Ky5hDZl+kA6mkBK7E\n", - "3LwAW+4rGYiO9JH1BLFQj0ZOJq0ybrdVynOYOw8TudsCI+I3fiT5nmYCkIO1N7h++s67fASBLfgP\n", - "CYo7yLNwfifRM3ay+JhoRmwX5tGJ8l9w676Zo1wDaqZ0Q5guAYSxSJk2jHShR6LxlZmIVJnq7S00\n", - "iBOM0mxomzMhjpxeX6zqy/aA2SEREi4ulxZsEvlIWhLQ5YFv6LMkVEh9RITRQOsKGEls7Y4eSRWc\n", - "f23FGWOVxL2MZUmPGVh++Xygx19XCiXwoatt/s2T7zGfLkQ2IBiMKXoeDb7yiR4q+0v6UjACWT2H\n", - "kOIRMpG/B4KQPsfMRT0Rk3cAwV9dNnKm4XTlo9P9TmyT71B/Greq+KvhEBDxAAACJkGaJEmoQWiZ\n", - "TAgj//61KoAW5ktFwTkgtAAhBassVgP2a7WSOTniW7GlpUC5YARIimzpboyDKn/53KIxVBS+A0NS\n", - "3NuuWMzq53zfHvhoSdYO4dYooBUDN2VkLpVK3v3kQo1FoE02X3cyV2j6ziOTJORgWGzqU5k0XKJO\n", - "1VCPDS1gJclQYem5NlGAENmSiR9I8XvNQLGvpLGF/2+aU31xCZzIPp4tUxyLu/gVqq+6L5DezfDz\n", - "gPP3+vv4JFttE5Nyc7LysmCaQfUhi6zPymHmdLjs3bZdma4hV61UMMsGBNZfYf2GUkV1dVZ9kkfz\n", - "RyUYJPFdwjA5S++T8sc03o81MYXnXYkO9hGiG6RRLRRV2fPSgGhghnaqxRhYVQiuVS0ENIpjxqqc\n", - "KBEaAMs1VoaLKEOrNhZ8yB1VLLV9KSiM7/prkkNKRuNLp0WeTv2eHtXhIdAfhKb+ic7Pb48CqpOl\n", - "FnnbgphlxDaS1dplrA4VxMNzEL/27xNMQzhuRvnSDNb60j/kSJHw5x2JG6G/VwCoVAfFrZll45AB\n", - "Puajv4y9+7flMd/pR8Rg9UAn+cey+vNCcCbbn7FNSWq2hl9cymk4fwW6iqBgiFEQ7YZtyDoNCyYz\n", - "KAnW0gvHCg+5n6+qxC+xDS291Y4JfSW927ZZudU0tXxvupwcKf6fDXxz/bqsOMvxj6Y81+e6Dezh\n", - "B2/8nCpk1Qc7N5s0JoStEQ8+K2ir0vIXayhFQIgAAAEeQZ5CRREsO/8AZTZTJbuKD3PiQhYpzA/Q\n", - "3Iqsld8XUz3sHppFsAHZevvXPBLN2cIUd+YCbEEH6MplVFEcbuDDV0dnlBcrCNrbp3+CAOdBsr6h\n", - "0YfLGDPxHlFlUCi4qTS1o0TT2Jzkq8/O+TU7SSImG1EjEmOGpKvxjn7KxERq2Pbd/0y1sNHk5hiQ\n", - "eJwHwc7Z19aIrWes4h3UYQqHeU6kfCpUHVgnGubU2A0Xjg0UrouNSumFogz0StLk4fuhL5slF3Bb\n", - "3NpP7YhgiVLV0FNM21/pfbXvRQFzmliOaZuScgePqa02nvOdEHEpGVRPLCGL/tvzSkZqhXResmQg\n", - "1qZ/TxlvqjWYqPRThBIk2nP66jbd6NLagdWz1BtbrwB3TQAAAVkBnmF0Q38Ajz7dDL7wKLyRAA5r\n", - "u/5Co2KbB/AnQg3XvWeaImUuto8KuobiZ5Rpi0jf/+r5lFprj/mYxpQ5OwqjQqFG0eXwqi1D6M23\n", - "HLH/3LvgYXkbAAGr9uWkQaEU+TeJ38WNXodDC29t8Y0uYEpwNzyC6FqtgkCyDYDpd/nESpdVRRJh\n", - "15SV0TP88AKwZsT7yWH2r5gpJv8AhXnnWmKJ/WMwiS/2+Kf3ikj614P+BDohXhMYGO4GSZ19EkRI\n", - "RjwO1zoy3Umd4iOMuBBPzevAs74sU7IUdkUF24rNAstoyqnAUgY510L3SgPXbZmJYMv+tRpT7ZuM\n", - "oLxE5ACIQ+eHStmGZgh2P1nvrIaZRiBxoWZ1B+DDOtu5OZpc7LbajGP/oy8HbEFyJIcGXHGB5VXY\n", - "HnskMmabuu5xyFIJcVaqbGg3TlqrbBE29OX6xO7K38oavU/okVlIM+AAAAGEAZ5jakN/AIdXv9ZL\n", - "/wCpeCQF0zyG8897iu+TVNq8xXl3pE8eXm424VBKoADmOQ/RgBgC6Y0IzpqUKPVKwCZafdEIuhUv\n", - "zhgtxewRpr3F4VdMy9NUqqvPfGroLPxDW64Af18RtCEv8t7amX9ezvEWK8AgZjHjHXeVi2k8dp4r\n", - "TuMjdngEOGe6y0V0qXE0vJudyGSblaiStnW6rV0e34JxbdN3Qbajy6ozlLfOkq7Wqx1iLXxa4foY\n", - "IPBIjzxdye8gOjZW7bP0axd+wppVHkXrrvuxUf9dp18AanJIIFv6MCm6ujRO2wyu4ZfSbZp/KVFm\n", - "xvxpBAJyjKSdCoPxWylEDyms9NAmwAADmUiy6WUOIsiAC130X9MRKfeLHi3miJh/YDGeINuX+P+e\n", - "NWBXxp3RqAzo1eISPcPztmgXUHCSN2VRpnCOFQoF4yyryK4v7s2U4a7V5e2sVJBhb7kguiVFACK3\n", - "rbLSCnWI4OCs6u017nghnGW3Juq0rF80iqmo5QCt19S62wAAAkZBmmhJqEFsmUwII//+tSqAFu/w\n", - "HjJpMYeKfGxaFh4NwH9VzFzipiNnWLhZf3lim8qQP0NcWviT9hCfSjxxrnYEE59yPQn7u6+tCr/u\n", - "vn8/iyWB73TxWIDTyqwOWzo0R8Wj7McP4QWP8yE0svd//Wkug5+3cHmcpP/ONbeBn+TAQ0VzErlc\n", - "2hXFLnmGW7EB004qvGi/S7JfG21T+V5Sx9Nre0PuomioWltV0uJSYiMg18UwZktQhoyeO+qpPgky\n", - "U9/xX6NUrUyAfCz03v4wSV58lpzV7BxftApX8ZGWBx2zWQV/YeOCEWbmbHqvN18Jd5FxK1iHRqe+\n", - "nBGg6SyBQEQQfCMxCo37AXM212ulRN9X2fE3P9HkhvkaOxQZ5AElyFJ4BlaM9J8bcUgOX6NS6Cqb\n", - "n7IHMcCIPjAIJ36atWVr0EheDYyrwatT/sRxqfSoF0RgoVqtGqstMXZF7XACu2N9LDV5Ss0B+mSl\n", - "kJJqGxc50wazbtpofP341QOLrRCoQigLO2IFkJyqTpln4FgoWIMbx8x6cKkFmIESXv7mZEx6LOrL\n", - "ggZa/EdzllkBPCO/+zBjmey1Y55MrbMpoidNDpdQ6yZ4UDU0ai3HtghNjtrUaVDC+dCrSCASLB02\n", - "bO819PX27qwUTWW1MCrVhUzQkUkht4Xa4bdnUW7zTudPa++EPxUMVY36vPDJoCGilCgIXzTOV6S9\n", - "OVTh4+OA6S/XkcoA6ZjbQLERX5kZSQMoFJs4bPot93titzpDSKAhc1QMx6eKK6Ol2IEAAAEkQZ6G\n", - "RRUsO/8AZUEFdKFRxHYcrgnLV1IJewAc5dAL6/Pr5YWcZb4ejev9b/lpY1ea5Xk1AlTe44c3rPkF\n", - "DXI6yAdEC7kxPh5StAse03AARSF2nro+Dr5bfPJyYF/ERJ9NScPmUIVihvTCsyh5qmuoAH9P7eCu\n", - "Y8rdH1hF/pTSa+Z1tzZc8gwGtgV/YsMtlWLs3VbLWxt2KTDW5Y2b0HA6zgNn25rXu72r6iiN5aw7\n", - "sjFipq/8rjgHE9K0EK2Opn+0SPK2Rbo28aoNdC9V8VxW1CpMNxKjFOs8YmQmJE6Qtkw+Uo5mh3ic\n", - "7Ng6Xje5wAF7a8Iyr8DMIwvMZnnVp6ilQ1B/LSGEPncviRIHH8w83Grtt0CsL1L2isuyMboY11N9\n", - "lxQPpwAAAUABnqV0Q38Aiz6zZgMl5b2XXQAXQ9yHCqNv7FVD9CxHdTnw5pqRTLAoFiba5ss3lqXG\n", - "QCf4/o32jzmzNKjZDN2ghdo3OS7n/NFKTMs4yX0NTqaEhdnVRvrbcGvcKo0NYMgzE8UNwneueU22\n", - "1vpuKbOkae4P82iS9XSi8TlOPcF8mmD+n9qfVTXzL4r0M/s5xxZempvnxqhz38EgmSM/Zw7kEyiv\n", - "giyuP/YjNhFl3FVcOSLiQTCj+F0nLUE7lia+UkuO/YNBXwUKZKD8Add8BG6ZTC4bD/RSktc7uv8w\n", - "NB82AXgnpuELTB2xZFOLAYJncjo03/3uAK678Cl8cw8fzlbnSpp5eUkHacCUtAY9LPrz/OMf2bA9\n", - "vBE2eUwrxz/W0Sg0tjzkUrpnJSF+xYsA2fgRolT6A0NA++mVN8PJVhaGzQAAAX4BnqdqQ38Aj1eg\n", - "HO2BrhbSJp3bjAA7Lyx/X3Tt1hQ2T/wP93u+Km2fQtCsS47kHT/v6cxSu0EEWzwOVr17m7uMIt8s\n", - "rOS2NL0s+wNbNsQiUhFGWcubxLdtukca9QFTdaQjRXuW15l7gz2QnuVPe/r9SLMinrQ8TAT7c4JB\n", - "GrUpwbYY2wvPKUw4NOIKdjGz2TGxM02Yhqm+YQD7nu+MPeXg/5dBf+XeKfPK+RchTbfnRfx28pUm\n", - "+MUq+ynmpWVmmfO3TbD8gZCbZRUeK4LOH5lP3nvVvkbZlQVhN5vPlxxNouZsDfsmprxmWrHzH3vb\n", - "E+c7VsDA88L9wCH+ZmQGzxFjyOQ8cz4P9rsZSuU8vQS1h6fmk4XXUosrmweEGKJT/Sv5qb0OG8e9\n", - "voRxFaPrroiqkALWSnA5n4zcQMwfY/xXX1aR5rslt9ItB406qJIsbsrkl8pXUe2CwOVm9B72bhd1\n", - "lqsCRNktqyPMF/Ek4JsxscPvDjbSqbQZL+uT8zjgAAAB5EGarEmoQWyZTAgj//61KoAZQB+OVG5p\n", - "SZHABUb2//v8PGtlbWZ+A0oGGFPTAdgmU2TFbsuJ6mwUCouNe8f1I2ythN04JSJ5lx+ik6KpnC91\n", - "1FD3eD5Jit+kJIg5holbnldcijL50GRMV+Tt0L65TPBxqSAUdrQu+eLUTHPpJCL4CV5RJau8pEIv\n", - "uK3a7QA/UMQ/nrDjeZ6jqf1BF3JjbyaeIc5drvnYbR6lQ0gBIzp/QRU9xrHm8FESnIe42aooWDJ9\n", - "bVMccs59QBQd45WisW0MXV7NFtyepgfK7biPJN57MDsWL2A4LYHAXH6f6In3GVsSrYQ2HUKGlxpv\n", - "Yf/Xvk0pBnHsuIEsslXTjxwTTzuRb2YT7QCJp6yHiUVL67n8RfvHMNoHfUzP4rVgPSXcPL8FOP2d\n", - "F8GxovHNOmsOSUyc+t9OZXQFF+4FJNSN23FsgARohBEJ3c1u0ax3ACLYlwfCd3/U1mT29ftZkWMR\n", - "uj01t9v2AGHvgKM29X2Vs/ALzLNDd2OM9z+AC4TlcpgcRujIhnjHf17Je/8RMBqJCZtdfrFmz6AW\n", - "Z/aNIv/p/WX6adpvStFWxoDAnf+Tai9COS20TO4GHDviQkpMo6tbNTk4tiYWsmvBNq5u/aO08r2y\n", - "Bs1eH2kAAAD6QZ7KRRUsO/8AZUj9pUTz7rNMoHjJ4gSsLw2wABNFEVCVBZ8at73oa3C8UmeDMVba\n", - "M3uHP8p2EFDXTkl9EiChbxZZgpuvefKfc50lYhoTJ/7H62X0Z9NX2I7S32WT1XJeJtD32zfVBu3K\n", - "VmE+30x6+W2pKnyMM0ZejDKLq8WyIyi+9rC0QVVyU0N739nDCyt6aqRfMfSdljqTnwOmgDB5pHyK\n", - "U8Nf/BZxnIET5uBVX/VcS4bjmT9sCYYwmAz5vBy8cv5J53FYPh0/wF7kP2myhm8SfTnmNtpTej0y\n", - "JjLbrdGSBUAu+lwbCsr/YdOCYrxvvrklZP4j4s5VlQAAAgYBnul0Q38Aiz6zZf6skuDOogA4jl3V\n", - "YKO0NncAuqtob34dJ/eVmQtCFk2jxP+6gBUwoAJ5d6wKEpypNd+AlIf83kNIAAC8trXyGAv3zzzV\n", - "tAa7kzCHOXS39Rxic+qZEHcHH0Hx0iIZnH1UNeoS6dQYQqolDkQpOXG8nP6tDCpAEYSQsJzo5kch\n", - "Xf9jICMUCBjMQXeVS1i3FdA07mrKCBowVzEdee9WvqvXV7KuMTufiL0hA8BHvtD6VFvEZ6eiqgvN\n", - "8RNM5cYXQ2i+4Lx4R2QlAIN1NNxqM8GvSjSh/rgipqY8DwHJh8p9Jbu0Zs+w86pgxJN8m/cvWxRZ\n", - "yFAtI7sBhDbJnNXx83ll0o93YVJhxi0TxWXPf6PlHZeEyvr6QOF2VVafQjsZUg34P/p6tj3lkAer\n", - "aZouLIrbfbTrpoGdtXuXR2qC418s780GZsUBVTlvppC7dgGYqQzB5daoV61BoiIg6tQyG20Yk/Ib\n", - "TtwSJmeU5Eiu/zRo0bpbU2jgV79WVCB/SVzxsmoD1jJEhzN1FHxsbajOijl9Vp76GofsezNr+37n\n", - "UWWhPPzCk1rCLQgaI34ekcMUWq/vBK2WDe7wKACe/5M5UglN5Ct9Orsd3SfYPc0336usW56marFA\n", - "xW2XgVLc1GludnoFyQrT+oASHSl68jJc1j3I4WTIeU/p+eW8RtUF4AAAAR4BnutqQ38Ai1egJmdK\n", - "YqnGBlYUAF9obzNVJ+s4Wyt0Rq0YuZmzKSClvCu/741bUzMW9+2RqBxHf8xROd9WCD2DFO6m3iiG\n", - "ZOgLMC6WQsGlrWDKBATBQkW8M70y/ztO1ZzNQj1ow5FREW75+T8qWeYnaEkP0sDPfhS/8A++EHpT\n", - "ONUZpoNHugOpCj8EFvE/MnQhkWbqDB+V4zYJeD+V1h9PGTTPeM5Ykyq4ZMi+8E5Gka9dd2CFXMaQ\n", - "M99mRo+FOH0+y87A4U4JusoMgrnGwBHn7tNdR1Jgk+wKYqmIwBj2jGPnQFJXhHhE3ZkpIjaeakM2\n", - "8MH5c8xC359KRjK1nfiZHGSkxS98YPps7lGGiAJ2WdM/l0XaVpItX1VPHy/wAAACGUGa8EmoQWyZ\n", - "TAgj//61KoAWNzc2A41R+LAApun++OIZUz7EikV/szjfxvYPLx+f9K2/F/he8DHawkBMdV2wRLxA\n", - "t50GIuRUSWE/39Xo4nAQqkjDTJdufKMgNIx0erMAcY2QA5ejjVo1tlzncJOxCqGpuGwA+5/4IKyu\n", - "bmTzdPecTw0ZdpVPq5j/sb/uUTmyS5oriK2QJUn4uMhurpWU0pM90BFHxmx/55iJQnC/E4AiRjGv\n", - "TSfvy9eol7L6q3/AmWDGKQmta5h6TQecJSS7keMMTmFMkcgh+dQEUTFbphGIZpTz6vxfkWPPyqpQ\n", - "VmS0gectGBeLssajkGiu1ivhXeMUvGnpqjpc6XSD8FJ8sVdfwdsse9JozsVq/t5YFq5+AnEYcopl\n", - "mlIiLVwif6/glDa/FvPVZyUrYuYY9L3TA7eEHe1IcHWSOPxpnafEFBrVGoeZPrbfymiVcHOQ/3CX\n", - "aGrpVwdWrmOHr8jLuajUxWOW37ajHobcyT1hYWMxRTx80fZmsfvsrNw/Nztdx7LidHGE8jPZ4gQZ\n", - "DABlByR/bof6mTmjqkfbsR1PCXy4RDNnn9nCnaSnb8pCApsF6YsDTv0+UmVzx2ZPSdm2LhZIqOim\n", - "mhiXHWt+ZE1dnYkLwTdsgNYEeAUTjY5XG25CAykSMfKGwGWeeOwqKmLAqTmb7mCXXxxpy4+bbELo\n", - "RAxOLFOR7z+Rlt4VIVMH4QAAASRBnw5FFSw7/wBiyP2mEJvZyVx6ACpM7CM8ZBKHKR5j7ndOem+L\n", - "X5lQTliSlHrc19blDxI+BarmPxVVRFr/CorqLGvI+vHNUfF9L5rOth1seL+LchCRD6bYXJMlctoQ\n", - "KBnrSfN8OsFA3rCX0rxhgXIKgdEDuCNRYd4XCiw0AyO8VPwgQ3UKQOwN4T9AdwOVZht3xWSjlGSY\n", - "LTfR+DOcni9vpFUI/V99yTFNeriW/Ezi0Mmb4Xp+UrrTAn+/oqePQryHATZ97i1I4TzdZJ6ol421\n", - "ZZiGDIa6I2z+mz36WJISXYfn5PcaqZon5evy7wkHdXdLSXQuyy6RoW3UMK1kv4eYGMx6MEUBV881\n", - "1DxJ4Az2tfQhJ60iq3lK6xGARpoGTWiGA3pBAAABAwGfLXRDfwCHPtdry+v+2nyY2Sk+gF5YW5HN\n", - "XoAL6QRR4alJgXnPRJGLu1H/XzBsCOVwj2OHZ7/Befz18ioG7PdTUWTo/DFmzXwFwKSHq5MESJ/K\n", - "+czoaBaMU0SilMUvvgF9NaNkzEcYOJjCpUUkl+lvc9iWY7aNcNT0YkO2YuPLl1ZJa6XpXyzgvJfC\n", - "YABMMMlHP4hWdgac8C4JyYJle4OEiXwhanMhhDIkpZpmZqqPP6iXGzuSTb+0ZDMJHqoDGqJmkb8S\n", - "IJuvyZGNE4panvJTPVd9f7g4/aXxMPm3Cn3wfT3mTthI056NzanOEWKjM1qGy4olpTOi0cV3zUKu\n", - "VGl1k7sAAAHXAZ8vakN/AInJcXImIY9AsY+/nZAB2XUf7nMR8KlDfCSlxubwbY5yyAvaK6FdhjtI\n", - "iTEMX/gD5nqi6yBjPV+WgerMVdQiwmsTWCh4ZDRMTEvRNiTK06p6H4BM93iWfwAaKh8Gz9Gaukwy\n", - "InHLEZ0yD1XqM2twrrM9K/zMIWUOeN0Z6Qpdges4mCaPjYBUMA0KTxEuHmES85gUYlt0s0Ks9Nu+\n", - "2hfyb2t0rmyvRs70WgBBgYrdeTZMCwmoCbRHPK4oxsSlCang/p1gu/DmbjnwYRln/v7ufz7R3gdP\n", - "Fr7XrHKEZc+f98DBxQMF82PBbmDGtLAQXHwptz6g5mqHfaJhvvgj78jkqTGrQ4WXMBaKzHGNvGYe\n", - "XIR0bHtcMMQd0uz0UHs+NS8bhlZ93PGBn0DI4S7X4qFOiND2PCIg5ogjbfFqU4Kuh5oLH4L3vi2E\n", - "bzWP7DaofhwjMqjCqAvZAgznNJDsvnJzQxJ6Pqjj2ny04t1drdQRUisSLN+PcLenLQZbe401Xg2H\n", - "yhW845ouHrITGSqb9EOEeoN97gj42PjsdYRMVLRDVvCV2BOAqdLbEmICPHZnyy75qPsejK7duPuc\n", - "fJ9rEnjynB/HxYz7zf/RM6xyYbzIoc3AAAACEkGbNEmoQWyZTAgj//61KoAbj1lLPyvb6PAZgAh9\n", - "7f/9/gX2SHKs8Uq31kdycpXc3bf6XPCYn1E4Nyshm7SbxYTXwR3t77AgzFtBuE6fBgZeY48yXmAW\n", - "rqOr3iMlgArjVOjemrjz47grY/T9rKmhvhaqPi8pvZTzkzZCl+tV6nzXVbBFw15yZW9xk2z611V7\n", - "GITjv5GH4Oi/06B5IbjEMVKEcRpvt893HwIyUBXniM9I90uh0TBxOedvsxxE2iLZsr/m/GNXryb+\n", - "9as6btju6GU5FfXHAHKy97PxI2Rac5Rx/FoPiuKEecRx7EQrDfRmlggPPP63oMY4jkBeTzC7Drwp\n", - "8ik2Z4rhoAMWlcRPfXCI56oe4Jt09oRInuaD3ww9/jGDjhHIXGbNYM/s5UG1XuYLCqaLxESIyPG/\n", - "eNnETthXX/QZDvDCFX3YINANkqDvHlUQ+vcUvksaWF/g1aVcMu45c8BoP1coWBAVWVE6iyDMwfYl\n", - "RYTcnNfp26mpOfqiSJnYH+AFj0qGJttgeZBuJCzdV4F5EDreo0WWAiq/0jdXljJ+ZxDij/UazQOM\n", - "0ct15Q7rTOqLKy+lpOVa/koSWj06e8eyy0wY1FBSVaROGYbDgXze1QzYiVyP6+WTk1fjz+Do+J+/\n", - "TxVlHJsfUOz0tbPJ3R4cSjRVigTxPg9VAYynpzzMlIr0/pCOGd4XYyl3SGTwAAABOUGfUkUVLDv/\n", - "AGU2ltMhgssRVFnYDYHdfwUIOpARUIP1pWfDHpU2pf97OTOpyP7SrW+j72yMHgCy10/KQJvVenOE\n", - "eMrSHUfyq6lVIsdEDgl0M+/NXx5VMpg+IZB+I7xozsY2f0ARjiAjA8ZSqG32YEqaGwpGp+vfKL3P\n", - "hav1CfnyaUmopPCa0Y5ww/PZN4YINPOwE+Gg36kaKP/ME/B0d8v00CzvLXmI8pIa3TqrGIa7PF4X\n", - "8miGO6oXkRH45ag0gFdgkGj+BD1PvtIptIkuqTa5jzG/NewDN9cCfws/hjc474K6NoCTyr++7Tth\n", - "LSIM60DcVje0csuhEMwOmCNob99l/AJp/9hMVsVsEaxUNsWBZFMKnZoLJU/ljkNlTtF1zcUwJoZD\n", - "oLTT6FmWVzlFnyfjiJdVIqMAAYsAAAIPAZ9xdEN/AI8+s1VkrBucudR5tN1L4cUDsugAOgW+6weD\n", - "VD4WeLhja/JOA5FtORnuW7CfHWfWrXcPJlwit0rQdaNL8wYmpMOBxVMKErdopYTnWfb0EZST9ZFP\n", - "kGeAI5wBNyE7pmk7U/hz6/Uncd5yONsvInzdtLdlFGIUuwPsZsiC4nxcPKJ4ER73zqMcPC62dMwB\n", - "YeP2JTSzcWxmsY8AuUeSUMff3wugzCWo2dZWIqj8MEevc9dnI6e4RX4rfqOmeKfJ7QFxuPllAOzz\n", - "FkyERujhdmr2mdRExctZgI01tg+iF/NwBCqP+hQ0BZaq12BgDPwBcWyuj8PXGo/75aroqbic3atK\n", - "78lcQoP6TccBH3q4TpJbdFKZCXZFrS7Hh71ZQxzuADlZ8DDRzGHyvFJs8+7LX0Z3SVEeli/7hzNR\n", - "3en2BovQV52x/rwTox00ojUHS89/I6QK5rr9xZ5z1Evdog7ewBETCofR8FQPxE+2X576ofb9SYpa\n", - "RU+FFWJ4WPQBj/u1ljXdmoINHOgs90YcpGG37DHSgRaxKh3h9samVWdsr/7ZPH7Krx9nfE8zJoXc\n", - "5Frf0sUOO22BhUTf6MatKarbA54SuNAmIi3ejRZKQJ4XCjhpsLBrmw33yy9Nk6OT0LCi0ELysL29\n", - "OvbOK/J+/iRz4bP6v+/3ppYXG9MzSEeggmS96wm6yOsevJy9wrAAAAHWAZ9zakN/AIdXwVSZADwX\n", - "ZeAC6HD/yFRsSkP+ZT/GPlFXimE8PIk5/ho1VfL2NNL2pqViOd6YYnwc7ksNMs5IkNYQ+fdC2XMm\n", - "GpZcBQdS+anJcAkZpOHFxqdIo1pLhI3h3bcsWXXBd+BTXZhbA2JSmhm8EWBGqSBNaO0U3Qcdcea5\n", - "428f3xthr08dSK0oFN+HNErgBuKfL3JZNShDHaW66u0MaG1B/cF2Go8z1F6LGKUAmsy0D/C2CM25\n", - "q38c827dgYTnZjZnTFxlPuxm+JuWvYpOeWyy3J/wjV/USVL+4BKz61/Ccy+EH/JkQUqRmUOtvYei\n", - "XxTdexyug9nI6kyTGc2H3hy0C3uFxKKFKo9PfiwDCQWhQ1+vZIsII4FYexn+pQbkz5kmdlWKB5Lx\n", - "ONpNVggWvIuTYEFI34NTLTOf285YYkebB68ywIJ5f1uX/OXMZ5RxH3gjNZ8mKLNX9suvs06qOt/Q\n", - "e2ZfZ7Orgt/l3O7GLxwWvzugIsO88I1KhpZhgYDdYZ//1lVBcwG/tKVYjF1obqjtyFctY9LPGIag\n", - "318ehZmIvkhW9djj90e+pnWknudbQDv3Os17s3l7qFADdqSGqYyGaSU47a6O12HCRSwmepV1bewA\n", - "AAIrQZt4SahBbJlMCCH//qpVAC8LE+AX+ndLRI9AAL65x3/f4eNbK2tvWi3seP5qm31GHdf4edmk\n", - "0/ZKv9BuxjUGH/qoYxXDUlaWZFHb65x0lomfbckqRBtklU+1LGTmYtvnPAbKnUSAh/jTBATZpFND\n", - "l6V6ofQ5PTBcFjOWwgI6YqalXUkmqnN6g77O4xvodhM7XQWhsA44ADmvatn61wvReF9d9MqoCN9N\n", - "Twpkx2kbbrSoHJrSyqidCsv+e2gnLoWDEdLGn/42++dseweQBj40iKRQ7paDrpDRwTZVjGQJ+52c\n", - "gaUSUp5A/cAn4FgESmp/sZ0NpfD9/7ZAmCbSUfPUar6ndxZ3XG2DXWcNFu473rzFQZNpJnXg/Pfh\n", - "QCQDuu/iX2Vi2NjGs1QVI3BReUxvD8Z/YeLy6w0jDh9dcJGJdKoNjb9Epdy5r0lFeFb9L8AWhdEd\n", - "sGreMPdTiMRlq+JOqjdogseyQTcuDo5iesxIsb0dhY+P9VqSJtTxyPO42dn6TXPZDgt1vROlp+Ic\n", - "VTutbib7FY5U+jSckVQsLzLRwDuIoa+HpEcHjzuwHMaHrKVljgiPeRI3Afdpqx3nHgy0MFCOhGEr\n", - "Jkw+Dadh5qrWjCGOX2K5HPLV0E5qw7krTDhpWX8sTsYsIqvxr/V2EjIFiKwnheBvunmhlbHNUKTl\n", - "ykWRC9Afa8QE+vO8sLJHYNqVh5kOrsn0+NP1Mm4JPbYiahSDJa4o8TJzkXFBAAABAkGflkUVLDv/\n", - "AGBJAvfAgTZO/kHo4lc9yaSVZkgaxkXEQAgySaAqoJy8U1XmJXFaLzsHv4KqZnckX0gP1AYFUr5X\n", - "3Zof5zltHp7OQG87KhkyMuJLOz4diYjf3ctsH2KA3/S29L1hP4qjZ9kfgNEsjrH/nSlX3ikiiFcQ\n", - "/2mu5vwlzQMTIUj5/0pAslvbULpI2rwxcgfjtpeW3qe/Q0sCZXyJ3L7VhEaeyKZo/ALUAi114xdn\n", - "Gao6fyKpZhWohGCsI53i8XO3Y7Dq+aD4ONx4A265BL770fTZiNNw+oM7dwTK1vcPMdOTVjz4fi6j\n", - "bCMBPzMCGM7CsAz7OQTIKiUTlOi8YAAAAakBn7V0Q38AeTG7snd+wR+ioRwfka+slSBm7w4HiigA\n", - "mYoe7RzT8waKJhe/5/xyHdk2lI4Qb6yur2vWdYx/k/gVzZWx+dAAALHLM2W5kE06MD+/WY8W9vMg\n", - "jgsWx+NCob+sUo3r0m3kC7Z6vE5pa/kp8NVK1XizBU/gSaY6/S/NP+nzZeAUHhvnb6LPnQnTmhI7\n", - "+CLAa1UiK6P+lwPbKP0S0Q5RWiopmhls/AKTmwxXB+WRWyrrFglLMCCi/H7yBlZCPn3f1nUi1WXW\n", - "txmtCNftDVTPLfu3fbw+YSszpG0LQoe/d+Hn14JtNEXcVveVKgdRtrJ2SZSzkDZoD5uTokEopKbG\n", - "geSmsxJSe6mDenK/tstnSjFiozTKWgyJb1mTK9iBWStV+uPeceDypkgatRgkwgz17Zgn457UL8xo\n", - "RIb3Rzvhn1PaM6KKHv4wQMqvpqRXKRm+SScKgBhgUzc706tHx+sk3QXrFbfmTj3VwEqpASdMV8SQ\n", - "Rc7Pl7VdiwexHM38nPcgZguGyvH4NF1CZay1mT9d+wee9MfU3VHZJgMp057sUGFJIJZNmQAAASYB\n", - "n7dqQ38Ah1fDGltbSoFNBABy4LNfpqaOuQiA03rsvInHR01iNZMDGQE2sq9jRvjWYcCsjv8TgHDx\n", - "TelM9UgK8aIkbW5xZBO7YH31DMzHB/HcoCKmBUni45/7i/CIo8gF1pGPr0DAA7wV6D09MIgWLTIz\n", - "u2RlgzWHXLOhQSqpesq6gEgghz4eO+szzJWiaji2cgnbFYV7gS1iXMpBIisJc8i3U9gywhFgtGxt\n", - "IPW/7TiYEwGOLwxyjZX1HkROuSI8lAAdZBpungwbYVpPKSngzu3PnOIcBqes7c29MHD8jRPn7Zrt\n", - "720E/jZ4jB2yT62h5AEs+TCYeJmiY6lwGwXm58hIVqeMFafCwAYhd3vDCtfE6mymrvYwtLYQ0YeE\n", - "Ebj2MbA5+zEAAAFwQZu6SahBbJlMFEwR//61KoAWx89GABUe1i4OfaowcQHQyqHCv9PnwkHOB5jh\n", - "ZaY1nqaJvfgMHLxnx0HRU319XsFiIgZ3fycxZ7MoTbod+V6rFy2y2Qtld8RvCt0Ug4PVQuLFLU9x\n", - "N6gbeWntqj92UVkXYHO8rtnoyHbc5vkyDRwK85+1rEknOmV2fCPAJQWJQHZKzqn/akJ6R91HlWya\n", - "u/8GgP8q7KTtX0XyZMALsB3jT/UhmW5AlGIwNHeW1rtDiMG/Xy+69i+m2kTOjww4y5o0/8WfwLLR\n", - "RKlhEE1LYjJQjoy3+hNy7YguxzdtR0GOg0UsPQLFZIBnnCwGmFharg9MSkzKoZck80tBnNzVcu5F\n", - "Ot8W+bdDLv2E/9UTXci1RXlM26z5jearPa/9d/CciU6kElsImbzJ5J2YpzVs+pvW89XbvAJMExZq\n", - "wXD26iUkefzti1p2cc2CbM5qN5CGCTCmR13du1Y9J/JQwXkxhEAAAAFiAZ/ZakN/AHwUpp6Dymc0\n", - "2L536BR5shJlFypABdlGcrzfdaw/6f5GB/atQKmEnLjISTsAvG6zfbdBMs7bm2yeFrIQxXuK81kC\n", - "9pAAAXcBlvswH72knWeKBsU0Ht1g5h3YcKtQv4e82ah693wXobc+mdHgPA3TBKIFWUv/iM+/E90G\n", - "S/NmTeZC+lgt/zT/+HMt/QSFK9C1+AMdH9l6Wmy5eJzA8pumBNuqAArwclv8LW1AC9Ryj7J7dIqZ\n", - "2nhKIYQ08cavMFAGExrDHt7RiTs4Auer+jpijDT1MWhCFcQjNZn9nbOp1MdYUZ3batlHR94YKH39\n", - "SB9iaEe1H+vDrSDRsP3b0PfVLevCUtQQ7tTMju5YxLigI0SkXHby6oMGwH35DOmYdZ/QEHihEbbH\n", - "ljlaWypqm6TR7b/zNBCPoaZiHS0IlbTr/gzMbXxGasP7GssB89XtUV2jZihKJYcij8456L2VAAAC\n", - "WkGb3knhClJlMCCH//6qVQAvW48vGhnpxPcAFRvWsRQfCH0ZQNKlkI/Fmy/VFBZqjdqwlFWyRDRU\n", - "ATa/x8nSCThm/LYIboN0iejGj3Uchm8nyLv3P3+HOOnCw7+XGsyycSpaT/SKI8hu4RwjrdDxqaYn\n", - "k6pZ6qjZtX+IZ04XS8X44piBkZKHHklQnddyez3eJG0JjT0fN5b/c72jAD+sOeXlR6iPKkSUzu0o\n", - "3ha2oHN6UEDmISbP1cbB3piI/SHrisHlFNjIuHiEdkqSzG95tlcEE5RmJMFHyIZtmV+VUnHUg//H\n", - "WOVjyT0+oFlaS4c8th8dtoQJgchjo9u+OPpSDxEJgWI6zeeh28ogNTGzlwRqjfRSsrTItvjA1MD/\n", - "oBFhKLk5Gm5LLSkMpDHu9T5I2IaoH3PKDFRJp5FswrHAqK+C6EMiKJRw3UfQ++e71IzTL0xpDNJL\n", - "z6AeitOHT7WHH1q0lcaxtRKIXyzlri2FOeAU+zEh7DbcM3wvbzCPYrbD4ePmP1flYALif0DM+F20\n", - "woqO1ciEp6KvfcdLwkVhOi6HukmunTXGsruYaqjkaLT2QlUIMJVPTAaXGvEAsJSG/0vfsDXKkk6Z\n", - "sB3ElNrSO3yHej1aIEgW5xnCNisEQsWn6TKnOYGilPN4ZN8EB64V0F8PWNB9Aq0baX+T8kKesmFw\n", - "2y/668NRP8ypn4s+0TEew3V5nLH+An+XxWolypflMoVnWhEhG2W+IIgxfWfPuSgDmqBKtSemnfnO\n", - "mj2z1HJ4yEmqNoBjJwYnWfK8e0PHHb381Mk1zGGJOgWAAAABUEGf/EU0TDv/AFlVerlP4Rak+BQA\n", - "rfH1MAekqKZtO9rI3YpPu0XbIusXd4D2mikBBjNWCs5ZCx1/nIkAW78LpHSyCScRX686DgqeELvg\n", - "+6gjEvz9oPv/Q5SyPMBeMNrb/QJ3ato+Qw19nLJWjl0bduh+HilMsrklIYKHCWBaC/dNC4s7Xl/r\n", - "RCzM7ZJuRKmUY/D5sEAdr/H6TIVmiD0u2jiehC8y8Gw6flB5fdlWyz5ArpMes88RS9cHH1n4Dp5A\n", - "9YiKoxa6XsjMVtwy/Q1CE1CcjEE8nX1x2wi3FF+AiuFwqQsSRlHtfUsVksDBdXLvE8zjbyOIuIMV\n", - "pnJU22cEHHqRAVAAAQz/a8I3JUwtCYefKDlHQuITIdlhxtkj1S9/MOKY0At1R1tnioLMWN7HUVCo\n", - "b6XS9uoGwS6oOJgKcTFbR1vNa4wchWq0XCPds0DBwQAAAPYBnht0Q38AeTSjvudgsbkOLNHOwJSE\n", - "7MIAOT4Tae/DlzyAOhFcKHSt+XmND2K3krM1WAe1ksxoXOx8R5ib25iI4yoXHAvjcPvcDoLvQIYy\n", - "rfzkEj8FCsgVqTty2M7mcrrsvBMmGI/tSEAq1Wpq/wSUg2I4oZj0GjiChzewD+uw3YnWAi/Ntf5Y\n", - "Cv2dU9qEo9e3jPCavhxnj6HVQyqcvxekJ6cEcAGQvRh8PwiQyys4LYMz+Th6jmnZO6zDQlY1h459\n", - "aXiX/1NPDVjhvbOibPxdXy1nW8ZFN/ZpmMtUtTAz4mvuGfLCJYTZv8r0n1cztBPRieehovEAAAGy\n", - "AZ4dakN/AHwTrqiSAEDVZr7cfUIfCi6SEtf6z4BBmn/qEvCbGFYoG0hJzipIIEfgPxGLOPb5hgYo\n", - "3EqlxYfhyi3ADlPB0rSvUe/2K1c1bOHHkBdbN7v2fRCe6cTgBUViIyBzKbW8+YVzs1NjLsftvDLF\n", - "Jws+AVbFUOsz2XZO6+tJqS4okplORVfI8Zh8pjE7ly6+HI7Omo301kEp6VZks8VHiVKJOuTRsuFe\n", - "1lak9cDIgZS7IV3MkEjdmu8V6wPVTOui5KhgRegdKpe7dvKwiZROacSHUyEpgoiQ49NAkgd9ICSC\n", - "nOG96XtcVUK5qLGXI1ECEXtJcuaFVMtCmmOBBiFL8jC1MpHbxQ+4k2qRSUjP3JvFi0NfrsxeXbrH\n", - "Ebg5vBmNpJE6T+wdC73c70xC+Mtp+wYFzu5kfTKcL8d+Nzu4GlIr338e6SWwNSpXRGjfdLp9o3Ic\n", - "2PzMtQmrlpbEeUDp1vnkaZoqSF5M9xanIk/zohgoPX5++NN/ebYvr56WROjUeIUdsOf6nrJlmboT\n", - "DZEat6r4aY15lVCgiz4Mpb/mqSazxzrszmdRYRxGsW8DnzAAAAHfQZoCSahBaJlMCHf//qmWALFy\n", - "5oM61QiAB+cxK4+jNCOHXw6RALujtnWF0llKsvjvaSIz+44BdTBn8Dqmduydu0Ab2yYLL8rBa9BR\n", - "bM/WBrO6FCt4pfpaT57HiAbORTevnWHgnUCdwsiqbddvhjkiuJYbgCMD0kEP1SURu/b2Z5hWsq5s\n", - "eIdJwlVUmffx/GFsHH2OVg2kldaudIzyWEsMXsnZccvZ4+1TTMECSDKdUtlhUW9AAgPUraaePKP1\n", - "hatMAsKbsEP5g1nzjTlmyHjs7FjRbwjKng4/qsqVQ+s9Z8Le9mq44VPerxrlkKxdRgf8PQXTEpxP\n", - "gMR8UP9I/vRSJBbzTafYsMhPytfC8ESUe9ySga0pNZKSvC+bN1h7zO9OEjqF3rsnXJU2SZN7NAbS\n", - "01WCPkWQIdWN39TZ8BwhuM2E1/XfXA9OxCI/7PAG40Z8M1rKVJPTY+iwZnIQA6cEF3rnJVasn/JZ\n", - "rircnzzi1JQr5NiwthCEkD02k7GAoyHtF8lIKArvw+GqH7Ox1Tpd6DhPPJm2hmyijeFH6E+9UCJk\n", - "Iiolc9K3UW1rmUlHlF/p9jHAvsiiJUpuG/KCfna2LEYj9yn6P2oNlWfqq5P2HNtctaJeVRZv9Qb/\n", - "mNVjyjAAAAErQZ4gRREsO/8AZUEtk8LzOoS4AAhIFC88oI10PfUAs3UxxCOOtSzHREgn4/jgVfHt\n", - "0r483Tf2Y8D+zGlycQw2lUV6Nidlo0k0sASUCm4dEwF8Hb0+IzseFE0dYexJdLqvhcI7IIUIH6RG\n", - "uv8cjTXFD8CTksvYGpGc+uBYXhlwc3/jHhNGtm8G24uHniey+Zy/NtEpSl5dub3bE324kx+/N1gF\n", - "sU/CxkQF6UQWvd6Br4nL+i2L6udCLqM/JAVJhScc01UR/bE+NX2i3upx0qofgxfWL8unNZ/BP9Vc\n", - "CvVXAtxPw+0JopAnWMlwtBFG9wd+oP4zOIJ88u/VEvyZQd0JJP1Y3qhYk13Deyiv0C1r6ci1z7CQ\n", - "UwYqgUT64pT/hlIvHeCzEZxqH+WbUbEAAAGYAZ5fdEN/AIteE+hbrZmAAHNd3/IVGxTYP4E6C+Wr\n", - "63le3xAHjzqOqEil1tIAAUY3LvF62/277H30QskV8sEjceHvPe7bE0mfZ44avBY2gS0AAAMByRDk\n", - "EKOyh31Y2H0mdsy+zcGsPrGm3pHtO2riBcgILxHO0F5398HG90hK8UgtDUfp9CQyPOvDSyEU4WTb\n", - "6/WT9Z3aca6tb4C53W6p8Geyjq/mwbvNpnCVbbqIcx1ZT2+dencovmeYmPlI7jrhk6KwLYEd+5gO\n", - "J2YeKk4iWai6BsaO9+Tb5P52jBVHcSZ+Vws5QhTxkBSpdHlWJRcbh50V4ViVltwUN//XNx+jx2bk\n", - "KsfglI41FGmS2xAJtr8ZhKDk1VRRL2tGsNB5nztuRXCFd8q4MIuVVWGjim0ntcxZ/R18mzJZN+sI\n", - "qKUvfsxoaeZp+oIaU1hLeXzgcHEe+3/6emdZeJWoDNhUqhkfWzWzVZbEzUKpDBS9AbVIA5KR27LD\n", - "3HEfRMw9yt8eYILg7m/Rm2ubtU8u6V2QuxVXq1OHry5oY2TAAAABvQGeQWpDfwCPV5unds/RGF4o\n", - "aWlq+XwTSVpG+igacFOApaqyNJIXSXT4q7gA4DkP0YAYAumNCN0MwD7HSEeIsv3Q3L9kZ2RagxvU\n", - "jle4yQq6Zl5W7AgdlZnaBngH/w8xYsqWx5t90zzi7s9VyRY9jaNshfxuJAZcRgFILNTmQNCPoCtl\n", - "wyo5Ht91VCy2qSby6JDLeTD096PzM4KOK7/I+amuefuT0S/QnDNs952oi11JV2mbadqtKDqJE9x4\n", - "nX/OjU9PBP1uhsFLNkjsz6ZHlTOcsZvWUxabbw0HBNFuLXWIYqtAYdWN7c/QUoqY2IlVBR//v+NN\n", - "Bxf/rxPv+9QlTTeUOAVhzyU/kQACorW+VEL2KFNUPF85LUxlbSGEYQv/98/fAQAu6hKRw3yoJoPy\n", - "tyr7S7Za9gGurMYseuvuasNoB+fPCmp37VWgm4yNZQ0LM+8CPtaQgShVMs2/RIG2cXksHuYVqEB7\n", - "PJtzP2tl8EYDen8RohIb2UO5d/Xdc8aoi/Nu4IzGq8ApuZIxjC5J9bUYtMDEDA6eChGKPjb20vqg\n", - "2PRBI2fSXJrcSROGTC4m+VsF+VagO1LnjrakndEAAAHtQZpDSahBbJlMCG///qeEAVH55ayIAL6z\n", - "9D9Go2JR/VsPgULYIy+HM1JNQWUio64eqKV59gHDbxQ77xKGvVi/RlMeepNHF+Cplpp4rKqgivaK\n", - "14o0jVVjKwdzXmYfm8QJck76NrSj9rXzMi3Th9DbQ5HQHvlFr1+Ft6fGVXaubVoF+Bx3J4nvsWO+\n", - "FhXDphKaWh9geM/3PqX1TK4zqhRL2wKgDCWdLvIi2s2e48RSWR1zksj0SjkMINJfgjA7wVj0dW8Z\n", - "NZGlcRPjgkoSgpomI+x9/l7dJ5fHEj4WOkMQMTJnj+KOqaXfgtXbhBachZ0Av1Z6rh+qw/iObJOy\n", - "7q2gUdlftEWI7In7KZjqqg18Bg+z35wI2FmknOyXdEiDAPaFiRrhqkKOLfgLssw1BdohiuTGWlKn\n", - "NvPL4EzIbAUeS+0qv5cFdXvRjnn1zOMYTMpyN1CZYg4pqjj8mGtGdm1F7w0Xo4Mnm3hRmvZyyOaW\n", - "yf38s1SCwyOkhQcwJhrAAebvkxMWrAUWrTq9K9PdCUqFbMVB9+93aovoux8zBfM/WLangtLLXd/D\n", - "T9TcgY0eosWGZeAhQk2sxNC3bgvMT328AT2T2XCg2nG4jsOakPWfscwbc0zKfItj/1eXvyR2tk+K\n", - "fpgdg9dJ/OdcXINTUAAAB95tb292AAAAbG12aGQAAAAAAAAAAAAAAAAAAAPoAAAnEAABAAABAAAA\n", - "AAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAA\n", - "AAAAAAAAAAAAAAAAAAACAAAHCHRyYWsAAABcdGtoZAAAAAMAAAAAAAAAAAAAAAEAAAAAAAAnEAAA\n", - "AAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAEAAAAABsAAAASAA\n", - "AAAAACRlZHRzAAAAHGVsc3QAAAAAAAAAAQAAJxAAAAgAAAEAAAAABoBtZGlhAAAAIG1kaGQAAAAA\n", - "AAAAAAAAAAAAACgAAAGQAFXEAAAAAAAtaGRscgAAAAAAAAAAdmlkZQAAAAAAAAAAAAAAAFZpZGVv\n", - "SGFuZGxlcgAAAAYrbWluZgAAABR2bWhkAAAAAQAAAAAAAAAAAAAAJGRpbmYAAAAcZHJlZgAAAAAA\n", - "AAABAAAADHVybCAAAAABAAAF63N0YmwAAACzc3RzZAAAAAAAAAABAAAAo2F2YzEAAAAAAAAAAQAA\n", - "AAAAAAAAAAAAAAAAAAABsAEgAEgAAABIAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\n", - "AAAAAAAAAAAY//8AAAAxYXZjQwFkABX/4QAYZ2QAFazZQbCWhAAAAwAEAAADAFA8WLZYAQAGaOvj\n", - "yyLAAAAAHHV1aWRraEDyXyRPxbo5pRvPAyPzAAAAAAAAABhzdHRzAAAAAAAAAAEAAABkAAAEAAAA\n", - "ABRzdHNzAAAAAAAAAAEAAAABAAADMGN0dHMAAAAAAAAAZAAAAAEAAAgAAAAAAQAAFAAAAAABAAAI\n", - "AAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQA\n", - "AAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAA\n", - "AAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAA\n", - "AAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAMAAAAAAEAAAQAAAAA\n", + "cG1pbj0wIHFwbWF4PTY5IHFwc3RlcD00IGlwX3JhdGlvPTEuNDAgYXE9MToxLjAwAIAAAAQZZYiE\n", + "ABH//veIHzLLafk613IR560urR9Q7kZxXqS9/iAAAAMAFpyZQ/thx05aw0AAQoAAjZrf0Z7SQAFS\n", + "RBmrGveunhOj4JFso/zYXaRjQ18w/5BhxFIRpIkBeRXl9T8OOtGMbM52JtIMXIY7KRr49/IsKi0w\n", + "jJUK8Z7XIFmlAjIU+jSbWER5LmeK+6/diSLijDB3co/ebDgChTdnt/smJJAlFMJhzTUcdwoA8NQo\n", + "YBnpXwCtHd9MDNyz4x4zrqfgfXAXtVDOuKqK+ZIROmkudESU5HAc84NxG9mIFkHTHpfRFX0vfuvN\n", + "v30XneTe8IilYhOJYkyOcVBz9L5D3N5P2RHbPf8d2Ia4qkwGurGLJl8PxjFsKE4dm+f6WYtxh4/M\n", + "EbibuuIVHuFVTrhDBdjGsnlvGJ613cHSu4frv4bqhIfOz9nOKI/zhLw9zlvfAkAek0G+jTz8be7+\n", + "o/ndntGdno6L1LXJpdgGJYFOyZwDpk3suJqu9FKdCFsjDfQ4s5OYpZkBRm/h6ksvqs/jKOI7H7Eu\n", + "JEDtMn0Px1875SS+KLSHaHwtTCNzTTTEE83rjSnRcLH2qekoCAzC/F7u+tWoo8/5q7AU8ZwbFyde\n", + "C0AcLGLOTLX2dctD5sMzDYlYtX/lYiEND4SUALBVfbetB5IH67pM/22hp7cM4zkyUfekvXZeKUpq\n", + "ihxpjZ/b0GfRGel+eaIkRAMer8l0HHBl4xOpdwEUiGEQqacmsmAKA7/Wn0I4FZAkAeHbrP6JQw8G\n", + "T6oLn8jHc2YBwe6YY+t5SuugRFwnijdFTQ2IYMHZ9spzZjJhn/lftFm13UY9ay8CDty2j8dXZfss\n", + "pdN3RSB6EMFrirN6yUkoxa8UPGBKHs9MUFO5MnKDgADHT4JhBGInxUASlDV0lsFB0GH9ED4tkRc6\n", + "7SnaMmZwf9T2i4a1NSsheM+jHEQWr9fgPDBABuIyToLYLrnVeLXqSC8JMeZigh4GOpQKyiIsG8oa\n", + "f6kiBTwG/5RebTqU6O7rrQLj5Wd5YFdqaacUZGByo8AxJ60NHIoQcxeNjsWAj6m8SKd2+g3en70+\n", + "zVQW9HkvHI7nnRF3FhwhZYu/LvproEPyWSYykJIx75ojR14WE7oWSjYs0X2AFiwEouayVGii6owJ\n", + "gdlCmnN8HoqT5PPnaOWG7mPgq/3meUuz982ZX4+4VMage3Fe0K3cqRdKLTge+gs4pyQbSUIdrgo3\n", + "4P4R1ejF0wAW1R8YjLZz6fQUzzzchgNN0t7aa8tlO2yDCmII5BbaYJXJrRvBm8Lb1m7TLILNalgu\n", + "RMjYD4Pf/P4iQqWsBEdgB3p334RMzrBfcviq+49N2SRQlYxV0SbSMdybZaH+vxuw+VyvLt3ulEcF\n", + "rmBwnxL4kpGATPv8mogAAAMAUMEAAAI7QZokbEEf/rUqgAYz+kaAoYS6oZnCZBWChU49QzRvBVh/\n", + "3Pl1tY/3h6ui3wW2qKCfpdwQ1h/uuKhRazpong7+Xsbw5g3mv3E7I0N68sUiey8Dbt0hMUrR6zYj\n", + "YtzMQ7gEdgcbbOEgu3H73w44JvEzvgZ4iO4Q2Kwp7BHY2uxxtdUENoG1kHXqnnQawFSCHZ9W6pRZ\n", + "ZX580jW/ekv7tzX5SLrr2mknIiIEL/9OqO/hdKRWyIS92L0VbeMgboQPIpdXZEemH8ScfWR641oo\n", + "Kb2ZqixayrynX4qeQdDAXvtKdnTPfgTsOJHs6zrnaaKb6SpoCg9ffzFUfiQ1YwLPZpLhwkJ1F58m\n", + "QtliSU1LCArOxcL0CdX1xv0PO1XbIga8mvD2ON78HrYIlpd7r9MIJUgGiGlRxLTUITjvxtxjLYBG\n", + "TBzSQ2Mqy08Y4xvBh9/AZrWGoBvplKVOooBAXsS/J3OngcAaMApnGniTlEgacIB/4ihqQm9Zync1\n", + "WrLEldONGr9K6gbteZcFnK/hoe6B53agN6YwjF+Hm1IYltzK42eiNQbmeo0nT6xx724Sek57Pcpp\n", + "/+64lZEYNhMLw61j8cLCmWJLqJ9+OlV3Tu4kvqWM5A7mBmXunK5EElFvFoiaHvfKnFzVKUZHVN47\n", + "dwwOu2bQK/GEFcs57H1A4Ddl2JAlJt4ZWgrJx+vzAgyhhcl1LtQgQcd3rX3aPisDf1CYETnay05i\n", + "xe8yUL0AVMzI07+lqERP6auGU//nlrslfAAAAS1BnkJ4h38AGAsZbANezx+IWo4Ni9MoMfKTC08P\n", + "cqaDTyueeuPLGgFgW9U33kZ+Bw1xhP+VnfaIAfTxYvkb1WNMMRMsh5PjwSMCmaFIlQvFeKZwdgkf\n", + "0eHuoCcg/XQXRqCvEyyYU7Kr945fY16Tu/18Zd8NU8RAJRLFspmBVoIPZ/aTPmSIXDq8KOSzL6TG\n", + "sWN+V8RKxGwExIfHZpdEvHu1tOeg+pVzKTracfnYiBxxlkuVIyzOz2mFv1LQ72jZQGocAdWS14tD\n", + "EtCsmNljiTGQDRggnoajq8kpnFHws9ZMWmcsV4dQvczexFmx4YibNvvMPauj3CH/KK6FXvQFumid\n", + "ftiga3Uno6si2epmOuEVTuVQwXsgCmOyejpjAiAjZuUS1zq40WginD1EPNgRAAAAXQGeYXRDfwAh\n", + "r6zZu6OyBrfB5mVsAz3QNRRqvrwAcnFznD7NXanOaWlAADNOwlJX/xGmO79sH9XeNRT/FnLuEPBH\n", + "1GJhJV/Xt2R0YziQPpgXV9BLMr5IaMaU9R2CpgAAAPgBnmNqQ38AHhCAmS1kGlkSnBkADoOXdXaF\n", + "NGZr+Q4fCvQ7bHDsrrZk+gghfDnB3EgAw+hgyCz7QjPCBdm4Oua2VioU2d4nUZ+UABLNnRNNghIa\n", + "znH4EU6++iAxhcURNicOGGgil2sQO5YirsL6J7S/TznXYcILcn91E9qrSkdqAKeiqMttbt/NlBlt\n", + "zFtTLIQV87eeTgQtRSaGjNkYcjtT9zsSroMxdQkaS/rgzWfPKqioru5///iiFvV7FHhGNapsB8Ep\n", + "xA6YqLEIyfxd3iBKiJ3g/96H/WMQrMVl8ykLYh6g9L/mEknpMxDRuX+/d5vuR5TJpN2l4QAAAY9B\n", + "mmdJqEFomUwII//+tSqABipnkgGrJGhoF2xhqIGFJgrTiV28TOHP6iMSZwA4LzauSvgcy42/qpKz\n", + "PF+GKWIn2EJeWsQWOqhnFWAeu8Qy08RHEYzw2BIfhXKPnsvQ1D45gRUsCZjYq85tliORVeVqHlvt\n", + "fzWrMqI5f+favhs74Q/1bo2ebSMVUSFuP3HPqFVDjXrf/wjJSgWTFPNzCZtjDghfnhYgAzPVh4sd\n", + "mfpnfQi7UGcAu+X0SPRW+sCzjBKyZsabYXRLvCvcRgXcWHRJnqJZ7DbIL5Ahmra4MUmiAdrDqxi1\n", + "yixz8Ge2MnwDKePhHbASj9FgVyabApZmODkYAk9x2eNsu3NC/GWuEsOYUEJXb3NkJ3H0Ehpogb5q\n", + "/7IADF2Rk2r94PZTFE6TdqRa+DeKrhf1PoBJxN2bNx2sA7Pci476Sn+ZpPsAPTlXaikJNRAhO4tD\n", + "lakPd29Edmfvk34bCqY6rFMuCfUJ3yzCy+VRKB59CtgS68dVzaJO/FxZ2Of18yjXsScM2fL16/kA\n", + "AADDQZ6FRREsN/8AHa60qBaQmR4IRAA6Dl3Sc6VtGJbtr5vbN23f25BY5Mbt9ZodJaqeGLgSZDt5\n", + "tMt3+exLq/o1or+DyDOaUjfDuI6HO9EMKVIFrK5bBNySwYGQ9ZOLXviohcSZAskgQCT8YbljWqgY\n", + "W5O+m+Ip3OoA9JMxAp4EiGRPR1hmuQDeRomyGX7bvvzp+lmhQcgx50Gtf2FsWph71RE5OIfz3vbU\n", + "YPJzvstNoHMLjQVN28uexbTk/wUswGjCQ8u5AAABFwGepmpDfwAhvaAbJNR/9ddNI1ZNZPr5vm6q\n", + "XTetXH7Eo8GqFltKJbOb+WxFxg1OZ9LY7Pm4G1n+FvJzAc9iMK3kbM6geeeFIdRl75A0UZYsXIff\n", + "dQXiQxB/kP/GUeJS/ghHdsFXhovY2ei0jBYXhl7XCQdiM+OxqVpdBNYdLY+vhvtTydDweWAQhmfY\n", + "3fYN3w2o0+YtvleCAQNIu+tN7OfSeOifT7EOLQk4YDYkvT1QcI6scYDf1en6ihiP1DSq11Clzx8a\n", + "ja6cddGuoMqDaNkxCF1dzf2Jvz1VA4BpWPjukcCUvSBL5Hjn5IenmZHNevhC9Ri5TKMMAK1OUZos\n", + "eUJttkHLI36Z4EqqgVQeXc7fMR78LG9GpQAAATJBmqlJqEFsmUwUTBH//rUqgAcd7WUAG1wL+eMP\n", + "5NbNjI1PanDtCkQqkSzemsYEjSdqyjDQBhMRhcVkBjrLnQ37QRY6anUo9HtaOXKEvV3Oq3t3zJnU\n", + "VnRnO4+DsYDha+hVjf2RQfz8iIHBAMZBzDCidKRjdK++FyTTJT//wjjoyDzrLD81EvvOEfP1hNq1\n", + "E7Mf/LNi4VzZp3xaz5k3oYD4Uh8itElOoUglEcP1/ghF2UcJA9hOtkSUpVhA8+T8Ytc1zpVMfYyg\n", + "QqbyRa4EvI2+PCgNWtypZmPOW/fUb8LPNYTg5GLhzbOmSjYpenEUzkib0QksNLKbj/E9aHrV1qHX\n", + "qXiny+3UUPxYGvj/pDuYRozh1EchMNkv/eHEkrQhTQjnyxDirLtyAwkvICbz8w9UK2AAAAC1AZ7I\n", + "akN/ACK9oCBuM4cceanCEEWpV8cuy27lpLcHp0RFJ/onjSEljOG8VqS2Rkf30kIRre+KMlNGVcvp\n", + "cL4orO6Yp5KjC/RRBwQz/yE8UKLNeO0Y0FFhQfICXcBtO9ndieTXXlspFHuGf4S6CeBKlAO/lDFn\n", + "Bm6rf4RqP1vvLrD8KUBlig+AFH77l/U3BNsHxmcjURJ4rz9SBUp3dWhkBmKNCP57UtC9bKnqFyE+\n", + "YvACZ+sMCAAAAZlBms1J4QpSZTAgj//+tSqAClE1egBKEwbZY3t792fWy96pbeQQCnoXHta8keYB\n", + "6YD4iyrisk5RAGXAP8hftXkqsIp3gIADtqeyulunIxMvA+tHyMYI4mH7Ktx24JQCDLGwr+SW5Lfl\n", + "LFzLN5Z5EpfMBtjuN1e5MGJfkKE7RLofReD1fgshPg5Hiu3eNzKNtXPqCUQOQrANHyjLVDHW1On8\n", + "GbpMg//3+EW5h//MyUrV8C3bm65GCPAdr+IiAQS5PLqRpJaqPFXYImLzCfEF4IcxGqfKzcnaOGUe\n", + "P5zhUa+at6SYruNLfSBlr3+mvyhAAxPUBpQBX3a2ZIbz3QLaxiA/KmUnrCDmuWAQmEAoRWFYDkhB\n", + "vSu304LzlIj5BSPPqNvyTdiIsLpzAu+SwxleN8rOU8p84R24aRhgQwchoF64pWQkYvhDlixS1XkC\n", + "+1BFsz/ugThqWNrj6DMWcUAmd8tN3JWA8raGQmJpBH1Zjd5483GFE2+DssYAdvIzFktdYvwqJy33\n", + "xqAAiKb/jZmChnRmwaKmyp+usNPBAAAA+UGe60U0TDv/ABgTM0cFpiU9S5COo+Eq1a5EDpKRq+6p\n", + "lSs4dhBzMdhHGYju3Syu9sir+n5TA4S4EozXRjp4djOH9s6Ebl4mnuRqUkAVVyRRxloLXXdAVwvm\n", + "Kw2kt3nH3KtGiXPZtoKRlLMwsYrakek54VGjJMSSK7z2j4bZfzdU5fWILhtGELYhukSGMv6CXtq0\n", + "ugZLCx24z5CJjXHZ6aJugoOXVvLE5AMKcYDe/LowGji7OLeFgeB849mfSaUGlnh7jxuhBOU+fRS4\n", + "p0ITI4vXzUUR4XVTQrOXBNie8HQwoivm+WRv0nW15Zl5mZ7wAnqm6XldppA1IAAAAMIBnwp0Q38A\n", + "Ir2gIG4zgb64sxYLzhi9P+r7lwy6Wa7RRkAjTYM9mY6ueOaRzgw6T2RlVKQ/Wnw9OUPsoB+98v3K\n", + "7Ai/8Ku9oiX4fIaC4XxFxl+0lQDznNsd4UfPo3AQh6FoBHug176P/7mBbtXW9HioX3mZhTRXJOlh\n", + "Psk7HP1i1klJ4f63KMPuZvFOjkq75Z+u+/aiOQvmn6+lP0r2vSaqs7nxNSGwPqSwNXaUgQz58aD0\n", + "pB2v6eKf+Yy3eGu8f7HHrAAAANkBnwxqQ38AH77opN4Quy1TZxAAOg5d0nOlbRa1oa+CUrbGUKO9\n", + "s1K1K60LxAZlk8ZQWiHU0UUuQDnHAAyjelIcwOj4NipQdTlRBT+HrLVCVEK5smCT4WEyhlST21vf\n", + "pS9QIx6rrJJt1ZwRk3fLMy3lh+GbSU8p/deKiRgvPKu2y5xljT8HokdUfoJBN0b+9AYNdPwZxzfv\n", + "wRj3rjB+XbCQdH7rLOmVBWtc7YBBcmnLfJ50Xx9vsPrIGyT/orCu88gDS7Q97WNMWaRoINuEV0SN\n", + "7lASQ8YC8xeRAAAByEGbEUmoQWiZTAgj//61KoAGg+KazAhO48Rk+mELCfGa3jedcL7j4gDd4k3m\n", + "hfDQA786lCeWa51/s1J2qe/kkvnBjg4L/5tqqnPuWzD5CtqsuCrBZfD9tieYn0V6h2QRjHTgf2S7\n", + "KbBJVduRkgXz0DCyLCsDRdQx7ZVeilFNQPYHPpL3dFbV2ZQLhZ15DCVv0ijUbfdtbaCxQWk4hFwi\n", + "4Cl7Vcv5eumMKNjbBf29eX+p4vfxRMeLxQVGLH+o2FLpf2SZwh6nFX8ReHwFB2aNAZojees14KLO\n", + "dDXVOKLwRfawG/F4iTHLNjIHr9KJ7RMP+ZW2v4UodTEwj2IkfoeugjPYygxsYBEN/HIWo7Lp4BiH\n", + "W+sGNW6nzMrLHeZnfPrIXJzjKMZ2dMe3r2TPoxLKTVgPHlFgXbB9gOVEkvjr1YtxEt3sHivjr7TH\n", + "zrmzrXSS01xk914HSqt/CnYSKPxa2MF69g9I/BNJSHdHCdNGwRVm5U4w/DYDySkJOTHhPK5xLTdI\n", + "6pomON2J7Snu3IFO1cMuZQAgHAwoynkWURtTVoyQbA1o0XW4HcVte0xmLSUrxW27KPhiReLpDIah\n", + "P07+6UwIug2Iw2yxWwAAAP1Bny9FESw7/wAZUxOT3tiejYgyJDRrCYHaMUHhX+buBbaoqZ/1iUWs\n", + "Jb7slI/imiQ6OnWj09SEskbfc/zlMQQ4SNXZauWfHJ95XYh7wMFGgh1p51IG9qMewyJwQS444Zn2\n", + "viLgUg5+yrpXHCf0t8/9jDlbqwjDulbT62pdxpAyxuynsO8RFT3dUKeSE5htp/jbraDowEdpXZyE\n", + "hG0WYkl+RbztI/PQNZCwZsz+nvpxvKr5XHM1hBpXHcYTolc3yg25EknXG5iovx0Y9EuSqthrt+Xw\n", + "mK43mYVJUVC/Oh8GeZYMuS8/kSjScKjb9J2cbfyAxgmK23G/LX345QQtAAAA2AGfTnRDfwAc/TTk\n", + "s3FNYSmNHdPgDfXQC1GBEwJGCqSU6MsmeFhDrrArJ4DXkS7h5Olwl5LsAdAjNSMWnsyuwfwlhiS4\n", + "Iu9nXiMR2gsFQTdJfxAGWv/oGKrfOpY9OM+oH5mmAEYRbo0uYIZjYyyv9H1tg0RX725ktocEeT9I\n", + "3B3Tp4qYCOAxN7JPiw1LGqnL098ntFu5ng1+yPoA7ayjGtnhqUNzDdxHw06qdCQZykRFXaAS2mFv\n", + "lmomA2wH7gnlU4hH+9/QtYxMog0PKOypGE94HJSUfoT7gAAAAEEBn1BqQ38AHE7WHA5VnN1RP/m4\n", + "B17wBGTsyVXKs9N7WlI9AxsJJ7v9zVkMjf6pvv+Cg6JoQ3BLOK7r3bcONYUtZQAAAddBm1VJqEFs\n", + "mUwII//+tSqABlJow5npTNmtYD16z8AGI7v0s/GnfyqOWKggEMwd90EmHsgCWksYKFE4Qru8Yv50\n", + "LqOKJvWMLHGzKIf1mWoops1hD8q4hCLJMEdRItKEcO/AvOw75DCgogAQMHz94YdBlV1FB7/3PGw/\n", + "kvp11c7Zd3bjgbTV5f9wCrj5V98Wrk1QkXKTao3xn1WeAORpyCtFJo3KIIzvry0ktsvXmShsZdHK\n", + "SF2Q6qY6Id0i1QRrrPRdF2iq2m2rhv1eY7FLgTuR+kimJsshiQFr/qQ4tOO2msQRBI4huY4JSA+L\n", + "KftHgweMeBwJfCg9ocoILqar/ZxuCC1Kx59hrQRJPfm8amRIkwU/k+wKJNYh9fLLSBsxlrg4XoMn\n", + "PzXBXS36HS/Vq/PUU0Saj0Ks8oGCHCVcz3eoIxgiU+QJY/DixHlF4+MYR1JrL+dYLi5XU6rOa8uy\n", + "cymZbC8fCrT8nFmCuYcD3DNSzmKt2Ypk8ahqcNxMHCCE377w4QcAAK8hLicCDiuo9KVio6ugqDQM\n", + "DiWya9QmBn0ClIbSCznyVdfSZyODo1gjrJ9IiCMcnWI45hcgB0F/w3f4fUDX3TFD/vbMoTmxwMKV\n", + "hWEq4XvI4IEAAAE5QZ9zRRUsO/8AFKVUcHl/E43Gt6o4RZvBs+iAp/X/n7d7Pz7RdmO0J7CPEDVr\n", + "YOGCwg4aa5sRnK1DwPx5sIYzP38566ezpK1+yb8tpnK38Otysb+fPORXq89pSQ+5zLmadq08PRPq\n", + "ft5b+CuHdsaohxgMdfr5HBiNNodd0VK8TNpXmgIXzYR5RpK7ScM1kMS9Nv/EnJHMV/HrvGwgTDTj\n", + "k64XWbP6seQRZKb98opQD+okWzwHsAFj5ehr/ekl0IlB4NOOkEs2vqjJoc0vIcwkba8FSFkLe2wm\n", + "HNG8c/q9E5Tipy3avrHlLTvT0bjPkjeD4HLfC3isImW2RvjzyyF2TiLuxINvE8y7u04RbyNnhNhC\n", + "J15BQDsVja0XtFDfnnr/h18foOkLRpLJ1yQTMBboYsOrVzSZ9GDWwAAAAM0Bn5J0Q38AHQXz6rvN\n", + "uarixND043ZCNdAAIHUCWbOjp5TUpZdEciERk/s2Hj36k/1QHuy5AO7bU6FcTtkwLNXpp4kEhhr2\n", + "pj14tuqcy7uq8XfveV+qzHFw516IWJuk3fnleTKVnyg4EmdGVkh8uUm8KAFIin8/UzurGkP5FXB1\n", + "JS0uIqtx2mbD94hCpeHMsXHXmWbW3GUD6bwQzUCwUdgGFWWOBIzHIH3jzzxIIZ0rnTzx6fd8zSRM\n", + "hMrhmhy9AElVESMBSl9RUVwHxFBAAAABSgGflGpDfwAhvaB1qIOto5yaJpOYSSkbksLCkPuZStd4\n", + "LeT7CV/DcB+jLm/y8AhlFfeod4crFEXxelJR/fWiWC5cEAQJB3xoICKkbqYOm6EmFwfhOJrnHL3F\n", + "i7egoJ4YJywxTcfWExKLj/7q5Qta5s9pQnji3v49xEhquy1bNbsP/0r8degDcM/eCvveCCuWJP4W\n", + "kmgZOsTL6w2RcANA9FiGFsZYFgwwIJNSoi5uPhHUWhw8DgpZUJJwhbcwAlrJ/XkpDgMQdv8+KTaK\n", + "5RNrXWUI+DQboZuQqh0EP6Ucm1iy8BiBubHVtPfvfM6aTMlQH2sGDo7kxk+QnIaS5zzgTFrv32D9\n", + "yKVtBoqoPJ0AuZgM4FsUTuUjy7Mb8fU+FNoSPESiOFS3CYbvMWBzWtiplx16c8G+2sTGiL+yia5h\n", + "U5UjqF9tl+DCrXkPmQAAAhVBm5lJqEFsmUwII//+tSqABlvipo+ln6jP3YEZZAIeN2gdAdBG93Am\n", + "88+PBAP+pBG1b08i0fIFrYTfZkz4SYTuxIQ1JlthBpef+blJppNwqif1piWVs/t6bCj9Z+mNxSeq\n", + "fY1/wgLfvSZhz+cH951YQ+3lZMxDj+AnlpOYgaA5ONYw7fbC4eXvAp07e1QLTwt7AKsxs6j/dp/S\n", + "ROqifCEiS8aS31tyrNd0WUbq8QssOlpj1+9+m64Uuc7+f7EFYNlp0SQRRU2ux+5kBFuUthOQf/99\n", + "ODAIvGEvExgFy7U9xycg96i+XWorpOkUsmc8UuZbMVhIEf4MYVuxmTzjhiOVDlxwcksj2gNb3xa2\n", + "pmXlh1zp/jlUP6lnJbCcR5jJhGaBJ/wuH3P+rOiJDpAwjSIE4agxxO9XGnmQRqhYjiBkbby/Qs/C\n", + "0p6IlpvwhBITpwXRBm1mH+MtJEskEccmYaNT1YNO6b966q1ndwWmG4wqG8yXMOLAMIGnxTjTIpRG\n", + "9a5Z9Xdl+HR4ndQhvFfQ+mQNsGUdDPAaOtDr9NfsDESdrHz/VFsWMxlbozv6ME9/FBsTE8SLTZxK\n", + "uKA7LtdEmFdsikvrVwkDRWs6mlddIWSLEJey878D400I9Bm2F1YzYF8hIer8urpKTRWH3dl5Pnql\n", + "OkpPyvm3RplNwN8DaGYvFB3ajEHHx79ej7jTTF7j2dZAVPOuzAAAAQNBn7dFFSw7/wAYtYg8t2YJ\n", + "aBl5mT7LoVquTMWPsAY8JEk7n2Ltj2VU9Y6yhnUjGblNmyV5I1tDP1WCa31R20KBx8ZAPYjEjgAl\n", + "IBPsF6gwEF1mGQPgwIt+DQ7Ltrn+WWljoOZe6qmL3ODaEJKUCy9wZy8Qi5WMsDYzpEybVU1vipuE\n", + "rsjD5epFom/S3CRpP+JRc2SuBGV9X135AtKz2dAbEFqb0f/DUfvRpyE/xar90tpMsUisBmDyfPqC\n", + "QCIWsyVA62u0XX4SHuuo3VkmdASLaLWJS0hWsThucD2h8t0xx4j3t8tQeFkAoX+vhWm72BA6IAOh\n", + "cP5AynBLYvgLjkBSaw6ZAAABWgGf1nRDfwAgt5i6arm7oDsF+i9EHiOJ6m6rVkYAHTQbG9yseMuo\n", + "2+jJx58xpeovc881Wv+6nIPwZiRTONb2IQaBwPwYP/UAnKjoweUWtNn8yjj61Yi1F5n9oYReT9vo\n", + "YNykd6+UIhqXBR69VB8JEqms6DNcB++Z+7S8cRY1PTjUFRAm3tXpZtcqOC46Yje8Z3mZdWtke57d\n", + "wfIWf/bjH+PQoHPWtMGigrlGqEUElC6TETXz+nB7X3pF40yVazdjxa5pCPS8j1Bqo/RmILtftGxN\n", + "Yu+1c8QTzG5+3qHYIB5lZeEW8bNhQmHlV1zck8pKhAWM+UMUo8Yo1gMDIjGuUuNGCTYOoVand7oO\n", + "JxBESUm+840sI50gEtqO5mhNaTQVfGrhYgQvynil8I63rBmEOncCHtkN57Vx9gduQDjk6aOyO6bY\n", + "qsBt2jiwg3SW9pmMOjEKBDS6IfMiAxcAAAD/AZ/YakN/ACK6K1xrl4Eswd4/m5m3eDoe6aKYRGzt\n", + "qScyJrEz0/YMsioeM46osJc2N8un8CXkVjpps6zgsf8LlkG70ab3ccrB+um/wXzisesiYCwJDgAm\n", + "D8ODYrLA2f4XQyaEvxMLwdPggFdV9SLGW7IaDs1Gj2MKL95CD69ggFd4PlXdr+MMXaKnRfCfYej6\n", + "jyRkJ6YHIJryGsscniQRwJ0d+J+1KTOriJZQomY6moOkqhpxON7UIyt9lzU6HlHOyQJ+oRH5iOIM\n", + "+hKNz7H8znQxxv6dKCBY67rZbPlwYKywoLx2OIjAEQohlh7LdbGhKMy/zzEiJYFobhp2mH1gAAAB\n", + "WkGb3UmoQWyZTAgj//61KoAGC/pGgJ9CubE/Hy/U90CEEMEEbF2Q4cnB3oAeksXBYLQl6DX56J1l\n", + "w/mHq8WxaGt2MnAvQ41YNYO39iE6FvpuFKpW712yS65PLr83LJiqo7HZlMfRzKZN59Hb83g9Yzjb\n", + "LItfty44d54BI12++V5xh28HT7V7r0Y3bFC5OovybNWx1HQWDmvmM+uWQT6BKmA1pblkm0jWUuJ0\n", + "KAyepKH6sPnyIzz9TF/cTcVBDLcJ0ebq4QoNf0i/efDFq1nH+LtoZFDiLpeCwZkCLTOE+JMjcVxC\n", + "aWP/XfyRHhNANFDKtoVePLPasXuBVFa5xCh3bB99SWFmaQdxLlk9zHTMNOyCWoiRa9OkdBShrOe1\n", + "dfGrU6t4YEao5nNo7umRhNJMptOYWcUtCbSBQmV/4G3c/zgmpJb1N+5bNROg3nNApsFhNWPnDxXX\n", + "YEcAkKEAAADvQZ/7RRUsO/8AGBSepWN8xnNsxE4oE6H3s58lr1m+iqw+EfUFRD+Jna0+Uvzz41Eu\n", + "ATVBokoBIC1dZOqsBeTj8Ij9FIuxNitjsFqDL+DuZwvmGihDa0HIS79MTSVw/f89Ulk3p2M2jbij\n", + "TpCkIItiAXbWCZspatvMx2+GoOmu0/Pjqc6iwrXWXyi9/N9Jj+yY/ClUEyj7sTv82Y9nVf++GCrf\n", + "1w5ltOrH9rRQKpUQaVxp4gxcgxC4qFFOgMxs83r/WkZSqY9kO/9UmmCqExD/ljnRMUJvxp8FxL1d\n", + "H7PGv4WLI5AeltB+MOGIOr9NYMAAAADwAZ4adEN/ACG6NY+qIzQfcYKCb0AhP1JJtQboSZcB2Ux6\n", + "0kAZypUjTcd/OmJjJuZBZL4W6I8Qwzms0HJLp8KRrHdk5GfU6sWQ2Z+fhfAzgzC1XgPD4QBqkDkc\n", + "T0sPX8iasgf4/DARkJP486Pq1cqH5kOYBwnnR907+n/qb/xaeHwouVk6h00s/qlqepq0S1p/xGR/\n", + "GdINVBgCemrU+PPAyI+EQBjfU66sma3ahiVaLQtsD7mxr/vZVvwLqa7Chr1J9NZveiHKnAzIMG16\n", + "G9Gmkk/8FUHgdrIbZ2heuBDh1KQSBCztE11k+ocodRJkiMj5AAABBQGeHGpDfwAhujWPq8KUOIXq\n", + "Yi8pfsfzwlVQDEG6igccpABq5mcqZlBxZf6f05WsPP5oiGUHFHfSykAR60y9PVPsKziKYov/dHwR\n", + "Kft2Arvz4qT56TCewQ06i1++DP3k7arAvxqk9+C83xiDX/XWrTHQ1+jT9fNei76g+LJLvs+Z4UVk\n", + "oEaQ3c6fXvOR9+Md7sWQeZnYPXpC/0w6s38iG8bM/+n0jsTdTFeBwE6YfrCAsv/ybSEXYS5eoPM3\n", + "f/HRzfWrUb9MZw2WEuoxs0K4qVyNiDTxcyb1DdadbkuzwkaFG7T2ZM6Pebp0YyXRqckmxx6YTGzB\n", + "LlKwKmWHeooj6Lm9LlzVgQAAAaFBmh9JqEFsmUwUTBH//rUqgAYrWZggqZs1s6MH6FUT684nhne8\n", + "ykZKf89h+0voVegpTcVlgsFoS6xwNTcMDCv9PiwISM3bG5gmdpPxwsd2af4u9VMbVGyE78HSQ5M/\n", + "nbkySYm5CPjed6c1fzFNEjUv+hlxYNfv3cPYnGT/Yav/5erFhxatniKB++1xw2wwwm3hwteUjAt3\n", + "Bi79ySg16ijYqJM5fa8+vosVJZysXRlnbW7/ITdmkkl3c8ndruo8FzJ7m8m8z0kOYciXI4QIL6Xh\n", + "qroOcvOVcWB7Uug78ZH3AowGQXzMbzVMrLD5Q7gJi2vHbYwWBG8EpVzYFtaj2m+v5trtiq/wJKtt\n", + "WosqXvVBFnxrWYQFjXg41D/ASyQHPzn2WsqemfWG6/EDepgeax6MAFQfxyDScuq3fNmr8jf0net2\n", + "tjnK9AbUeZfaZDCLHpnptMZuk8clMx5Y+UVSA4sRK6q5yL86vVu3TWQ+TGs9ZFdT4m8kNBPSkwSz\n", + "rQpsGSml5JPzqe84pJi6yJhqfYRsb2q5mJ8tkrUntJCF8lR106wAAACuAZ4+akN/AB1RsSI82HuA\n", + "EDVZr5mUHFl/p/ZTcmoRWj4TfRvTsYw8OlDJB7dvZ/vcXyur4LGUumPqBQUBQHfGq57+bI/8tRzs\n", + "Z+nHU7WH8qJ9BM8/NBixjH12m2oVcRb4XvfrX32V+Y0hU+0j88MNPEcdX4rv7aeeep8jA96PadWJ\n", + "mSmtmcZfJIFp4fz7nGsOeHvsRUbV0MKDUYmKN+mrh03bThLfJGXI3U9Tnh+UAAABmUGaI0nhClJl\n", + "MCCP//61KoAFm+ceSLbmAtKM+jG0tYuAZBSWLg59auQBOS8BoT1gHMsjZkIU234iG6WAeSbLJEu0\n", + "KCLhFA+AqaJQGzw142KKgdSAFtORqvq8YepvegTzCCnS1DU11oB/GUVDtDnboQEryLd0x6NUSSMN\n", + "cECL9Mzb9QebAeTbVcgtE4xPKr7FEgVH4vbNIioC6rYN5svm+n7fErwoxd1c4B0MbzpTJ9ypWCIt\n", + "jDqP/6ecCXKe8Ac6gqcpyPRaKmFcKdx7byHCFs3Y36UHxsmpasB5iKonQtfou1T7ViPEDD+TNshw\n", + "6ncI9FQOyx3EYxNs7CdmXQjjuiQ/hVztgan/8HWeS5jp2zgzBv5BXUEnWn+A7+FBONSn2LL/uQ/w\n", + "xRZTcRa0x52ow/V5cvgKu7FATp/RCkX/G+w1Qnp+0VyZbVkCutQ1yOnQYxf79Uw65C1zWPQdQMP/\n", + "K+VS6vPAs27IKeqUeSeiBKHv/3isIgE+rjxQbN9Lh1YW9R/9r++mSeHrs60NzUtdlXFG/VIZkaKd\n", + "XMkAAADXQZ5BRTRMO/8AFlm8HmElw5CLBq61UEezfOfwLuaBDj371pFQE2TaGfrDL2cPvWN1QZqb\n", + "tmH36IVd+buOk4nAS7OK6LGtZWekVP+ro0ezqUL6LNjplSKI15AkcuTQweCsbYhrSLoTsRiawYgs\n", + "mv975sfbTCY9L8bxROvDNcwG30R1+JWvK+o/hwf/xA32LhBb08HGKIsZFejSCR/ZACyPMiASYPKQ\n", + "KnKHiabUDVxwGq+/saT475SIsPn2KAHPd1oy/JYI5la+DZBAp1lqCWQj4yUkciIB5BAAAABzAZ5g\n", + "dEN/AB8V9DqLglnogAnlbAbcaeEM/+Dr1d94BLu23/b924ZA1vKLZ+NWO2PdXQ6go3Sf7NA4nwhe\n", + "Jfk07l2+PnIu+kI9sd8bYLUmTTByKGfoyEUnQqTPIf5dfjB+AgnVTc5y8pWcKU354gRsJCt4lQAA\n", + "AO0BnmJqQ38AHxX0OouCWHEND0XeNAIAEOFUWlDAA6yKdnA6h0XJ5AHh6k3PwK41LuRgTA6dFitc\n", + "eGcLOFImUAXmZeNXd8BBiP4Y7WDb/nj/8t7UR/ChuIYJmbMzvyMcttz9Od2nvufuLeTpnnGxlC5D\n", + "sKIQ4TiAF1Zf6Jjc46nP71VK4g2t6fmiQijizaslPXbGXByTezIrwT4YraOsiMH4GMwabs58JhIR\n", + "tYealSfNunZO0jU9FNwqBbfEknuQIRSATwmWr49+JU7MtkfWDJ9lAsDVu2W/43LTVqxccM6dY8NC\n", + "EBnYMhV6U9uYbKYAAAGwQZpnSahBaJlMCCP//rUqgAZTWZgI3NAzNytjReukCJhCqRIQrgVE5TFG\n", + "RpO1ZRhoAw39KCX0FTF/pEpCWlYTREK0RX8M+i/Zkz6IOh5zRR0GMJniH0SeRA8U+ZBIRrL9Hl62\n", + "8kZwKv6q5Netv/8gTYt8wrrWIwWANbXHJaruY4G39urxvB/yx7ozBV54M/wmK8P5AgF0ljjPQAUZ\n", + "DnLEHwmopi3rWM++lGz+7pSmghGU/3PNF3AxzoRutm1cdRdLqAFKdPRrKeDtflDHW39dHMmsizA0\n", + "JAD4HEW4vO3o1CbLX2IxlZFPJGuT1QOtzPR7lO7pJCxfeGJXFchlosXXXbYjZoXRMBBKcHqbIWa+\n", + "lcjl1FcSEXbk84/WCNR/hEiDPBQ56Zc4Yg/Uu5te5H7B3WBkQkc5+tttienjQao2TkWT/tLarBIb\n", + "fSMA+83k8gbv1oyeFIIWqR6ZYarMVbzfFtnH/fWhWkYB/el6Kk3P0OPSTUOVwdEnhQ/ztu0l8Ij9\n", + "PRLg28jDAaygyMt+MtthW/hM1h+aETPrMcrgZoJoV2dKCm8mLdDu/CmksDfLJBRBAAABQkGehUUR\n", + "LDv/ABi1i6Ag4bMBZUwXqVJnyx2PYc2F7FCjvy82YHTp5//HJrbZhCcYERymRfl1ah1T5z9noaM6\n", + "FqCYiKh/nb1NKcv6lay4yu1An9EGWzEXMRaTXWcwehWRMZky6GX2Elv0mAOhcWIk8WVG2FWKKMhd\n", + "27a8KH0mx5CnVDu76Igw2moc1+yPfDPZnRGymeVWDMSj1/TY3hGgb5hmSfANHPp4nyrFETtH62Dy\n", + "FIZnfZ2tua96PI/858zqXLfYaSaEy66elRjPHGSUQ+kLj7sT6e2TgQoh23asg1dvl0lw6aW2KtOQ\n", + "yQVjdxBZzehiTDj2VDDo/FI5LuGH/jfe71B2giPdfSUEN0GwZPmh+oBJ3YPtBDdEXjvqGtPnj9YN\n", + "o2RsGDqkSW3oa8BY1cptmQPEHp1SMBrX83w6xtQW5X0AAAD0AZ6kdEN/ACG9oBtcoOCFYVPj9Yn2\n", + "v/zfoFr4rWL2j9A7ZlqQHr0ZVpbLuAQJB33EyTSBNnFvVuljxMl3V6GA7Dl0BClPwL31OrTpG1l7\n", + "a7ghzL0atyS5ApCJWtp2wOBNzezTQ3N+Y1tH+luIT/i1PP0KLgniqnzZyMrwKfZeXoYEIl7twi0H\n", + "PJVeAcAdd8vPtJ2LywfKZ3u1S3on0S/4f7cj446r85qt7SkU/lr6c/+gK5erYXiPq/kf9oXoMNwY\n", + "9h0XgCkkY0ibuAMW3BGf/tJy6AGuO11Q5hQVr9nNkIcjB8Plen8B0nqwKQkOaIEp5QYqYQAAAQkB\n", + "nqZqQ38AIr2gIG4zhxx5qcIQ9c2Osw5+uNtUP7c8wH627Nk93kOS5kJwZOUsa/GuB8LSJPcgk4rv\n", + "NNy4X5Kv65LRXZpkjxKOzss2V4BAkHf3fdjwk53/8IYs8s8oIvwVKvgR9wljv8Ag07Nf+XJo681q\n", + "NbSzOUK6bv18ql/byQhgzEpF9gyeKzBYpIes4Jq5ygJqsHenGCQnuZZGCejK/v7YZig/zrXj2vhG\n", + "gCib7VW/rlAZYnZRYtYW6jN8+34R58oAelpNik7qpp/KkHdSQspzMHjVSAa9yHgI/KVEUfAeaSTC\n", + "N1Z3u1GIF1TdZRU1zNyC6xbuAxPXtz6Ez91WiAF1zBDEIltBAAABt0Gaq0moQWyZTAgj//61KoAG\n", + "e1mYdETW3g4OxfplN37UKMHTaFqDxb+9ytAjpKDc3XnMw/MxT04D0MH+PToJ4KWEuN7AocErZRv2\n", + "Rz2GQBbpS8lS31542pk6xM8YYh0/yeF1AnMnBxO2+HilOPhojFg3EW0klIcf/AybMYAo9NSuBD9C\n", + "s4e75EU0t8atdvYkg/yfik+FMNyFYTUg/mi4EKL8VgLWVSi8mxQ1+/EWE53/+fwb7K+j+527pMW9\n", + "VCj1B/8oEXG8oxyHRw/TQGPoBS7lGz9zLwh8gXusGZBvY9Xy0pnRdJKDkZLO/YjZFLNiCRPsHTqL\n", + "i2GYmJ9itG9pRnevDN9cAKQP0fgHBe/nvlXFVK7JMen+RKub1gCuPtFfO/y6rA2fstwepz1bap4Z\n", + "wJXzTLHNbeZ6/jnjul1UTQDo+Wyv2+WNy23qAxLYAQV2nquSCySITwJSTVvg+SdePIAmj5UPClGF\n", + "OrJIf0RX1xfSrhrpF0W0EhW8ceypgG4+dXb+bPwXKBwbO3GymyW89X2WJwubd13etWWTwju8K204\n", + "+w8LWTwxqMyJaP52mExMi4W5Yjr9AyAAAAElQZ7JRRUsO/8AGBMzRwWmJT1LkI6j4SrVrkQOkpGr\n", + "7qmVB6agtU/P7NMI3vz5LIs62lee9zlMDhLgStRXRkKeHaPAGaY9hwFwZg4RZnlEijsKiC6r+GA3\n", + "jOJMGPR2G+iEvFq9JqYdk0b1d9ABTX/7oiMKav8zTfVNhhkqe32oj6u1ioYXU2U/9Y4cH3f/N9Gx\n", + "JhjbFALTGuJMdeB2a/pmxPSRSx2DhwUwXe3BT4iK5IJF2QdQUjRydlTK56i3AOElSAfT6NVqnLr8\n", + "mfbO/AiWtC7ZCdSKqLQrBheoCisxuwRDc+0Qj4IlPLBawyneGpiLaece3KMzpKTos+5YxlSYlKtg\n", + "/Me6PG+fH2sUI9B09T2Px/9ucFTXTUC5j4ELLv01D5MY2VAAAADfAZ7odEN/ACK9oCBuM4G+uLMW\n", + "L2dP1lfTvDhmlpluM7IE4yEUJKicqu4KM5OijIBGmwd/fv/FYUE8C16mNefQ0Uy/D+0+Hpx1ZFAP\n", + "3vl+5XYGW/hV3tVz6fpDmClx2VYPTKI+QsHyxc+qQa6raGV2rQAFnERDWDAoPELDpD0DBzrtQ9Gj\n", + "f1X0zbjtJNpqrwp/hRbaIrr15pQNp8wHXKVl3vyz9d+FD2rUtkJQVzj6V7XpNVWdz4mpDYH1JRGS\n", + "i2MURr0RotwXgP3Qnz/8L/EyxM0Sb/CNWw8xQFPmbCgpDwAAAOUBnupqQ38AH77opN4Quy1TZxAA\n", + "Og5d0nOlbRa1c67qPfhIW7P+8Av3GtFE0HFQCvcwO1xKybwlnguY0Nqo5bzwqVZ4m1UebapfH7JG\n", + "d9M94gSTzLBzp+7XrhnquJ9dwfh5fBCyLWBt8xSfTcJZr1HXGrAMOw+Jv+pCMMogCsMVlWbHeQuT\n", + "mD3/yuQp5lDob+9AYNdyDEIT/fV+2vxg/LuQxTIX08ne1pWMu28zMsHEcHxols+2LTEYzIWCi8BU\n", + "K3ZtJRE3rAjZxLOQ4w3m2m/D157HitClmlKcP9jJchoyWV95Jy2gAAABu0Ga70moQWyZTAgj//61\n", + "KoAGg+KazAhO48Rk+mELCfGa3jedcL7j4i4wMKqReszSNQj5h17BpSVMT9hX+zPhBrSs6Vj7HyaE\n", + "qm6lvw7kPbwwNhW67XEllpB7/AB7Dtmc/Lsrl2N4BzMZzIFVEJCqVkWDwHz0DCyLCsDRdQx8uGEg\n", + "Ikolt9wM9AgzvQ7TxR98jTrIYP8SP9CCVhDDASOwwiUKcH0pWRrgAYwjw8Gf7OlbogYj/no1BpFx\n", + "lYglvem+TH822s9SIsjJ3EA1IN/sTGSWgAXqwMREDl6rGx1E4un7krghrGWUm+/7j4jDoGqrYrQI\n", + "g7E+ktnqOLNELPNyQd8WQ/umSuXC1xL1umwA8X5+yPqMMHEIeQL1fzz/JWAXyMH93QMSzGumbhKw\n", + "Zwg0U+25Tvu4PnK5VQHbV0zvOU2Pj+MGf/nsDxqxrqZsD9S4YY9rcTfMxz/MkkzIgfRGQF/OgLHr\n", + "joIjF7P6XCeWe+XUgCwqZQG68PRNzfXkn+zUJpMMk0jjnoYnDkQ975Dz0Z65i4o7OdZtwLEOfaoE\n", + "pB0fo5td4PyA9vYIFlRo3xi7uvrQcih7/M7KbZFgAAAA9kGfDUUVLDv/ABlUeHLsmGHl+OQZEho1\n", + "hMDtEgrgr/N3AttUVM/7crMT5dwlm5uvzGVCn6w/p670sqgr5PJ6oiWC1npINQXp4CRzsctCmXzn\n", + "Ugai5K7NbwfaQcfbZKrjzT/10H2u4nhhcuuZyNqUHfbG94mETU3kKDy9A89Il0BA9I1A+R3yjNfc\n", + "+Nz5BwP3DN+ZYjka/GHLl0y68JgPyPoe9w8jyG5IXdu2vCa+LYvH9kU234z4psgT4qxlrdkhxxyP\n", + "UJXN8nPpx6cXDiQznv0L2owqy0csZbCzUw4CVJ98G+4T1R39bjI9WT0YHLigorskW6Eh4QAAAMYB\n", + "nyx0Q38AHP005LNxTWEpiZ1J9di26t3EruDGda0AVBouFN0G1ywEJMXJZuIMxrfHCac7PtwdnQsN\n", + "5ABPxruKApfvrd4v1WFO3Cl2Zd1SOG3/r1ORn6HwtueiSFcG0RNU2EL7iLFK3PfYpxwH299J2sER\n", + "9fENVpZ0Q3jjs6HsM0edV/QB07Ofn+R5vOS4TYLqhcaZAnuosw5RlS5g1Q8CuW9BZXMHWP4TGLry\n", + "nY5Y9ez3m8FrqVUEclyyvuywjGI3odTE+j8AAABPAZ8uakN/ABxO1hwOVZzdUT/5uAde8ARk7MlV\n", + "yrPTe1pSPQMTQCpdw5z/lBFmnGZwxWyqh+3IqkDkhpoxeW8ZCVdNB2x/1RnvvpDhcO3MwQAAAbZB\n", + "mzNJqEFsmUwII//+tSqABlJow5npTNmtYD16z8AGI7v0s/GnfyqOWOrIj7MzWLMA+5yFNFLu1hTu\n", + "dlbGlkD8jL3ONezhs0gurnHp2pFLsP3djo3BgKHcLr5q4kg5WMX28rT11jnIH4bHAuJDI0/Gub5+\n", + "542H8l9OurnbLu7ccDaau7k+AVcLYmIJfjhEaissSRpn2usY/14Z8WeJwbzUwclx5b0pufbMDj2m\n", + "E4jonmtfVQvsVKXSLVBGus9F0XUey7wsw1/Hxpa1Dj6X89JFMTZZDEgLc8SXNlb52uC+3SYuA3pO\n", + "yIZ3zYRDkwb5/sIpC9s/jtT+DR4JrFHAg/zOLQvdBHh2BZ/H88Qk1FOi1nkBwtogVwTsAvTRwaaM\n", + "L+Fy6Vw65xxtt2p06IrGo+vGB6Ev7rBsQ1lA5dJTwIES1/HSnI96cCqyJNRkq8io7XoKHq1jP8jJ\n", + "K8KCILcbnjTzWMILhY3EuZ8pRzEGblkg+ofcWDech+PkwDbk4flJvQ1eVGNBBbzkH58MbHNkp5C1\n", + "pRDfsnIb9VIwGZIgexRK5GP0EM8ZveKhcNpqg0C7EdFVGM7dDkwAAAFMQZ9RRRUsO/8AFKVU3AQX\n", + "TKYCKlUskM896ABcbpuBaq23+VbIBAleYM+Uh2fmC8hKxXufvA+Jyd8ERfcMKq2QBuOeaw8cG8nv\n", + "l00dW9FnZ2ewlISmCmZ99L0bw0GXPORXq89pSQ+5zLmGTJWLpbqXg/Gg/k26eFQ7yctp0OrjpANw\n", + "gpKfTmSwqfpdIyAO4i1HmWAczC/dxtyvK6EJns7ev/M+uhg/UBsLPdCc4ktjYaoFvgpYJl8v+SaB\n", + "iW6/qJFs8B7ABY+Xoa/3pJdDPx7Wo16RIr9F0VKx7gY2CroKhVZyesK3QK039pTJworswqeMoYtQ\n", + "SxUGWdIlnZAh/LxAqJSAgdbCea7vV7Jw7UJ3RZWLCaN03DO0g6FTEO0PNlB/y2w2d5hCS2yZtMLR\n", + "726poAjDu+5lgVHjodzIR1vHcKS57NpFhydymmBuCPgAAAD3AZ9wdEN/AB0F8+qoYAk/JkWPAABe\n", + "eS/K4R2z8W8rEZ4Es2dHO2B1xqZeWERk/2j9D35SD32hnizfkl5AQkKu7sKMRtxB0qUTg/5Ai8ci\n", + "ewPsEvh0cTnE+UnVVZQsy2FhpSkguxSgj2GzhV7H4B4oQdASRatW+4ge9XWWDwbNzKDfs2ikSZGn\n", + "ZK2J2cdk5ZNdF/NbhHS0c6vDp3S53pob/1OoP8UOX13YMuZJYtnSstfaINj9HWvrLOMusuMgy0ge\n", + "hr00WpqM4G4LNFMeeHMWs3VdDioqjp1BlI0pyKTUMl2eH+Urm0ENGx6u7gM90gDkOBdN7tgm4QAA\n", + "ASkBn3JqQ38AIb2gdaiDraOcmiaTmEkpG5LCwpD7mwoBhbYx9hK/huA/Rlz76MMOi96iXfBz3DSh\n", + "vG5XYVehGnggzBAkHfGgYDsO5F3SWLpvAiWuQYgw379rpdMwhqWoBgIHHe7UqoU3PiKCUX8CUwon\n", + "PUuq8JY4AYYztu7mmGelokJyoAJS97RU/X6H+RdsNNzitkC1d8I6jDPIy7qqN4tCnL3rY6Yesfv1\n", + "e8kTaN9S190RCoZyxCFd2JzsfgZhniY0nZmfUb/Ilr3HhSfAoNjT9YPJpZU0gCEN/XEjzBiwlPnv\n", + "oPqWZP16sXNdepP+5XR/WuewqnrAjpV8x4yn9rFVK/AamriL1xzzEUk66pD3JF3R2TNlp/oPgGf2\n", + "3Zht7rWDs3F41xpI2UAAAAHmQZt3SahBbJlMCCP//rUqgAZb4qaPpZ+oz92BGWQCHjdoHQHQRvdw\n", + "JuWMeCAf9SCNq3pRzo+QLWwm+zJnwkwndhEvWHQ/SujctvY5pe+lS1QEjQXzeizSF8k6tO14eAtl\n", + "F+Mync2FH/YIAKwBXgDqn6AXOHpWQcynHtaJryxWYm270/11pJpJLJP1UcyORiPI54DPlbzdu+l/\n", + "jiFd4hpdaoZTSIPUh6A6ClqPxEqekFrNjAxud2WiOSd4IE7Kaf//vpwZ0mh9bmck4Z3rAu3/6Cvy\n", + "KA3WyoqAFX4UT0ZjH4z6LrUYRBEZElMEZc4snCHRyZf+tjKnoDXWOrVFpzxu69dV7GJ+V1irRKox\n", + "Pd1LRXYUoYi+P14fumR2pYbtX+VBW+m+c7NAd8Z01d3TTKV7Mg7nTZdtCA/oFcETl7++5b2EIheP\n", + "k2Fg+5ToPyynpqzSsvv9vWMyfYTJnDg6PojbFsxSs0nRUvqnP5QCdr6QHBhWXFOG60F0RsLzEsNc\n", + "wpNcPfKeYjjdCfe8YUIVjq0PBSvcnC+B/ETQWaX7IFbWhPaknWILlx3KsiYwYSMVn5rwfQd4Jkdd\n", + "9H+fdht5f/EJHYCK5IGupAjPxHpu+QiB/iUSmCHkkTiMqsG8twzlljjsl22n8veAAAABCEGflUUV\n", + "LDv/ABi1iDy3ZgloGXmZPsuhVsylb+qqNi7GSIfQ+OHuoRwObuWCiDJsleSNbQz9VgmS3f493Q1l\n", + "fk0LSjQ0QBKQCe3UmCkV8vYYHcKN9CZn1L0i/3IstLHQcy91VMXucG0IQjYMvd5K4nw1TsRQ+zNt\n", + "c33OM7wT4gTiFbFnfUP6sORkbyxKD8+9VWHRCKkGnoAnjqhwkHV3YzaNKz290rB0XwxFDvsi8iqf\n", + "z+DNrf49LxpvDCniJY8b921MDAhjoaXQisEELwuIkEG2MG16iA+xn4KZIc8cifkUnLKYTAHTEosc\n", + "/geFGHZmG9d/0Ad4ehB1+UFj3eeT8gc12jWX2ySdSQAAAUIBn7R0Q38AHbXz6qhgDdTYSzAi1h3K\n", + "16Xr3JTVUajJdHP4n1zwK/61yxZ9pP4QSRtJbkJZWH6vivN5vckWYfjVoaQoNcq3qWx+bI+OTtrh\n", + "UNznJnNVmMngQpK+748FuR69zyCunCVVntkmuIrtQvOCVbqBuRz5Qxvz7t49H+VL6IAp+Rh2gf74\n", + "0j/UPUfosZ/ElbvCMu7rvOP7cWI+JN6KUOE+/AXQCyHGSkSvvSc5FsX0fFal2fQXaEkH67EHfCc5\n", + "xhdseiByl+PiqAs8A9zuy4qmXDeeIj+3Yojnw30fZXbmjymzKitBenCylofDP0QjYedpgwNVFWxv\n", + "pKDrpf57i5C5JHBxrkMOZNs3TkoKjfQLvKDT/j1Fvw02tHitRU1MR1mnPja0zhtM0e5b68dpKMZ6\n", + "9AO+761c+Ba/40Js4HhAAAABBwGftmpDfwAiuitca5eBLMHeP5uZuF9cX0/VXhqHcuiBABGdnZlB\n", + "vvbdh+1A3f4uQyVZizhw70/9zDh2nx3tQGn11M/7g3e0ETDcFJMpuy3pyqZj8OhCsFXcJg/Dg2Ky\n", + "wNn+F0Nd65xqPmrT4IAWVNyWgNuyHhWrg80hH2qe3n3QFTH+AG0t1LUQWRwdt8cDbAi+8IGZZrTn\n", + "QzKAGB5g+jkMrZS2t5af/14Dikh/TUO9x6vp3udUZwfEqX9x43nyKd2KkcrjEt0VxTQ1LHt4TKTU\n", + "ov9g2wymXIrIg/m2cGScMEoY8xa4E2v0IBu8Siv364Oh7cF3cjWG+ZJkZ6xGCUsmpmsJt4n9AAAB\n", + "cEGbu0moQWyZTAgj//61KoAGC/pGgJ9CubE/Hy/U90CEEMEEbF2P5yKT5EQsPLolJYuDn1q5ANTN\n", + "SJwpmVcvZVK2Tco4v2Comd7hwZPuuXhX+lvh+l6ZtjrC3czf1ZVbdumb3r3D/ioYe7qcFNf7aS5r\n", + "2YnlPFx/ox3Po4uR9L227Pa5JPu/JVHojzbyIvC2hUPLYoK3yo8EFTOEx9VW2Kka/dDqBAClQEXM\n", + "coaHOVrqvWOBlx0SmrR2Fn5qD0ttjA+wKyG9Ww/+/fxdGsIy8lThxbGnpYEDoqIDxAPPdyC1j/7C\n", + "x1S6SZ6cX8TWD+edELbCVScHr4twowGayNRkN1sGJ3ChzFZqefnm592USWq1KVPalCkn+IgAbkI0\n", + "gf8crEnxuQcz5L3ov1loEzryk4ptgt40vN/cUUrwi49uNdXDzDlba6ntBbOYIPKYQqVbRsWX//V3\n", + "7VjjZzb0fU2VitbTbNlERmPP5obsCvIRmiOfAAAA7EGf2UUVLDv/ABgUnqVjfMZzbMROTbEr98Ov\n", + "G6hTv8LwbEOVBTuoZFwTL9eOUuW51yt7Pk5XoOwvCITHjPxM0+ACPLC5p8LXGPLXOMFwxyKNAOm2\n", + "+bVnL7eC/eonqWYHV7ElnGiaPE4DZvhksvIAUMvT1hgYsLWg5pHxPTMEf4vPc7k/U4gx+qn0dLIb\n", + "xLE6WPqhOli4SJOCHhekKlwgxlnM6S8wIxjTrZQVP6tyjUXc7nRDpn5+4xHTB5JTQd/Y+v5uYYim\n", + "vSxL9Lp9+sJa/YqUqQ0UFcQR3Tlp/PCrTJ5gUcQmlTDSjEV8pdpwAAABAgGf+HRDfwAhujWPq7Ze\n", + "gCJPvLBRhSSbcG6El3BFXKqbl3V6+XLJCsWmxwO7Xskzh85D3/GGBbxCjXU3okqTeEYfyjkOl+SH\n", + "4VGFs6uGeBXI6FuyUdCktochZVIQW+D6bukSQtQ9xBoZWqRH4hlWFBiT6bV+GQGerlgKyeaNsqD5\n", + "s+IDfM/wce0dikHUV0++Nr2rHe3jcRRrSy2FHjFSMdnyldmaj1iFauYYGv6d3l/8LPJtc5g5u4Q0\n", + "WerxF6DQAN+WlQUAod5dWuqnUKOySujKDQh4Sh1bNoaribkhCngsbjiJUpnyDzJfWcRyF47YB87L\n", + "Omkfy8ijCTvweGsJYAgScQAAAQUBn/pqQ38AIbo1j6vClDiF6mIvKX7IDWIXdy1QyeJm7hwAhKrN\n", + "5ZQTH6lrtJ9D3xtslHyvy2ywnd5a5/owLJHRc2EtkPadJ8Uji+G9O7CT6ooBM3rAgAWaKgWADHof\n", + "Rk55HzZ+V8DMw4S4pnRLudTRFnX1DyLXHV3VXMnhAeP+ewFDtdkUHGMhcSI0U8KajX0wWNdBGeGb\n", + "D8Ns9BH8mxfhSu/SqyYkA2AIdaTRVyL0w7XOVFH3DXljVqrcwMdXPvGgiBcw6chMaLbepo7nSmh1\n", + "vAbwAQYruBhNTN0eawky0jofbme4HocI40c1sz31wjy2n2/uelK4XikXYFYmVtl4Kdutz8YAAAGb\n", + "QZv9SahBbJlMFEwR//61KoAGK1mYIKmbNbOjB+hVE+vOJ4Z3vMpGSn/PYftL6FXoKU3FZYLBaEus\n", + "cDU8hX8r/T4sCEjN2tKC+to/+IoDOzT/F3qpjao2Qnfg6SHJn87cmSTE3IR8bzvTmr+Ye4Ac/+hl\n", + "xYNmjmRG01XaPV08JLNnbV2zuL5cn/7CsR7I4pKAadGKE6UheVLfqn0i791ThTaaO2OCRjsSWF8e\n", + "1o7SXLcWHdmh1WCFSlfjet1S/FkIphxf8M1ZQjLPF96/W7wlOpiP6jEis8o6251YpmdqxS3VSmv/\n", + "s9Bv3ISLvkMspiZj+iQwr28MINay/7syEY2A7ZiKqNUJX069yti8CuYwd1gGvQZSlufV+auVaTNU\n", + "xocXs0XuFW0e/AWENf2i3yxrLFTHW9CCBeoKH21CafAHq6hi+H/e9DkZU77nSidgvmP6DIx/XjI4\n", + "Sp9anaBxYwcylzQtEH2XN+nrwpDPp45KYG9LI0xieadJ2QOTHIvADfNhP/PY2gqE0NQ2qkvQc0a7\n", + "Xw6JCi5LfZz745MNAAAA8QGeHGpDfwAdo0DVwAgarNdw1dyEo22Z+2voCmn3MepWOJpNH9uE22Fc\n", + "UAf4fo25DS3VGYdH0kZ3bYGxdzd+R7awrh1yiW2ItRU9+fbZ+7eJ43X/1GQK2tLeuYX+rXNnNYVn\n", + "3JiyKGKiuk48G4gEpBGTo6LBxeBZg0OXhUHfR3yB3h9X56ir+g4EbNusZoLNQh23BaGzc9/s1PO9\n", + "1PPSEqrUiAosSTAygJNCJGqMs5yCqcS+EZopY3ntHhRp/rTMQhL4aAxAb8XQkEJtEmWrzD4p1eX6\n", + "QEZh/6hTVX/Gz191R2H/Dtkpg79J3GkssFm0vPkAAAH+QZoBSeEKUmUwII///rUqgAWb5x2D2a6r\n", + "t0Z9OpYFG2tABdnWLgsFoSkhKeOGdpZQLTxZJNtdR1o3VEUaCsJe7TDcWLiNBjbFk4iCHCNTwP1B\n", + "ET8aIdy/mqBaPrTdtuT/6FMRex7yXV0X/b0t3IdDKZDeFLpQzjHVkdbvbm3BNwCciVQUNcJ7Sjbw\n", + "T4hbhPp0oEDMMYqhG0FXqi8cqsDNhwZenV4L974lIjS1k1BRVCVuxIwrhHZ+ZNeKQOVccqtyU7fb\n", + "1nmmkdbnAEav9V5tnQTxoYHQvrZLL4f7C+LE0IOtSnKggNbex2Xp0FNi9T/+fjTgmF5bW9OJ+WCx\n", + "leyLvNiQF8k0bwSPMh7702+7OB9yXypsT0VFN+3fNlolLg4yJ7ye2ijeDcs0TyR0KI9OqHHwk9VT\n", + "lv0R4DjKMuNtxv3yyDdQ02ld84rRe/IbVoqtujoBlwArv27SRkTybmrwQddynU1vfFNgJ2tkTxsX\n", + "EuhAyTUDk1pdyrePvO3Kyjq07E+ZdqW1unVDCL0p2PAM0Bdj+ozOm4QJPGRq3YEQjJpnk1BNx6E0\n", + "yZMxRvkyW2tYZosgoDR8rW5jEN/sH3PsICgk/jLYhgpsvFfXxjf0NPxMCt81bgYfKxBAoUrGuF/8\n", + "Gb453zLMx96NgDfHj/3/yVULmADuEWX3e7X8vwCYAAAA+kGeP0U0TDv/ABZZvB5hJcOQiwautVBH\n", + "s3zn8C7pn/fvWkU93yxomewKAdw+9VXghKzj8nMy4EQ6n26QhvOvN3ZOGl4wrl9GlrTzWwgssqXz\n", + "oLBd9XVA4LrC7D/kDb3CEAYvcHCWxuhsk3WHFeLlRhwB95RghbDR4boSp+CQz3CY9L8bxC9Ohf/r\n", + "dy9+xoLX1H7kyaZJ3YehTdM+5Wu6Hpc4XocPo/ogFns0WlfgVPekkiZdh228q3p+OFEAyCsprsbc\n", + "bh4x6zwYau0C11ECccZga0PS18ku4j08dAfMYirHksImmVD9Aw8yto6D9YLwntF8IaA+FPG9VagA\n", + "AAB+AZ5edEN/AB8T3aVQEVcYwT0kXXzzDP4yP2lC7bONTcb6acU9HQ87UdrkSLI4+OHKFlU0EAFz\n", + "P/GPhcZ5NOIVfnz6vsVd3DH3XZLg43PF1cMypwOcG8sbzfthjMA4FQSgVvJe40X2MhECJet9t2G/\n", + "XdWa+YBzkUuLdbRPBeGTAAABAAGeQGpDfwAfDtYNiNYeWLJ1JGi8AHLac8oZrJR5tDRFy80bn36g\n", + "01RfxVuWBDFeUQUU4VHoswV2zHbq6MzAloc0SM3f88f/qXApn5tj32GTO8MmdjG+5h2BlZLr7lVk\n", + "BcTdEueULRCVgGF4dFB9PX4Y3jYyGQfKH/BWnAEfbs4hEQ8ebrGB8mSRpcKz5q1oNG7pkp8qNfsq\n", + "nkhG1h5qVJ826dklpNvhQDQdQnVi0zusZWH7g9GItx1/0euTzo8U/z7D4DrbASMUmgB0DC8TSqJd\n", + "xZ+UMAYbubxMdW+iPv2N1tIKXHdcOVBHhDDt1MeY4rBQavQwdjpFZBiUMt5ya+AAAAH0QZpFSahB\n", + "aJlMCCH//qpVAAyk+dgiPwCMdRFSufgoxGSIR+/0rSe9Cp9hy8WpEfkfjpu1RSHWd3zlulcFC+Nh\n", + "XPR//hjTft5KlTxkfWUjrzSX8Q8sCzZTRHqzVvb/rscPsXHQf0E6taB/yJXWDm9ZR5fbjX3mwQRc\n", + "72p/7Nk/lJUO//4LM1qLgtlckFFvGA4aviZYHpBb9w1OJg/Jqwvkkixar7ua0LNG3ane8+4yu/5g\n", + "n8krsqxREhrpsaI39b317zkKj6KVaeKiNvQ1KBsts5QsX+yTO1tzmbv5PRxGS8tz2hKf4zB8fbWM\n", + "XhqB6Gi2mMVEo6jXnv5vErjT3e551EcovqLpcSnuFBTI4jT6V7ZqZq5zqsmn23ZqFTbBnXJfy5qg\n", + "Xc1RIbUSG7SAPcicWIbuNtZ4GQS+WKAEZUxr++6VPQD3gpW4BeKCxEy910wCA11VXaqCgcSgS5FA\n", + "dwACIPfrp0NhEyPCvA4qNFC9NitDM1I8HthEGAjfRL6imFuJfW4+Sk08ZcO8JNBK0/bkkNG7XFo7\n", + "Hs15nZek/o+FGsRiwki6FYqc1HBc8skTelrrFiYgicL9M/ehriAlP3GGSVQdD58oSyTAbR/XOwHh\n", + "/k7736bu5rnUg2SpAi/FdrWUFq0zx+C7UUDgbK+SgABs/nsA2PEAAAE5QZ5jRREsO/8AGLWLoCDh\n", + "swFlTBepUmfLHY9h6nZJebQZXCAk5QrW0LEqJOc6Tf3RfmBa+BH+trXpxDsoWsYBGGxFB6vHNSw7\n", + "QTuHxSINvJ7kINONdsnA7unyZfe+/dUQpBab4cd9DfyyBJrHeEf61R0Nfn0RkLu3bt6BWIYQlYtM\n", + "K9Nfs/vIPwJSfjpXcON5DPtNNDffXZk4RydlgN+S/E7EUmDtA6DaeTT9v6cz5zUd9DSGZ32drbmv\n", + "ejyP/MmN69TJZPy1fo/BndGgtSNNbFsKVeTDjxqdcz9cfjIrJ3P86/aSSTu++gY85cN7L+QFkn5k\n", + "/lX20+90kKxSs6X+x+u7me+jslyG1ZQaBGKwDx+RwViiPwDARZocg2yGxzRByDsEM59E93SHlUl9\n", + "GT+PqBiUfn848MoVbAAAAOoBnoJ0Q38AIb2gG1yg4IVhU+P1ifa//N+gWvitYvaP0DtmWpAevRlW\n", + "lsu4BAkHfcTJNIE2cW9WUS6DTP7xEfhthE/Au7/XTkrYH5bPnHuWMD+L4E2Ys7TDv/WnXsb8WMjs\n", + "GVKLefmxcZqtW10iMABVusPZiYCVoxR1g16JAWeZ7iIjTKxZ0g1yWUY7SYbSh6LLTrvWvhE7lU5U\n", + "CdpswEmIpPdhoFfYojayY1ypJuWbbU1PB5nvwD9t85tVUeFQcQm5aN4kQawNooLXHpvRUW63Gqd8\n", + "iY0WiZheEXu2JHmP8XM7t/dfyrk3Fx0AAAEdAZ6EakN/ACK9oCBuM4cceanCEPXT9ZV29ukUDhUK\n", + "Q43qY97tIKPQ4ZLk+xSOgxBfQxL7yIrZscfkKmKCSoYxQfZ+tSzvOZ1GhW2ifFuVzAIEg7+77ixc\n", + "Kx//CGLPLPJ464HVUHGkhcx37PQ+kbQrXlUbN3cWUp0Qf4LtEibFhZ+LpSZJ4udEDKi6Q/S18Psl\n", + "/qmdcccWROb1W4f/Xy9V+lMS0Du/XhxzsIhWccm/rlAZXG9J5NMLdRfS734QHwqLqFpe0KPTU/Mz\n", + "iY1ev2MPDzHxs95uiDK6gRc1gvD7TgXhVki57ReTigwP0Vcnsm9mMNHj3Nt6/RMhlMwCLQhy6qqL\n", + "YC7Z58RnNbEutfWZAa9Y2SYIcplB+x/e/c7TAAABlkGaiUmoQWyZTAh3//6plgAykehDX8oAigHL\n", + "uS7e5BiYpAhLP0Zp72qQ9WFfih2hD6ViubvwAAAy+5vuYY1yi1tJuPfBi/DL0xvClymIwqUp5EK2\n", + "pijOf291KPaqRN5kbJjB/2wfKr1+XMiKLX6DysREeFfQlDwQLBucvt+vNOXQokOSOb4yTYfCyIZ/\n", + "GHqmX89FI8GoC7SVJ8dqrGOCOpcjHfvSY2QsrqBh9dhAV5Sl9v/BQKeopbgb9Qoepn/uEMh2fyEW\n", + "JmX+JgRFJalJclAgIlVBNaF+FoinY0YPKhqMcuoH+rtaEk2LTWu4NHdn9ysTAkHlBR2G+58hU289\n", + "8X49s9CJy7d2oeKmsapTwnIxxJ2LNCm+TxMniHit0ZHqI5VMxQ+5ZJ2tPHM7/cT3gdae3yVR8+YM\n", + "/KU5H6oISvxSd8TybIcXMyYVHn6O+gwy4SKx3AkMYLFpKRIO1eI3ZmEPll+L/2Ahp3aDBQxulIlY\n", + "Qc1v4+BSAHSjYxY/VpZwrkFkWgmuXijX9pnceU+eCQb0BkKKYYEAAAE7QZ6nRRUsO/8AGBMrmVrk\n", + "4p6lyIABr6JcUvWGXYV0DKg9NQWqfn9mmEaqxk2L7hAoVLAefd2AT3uOnaK6MhbcdSJ0jbOgAdky\n", + "1NCtoTFYEK1L3oNAlJW78V3WE6NttmJ67HTQFhc7jbPt6n2fAdknrF4tehh2ttPPRj0ZMNDck2O/\n", + "Og/0bAxzaaL7DSYz/qGCfH6ue/8E9mejEEqzP8HffVv8Obhn2u8eQxOotWj4hO+DblITeYVYJXny\n", + "h4Mo9PoOPQCtWY4pEEbVZmokYfc6NrhoTMJC8d+WVfQUp/9dQN2FtoGBhQPHEwvVbIcYhR7B4iO2\n", + "lHuM7fr8Nz2PLRQOuR4Lhle59+tgw9IpLSJGfVu5u0NIKILKM/viNoDYYuKxIDdR/J6apnFKAoah\n", + "uk9v6if+0v3ru/qsdmBBAAAA3wGexnRDfwAivaAgbjOBvrizFgvOGL0/6vuXDLpZruFaiDwd2rdX\n", + "jHVzx9p+aFelpPZGVUpD9afD05Q+ygH73y/cGcCL/wq72iJds0hr5PUpNV/aSoB5zpjnS1krIC0g\n", + "xgvcsTNLJd1aFsq1w5umkQK05c9QgDPa1eUOrMmn+/YlpdytXE6u+4FAjIpYVgn74StUYfcT8IT8\n", + "SGX5Wru0UB/4BiwZwXDYz0r2pPySvTt1TUg57ubb0S/BqMvEVZ5rArNFw0GaRO5EmmTuHjFK31Ed\n", + "ZcrudMiOWUSCfSesj44AAADfAZ7IakN/AB++6KTdC0Gg2vR2G3QAHQcu6TnStota0MGq57eEms8e\n", + "GSZ8YTYymFLgl7YZGG1YXmh3orKEBl6b97W6tU9/+wsf9/cg00EpDLAMwmuhlqrl+tcaP161PaCT\n", + "db1JjfLZ6rQlIR/u8Lq+hDMPBrZgZ6lFmsHEDUzmL1vhrC/Eg5wjH+dLR3xJpn70Bg13IMQhP99X\n", + "7a/GD8u5DFMhlFEykeU8M0AF5LVwxauGljyJ2PG9wt/W7GNjLNgsX4aFTR897+cKWdUMsr13pC8x\n", + "KjWMpGHXcQ2lKSkGzAAAAR5BmspJqEFsmUwIb//+p4QAYn+ayCPJyJ7QOf/irXuB3I7yUvrv3Wd8\n", + "OLQaJBb/+EMR1r6SAeh0um3VtQPrwYoZU0zDlMzZlECRYSRYOAqgamI/sUVWVEYaYAVab8QpucQ/\n", + "sSTh0wVtYsFYYkt/gr7uhkEpx1NPSuJ9CqWeDhMsefol+oaGZkPTooDGiCB29X8Zubhk7s13xY5c\n", + "l2KWl6cdQs8QOBu4PKBLJa04v3ctO+FHUCNJTXN7J5YnaOHn+BLPFy7A6HoUxVmuK9kB/hB9j6ln\n", + "0nykP3r6vgXJiVxtga3Ek+Zj3edZUHSAUux6bbxkCgdvPWLgxmKM0iIQ0SZS+9McjsqW/5Kw1hL5\n", + "sobdDT0GsHJ+I+IDODn9/vmRAAAGqm1vb3YAAABsbXZoZAAAAAAAAAAAAAAAAAAAA+gAAB1MAAEA\n", + "AAEAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAA\n", + "AAAAAAAAAAAAAAAAAAAAAAAAAAIAAAXUdHJhawAAAFx0a2hkAAAAAwAAAAAAAAAAAAAAAQAAAAAA\n", + "AB1MAAAAAAAAAAAAAAAAAAAAAAABAAAAAAAAAAAAAAAAAAAAAQAAAAAAAAAAAAAAAAAAQAAAAAGw\n", + "AAABIAAAAAAAJGVkdHMAAAAcZWxzdAAAAAAAAAABAAAdTAAACAAAAQAAAAAFTG1kaWEAAAAgbWRo\n", + "ZAAAAAAAAAAAAAAAAAAAKAAAASwAVcQAAAAAAC1oZGxyAAAAAAAAAAB2aWRlAAAAAAAAAAAAAAAA\n", + "VmlkZW9IYW5kbGVyAAAABPdtaW5mAAAAFHZtaGQAAAABAAAAAAAAAAAAAAAkZGluZgAAABxkcmVm\n", + "AAAAAAAAAAEAAAAMdXJsIAAAAAEAAAS3c3RibAAAALNzdHNkAAAAAAAAAAEAAACjYXZjMQAAAAAA\n", + "AAABAAAAAAAAAAAAAAAAAAAAAAGwASAASAAAAEgAAAAAAAAAAQAAAAAAAAAAAAAAAAAAAAAAAAAA\n", + "AAAAAAAAAAAAAAAAABj//wAAADFhdmNDAWQAFf/hABhnZAAVrNlBsJaEAAADAAQAAAMAUDxYtlgB\n", + "AAZo6+PLIsAAAAAcdXVpZGtoQPJfJE/FujmlG88DI/MAAAAAAAAAGHN0dHMAAAAAAAAAAQAAAEsA\n", + "AAQAAAAAFHN0c3MAAAAAAAAAAQAAAAEAAAJgY3R0cwAAAAAAAABKAAAAAQAACAAAAAABAAAUAAAA\n", + "AAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABAAAAAAAgAABAAAAAABAAAMAAAAAAEAAAQAAAAA\n", "AQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAAB\n", "AAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEA\n", - "AAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAA\n", - "CAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAM\n", + "AAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAAAwAAAAAAQAA\n", + "BAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAA\n", "AAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgA\n", "AAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAA\n", "AAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAA\n", - "AAEAABQAAAAAAQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAA\n", - "AQAABAAAAAABAAAMAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAAB\n", - "AAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAAAgAAAAAHHN0c2MAAAAAAAAAAQAAAAEA\n", - "AABkAAAAAQAAAaRzdHN6AAAAAAAAAAAAAABkAAAGhgAAAl8AAAFjAAAAvgAAAXYAAAHzAAABDgAA\n", - "ATYAAAFIAAAB9QAAAOIAAAD6AAABWgAAAbAAAADTAAAB8wAAAN4AAAH+AAABEAAAAOIAAAG2AAAC\n", - "DAAAAWUAAAGkAAABmgAAAckAAAEdAAABfQAAAPMAAAFxAAABIgAAAjYAAAEmAAAA5AAAAXoAAAH+\n", - "AAAA/wAAAT0AAAFnAAACAwAAARQAAAE3AAABTwAAAckAAADrAAACFwAAAP0AAAHzAAABIQAAAOAA\n", - "AAHKAAACOwAAAVQAAAHFAAABugAAAdQAAAD3AAABUgAAARIAAAFuAAABLwAAAhAAAAERAAAA9gAA\n", - "AZkAAAIqAAABIgAAAV0AAAGIAAACSgAAASgAAAFEAAABggAAAegAAAD+AAACCgAAASIAAAIdAAAB\n", - "KAAAAQcAAAHbAAACFgAAAT0AAAITAAAB2gAAAi8AAAEGAAABrQAAASoAAAF0AAABZgAAAl4AAAFU\n", - "AAAA+gAAAbYAAAHjAAABLwAAAZwAAAHBAAAB8QAAABRzdGNvAAAAAAAAAAEAAAAsAAAAYnVkdGEA\n", - "AABabWV0YQAAAAAAAAAhaGRscgAAAAAAAAAAbWRpcmFwcGwAAAAAAAAAAAAAAAAtaWxzdAAAACWp\n", - "dG9vAAAAHWRhdGEAAAABAAAAAExhdmY1Ny44My4xMDA=\n", + "AAEAAAwAAAAAAQAABAAAAAABAAAUAAAAAAEAAAgAAAAAAQAAAAAAAAABAAAEAAAAAAEAABQAAAAA\n", + "AQAACAAAAAABAAAAAAAAAAEAAAQAAAAAAQAAFAAAAAABAAAIAAAAAAEAAAAAAAAAAQAABAAAAAAB\n", + "AAAIAAAAABxzdHNjAAAAAAAAAAEAAAABAAAASwAAAAEAAAFAc3RzegAAAAAAAAAAAAAASwAABs8A\n", + "AAI/AAABMQAAAGEAAAD8AAABkwAAAMcAAAEbAAABNgAAALkAAAGdAAAA/QAAAMYAAADdAAABzAAA\n", + "AQEAAADcAAAARQAAAdsAAAE9AAAA0QAAAU4AAAIZAAABBwAAAV4AAAEDAAABXgAAAPMAAAD0AAAB\n", + "CQAAAaUAAACyAAABnQAAANsAAAB3AAAA8QAAAbQAAAFGAAAA+AAAAQ0AAAG7AAABKQAAAOMAAADp\n", + "AAABvwAAAPoAAADKAAAAUwAAAboAAAFQAAAA+wAAAS0AAAHqAAABDAAAAUYAAAELAAABdAAAAPAA\n", + "AAEGAAABCQAAAZ8AAAD1AAACAgAAAP4AAACCAAABBAAAAfgAAAE9AAAA7gAAASEAAAGaAAABPwAA\n", + "AOMAAADjAAABIgAAABRzdGNvAAAAAAAAAAEAAAAsAAAAYnVkdGEAAABabWV0YQAAAAAAAAAhaGRs\n", + "cgAAAAAAAAAAbWRpcmFwcGwAAAAAAAAAAAAAAAAtaWxzdAAAACWpdG9vAAAAHWRhdGEAAAABAAAA\n", + "AExhdmY1Ny44My4xMDA=\n", "\"\u003e\n", " Your browser does not support the video tag.\n", "\u003c/video\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f84b2253b50\u003e" + "\u003cIPython.core.display.HTML at 0x7f1286b190b8\u003e" ] }, "metadata": { @@ -1209,15 +790,15 @@ "source": [ "import time\n", "import traceback\n", + "import sys\n", "\n", "from matplotlib import pyplot as plt\n", "from matplotlib import animation as anim\n", - "import tensorflow as tf\n", - "from tensorflow.contrib import autograph as ag\n", + "import numpy as np\n", "from IPython import display\n", "\n", "\n", - "@ag.do_not_convert(ag.RunMode.PY_FUNC)\n", + "@tf.autograph.experimental.do_not_convert\n", "def render(boards):\n", " fig = plt.figure()\n", "\n", @@ -1237,74 +818,71 @@ " except RuntimeError:\n", " print('Coult not render animation:')\n", " traceback.print_exc()\n", + " return 1\n", + " return 0\n", "\n", "\n", "def gol_episode(board):\n", - " directions = tf.constant(\n", - " ((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)))\n", + " new_board = tf.TensorArray(tf.int32, 0, dynamic_size=True)\n", "\n", - " new_board = []\n", - " ag.set_element_type(new_board, tf.int32)\n", - "\n", - " for i in range(len(board)):\n", - " for j in range(len(board[i])):\n", - " num_neighbors = 0\n", - " for d in directions:\n", - " ni = i + d[0]\n", - " nj = j + d[1]\n", - " if ni \u003e= 0 and nj \u003e= 0 and ni \u003c len(board) and nj \u003c len(board[i]):\n", - " num_neighbors += board[ni][nj]\n", + " for i in tf.range(len(board)):\n", + " for j in tf.range(len(board[i])):\n", + " num_neighbors = tf.reduce_sum(\n", + " board[tf.maximum(i-1, 0):tf.minimum(i+2, len(board)),\n", + " tf.maximum(j-1, 0):tf.minimum(j+2, len(board[i]))]\n", + " ) - board[i][j]\n", " \n", - " new_cell = 0\n", " if num_neighbors == 2:\n", " new_cell = board[i][j]\n", " elif num_neighbors == 3:\n", " new_cell = 1\n", + " else:\n", + " new_cell = 0\n", " \n", " new_board.append(new_cell)\n", - " final_board = ag.stack(new_board)\n", + " final_board = new_board.stack()\n", " final_board = tf.reshape(final_board, board.shape)\n", " return final_board\n", " \n", "\n", + "@tf.function(experimental_autograph_options=(\n", + " tf.autograph.experimental.Feature.EQUALITY_OPERATORS,\n", + " tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS,\n", + " tf.autograph.experimental.Feature.LISTS,\n", + " ))\n", "def gol(initial_board):\n", " board = initial_board\n", - " boards = []\n", - " ag.set_element_type(boards, tf.int32)\n", - " # We are being explicit about tensor constants to ensure the loop\n", - " # is not unrolled in the graph. This may change in the future.\n", - " for i in range(tf.constant(NUM_STEPS)):\n", + " boards = tf.TensorArray(tf.int32, size=0, dynamic_size=True)\n", + "\n", + " i = 0\n", + " for i in tf.range(NUM_STEPS):\n", " board = gol_episode(board)\n", " boards.append(board)\n", - " boards = ag.stack(boards)\n", - " render(boards)\n", - " return tf.no_op()\n", + " boards = boards.stack()\n", + " tf.py_function(render, (boards,), (tf.int64,))\n", + " return i\n", " \n", "\n", - "with tf.Graph().as_default():\n", - " # Gosper glider gun\n", - " # Adapted from http://www.cplusplus.com/forum/lounge/75168/\n", - " _ = 0\n", - " initial_board = tf.constant((\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", - " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,1,_,1,1,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", - " ))\n", - " initial_board = tf.pad(initial_board, ((0, 20), (0, 10)))\n", - " \n", - " tf_gol = ag.to_graph(gol)\n", - " game_ops = tf_gol(initial_board)\n", - " with tf.Session() as sess:\n", - " sess.run(game_ops)\n" + "# Gosper glider gun\n", + "# Adapted from http://www.cplusplus.com/forum/lounge/75168/\n", + "_ = 0\n", + "initial_board = tf.constant((\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_ ),\n", + " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,1,1,_,_,_,_,_,_,_,_,1,_,_,_,1,_,1,1,_,_,_,_,1,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,1,_,_,_,_,_,1,_,_,_,_,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,1,_,_,_,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,1,1,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + " ( _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ ),\n", + "))\n", + "initial_board = tf.pad(initial_board, ((0, 10), (0, 5)))\n", + "\n", + "_ = gol(initial_board)" ] }, { @@ -1319,179 +897,21 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 0, "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "height": 2323 - }, + "colab": {}, "colab_type": "code", - "executionInfo": { - "elapsed": 753, - "status": "ok", - "timestamp": 1532101593840, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 240 - }, - "id": "hIGYeX0Cxs3i", - "outputId": "e0b62eb1-3e12-4e53-dc54-8a3fa56d823d" + "id": "hIGYeX0Cxs3i" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "from __future__ import print_function\n", - "import tensorflow as tf\n", - "\n", - "def tf__gol_episode(board):\n", - " try:\n", - " with tf.name_scope('gol_episode'):\n", - " directions = tf.constant(((-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1),\n", - " (1, -1), (1, 0), (1, 1)))\n", - " new_board = ag__.new_list([])\n", - "\n", - " def extra_test_2(new_board_2):\n", - " with tf.name_scope('extra_test_2'):\n", - " return True\n", - "\n", - " def loop_body_2(i, new_board_2):\n", - " with tf.name_scope('loop_body_2'):\n", - "\n", - " def extra_test_1(new_board_1):\n", - " with tf.name_scope('extra_test_1'):\n", - " return True\n", - "\n", - " def loop_body_1(j, new_board_1):\n", - " with tf.name_scope('loop_body_1'):\n", - " num_neighbors = 0\n", - "\n", - " def extra_test(num_neighbors_2):\n", - " with tf.name_scope('extra_test'):\n", - " return True\n", - "\n", - " def loop_body(d, num_neighbors_2):\n", - " with tf.name_scope('loop_body'):\n", - " ni = i + ag__.get_item(d, (0), opts=ag__.GetItemOpts(\n", - " element_dtype=None))\n", - " nj = j + ag__.get_item(d, (1), opts=ag__.GetItemOpts(\n", - " element_dtype=None))\n", - "\n", - " def if_true():\n", - " with tf.name_scope('if_true'):\n", - " num_neighbors_1, = num_neighbors_2,\n", - " num_neighbors_1 += ag__.get_item(ag__.get_item(board,\n", - " (ni), opts=ag__.GetItemOpts(element_dtype=None)),\n", - " (nj), opts=ag__.GetItemOpts(element_dtype=None))\n", - " return num_neighbors_1,\n", - "\n", - " def if_false():\n", - " with tf.name_scope('if_false'):\n", - " return num_neighbors_2,\n", - " num_neighbors_2 = ag__.utils.run_cond(tf.logical_and(tf.\n", - " greater_equal(ni, 0), tf.logical_and(tf.greater_equal\n", - " (nj, 0), tf.logical_and(tf.less(ni, ag__.utils.\n", - " dynamic_builtin(len, board)), tf.less(nj, ag__.utils.\n", - " dynamic_builtin(len, ag__.get_item(board, (i), opts=\n", - " ag__.GetItemOpts(element_dtype=None))))))), if_true,\n", - " if_false)\n", - " return num_neighbors_2,\n", - " num_neighbors = ag__.for_stmt(directions, extra_test,\n", - " loop_body, (num_neighbors,))\n", - " new_cell = 0\n", - "\n", - " def if_true_2():\n", - " with tf.name_scope('if_true_2'):\n", - " new_cell_2, = new_cell,\n", - " new_cell_2 = ag__.get_item(ag__.get_item(board, (i), opts\n", - " =ag__.GetItemOpts(element_dtype=None)), (j), opts=\n", - " ag__.GetItemOpts(element_dtype=None))\n", - " return new_cell_2,\n", - "\n", - " def if_false_2():\n", - " with tf.name_scope('if_false_2'):\n", - " new_cell_3, = new_cell,\n", - "\n", - " def if_true_1():\n", - " with tf.name_scope('if_true_1'):\n", - " new_cell_1, = new_cell_3,\n", - " new_cell_1 = 1\n", - " return new_cell_1,\n", - "\n", - " def if_false_1():\n", - " with tf.name_scope('if_false_1'):\n", - " return new_cell_3,\n", - " new_cell_3 = ag__.utils.run_cond(tf.equal(num_neighbors, \n", - " 3), if_true_1, if_false_1)\n", - " return new_cell_3,\n", - " new_cell = ag__.utils.run_cond(tf.equal(num_neighbors, 2),\n", - " if_true_2, if_false_2)\n", - " new_board_1 = ag__.list_append(new_board_1, new_cell)\n", - " return new_board_1,\n", - " new_board_2 = ag__.for_stmt(ag__.utils.dynamic_builtin(range,\n", - " ag__.utils.dynamic_builtin(len, ag__.get_item(board, (i),\n", - " opts=ag__.GetItemOpts(element_dtype=None)))), extra_test_1,\n", - " loop_body_1, (new_board_2,))\n", - " return new_board_2,\n", - " new_board = ag__.for_stmt(ag__.utils.dynamic_builtin(range, ag__.\n", - " utils.dynamic_builtin(len, board)), extra_test_2, loop_body_2, (\n", - " new_board,))\n", - " final_board = ag__.list_stack(new_board, opts=ag__.ListStackOpts(\n", - " element_dtype=tf.int32, original_call=ag.stack))\n", - " final_board = tf.reshape(final_board, board.shape)\n", - " return final_board\n", - " except:\n", - " ag__.rewrite_graph_construction_error(ag_source_map__)\n", - "\n", - "def tf__gol(initial_board):\n", - " try:\n", - " with tf.name_scope('gol'):\n", - " board = initial_board\n", - " boards = ag__.new_list([])\n", - "\n", - " def extra_test(board_1, boards_1):\n", - " with tf.name_scope('extra_test'):\n", - " return True\n", - "\n", - " def loop_body(i, board_1, boards_1):\n", - " with tf.name_scope('loop_body'):\n", - " board_1 = tf__gol_episode(board_1)\n", - " boards_1 = ag__.list_append(boards_1, board_1)\n", - " return board_1, boards_1\n", - " board, boards = ag__.for_stmt(ag__.utils.dynamic_builtin(range, tf.\n", - " constant(NUM_STEPS)), extra_test, loop_body, (board, boards))\n", - " boards = ag__.list_stack(boards, opts=ag__.ListStackOpts(\n", - " element_dtype=tf.int32, original_call=ag.stack))\n", - " with ag__.utils.control_dependency_on_returns(render(boards)):\n", - " boards_2 = ag__.utils.alias_tensors(boards)\n", - " return tf.no_op()\n", - " except:\n", - " ag__.rewrite_graph_construction_error(ag_source_map__)\n", - "\n" - ] - } - ], + "outputs": [], "source": [ - "print(ag.to_code(gol))" + "print(tf.autograph.to_code(gol.python_function))" ] } ], "metadata": { "colab": { - "collapsed_sections": [ - "p8zZyj-tq4K3", - "Lkq3DBGOv3fA", - "r8_0ioEuAI-a", - "7NgrSPCZxs3h" - ], - "default_view": {}, + "collapsed_sections": [], "last_runtime": { "build_target": "", "kind": "local" @@ -1503,8 +923,11 @@ "timestamp": 1528465909719 } ], - "version": "0.3.2", - "views": {} + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" } }, "nbformat": 4, diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index 51b27ea4212..1e6de7ee17e 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -214,8 +214,8 @@ class ToBigtableOp : public AsyncOpKernel { std::vector columns; columns.reserve(column_families_tensor->NumElements()); for (uint64 i = 0; i < column_families_tensor->NumElements(); ++i) { - column_families.push_back(column_families_tensor->flat()(i)); - columns.push_back(columns_tensor->flat()(i)); + column_families.push_back(column_families_tensor->flat()(i)); + columns.push_back(columns_tensor->flat()(i)); } DatasetBase* dataset; @@ -317,7 +317,7 @@ class ToBigtableOp : public AsyncOpKernel { "Iterator produced a set of Tensors shorter than expected"); } ::google::cloud::bigtable::SingleRowMutation mutation( - std::move(tensors[0].scalar()())); + std::move(tensors[0].scalar()())); std::chrono::milliseconds timestamp(timestamp_int); for (size_t i = 1; i < tensors.size(); ++i) { if (!TensorShapeUtils::IsScalar(tensors[i].shape())) { @@ -326,11 +326,11 @@ class ToBigtableOp : public AsyncOpKernel { if (timestamp_int == -1) { mutation.emplace_back(::google::cloud::bigtable::SetCell( column_families[i - 1], columns[i - 1], - std::move(tensors[i].scalar()()))); + std::move(tensors[i].scalar()()))); } else { mutation.emplace_back(::google::cloud::bigtable::SetCell( column_families[i - 1], columns[i - 1], timestamp, - std::move(tensors[i].scalar()()))); + std::move(tensors[i].scalar()()))); } } bulk_mutation->emplace_back(std::move(mutation)); diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc index 01cedd8d762..13658558bc0 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.cc @@ -67,9 +67,9 @@ Status GcpStatusToTfStatus(const ::google::cloud::Status& status) { strings::StrCat("Error reading from Cloud Bigtable: ", status.message())); } -string RegexFromStringSet(const std::vector& strs) { +string RegexFromStringSet(const std::vector& strs) { CHECK(!strs.empty()) << "The list of strings to turn into a regex was empty."; - std::unordered_set uniq(strs.begin(), strs.end()); + std::unordered_set uniq(strs.begin(), strs.end()); if (uniq.size() == 1) { return *uniq.begin(); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index 1325560e772..ce2bea0d759 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -25,7 +25,7 @@ namespace tensorflow { Status GcpStatusToTfStatus(const ::google::cloud::Status& status); -string RegexFromStringSet(const std::vector& strs); +string RegexFromStringSet(const std::vector& strs); class BigtableClientResource : public ResourceBase { public: @@ -115,6 +115,15 @@ class BigtableReaderDatasetIterator : public DatasetIterator { const ::google::cloud::bigtable::Row& row, std::vector* out_tensors) = 0; + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented("RestoreInternal is currently not supported"); + } + private: Status EnsureIteratorInitialized() EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (reader_) { diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index 8039ef8cd77..a69936236be 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -29,11 +29,11 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { core::RefCountPtr table; OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &table)); - std::vector column_families; - std::vector columns; - OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", - &column_families)); - OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); OP_REQUIRES( ctx, column_families.size() == columns.size(), errors::InvalidArgument("len(columns) != len(column_families)")); @@ -58,8 +58,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, BigtableTableResource* table, - std::vector column_families, - std::vector columns, + std::vector column_families, + std::vector columns, const DataTypeVector& output_types, std::vector output_shapes) : DatasetBase(DatasetContext(ctx)), @@ -97,18 +97,23 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { return "BigtableLookupDatasetOp::Dataset"; } + Status CheckExternalState() const override { + return errors::FailedPrecondition(DebugString(), + " depends on external state."); + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + return errors::Unimplemented(DebugString(), + " does not support serialization"); } private: static ::google::cloud::bigtable::Filter MakeFilter( - const std::vector& column_families, - const std::vector& columns) { + const std::vector& column_families, + const std::vector& columns) { string column_family_regex = RegexFromStringSet(column_families); string column_regex = RegexFromStringSet(columns); @@ -154,13 +159,13 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { ::google::cloud::StatusOr< std::pair> row = dataset()->table_->table().ReadRow( - input_tensors[0].scalar()(), dataset()->filter_); + input_tensors[0].scalar()(), dataset()->filter_); if (!row.ok()) { return GcpStatusToTfStatus(row.status()); } if (!row->first) { return errors::DataLoss("Row key '", - input_tensors[0].scalar()(), + input_tensors[0].scalar()(), "' not found."); } TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors)); @@ -172,13 +177,24 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + private: Status ParseRow(IteratorContext* ctx, const ::google::cloud::bigtable::Row& row, std::vector* out_tensors) { out_tensors->reserve(dataset()->columns_.size() + 1); Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); - row_key_tensor.scalar()() = string(row.row_key()); + row_key_tensor.scalar()() = tstring(row.row_key()); out_tensors->emplace_back(std::move(row_key_tensor)); if (row.cells().size() > 2 * dataset()->columns_.size()) { @@ -194,9 +210,9 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { for (auto cell_itr = row.cells().begin(); !found_column && cell_itr != row.cells().end(); ++cell_itr) { if (cell_itr->family_name() == dataset()->column_families_[i] && - string(cell_itr->column_qualifier()) == + tstring(cell_itr->column_qualifier()) == dataset()->columns_[i]) { - col_tensor.scalar()() = string(cell_itr->value()); + col_tensor.scalar()() = tstring(cell_itr->value()); found_column = true; } } @@ -216,8 +232,8 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { const DatasetBase* const input_; BigtableTableResource* table_; - const std::vector column_families_; - const std::vector columns_; + const std::vector column_families_; + const std::vector columns_; const DataTypeVector output_types_; const std::vector output_shapes_; const ::google::cloud::bigtable::Filter filter_; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index e9d4a1e05ea..6af5c6d0fc2 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -26,8 +26,8 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - string prefix; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + tstring prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); core::RefCountPtr resource; OP_REQUIRES_OK(ctx, @@ -71,12 +71,17 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + Status CheckExternalState() const override { + return errors::FailedPrecondition(DebugString(), + " depends on external state."); + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + return errors::Unimplemented(DebugString(), + " does not support serialization"); } private: @@ -97,7 +102,7 @@ class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { const ::google::cloud::bigtable::Row& row, std::vector* out_tensors) override { Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); - output_tensor.scalar()() = string(row.row_key()); + output_tensor.scalar()() = tstring(row.row_key()); out_tensors->emplace_back(std::move(output_tensor)); return Status::OK(); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index be3c7cc5f38..22f7ddfe15d 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -26,11 +26,11 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - string start_key; + tstring start_key; OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "start_key", &start_key)); - string end_key; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + ParseScalarArgument(ctx, "start_key", &start_key)); + tstring end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); core::RefCountPtr resource; OP_REQUIRES_OK(ctx, @@ -76,12 +76,17 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + Status CheckExternalState() const override { + return errors::FailedPrecondition(DebugString(), + " depends on external state."); + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + return errors::Unimplemented(DebugString(), + " does not support serialization"); } private: @@ -103,7 +108,7 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { const ::google::cloud::bigtable::Row& row, std::vector* out_tensors) override { Tensor output_tensor(ctx->allocator({}), DT_STRING, {}); - output_tensor.scalar()() = string(row.row_key()); + output_tensor.scalar()() = tstring(row.row_key()); out_tensors->emplace_back(std::move(output_tensor)); return Status::OK(); } diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index 880f5e40f25..08bf35f6c23 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -27,14 +27,14 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - string prefix; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + tstring prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); - string start_key; + tstring start_key; OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "start_key", &start_key)); - string end_key; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + ParseScalarArgument(ctx, "start_key", &start_key)); + tstring end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); core::RefCountPtr resource; OP_REQUIRES_OK(ctx, @@ -89,12 +89,17 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { return "BigtableSampleKeyPairsDatasetOp::Dataset"; } + Status CheckExternalState() const override { + return errors::FailedPrecondition(DebugString(), + " depends on external state."); + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + return errors::Unimplemented(DebugString(), + " does not support serialization"); } private: @@ -175,16 +180,27 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { *end_of_sequence = false; out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = keys_[index_]; + out_tensors->back().scalar()() = keys_[index_]; out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = keys_[index_ + 1]; + out_tensors->back().scalar()() = keys_[index_ + 1]; ++index_; return Status::OK(); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + private: mutex mu_; size_t index_ GUARDED_BY(mu_) = 0; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index 53be3b5a2bb..f4498305aa2 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -64,12 +64,17 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + Status CheckExternalState() const override { + return errors::FailedPrecondition(DebugString(), + " depends on external state."); + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + return errors::Unimplemented(DebugString(), + " does not support serialization"); } private: @@ -97,8 +102,8 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { if (index_ < row_keys_.size()) { out_tensors->emplace_back(ctx->allocator({}), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = - string(row_keys_[index_].row_key); + out_tensors->back().scalar()() = + tstring(row_keys_[index_].row_key); *end_of_sequence = false; index_++; } else { @@ -107,6 +112,17 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel { return Status::OK(); } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "RestoreInternal is currently not supported"); + } + private: mutex mu_; size_t index_ = 0; diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index e68c83ed547..d2b6959fef5 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -26,13 +26,13 @@ class BigtableScanDatasetOp : public DatasetOpKernel { using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - string prefix; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); - string start_key; + tstring prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "prefix", &prefix)); + tstring start_key; OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "start_key", &start_key)); - string end_key; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); + ParseScalarArgument(ctx, "start_key", &start_key)); + tstring end_key; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "end_key", &end_key)); OP_REQUIRES(ctx, !(prefix.empty() && start_key.empty()), errors::InvalidArgument( @@ -46,11 +46,11 @@ class BigtableScanDatasetOp : public DatasetOpKernel { "If prefix is specified, end_key must be empty.")); } - std::vector column_families; - std::vector columns; - OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", - &column_families)); - OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); + std::vector column_families; + std::vector columns; + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "column_families", + &column_families)); + OP_REQUIRES_OK(ctx, ParseVectorArgument(ctx, "columns", &columns)); OP_REQUIRES( ctx, column_families.size() == columns.size(), errors::InvalidArgument("len(columns) != len(column_families)")); @@ -90,8 +90,8 @@ class BigtableScanDatasetOp : public DatasetOpKernel { public: explicit Dataset(OpKernelContext* ctx, BigtableTableResource* table, string prefix, string start_key, string end_key, - std::vector column_families, - std::vector columns, float probability, + std::vector column_families, + std::vector columns, float probability, const DataTypeVector& output_types, std::vector output_shapes) : DatasetBase(DatasetContext(ctx)), @@ -131,12 +131,17 @@ class BigtableScanDatasetOp : public DatasetOpKernel { BigtableTableResource* table() const { return table_; } + Status CheckExternalState() const override { + return errors::FailedPrecondition(DebugString(), + " depends on external state."); + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + return errors::Unimplemented(DebugString(), + " does not support serialization"); } private: @@ -175,7 +180,7 @@ class BigtableScanDatasetOp : public DatasetOpKernel { std::vector* out_tensors) override { out_tensors->reserve(dataset()->columns_.size() + 1); Tensor row_key_tensor(ctx->allocator({}), DT_STRING, {}); - row_key_tensor.scalar()() = string(row.row_key()); + row_key_tensor.scalar()() = tstring(row.row_key()); out_tensors->emplace_back(std::move(row_key_tensor)); if (row.cells().size() > 2 * dataset()->columns_.size()) { @@ -191,9 +196,9 @@ class BigtableScanDatasetOp : public DatasetOpKernel { for (auto cell_itr = row.cells().begin(); !found_column && cell_itr != row.cells().end(); ++cell_itr) { if (cell_itr->family_name() == dataset()->column_families_[i] && - string(cell_itr->column_qualifier()) == + tstring(cell_itr->column_qualifier()) == dataset()->columns_[i]) { - col_tensor.scalar()() = string(cell_itr->value()); + col_tensor.scalar()() = tstring(cell_itr->value()); found_column = true; } } @@ -212,8 +217,8 @@ class BigtableScanDatasetOp : public DatasetOpKernel { const string prefix_; const string start_key_; const string end_key_; - const std::vector column_families_; - const std::vector columns_; + const std::vector column_families_; + const std::vector columns_; const string column_family_regex_; const string column_regex_; const float probability_; diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py index 4f1d7990ce6..e55c0dc7806 100644 --- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py +++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py @@ -476,7 +476,7 @@ class BigtableTable(object): if tensor_type != dtypes.string: raise ValueError("Not all elements of the dataset were `tf.string`") for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)): - if not shape.is_compatible_with(tensor_shape.scalar()): + if not shape.is_compatible_with(tensor_shape.TensorShape([])): raise ValueError("Not all elements of the dataset were scalars") if len(column_families) != len(columns): raise ValueError("len(column_families) != len(columns)") diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 5a8b2ba9caf..60f92a0ff25 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import tempfile import numpy as np +from google.protobuf import text_format from tensorflow.contrib.boosted_trees.estimator_batch import estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.python.estimator.canned import head as head_lib @@ -137,6 +139,15 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self._export_dir_base = tempfile.mkdtemp() + "export/" gfile.MkDir(self._export_dir_base) + def _assert_checkpoint_and_return_model(self, model_dir, global_step): + reader = checkpoint_utils.load_checkpoint(model_dir) + self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) + serialized = reader.get_tensor("ensemble_model:0_config") + ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig() + ensemble_proto.ParseFromString(serialized) + + return ensemble_proto + def _assert_checkpoint(self, model_dir, global_step): reader = checkpoint_utils.load_checkpoint(model_dir) self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP)) @@ -404,8 +415,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) + learner_config.regularization.tree_complexity = (1.0 / + _QUANTILE_REGRESSION_SIZE) train_input_fn, test_input_fn, y = _quantile_regression_input_fns() @@ -437,8 +448,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) + learner_config.regularization.tree_complexity = (1.0 / + _QUANTILE_REGRESSION_SIZE) train_input_fn, test_input_fn, y = _quantile_regression_input_fns( two_dimension=True) @@ -471,6 +482,329 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): self.assertTrue(frac_both_below_upper >= 0.91) self.assertTrue(frac_both_below_upper <= 0.99) + def testForcedInitialSplits(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 3 + + initial_subtree = """ + nodes { + dense_float_binary_split { + feature_column: 0 + threshold: -0.5 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + dense_float_binary_split { + feature_column: 0 + threshold: 0.52 + left_id: 3 + right_id: 4 + } + node_metadata { + gain: 0 + } + } + nodes { + dense_float_binary_split { + feature_column: 0 + threshold: 0.554 + left_id: 5 + right_id: 6 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + """ + tree_proto = tree_config_pb2.DecisionTreeConfig() + text_format.Merge(initial_subtree, tree_proto) + + # Set initial subtree info. + learner_config.each_tree_start.CopyFrom(tree_proto) + learner_config.each_tree_start_num_layers = 2 + + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=2, + examples_per_layer=6, + model_dir=model_dir, + config=config, + center_bias=False, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=False) + + classifier.fit(input_fn=_train_input_fn, steps=100) + # When no override of global steps, 5 steps were used. + ensemble = self._assert_checkpoint_and_return_model( + classifier.model_dir, global_step=6) + + # TODO(nponomareva): find a better way to test this. + expected_ensemble = """ + trees { + nodes { + dense_float_binary_split { + threshold: -0.5 + left_id: 1 + right_id: 2 + } + node_metadata { + } + } + nodes { + dense_float_binary_split { + threshold: 0.519999980927 + left_id: 3 + right_id: 4 + } + node_metadata { + } + } + nodes { + dense_float_binary_split { + threshold: 0.554000020027 + left_id: 5 + right_id: 6 + } + node_metadata { + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 7 + right_id: 8 + } + node_metadata { + gain: 0.888888895512 + } + } + nodes { + leaf { + vector { + value: -2.0 + } + } + } + nodes { + leaf { + vector { + value: 2.00000023842 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: -0.5 + left_id: 1 + right_id: 2 + } + node_metadata { + } + } + nodes { + dense_float_binary_split { + threshold: 0.519999980927 + left_id: 3 + right_id: 4 + } + node_metadata { + } + } + nodes { + dense_float_binary_split { + threshold: 0.554000020027 + left_id: 5 + right_id: 6 + } + node_metadata { + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 7 + right_id: 8 + } + node_metadata { + gain: 0.727760672569 + } + } + nodes { + leaf { + vector { + value: -1.81873059273 + } + } + } + nodes { + leaf { + vector { + value: 1.81873047352 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: -0.5 + left_id: 1 + right_id: 2 + } + node_metadata { + } + } + nodes { + dense_float_binary_split { + threshold: 0.519999980927 + left_id: 3 + right_id: 4 + } + node_metadata { + } + } + nodes { + dense_float_binary_split { + threshold: 0.554000020027 + left_id: 5 + right_id: 6 + } + node_metadata { + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + nodes { + leaf { + vector { + value: 0.0 + } + } + } + } + tree_weights: 0.10000000149 + tree_weights: 0.10000000149 + tree_weights: 0.10000000149 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 3 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 3 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 2 + } + growing_metadata { + num_layers_attempted: 3 + } + """ + self.assertProtoEquals(expected_ensemble, ensemble) + class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): @@ -674,8 +1008,8 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): learner_config.constraints.min_node_weight = 1 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l2 = 1.0 / _QUANTILE_REGRESSION_SIZE learner_config.regularization.l1 = 1.0 / _QUANTILE_REGRESSION_SIZE - learner_config.regularization.tree_complexity = ( - 1.0 / _QUANTILE_REGRESSION_SIZE) + learner_config.regularization.tree_complexity = (1.0 / + _QUANTILE_REGRESSION_SIZE) train_input_fn, test_input_fn, y = _quantile_regression_input_fns() y = y.reshape(_QUANTILE_REGRESSION_SIZE, 1) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 07fa4ca684b..477b191bcb7 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -29,6 +29,9 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import ops from tensorflow.python.ops import state_ops from tensorflow.python.training import training_util +from google.protobuf import text_format +from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 + class ModelBuilderOutputType(object): MODEL_FN_OPS = 0 @@ -106,10 +109,30 @@ def model_builder(features, training_features = copy.copy(features) training_features.pop(weight_column_name, None) global_step = training_util.get_global_step() + + initial_ensemble = "" + if learner_config.each_tree_start.nodes: + if learner_config.each_tree_start_num_layers <= 0: + raise ValueError("You must provide each_tree_start_num_layers.") + num_layers = learner_config.each_tree_start_num_layers + initial_ensemble = """ + trees { %s } + tree_weights: 0.1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: %d + is_finalized: false + } + """ % (text_format.MessageToString( + learner_config.each_tree_start), num_layers) + tree_ensemble_proto = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge(initial_ensemble, tree_ensemble_proto) + initial_ensemble = tree_ensemble_proto.SerializeToString() + with ops.device(global_step.device): ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=0, - tree_ensemble_config="", # Initialize an empty ensemble. + tree_ensemble_config=initial_ensemble, # Initialize the ensemble. name="ensemble_model") # Create GBDT model. diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 9655e49d91b..5f9976a491c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -46,7 +46,7 @@ class CreateTreeEnsembleVariableOp : public OpKernel { OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); auto* result = new DecisionTreeEnsembleResource(); - if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), + if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), stamp_token)) { result->Unref(); OP_REQUIRES( @@ -99,7 +99,7 @@ class TreeEnsembleSerializeOp : public OpKernel { Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(1, TensorShape(), &output_config_t)); - output_config_t->scalar()() = + output_config_t->scalar()() = ensemble_resource->SerializeAsString(); } }; @@ -130,7 +130,7 @@ class TreeEnsembleDeserializeOp : public OpKernel { OP_REQUIRES( context, ensemble_resource->InitFromSerialized( - tree_ensemble_config_t->scalar()(), stamp_token), + tree_ensemble_config_t->scalar()(), stamp_token), errors::InvalidArgument("Unable to parse tree ensemble config.")); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 431dc68836b..ee31a4b72c8 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -324,7 +324,7 @@ class QuantileAccumulatorAddSummariesOp : public OpKernel { context, ParseProtoUnlimited( summary_proto, - summary_list[resource_handle_idx].scalar()()), + summary_list[resource_handle_idx].scalar()()), errors::InvalidArgument("Unable to parse quantile summary.")); std::vector entries; entries.reserve(summary_proto->entries_size()); @@ -398,7 +398,7 @@ class MakeQuantileSummariesOp : public OpKernel { // Output to tensor. Tensor* output_t = nullptr; OP_REQUIRES_OK(context, output_list->allocate(index, {}, &output_t)); - summary_proto->SerializeToString(&output_t->scalar()()); + SerializeToTString(*summary_proto, &output_t->scalar()()); }; // These are blocks of ranges. We are iterating over both sparse and @@ -494,7 +494,7 @@ class QuantileAccumulatorSerializeOp : public OpKernel { for (const auto& summary : stream.SerializeInternalSummaries()) { CopySummaryToProto(summary, stream_proto->add_summaries()); } - stream_proto->SerializeToString(&stream_state_t->scalar()()); + SerializeToTString(*stream_proto, &stream_state_t->scalar()()); Tensor* buckets_t = nullptr; OP_REQUIRES_OK( context, @@ -543,7 +543,7 @@ class QuantileAccumulatorDeserializeOp : public OpKernel { ::boosted_trees::QuantileStreamState state_proto; OP_REQUIRES( context, - ParseProtoUnlimited(&state_proto, stream_state_t->scalar()()), + ParseProtoUnlimited(&state_proto, stream_state_t->scalar()()), errors::InvalidArgument("Unabnle to parse quantile stream state.")); std::vector summaries; summaries.reserve(state_proto.summaries_size()); @@ -669,7 +669,7 @@ class QuantileAccumulatorFlushSummaryOp : public OpKernel { Tensor* output_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({}), &output_t)); - summary_proto->SerializeToString(&output_t->scalar()()); + SerializeToTString(*summary_proto, &output_t->scalar()()); streams_resource->Reset(next_stamp_token); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 65276242aba..0afab357414 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -213,8 +213,8 @@ class BuildDenseInequalitySplitsOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output("split_infos", TensorShape({size_output}), &output_splits_t)); - tensorflow::TTypes::Vec output_splits = - output_splits_t->vec(); + tensorflow::TTypes::Vec output_splits = + output_splits_t->vec(); if (num_elements == 0) { return; @@ -248,7 +248,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const Tensor* gradients_t, const Tensor* hessians_t, tensorflow::TTypes::Vec* output_partition_ids, tensorflow::TTypes::Vec* gains, - tensorflow::TTypes::Vec* output_splits) { + tensorflow::TTypes::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -293,7 +293,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { state->FillLeaf(best_left_node_stats, left_child); state->FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&(*output_splits)(root_idx)); + SerializeToTString(split_info, &(*output_splits)(root_idx)); (*gains)(root_idx) = best_gain - root_stats.gain - state->tree_complexity_regularization(); (*output_partition_ids)(root_idx) = partition_ids(start_index); @@ -308,7 +308,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { const Tensor* gradients_t, const Tensor* hessians_t, tensorflow::TTypes::Vec* output_partition_ids, tensorflow::TTypes::Vec* gains, - tensorflow::TTypes::Vec* output_splits) { + tensorflow::TTypes::Vec* output_splits) { // Holds the root stats per each node to be split. std::vector current_layer_stats; current_layer_stats.reserve(num_elements); @@ -411,7 +411,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel { (*output_partition_ids)(root_idx) = partition_ids(start_index); oblivious_split_info.add_children_parent_id(partition_ids(start_index)); } - oblivious_split_info.SerializeToString(&(*output_splits)(0)); + SerializeToTString(oblivious_split_info, &(*output_splits)(0)); } }; REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), @@ -529,8 +529,8 @@ class BuildSparseInequalitySplitsOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output( "split_infos", TensorShape({num_elements}), &output_splits_t)); - tensorflow::TTypes::Vec output_splits = - output_splits_t->vec(); + tensorflow::TTypes::Vec output_splits = + output_splits_t->vec(); SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { @@ -674,7 +674,7 @@ class BuildSparseInequalitySplitsOp : public OpKernel { auto* right_child = split_info.mutable_right_child(); state.FillLeaf(best_left_node_stats, left_child); state.FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&output_splits(root_idx)); + SerializeToTString(split_info, &output_splits(root_idx)); gains(root_idx) = best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); @@ -780,8 +780,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output("split_infos", TensorShape({size_output}), &output_splits_t)); - tensorflow::TTypes::Vec output_splits = - output_splits_t->vec(); + tensorflow::TTypes::Vec output_splits = + output_splits_t->vec(); if (num_elements == 0) { return; } @@ -818,7 +818,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { const Tensor* gradients_t, const Tensor* hessians_t, tensorflow::TTypes::Vec* output_partition_ids, tensorflow::TTypes::Vec* gains, - tensorflow::TTypes::Vec* output_splits) { + tensorflow::TTypes::Vec* output_splits) { for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; @@ -873,7 +873,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { auto* right_child = split_info.mutable_right_child(); state->FillLeaf(best_left_node_stats, left_child); state->FillLeaf(best_right_node_stats, right_child); - split_info.SerializeToString(&(*output_splits)(root_idx)); + SerializeToTString(split_info, &(*output_splits)(root_idx)); (*gains)(root_idx) = best_gain - root_stats.gain - state->tree_complexity_regularization(); (*output_partition_ids)(root_idx) = partition_ids(start_index); @@ -891,7 +891,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { const Tensor* gradients_t, const Tensor* hessians_t, tensorflow::TTypes::Vec* output_partition_ids, tensorflow::TTypes::Vec* gains, - tensorflow::TTypes::Vec* output_splits) { + tensorflow::TTypes::Vec* output_splits) { // Holds the root stats per each node to be split. std::vector current_layer_stats; current_layer_stats.reserve(num_elements); @@ -992,7 +992,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel { (*output_partition_ids)(root_idx) = partition_ids(start_index); oblivious_split_info.add_children_parent_id(partition_ids(start_index)); } - oblivious_split_info.SerializeToString(&(*output_splits)(0)); + SerializeToTString(oblivious_split_info, &(*output_splits)(0)); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 91c017839b5..bf5f5d34457 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -432,6 +432,27 @@ class GrowTreeEnsembleOp : public OpKernel { if (tree_config->nodes_size() <= 0) { ensemble_resource->RemoveLastTree(); } + + if ((ensemble_resource->num_trees() == 0 || + ensemble_resource->LastTreeMetadata()->is_finalized()) && + learner_config_.has_each_tree_start() && + learner_config_.each_tree_start().nodes_size() > 0) { + DCHECK_GT(learner_config_.each_tree_start_num_layers(), 0); + // Add new dummy tree + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->AddNewTree(learning_rate); + VLOG(1) << "Adding a new forced tree"; + + *tree_config = learner_config_.each_tree_start(); + + boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = + ensemble_resource->LastTreeMetadata(); + + tree_metadata->set_is_finalized(max_tree_depth <= 1); + tree_metadata->set_num_tree_weight_updates(1); + tree_metadata->set_num_layers_grown( + learner_config_.each_tree_start_num_layers()); + } } } @@ -447,7 +468,7 @@ class GrowTreeEnsembleOp : public OpKernel { for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); - const auto& splits = splits_list[handler_id].vec(); + const auto& splits = splits_list[handler_id].vec(); OP_REQUIRES(context, partition_ids.size() == gains.size(), errors::InvalidArgument( "Inconsistent partition Ids and gains tensors: ", @@ -481,7 +502,7 @@ class GrowTreeEnsembleOp : public OpKernel { // Find best split per partition going through every feature candidate. for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { const auto& gains = gains_list[handler_id].vec(); - const auto& splits = splits_list[handler_id].vec(); + const auto& splits = splits_list[handler_id].vec(); OP_REQUIRES(context, gains.size() == 1, errors::InvalidArgument( "Gains size must be one for oblivious weak learner: ", diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py index 386dc19fc7b..04dec603667 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py @@ -60,8 +60,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): indices = [[0, 0], [0, 1], [2, 0], [3, 0]] values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = categorical_split_handler.EqualitySplitHandler( @@ -183,8 +183,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): indices = [[0, 0], [1, 0], [2, 0], [3, 0]] values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = categorical_split_handler.EqualitySplitHandler( @@ -294,8 +294,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): indices = [[0, 0], [0, 1], [2, 0], [3, 0]] values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = categorical_split_handler.EqualitySplitHandler( @@ -489,8 +489,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): indices = constant_op.constant_v1([], dtype=dtypes.int64, shape=[0, 2]) values = constant_op.constant_v1([], dtype=dtypes.int64) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = categorical_split_handler.EqualitySplitHandler( @@ -537,8 +537,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): indices = [[0, 0], [0, 1], [2, 0], [3, 0]] values = array_ops.constant([1, 2, 2, 1], dtype=dtypes.int64) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = categorical_split_handler.EqualitySplitHandler( @@ -591,8 +591,8 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase): indices = [[0, 0], [0, 1], [2, 0]] values = array_ops.constant([1, 2, 2], dtype=dtypes.int64) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = categorical_split_handler.EqualitySplitHandler( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 0e6a9f8f3a0..75881945fde 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -75,7 +75,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -261,8 +260,7 @@ class DenseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - if (self._gradient_shape == tensor_shape.scalar() and - self._hessian_shape == tensor_shape.scalar()): + if (self._gradient_shape.rank == 0 and self._hessian_shape.rank == 0): handler = make_dense_split_scalar else: handler = make_dense_split_tensor @@ -441,8 +439,7 @@ class SparseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - if (self._gradient_shape == tensor_shape.scalar() and - self._hessian_shape == tensor_shape.scalar()): + if self._gradient_shape.rank == 0 and self._hessian_shape.rank == 0: handler = make_sparse_split_scalar else: handler = make_sparse_split_tensor diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 4a1b528646e..d41463d002f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -63,8 +63,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) class_id = -1 - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, l2_regularization=1., @@ -197,8 +197,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32) class_id = -1 - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, l2_regularization=1., @@ -333,8 +333,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) class_id = -1 - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.2, l2_regularization=2., @@ -645,8 +645,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.DenseSplitHandler( @@ -720,8 +720,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13]) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.DenseSplitHandler( @@ -854,8 +854,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessians = array_ops.constant([0.12, 0.07, 0.2, 2]) partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.DenseSplitHandler( @@ -965,8 +965,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): values = array_ops.constant([0.52, 0.3, 0.52]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( @@ -1088,8 +1088,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): values = array_ops.constant([0.52, 0.3, 0.52]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( @@ -1411,8 +1411,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): values = array_ops.constant([0.52, 0.3, 0.52]) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( @@ -1481,8 +1481,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): values = constant_op.constant_v1([], dtype=dtypes.float32) sparse_column = sparse_tensor.SparseTensor(indices, values, [4, 1]) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( @@ -1565,8 +1565,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): non_empty_indices, non_empty_values, [4, 2]) non_empty_sparse_column = non_empty_sparse_column.eval(session=sess) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( @@ -1650,8 +1650,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): values = array_ops.constant([0.58]) sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1]) - gradient_shape = tensor_shape.scalar() - hessian_shape = tensor_shape.scalar() + gradient_shape = tensor_shape.TensorShape([]) + hessian_shape = tensor_shape.TensorShape([]) class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD index edddc59956a..ca3dd545489 100644 --- a/tensorflow/contrib/boosted_trees/proto/BUILD +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") package( licenses = ["notice"], # Apache 2.0 @@ -12,6 +12,9 @@ tf_proto_library( "learner.proto", ], cc_api_version = 2, + protodeps = [ + ":tree_config_proto", + ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index c49cb48cdea..fc5f158c073 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -1,9 +1,11 @@ syntax = "proto3"; -option cc_enable_arenas = true; - package tensorflow.boosted_trees.learner; +import "tensorflow/contrib/boosted_trees/proto/tree_config.proto"; + +option cc_enable_arenas = true; + // Tree regularization config. message TreeRegularizationConfig { // Classic L1/L2. @@ -149,4 +151,11 @@ message LearnerConfig { // By default we use NORMAL_DECISION_TREE as weak learner. WeakLearnerType weak_learner_type = 12; + + // If you want to enforce some splits and allow boosting to figure out the + // rest, you can provide a tree that represents the starting splits for each + // tree in the ensemble. + // Set both each_tree_start and each_tree_start_num_layers. + tensorflow.boosted_trees.trees.DecisionTreeConfig each_tree_start = 13; + int32 each_tree_start_num_layers = 14; } diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index ba459e8b812..d21a0f16621 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -32,8 +32,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar()) + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([])) with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, @@ -60,8 +60,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar()) + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([])) with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, @@ -89,8 +89,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar()) + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([])) with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, @@ -121,8 +121,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar()) + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([])) with ops.control_dependencies([accumulator.initializer]): op1 = accumulator.add( stamp_token=0, @@ -162,8 +162,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar()) + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([])) with ops.control_dependencies([accumulator.initializer]): # These will be deleted due to deserialize call. op1 = accumulator.add( @@ -199,8 +199,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): with self.cached_session() as sess: accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar()) + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([])) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], feature_ids=[[2, 0], [3, 1], [2, 0]], diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index 86fd5770a03..74a51f4e4d8 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -142,7 +142,8 @@ def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight): def _get_bias_update(grads, hess): - return array_ops.where(hess > 0, -grads / hess, array_ops.zeros_like(grads)) + return array_ops.where_v2(hess > 0, -grads / hess, + array_ops.zeros_like(grads)) class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 4dc764f9571..8083d8fac85 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -25,7 +25,6 @@ import six from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -65,7 +64,7 @@ def _move_tensors(tensors, device): # logic. zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): - if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): + if all(tensor.shape.rank == 0 for tensor in tensors): with ops.device(tensors[0].device): values = array_ops.stack(tensors) with ops.device(device): diff --git a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py index 1f6bbbf5740..62d0d0821b2 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/stats_accumulator_ops.py @@ -23,7 +23,6 @@ from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader # pylint: enable=unused-import from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import resources from tensorflow.python.training import saver from tensorflow.python.training.tracking import tracking @@ -134,8 +133,7 @@ class StatsAccumulator(tracking.TrackableResource): self._hessian_shape = hessian_shape self._container = container - if (gradient_shape == tensor_shape.scalar() and - hessian_shape == tensor_shape.scalar()): + if (gradient_shape.rank == 0 and hessian_shape.rank == 0): self._is_scalar = True else: self._is_scalar = False diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 4a13da4b5be..ffad201cbf1 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -34,6 +34,7 @@ from tensorflow.contrib.boosted_trees.python.ops import training_ops from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import feature_column_ops from tensorflow.python.feature_column import feature_column as fc_core +from tensorflow.python.feature_column import feature_column_v2 as fc_v2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -184,16 +185,20 @@ def extract_features(features, feature_columns, use_core_columns): # Make a shallow copy of features to ensure downstream usage # is unaffected by modifications in the model function. features = copy.copy(features) + # pylint: disable=protected-access + state_manager = fc_v2._StateManagerImpl(layer=None, trainable=False) if feature_columns: scope = "gbdt" with variable_scope.variable_scope(scope): feature_columns = list(feature_columns) transformed_features = collections.OrderedDict() for fc in feature_columns: - # pylint: disable=protected-access if use_core_columns: - # pylint: disable=protected-access - tensor = fc_core._transform_features(features, [fc])[fc] + if isinstance(fc, fc_v2.FeatureColumn): + tensor = fc_v2._transform_features_v2( + features, [fc], state_manager)[fc] + else: + tensor = fc_core._transform_features(features, [fc])[fc] transformed_features[fc.name] = tensor elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access @@ -368,8 +373,8 @@ class GradientBoostedDecisionTreeModel(object): if logits_dimension == 1 or learner_config.multi_class_strategy == ( learner_pb2.LearnerConfig.TREE_PER_CLASS): - self._gradient_shape = tensor_shape.scalar() - self._hessian_shape = tensor_shape.scalar() + self._gradient_shape = tensor_shape.TensorShape([]) + self._hessian_shape = tensor_shape.TensorShape([]) else: if center_bias: raise ValueError("Center bias should be False for multiclass.") @@ -838,8 +843,8 @@ class GradientBoostedDecisionTreeModel(object): # Create steps accumulator. steps_accumulator = stats_accumulator_ops.StatsAccumulator( stamp_token=0, - gradient_shape=tensor_shape.scalar(), - hessian_shape=tensor_shape.scalar(), + gradient_shape=tensor_shape.TensorShape([]), + hessian_shape=tensor_shape.TensorShape([]), name="StepsAccumulator") # Create ensemble stats summaries. summary.scalar("layer_stats/num_examples", num_layer_examples) @@ -1212,7 +1217,7 @@ class GradientBoostedDecisionTreeModel(object): def _get_weights(self, hessian_shape, hessians): """Derives weights to be used based on hessians and multiclass strategy.""" - if hessian_shape == tensor_shape.scalar(): + if hessian_shape.rank == 0: # This is tree per class. weights = hessians elif len(hessian_shape.dims) == 1: diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 728b764898a..c9f37508677 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.feature_column import feature_column_lib as core_feature_column +from tensorflow.python.feature_column import feature_column_v2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor @@ -176,6 +177,38 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllEqual(sparse_int_shapes[0].eval(), features["sparse_categorical"].dense_shape.eval()) + def testExtractFeaturesFromV2FeatureColumns(self): + """Tests feature extraction when using v2 columns.""" + with self.cached_session(): + features = {} + features["dense_float"] = array_ops.zeros([2, 1], dtypes.float32) + features["sparse_categorical"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) + feature_columns = set() + feature_columns.add(feature_column_v2.numeric_column("dense_float")) + feature_columns.add( + feature_column_v2.categorical_column_with_hash_bucket( + "sparse_categorical", hash_bucket_size=1000000)) + (fc_names, dense_floats, _, _, _, sparse_int_indices, sparse_int_values, + sparse_int_shapes) = ( + gbdt_batch.extract_features( + features, feature_columns, use_core_columns=True)) + self.assertEqual(len(fc_names), 2) + self.assertAllEqual(fc_names, ["dense_float", "sparse_categorical"]) + self.assertEqual(len(dense_floats), 1) + self.assertEqual(len(sparse_int_indices), 1) + self.assertEqual(len(sparse_int_values), 1) + self.assertEqual(len(sparse_int_shapes), 1) + self.assertAllEqual(dense_floats[0].eval(), + features["dense_float"].eval()) + self.assertAllEqual(sparse_int_indices[0].eval(), + features["sparse_categorical"].indices.eval()) + self.assertAllEqual(sparse_int_values[0].eval(), [397263, 397263]) + self.assertAllEqual(sparse_int_shapes[0].eval(), + features["sparse_categorical"].dense_shape.eval()) + def testExtractFeaturesFromCoreFeatureColumns(self): """Tests feature extraction when using core columns.""" with self.cached_session(): diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 152d8836df5..d7bbbc10a17 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -10,7 +10,7 @@ load( # For platform specific build config load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library", ) diff --git a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc index b0f9237ea27..ae6402b391e 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_reader_ops.cc @@ -66,7 +66,7 @@ class BigQueryReader : public ReaderBase { return Status::OK(); } - Status ReadLocked(string* key, string* value, bool* produced, + Status ReadLocked(tstring* key, tstring* value, bool* produced, bool* at_end) override { *at_end = false; *produced = false; @@ -153,7 +153,7 @@ class GenerateBigQueryReaderPartitionsOp : public OpKernel { context->allocate_output(0, TensorShape({num_partitions_}), &output_tensor)); - auto output = output_tensor->template flat(); + auto output = output_tensor->template flat(); for (int64 i = 0; i < num_partitions_; ++i) { BigQueryTablePartition partition; partition.set_start_index(i * partition_size); diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc index 648a219fb87..04571348272 100644 --- a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -83,8 +83,9 @@ class GcsCredentialsOpKernel : public OpKernel { RetryingGcsFileSystem* gcs = nullptr; OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); - string json_string; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + tstring json_string; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "json", &json_string)); Json::Value json; Json::Reader reader; @@ -179,13 +180,13 @@ class GcsBlockCacheOpKernel : public OpKernel { RetryingGcsFileSystem* gcs = nullptr; OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); - size_t max_cache_size, block_size, max_staleness; - OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + uint64 max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", &max_cache_size)); OP_REQUIRES_OK(ctx, - ParseScalarArgument(ctx, "block_size", &block_size)); + ParseScalarArgument(ctx, "block_size", &block_size)); OP_REQUIRES_OK( - ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); if (gcs->underlying()->block_size() == block_size && gcs->underlying()->max_bytes() == max_cache_size && diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index b15143bfc1c..2926889301a 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 1.20.2) +set(nsync_TAG 1.22.0) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index ee0f1f02835..ae6f77238c5 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -172,16 +172,6 @@ tensorflow/contrib/fused_conv tensorflow/contrib/fused_conv/kernels tensorflow/contrib/fused_conv/python tensorflow/contrib/fused_conv/python/ops -tensorflow/contrib/gan -tensorflow/contrib/gan/python -tensorflow/contrib/gan/python/estimator -tensorflow/contrib/gan/python/estimator/python -tensorflow/contrib/gan/python/eval -tensorflow/contrib/gan/python/eval/python -tensorflow/contrib/gan/python/features -tensorflow/contrib/gan/python/features/python -tensorflow/contrib/gan/python/losses -tensorflow/contrib/gan/python/losses/python tensorflow/contrib/graph_editor tensorflow/contrib/graph_editor/examples tensorflow/contrib/grid_rnn diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index be66fac66b8..5831781c2ac 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function import argparse -import collections import functools import itertools import os @@ -59,6 +58,7 @@ from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib from tensorflow.python.training.tracking import util as trackable_utils +from tensorflow.python.util.compat import collections_abc CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -1131,7 +1131,7 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): return numeric_grad.reshape(x_shape) def _GetShape(self, sess, inputs): - if not isinstance(inputs, collections.Iterable): + if not isinstance(inputs, collections_abc.Iterable): return sess.run(array_ops.shape(inputs)) else: return sess.run([array_ops.shape(x) for x in inputs]) diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 220f9934b67..d5bcdebf81a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os import shutil +import sys from tensorflow.contrib.data.python.ops import readers from tensorflow.python.data.kernel_tests import test_base @@ -40,7 +41,10 @@ class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): super(LMDBDatasetTest, self).setUp() # Copy database out because we need the path to be writable to use locks. - path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb") + # The on-disk format of an LMDB database is different on big-endian + # machines, because LMDB is a memory-mapped database. + db_file = "data.mdb" if sys.byteorder == "little" else "data_bigendian.mdb" + path = os.path.join(prefix_path, "lmdb", "testdata", db_file) self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") shutil.copy(path, self.db_path) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index d51fa2e0c5c..92d4820d60a 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -45,8 +45,8 @@ def make_csv_dataset( shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=dataset_ops.AUTOTUNE, - num_parallel_reads=1, + prefetch_buffer_size=None, + num_parallel_reads=None, sloppy=False, num_rows_for_inference=100, compression_type=None, @@ -112,7 +112,7 @@ def make_csv_dataset( batches to prefetch for performance improvement. Recommended value is the number of batches consumed per training step. Defaults to auto-tune. num_parallel_reads: Number of threads used to read CSV records from files. - If >1, the results will be interleaved. + If >1, the results will be interleaved. Defaults to `1`. sloppy: If `True`, reading performance will be improved at the cost of non-deterministic ordering. If `False`, the order of elements produced is deterministic prior to shuffling (elements are still @@ -173,9 +173,9 @@ def make_batched_features_dataset(file_pattern, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, - prefetch_buffer_size=dataset_ops.AUTOTUNE, - reader_num_threads=1, - parser_num_threads=2, + prefetch_buffer_size=None, + reader_num_threads=None, + parser_num_threads=None, sloppy_ordering=False, drop_final_batch=False): """Returns a `Dataset` of feature dictionaries from `Example` protos. @@ -248,9 +248,9 @@ def make_batched_features_dataset(file_pattern, improve performance. Recommended value is the number of batches consumed per training step. Defaults to auto-tune. reader_num_threads: Number of threads used to read `Example` records. If >1, - the results will be interleaved. + the results will be interleaved. Defaults to `1`. parser_num_threads: Number of threads to use for parsing `Example` tensors - into a dictionary of `Feature` tensors. + into a dictionary of `Feature` tensors. Defaults to `2`. sloppy_ordering: If `True`, reading performance will be improved at the cost of non-deterministic ordering. If `False`, the order of elements produced is deterministic prior to shuffling (elements are still diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index a0b2ca59d7b..ebbb9b3c052 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -1,5 +1,5 @@ load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library", "tf_pyclif_proto_library", ) diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 8730dd45f3a..926797bebf1 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -1,7 +1,7 @@ # Implementation of a prototype TF distributed computation library. load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") -load("//tensorflow/core:platform/default/distribute.bzl", "distribute_py_test") +load("//tensorflow/core/platform:default/distribute.bzl", "distribute_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_test") package( @@ -206,6 +206,7 @@ cuda_py_test( ], tags = [ "multi_and_single_gpu", + "noguitar", # b/139307796 ], ) @@ -273,6 +274,7 @@ distribute_py_test( "no_windows_gpu", "notsan", ], + xla_enable_strict_auto_jit = False, # Ignoring due to in contrib. deps = [ ":mirrored_strategy", "//tensorflow/python/distribute:tpu_strategy", diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py index 6dda497459f..1f527340d8d 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py @@ -32,11 +32,9 @@ from tensorflow.python.distribute import cross_device_ops as cross_device_ops_li from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import multi_worker_test_base -from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute import values -from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -54,7 +52,6 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.training import adam from tensorflow.python.training import training_util -from tensorflow.python.training.server_lib import ClusterSpec class MockCollectiveAllReduceStrategy(distribute_lib.StrategyV1): @@ -71,38 +68,22 @@ class MockCollectiveAllReduceStrategy(distribute_lib.StrategyV1): def create_test_objects(cluster_spec=None, task_type=None, task_id=None, - num_gpus=None, - use_core_strategy=False): + num_gpus=None): sess_config = config_pb2.ConfigProto() if num_gpus is None: num_gpus = context.num_gpus() - if use_core_strategy: - if cluster_spec and task_type and task_id is not None: - cluster_resolver = SimpleClusterResolver( - cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), - task_type=task_type, - task_id=task_id, - num_accelerators={'GPU': num_gpus}) - target = 'grpc://' + cluster_spec[task_type][task_id] - else: - cluster_resolver = SimpleClusterResolver( - ClusterSpec({}), num_accelerators={'GPU': num_gpus}) - target = '' - strategy = MockCollectiveAllReduceStrategy(cluster_resolver) - sess_config = strategy.update_config_proto(sess_config) + strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( + num_gpus_per_worker=num_gpus) + if task_type and task_id is not None: + strategy.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[task_type][task_id] else: - strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy( - num_gpus_per_worker=num_gpus) - if task_type and task_id is not None: - strategy.configure( - session_config=sess_config, - cluster_spec=cluster_spec, - task_type=task_type, - task_id=task_id) - target = 'grpc://' + cluster_spec[task_type][task_id] - else: - target = '' + target = '' return strategy, target, sess_config @@ -120,17 +101,12 @@ class CollectiveAllReduceStrategyTestBase( CollectiveAllReduceStrategyTestBase.collective_key_base += 100000 super(CollectiveAllReduceStrategyTestBase, self).setUp() - def _get_test_object(self, - task_type, - task_id, - num_gpus=0, - use_core_strategy=False): + def _get_test_object(self, task_type, task_id, num_gpus=0): strategy, target, session_config = create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id, - num_gpus=num_gpus, - use_core_strategy=use_core_strategy) + num_gpus=num_gpus) collective_keys = cross_device_utils.CollectiveKeys( group_key_start=10 + @@ -144,11 +120,7 @@ class CollectiveAllReduceStrategyTestBase( return strategy, target, session_config - def _test_minimize_loss_graph(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -215,11 +187,7 @@ class CollectiveAllReduceStrategyTestBase( # Error should go down self.assertLess(error_after, error_before) - def _test_complex_model(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _test_complex_model(self, task_type, task_id, num_gpus): d, master_target, config = self._get_test_object(task_type, task_id, num_gpus) @@ -270,11 +238,7 @@ class CollectiveAllReduceStrategyTestBase( sess.run(variables.global_variables_initializer()) sess.run(train_op) - def _test_variable_initialization(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _test_variable_initialization(self, task_type, task_id, num_gpus): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) with ops.Graph().as_default(), \ @@ -309,8 +273,7 @@ class CollectiveAllReduceStrategyTestBase( input_fn, expected_values, test_reinitialize=True, - ignore_order=False, - use_core_strategy=False): + ignore_order=False): distribution, master_target, config = self._get_test_object( task_type, task_id, num_gpus) devices = distribution.extended.worker_devices @@ -360,62 +323,41 @@ class DistributedCollectiveAllReduceStrategyTest( cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( num_workers=3, num_ps=0) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def test_num_replicas_in_sync(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def test_num_replicas_in_sync(self): distribution, _, _ = create_test_objects( cluster_spec=self._cluster_spec, task_type='worker', task_id=0, - num_gpus=2, - use_core_strategy=use_core_strategy) + num_gpus=2) num_workers = len(self._cluster_spec.get('chief', []) + self._cluster_spec.get('worker', [])) self.assertEqual(2 * num_workers, distribution.num_replicas_in_sync) @combinations.generate( - combinations.combine( - mode=['graph'], - num_gpus=[0, 1, 2], - required_gpus=1, - use_core_strategy=[True, False])) - def testMinimizeLossGraph(self, num_gpus, use_core_strategy): - self._run_between_graph_clients( - self._test_minimize_loss_graph, - self._cluster_spec, - num_gpus, - use_core_strategy=use_core_strategy) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) @combinations.generate( - combinations.combine( - mode=['graph'], - num_gpus=[0, 1, 2], - required_gpus=1, - use_core_strategy=[True, False])) - def testVariableInitialization(self, num_gpus, use_core_strategy): + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testVariableInitialization(self, num_gpus): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( self._test_variable_initialization, self._cluster_spec, - num_gpus=num_gpus, - use_core_strategy=use_core_strategy) + num_gpus=num_gpus) @combinations.generate( - combinations.combine( - mode=['graph'], - num_gpus=[0, 1, 2], - required_gpus=1, - use_core_strategy=[True, False])) - def testComplexModel(self, num_gpus, use_core_strategy): + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1)) + def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') self._run_between_graph_clients( - self._test_complex_model, - self._cluster_spec, - num_gpus=num_gpus, - use_core_strategy=use_core_strategy) + self._test_complex_model, self._cluster_spec, num_gpus=num_gpus) # TODO(yuefengz): Update how we use num_gpus and required_gpus @combinations.generate( @@ -423,9 +365,8 @@ class DistributedCollectiveAllReduceStrategyTest( mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1, - use_dataset=[True, False], - use_core_strategy=[True, False])) - def testMakeInputFnIterator(self, num_gpus, use_dataset, use_core_strategy): + use_dataset=[True, False])) + def testMakeInputFnIterator(self, num_gpus, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -452,17 +393,12 @@ class DistributedCollectiveAllReduceStrategyTest( input_fn, expected_values, test_reinitialize=use_dataset, - ignore_order=not use_dataset, - use_core_strategy=use_core_strategy) + ignore_order=not use_dataset) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testUpdateConfigProto(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testUpdateConfigProto(self): strategy, _, _ = self._get_test_object( - task_type='worker', - task_id=1, - num_gpus=2, - use_core_strategy=use_core_strategy) + task_type='worker', task_id=1, num_gpus=2) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) rewrite_options = config_proto.graph_options.rewrite_options @@ -484,29 +420,6 @@ class DistributedCollectiveAllReduceStrategyTest( self.assertEqual(['CollectiveReduce'], new_rewrite_options.scoped_allocator_opts.enable_op) - @combinations.generate(combinations.combine(mode=['eager'])) - def testEnableCollectiveOps(self): - mock_called = [False] - - # pylint: disable=dangerous-default-value - def mock_enable_collective_ops(server_def, mock_called=mock_called): - self.assertEqual('worker', server_def.job_name) - self.assertEqual(1, server_def.task_index) - self.assertEqual('grpc', server_def.protocol) - mock_called[0] = True - - def mock_configure_collective_ops(*args, **kwargs): - del args, kwargs - - with test.mock.patch.object(context.context(), 'enable_collective_ops', - mock_enable_collective_ops), \ - test.mock.patch.object(context.context(), 'configure_collective_ops', - mock_configure_collective_ops): - strategy, _, _ = self._get_test_object( - task_type='worker', task_id=1, num_gpus=2, use_core_strategy=True) - self.assertTrue(strategy.extended._std_server_started) - self.assertTrue(mock_called[0]) - class DistributedCollectiveAllReduceStrategyTestWithChief( CollectiveAllReduceStrategyTestBase, parameterized.TestCase): @@ -550,41 +463,28 @@ class LocalCollectiveAllReduceStrategy( @combinations.generate( combinations.combine( - mode=['graph', 'eager'], - num_gpus=[2, 4], - required_gpus=2, - use_core_strategy=[True, False])) - def testMinimizeLoss(self, num_gpus, use_core_strategy): + mode=['graph', 'eager'], num_gpus=[2, 4], required_gpus=2)) + def testMinimizeLoss(self, num_gpus): # Collective ops doesn't support strategy with one device. if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if context.executing_eagerly(): - strategy, _, _ = self._get_test_object( - None, None, num_gpus, use_core_strategy=use_core_strategy) + strategy, _, _ = self._get_test_object(None, None, num_gpus) self._test_minimize_loss_eager(strategy) else: - self._test_minimize_loss_graph( - None, None, num_gpus, use_core_strategy=use_core_strategy) + self._test_minimize_loss_graph(None, None, num_gpus) @combinations.generate( - combinations.combine( - mode=['graph'], - num_gpus=[2, 4], - required_gpus=2, - use_core_strategy=[True, False])) - def testComplexModel(self, num_gpus, use_core_strategy): + combinations.combine(mode=['graph'], num_gpus=[2, 4], required_gpus=2)) + def testComplexModel(self, num_gpus): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - self._test_complex_model( - None, None, num_gpus, use_core_strategy=use_core_strategy) + self._test_complex_model(None, None, num_gpus) @combinations.generate( combinations.combine( - mode=['graph', 'eager'], - required_gpus=2, - use_dataset=[True, False], - use_core_strategy=[True, False])) - def testMakeInputFnIterator(self, use_dataset, use_core_strategy): + mode=['graph', 'eager'], required_gpus=2, use_dataset=[True, False])) + def testMakeInputFnIterator(self, use_dataset): num_gpus = 2 if use_dataset: fn = lambda: dataset_ops.Dataset.range(5 * num_gpus) @@ -607,71 +507,56 @@ class LocalCollectiveAllReduceStrategy( input_fn, expected_values, test_reinitialize=use_dataset, - ignore_order=not use_dataset, - use_core_strategy=use_core_strategy) + ignore_order=not use_dataset) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testAllReduceSum(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testAllReduceSum(self): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object( - None, None, num_gpus=2, use_core_strategy=use_core_strategy) + distribution, target, config = self._get_test_object(None, None, num_gpus=2) with self.cached_session(config=config, target=target): self._test_all_reduce_sum(distribution) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testAllReduceSumGradients(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testAllReduceSumGradients(self): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object( - None, None, num_gpus=2, use_core_strategy=use_core_strategy) + distribution, target, config = self._get_test_object(None, None, num_gpus=2) with self.cached_session(config=config, target=target): self._test_all_reduce_sum_gradients(distribution) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testAllReduceSumGradientTape(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testAllReduceSumGradientTape(self): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object( - None, None, num_gpus=2, use_core_strategy=use_core_strategy) + distribution, target, config = self._get_test_object(None, None, num_gpus=2) with self.cached_session(config=config, target=target): self._test_all_reduce_sum_gradient_tape(distribution) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testAllReduceMean(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testAllReduceMean(self): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object( - None, None, num_gpus=2, use_core_strategy=use_core_strategy) + distribution, target, config = self._get_test_object(None, None, num_gpus=2) with self.cached_session(config=config, target=target): self._test_all_reduce_mean(distribution) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testAllReduceMeanGradients(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testAllReduceMeanGradients(self): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object( - None, None, num_gpus=2, use_core_strategy=use_core_strategy) + distribution, target, config = self._get_test_object(None, None, num_gpus=2) with self.cached_session(config=config, target=target): self._test_all_reduce_mean_gradients(distribution) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testAllReduceMeanGradientTape(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testAllReduceMeanGradientTape(self): if context.num_gpus() < 2: self.skipTest('Not enough GPUs') - distribution, target, config = self._get_test_object( - None, None, num_gpus=2, use_core_strategy=use_core_strategy) + distribution, target, config = self._get_test_object(None, None, num_gpus=2) with self.cached_session(config=config, target=target): self._test_all_reduce_mean_gradient_tape(distribution) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testNumpyIterator(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testNumpyIterator(self): num_gpus = 2 if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') - strategy, _, _ = self._get_test_object( - None, None, num_gpus=num_gpus, use_core_strategy=use_core_strategy) + strategy, _, _ = self._get_test_object(None, None, num_gpus=num_gpus) self._test_numpy_iterator(strategy) diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py index c97f93371bf..98195cca3c3 100644 --- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py +++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py @@ -369,7 +369,12 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) inputs = np.zeros((64, 3), dtype=np.float32) targets = np.zeros((64, 4), dtype=np.float32) @@ -399,7 +404,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32) input_b_np = np.asarray(np.random.random((64, 5)), dtype=np.float32) @@ -432,7 +441,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) inputs = np.zeros((20, 3), np.float32) targets = np.zeros((20, 4), np.float32) @@ -448,7 +461,11 @@ class TestDistributionStrategyWithNumpyArrays(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) # We take 6 input samples with each input having a dimension of 3 or 5. input_a_np = np.asarray(np.random.random((6, 3)), dtype=np.float32) @@ -478,7 +495,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -497,7 +519,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, gradient_descent.GradientDescentOptimizer(0.001), loss='mse', metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) + distribute=distribution, + experimental_run_tf_function=False) interleaved_model = get_model() interleaved_model.set_weights(user_controlled_model.get_weights()) @@ -505,7 +528,8 @@ class TestDistributionStrategyWithDatasets(test.TestCase, gradient_descent.GradientDescentOptimizer(0.001), loss='mse', metrics=['mae', keras.metrics.CategoricalAccuracy()], - distribute=distribution) + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -546,7 +570,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) input_a_np = np.random.random((10, 3)) input_b_np = np.random.random((10, 5)) @@ -578,7 +607,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -592,7 +626,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model = get_model() loss = 'mse' - model.compile(optimizer(), loss, distribute=distribution) + model.compile( + optimizer(), + loss, + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -605,7 +643,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) inputs = np.zeros((10, 3), np.float32) targets = np.zeros((10, 4), np.float32) @@ -633,7 +675,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) # Wrong input shape inputs = np.zeros((10, 5), dtype=np.float32) @@ -660,7 +706,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) # User forgets to batch the dataset inputs = np.zeros((10, 3), dtype=np.float32) @@ -692,7 +742,12 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent.GradientDescentOptimizer(0.005) loss = 'mse' metrics = ['acc'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) batch_size = 8 if isinstance(distribution, mirrored_strategy.CoreMirroredStrategy): @@ -727,7 +782,11 @@ class TestDistributionStrategyWithDatasets(test.TestCase, optimizer = gradient_descent_keras.SGD(0.01) loss = 'mse' - model.compile(optimizer, loss, distribute=distribution) + model.compile( + optimizer, + loss, + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -761,7 +820,12 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -816,7 +880,12 @@ class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase): optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae'] - model.compile(optimizer, loss, metrics=metrics, distribute=distribution) + model.compile( + optimizer, + loss, + metrics=metrics, + distribute=distribution, + experimental_run_tf_function=False) dataset = get_dataset(distribution) @@ -856,9 +925,11 @@ class TestDistributionStrategyWithLossMasking(test.TestCase, model.add( keras.layers.TimeDistributed( keras.layers.Dense(1, kernel_initializer='one'))) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution, + experimental_run_tf_function=False) y = np.array([[[1], [1]], [[1], [1]]]) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) @@ -877,9 +948,11 @@ class TestDistributionStrategyWithNormalizationLayer( model = keras.models.Sequential() norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8) model.add(norm) - model.compile(loss='mse', - optimizer=gradient_descent.GradientDescentOptimizer(0.01), - distribute=distribution) + model.compile( + loss='mse', + optimizer=gradient_descent.GradientDescentOptimizer(0.01), + distribute=distribution, + experimental_run_tf_function=False) # centered on 5.0, variance 10.0 x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10)) @@ -924,7 +997,8 @@ class TestDistributionStrategyCorrectness(test.TestCase, loss=keras.losses.mean_squared_error, optimizer=gradient_descent.GradientDescentOptimizer(0.5), metrics=[keras.metrics.BinaryAccuracy()], - distribute=distribution) + distribute=distribution, + experimental_run_tf_function=False) batch_size = 64 if not distributed_training_utils.global_batch_size_supported( @@ -950,7 +1024,8 @@ class TestDistributionStrategyCorrectness(test.TestCase, loss='mae', metrics=['accuracy', keras.metrics.BinaryAccuracy()], optimizer=gradient_descent.GradientDescentOptimizer(0.001), - distribute=distribution) + distribute=distribution, + experimental_run_tf_function=False) # verify correctness of stateful and stateless metrics. x = np.ones((100, 4)).astype('float32') @@ -1026,7 +1101,8 @@ class TestDistributionStrategyCorrectness(test.TestCase, loss=keras.losses.mean_squared_error, optimizer=gradient_descent_keras.SGD(0.5), metrics=['mse'], - distribute=with_distribution) + distribute=with_distribution, + experimental_run_tf_function=False) training_inputs, eval_inputs, predict_inputs = ( get_correctness_test_inputs(use_numpy, use_validation_data, diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py index 12926cfa164..a4d5f0cf5a1 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py @@ -24,17 +24,14 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import parameter_server_strategy from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.distribute import central_storage_strategy from tensorflow.python.distribute import combinations from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.distribute import parameter_server_strategy as core_parameter_server_strategy from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute import values -from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.estimator import run_config @@ -69,42 +66,24 @@ def create_test_objects(cluster_spec=None, task_type=None, task_id=None, num_gpus=None, - sess_config=None, - use_core_strategy=False): + sess_config=None): sess_config = sess_config or config_pb2.ConfigProto() if num_gpus is None: num_gpus = context.num_gpus() - if use_core_strategy: - if cluster_spec and task_type and task_id is not None: - cluster_resolver = SimpleClusterResolver( - cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), - task_type=task_type, - task_id=task_id, - num_accelerators={'GPU': num_gpus}) - distribution = core_parameter_server_strategy.ParameterServerStrategy( - cluster_resolver) - target = 'grpc://' + cluster_spec[WORKER][task_id] - else: - distribution = ( - central_storage_strategy.CentralStorageStrategy._from_num_gpus( - num_gpus)) - target = '' + distribution = parameter_server_strategy.ParameterServerStrategy( + num_gpus_per_worker=num_gpus) + + if task_type: sess_config = copy.deepcopy(sess_config) - sess_config = distribution.update_config_proto(sess_config) + distribution.configure( + session_config=sess_config, + cluster_spec=cluster_spec, + task_type=task_type, + task_id=task_id) + target = 'grpc://' + cluster_spec[WORKER][task_id] else: - distribution = parameter_server_strategy.ParameterServerStrategy( - num_gpus_per_worker=num_gpus) - if task_type: - sess_config = copy.deepcopy(sess_config) - distribution.configure( - session_config=sess_config, - cluster_spec=cluster_spec, - task_type=task_type, - task_id=task_id) - target = 'grpc://' + cluster_spec[WORKER][task_id] - else: - target = '' + target = '' return distribution, target, sess_config @@ -122,27 +101,17 @@ class ParameterServerStrategyTestBase( self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) super(ParameterServerStrategyTestBase, self).setUp() - def _get_test_objects(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _get_test_objects(self, task_type, task_id, num_gpus): return create_test_objects( cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id, num_gpus=num_gpus, - sess_config=self._sess_config, - use_core_strategy=use_core_strategy) + sess_config=self._sess_config) - def _test_device_assignment_distributed(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) - d, _, sess_config = self._get_test_objects( - task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) with ops.Graph().as_default(), \ self.cached_session(target=self._default_target, config=sess_config) as sess, \ @@ -240,9 +209,8 @@ class ParameterServerStrategyTestBase( self.assertEqual(f_val, 46.0) def _test_device_assignment_distributed_enable_partitioner( - self, task_type, task_id, num_gpus, use_core_strategy=False): - d, _, sess_config = self._get_test_objects( - task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + self, task_type, task_id, num_gpus): + d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) num_shards = len(d.extended.parameter_devices) partitioner = partitioned_variables.fixed_size_partitioner(num_shards) with ops.Graph().as_default(), \ @@ -390,13 +358,9 @@ class ParameterServerStrategyTestBase( self.assertEqual(z_val, 43.0) self.assertEqual(f_val, 46.0) - def _test_simple_increment(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _test_simple_increment(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + task_type, task_id, num_gpus) if d.extended._cluster_spec: num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) if 'chief' in d.extended._cluster_spec.as_dict(): @@ -462,13 +426,9 @@ class ParameterServerStrategyTestBase( self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) self.assertEqual(z_val, 30.0 + 1.0 * num_workers) - def _test_minimize_loss_graph(self, - task_type, - task_id, - num_gpus, - use_core_strategy=False): + def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): d, master_target, sess_config = self._get_test_objects( - task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + task_type, task_id, num_gpus) if task_type: # Multi-worker assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec @@ -561,10 +521,9 @@ class ParameterServerStrategyTestBase( input_fn, expected_values, test_reinitialize=True, - ignore_order=False, - use_core_strategy=False): + ignore_order=False): distribution, master_target, config = self._get_test_objects( - task_type, task_id, num_gpus, use_core_strategy=use_core_strategy) + task_type, task_id, num_gpus) devices = distribution.extended.worker_devices with ops.Graph().as_default(), \ @@ -613,84 +572,62 @@ class ParameterServerStrategyTest( num_workers=3, num_ps=2) cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def test_num_replicas_in_sync(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def test_num_replicas_in_sync(self): + strategy, _, _ = create_test_objects(num_gpus=2) # All the devices on a given worker are in sync which in this case is the # number of gpus on each worker. self.assertEqual(2, strategy.num_replicas_in_sync) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testDeviceAssignmentLocalCPU(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=0, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testDeviceAssignmentLocalCPU(self): + strategy, _, _ = create_test_objects(num_gpus=0) self._test_device_assignment_local( strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testDeviceAssignmentLocalOneGPU(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=1, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testDeviceAssignmentLocalOneGPU(self): + strategy, _, _ = create_test_objects(num_gpus=1) self._test_device_assignment_local( strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testDeviceAssignmentLocalTwoGPUs(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testDeviceAssignmentLocalTwoGPUs(self): + strategy, _, _ = create_test_objects(num_gpus=2) self._test_device_assignment_local( strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) @combinations.generate( - combinations.combine( - mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) - def testDeviceAssignmentDistributed(self, num_gpus, use_core_strategy): - self._test_device_assignment_distributed( - 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributed(self, num_gpus): + self._test_device_assignment_distributed('worker', 1, num_gpus) @combinations.generate( - combinations.combine( - mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) - def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus, - use_core_strategy): + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): self._test_device_assignment_distributed_enable_partitioner( - 'worker', 1, num_gpus, use_core_strategy=use_core_strategy) + 'worker', 1, num_gpus) + + @combinations.generate(combinations.combine(mode=['graph'])) + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, context.num_gpus()) @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testSimpleBetweenGraph(self, use_core_strategy): - self._run_between_graph_clients( - self._test_simple_increment, - self._cluster_spec, - context.num_gpus(), - use_core_strategy=use_core_strategy) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testLocalSimpleIncrement(self, num_gpus): + self._test_simple_increment(None, 0, num_gpus) @combinations.generate( - combinations.combine( - mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) - def testLocalSimpleIncrement(self, num_gpus, use_core_strategy): - self._test_simple_increment(None, 0, num_gpus, use_core_strategy) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraphDistributed(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) @combinations.generate( - combinations.combine( - mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) - def testMinimizeLossGraphDistributed(self, num_gpus, use_core_strategy): - self._run_between_graph_clients( - self._test_minimize_loss_graph, - self._cluster_spec, - num_gpus, - use_core_strategy=use_core_strategy) - - @combinations.generate( - combinations.combine( - mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) - def testMinimizeLossGraphLocal(self, num_gpus, use_core_strategy): - self._test_minimize_loss_graph(None, None, num_gpus, use_core_strategy) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraphLocal(self, num_gpus): + self._test_minimize_loss_graph(None, None, num_gpus) # TODO(priyag): Refactor this and other multi worker tests. @combinations.generate( @@ -698,10 +635,8 @@ class ParameterServerStrategyTest( mode=['graph'], num_gpus=[1, 2], required_gpus=1, - use_core_strategy=[True, False], use_dataset=[True, False])) - def testMakeInputFnIteratorDistributed( - self, num_gpus, use_core_strategy, use_dataset): + def testMakeInputFnIteratorDistributed(self, num_gpus, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -726,18 +661,15 @@ class ParameterServerStrategyTest( input_fn, expected_values, test_reinitialize=use_dataset, - ignore_order=not use_dataset, - use_core_strategy=use_core_strategy) + ignore_order=not use_dataset) @combinations.generate( combinations.combine( mode=['graph'], num_gpus=[1, 2], required_gpus=1, - use_core_strategy=[True, False], use_dataset=[True, False])) - def testMakeInputFnIteratorLocal(self, num_gpus, use_core_strategy, - use_dataset): + def testMakeInputFnIteratorLocal(self, num_gpus, use_dataset): if context.num_gpus() < num_gpus: self.skipTest('Not enough GPUs') if use_dataset: @@ -762,24 +694,20 @@ class ParameterServerStrategyTest( input_fn, expected_values, test_reinitialize=use_dataset, - ignore_order=not use_dataset, - use_core_strategy=use_core_strategy) + ignore_order=not use_dataset) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testGlobalStepUpdate(self, use_core_strategy): - strategy, _, _ = create_test_objects(use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testGlobalStepUpdate(self): + strategy, _, _ = create_test_objects() self._test_global_step_update(strategy) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testUpdateConfigProtoMultiWorker(self, use_core_strategy): + @combinations.generate(combinations.combine(mode=['graph'])) + def testUpdateConfigProtoMultiWorker(self): strategy, _, _ = create_test_objects( cluster_spec=self._cluster_spec, task_type='worker', task_id=1, - num_gpus=2, - use_core_strategy=use_core_strategy) + num_gpus=2) config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) @@ -792,11 +720,9 @@ class ParameterServerStrategyTest( # Verify isolate_session_state self.assertFalse(new_config.isolate_session_state) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testUpdateConfigProtoLocal(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testUpdateConfigProtoLocal(self): + strategy, _, _ = create_test_objects(num_gpus=2) config_proto = config_pb2.ConfigProto() new_config = strategy.update_config_proto(config_proto) @@ -854,30 +780,20 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, num_workers=3, num_ps=2, has_chief=True) cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testSimpleBetweenGraph(self, use_core_strategy): - self._run_between_graph_clients( - self._test_simple_increment, - self._cluster_spec, - context.num_gpus(), - use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testSimpleBetweenGraph(self): + self._run_between_graph_clients(self._test_simple_increment, + self._cluster_spec, context.num_gpus()) @combinations.generate( - combinations.combine( - mode=['graph'], num_gpus=[0, 1, 2], use_core_strategy=[True, False])) - def testMinimizeLossGraph(self, num_gpus, use_core_strategy): - self._run_between_graph_clients( - self._test_minimize_loss_graph, - self._cluster_spec, - num_gpus, - use_core_strategy=use_core_strategy) + combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) + def testMinimizeLossGraph(self, num_gpus): + self._run_between_graph_clients(self._test_minimize_loss_graph, + self._cluster_spec, num_gpus) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testGlobalStepIsWrappedOnTwoGPUs(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testGlobalStepIsWrappedOnTwoGPUs(self): + strategy, _, _ = create_test_objects(num_gpus=2) with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() @@ -889,11 +805,9 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, self.assertIs(values.AggregatingVariable, type(get_step)) self.assertIs(strategy, created_step.distribute_strategy) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testGlobalStepIsNotWrappedOnOneGPU(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=1, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testGlobalStepIsNotWrappedOnOneGPU(self): + strategy, _, _ = create_test_objects(num_gpus=1) with ops.Graph().as_default(), strategy.scope(): created_step = training_util.create_global_step() get_step = training_util.get_global_step() @@ -908,11 +822,9 @@ class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, self.assertFalse(hasattr(strategy, 'distribute_strategy')) self.assertIs(strategy, created_step._distribute_strategy) - @combinations.generate( - combinations.combine(mode=['graph'], use_core_strategy=[True, False])) - def testValueContainer(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) + @combinations.generate(combinations.combine(mode=['graph'])) + def testValueContainer(self): + strategy, _, _ = create_test_objects(num_gpus=2) with ops.Graph().as_default(), strategy.scope(): def f(): @@ -930,11 +842,9 @@ class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase, parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'], - use_core_strategy=[True, False], required_gpus=2)) - def testNumpyDataset(self, use_core_strategy): - strategy, _, _ = create_test_objects( - num_gpus=2, use_core_strategy=use_core_strategy) + def testNumpyDataset(self): + strategy, _, _ = create_test_objects(num_gpus=2) self._test_numpy_dataset(strategy) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index f502a0b8279..87c920efa2b 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -513,6 +513,7 @@ cuda_py_test( "//tensorflow/python:platform_test", ], tags = ["nomsan"], # disable to avoid false positives from scipy. + xla_enable_strict_auto_jit = False, ) cuda_py_test( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index dc18eb3df69..8b61d4be63c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -461,13 +462,14 @@ class AffineBijectorTest(test.TestCase): def testNoBatchMultivariateRaisesWhenSingular(self): with self.cached_session(): mu = [1., -1] - bijector = Affine( - shift=mu, - # Has zero on the diagonal. - scale_diag=[0., 1], - validate_args=True) - with self.assertRaisesOpError("diagonal part must be non-zero"): - bijector.forward([1., 1.]).eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "diagonal part must be non-zero"): + _ = Affine( + shift=mu, + # Has zero on the diagonal. + scale_diag=[0., 1], + validate_args=True) + # Error detected statically; don't need to run the op. def _makeScale(self, x, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 79eadf524b5..f3d63da373a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite @@ -150,6 +151,27 @@ class _ReshapeBijectorTest(object): with self.assertRaisesError(expected_error_message): sess.run(bijector.forward_event_shape_tensor(shape_in), feed_dict=feed_dict) + + def _testInvalidDimensionsStatic(self, expected_error_message): + """Version of _testInvalidDimensionsOpError for errors detected statically. + + Statically means at graph construction time. + + Args: + expected_error_message: String that should be present in the error + message that `Reshape` raises for invalid shapes. + """ + shape_in, shape_out, _ = self.build_shapes([2, 3], [ + 1, + 2, + -2, + ]) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + expected_error_message): + _ = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) # pylint: enable=invalid-name def testValidButNonMatchingInputOpError(self): @@ -300,9 +322,9 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): assert_bijective_and_finite( bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) - def testInvalidDimensionsOpError(self): - self._testInvalidDimensionsOpError( - "Invalid value in tensor used for shape: -2") + def testInvalidDimensionsStatic(self): + self._testInvalidDimensionsStatic( + "elements must be either positive integers or `-1`") def testInputOutputMismatchOpError(self): self._testInputOutputMismatchOpError("Cannot reshape a tensor with") diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py index e805619041d..2e7ab3ecfd2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus +from tensorflow.python.framework import errors from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -43,9 +44,10 @@ class SoftplusBijectorTest(test.TestCase): def testHingeSoftnessZeroRaises(self): with self.cached_session(): - bijector = Softplus(hinge_softness=0., validate_args=True) - with self.assertRaisesOpError("must be non-zero"): - bijector.forward([1., 1.]).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "must be non-zero"): + _ = Softplus(hinge_softness=0., validate_args=True) + # Error detected statically; don't need to run op. def testBijectorForwardInverseEventDimsZero(self): with self.cached_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py index 4411d6f4611..f5d6944d166 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import cauchy as cauchy_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -400,9 +401,10 @@ class CauchyTest(test.TestCase): def testCauchyNegativeLocFails(self): with self.cached_session(): - cauchy = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True) - with self.assertRaisesOpError("Condition x > 0 did not hold"): - cauchy.mode().eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Condition x > 0 did not hold"): + _ = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True) + # Error detected statically; no need for _.mode().eval() def testCauchyShape(self): with self.cached_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py index 36fc7a70c8a..bdcf6f39445 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -40,11 +41,10 @@ class DeterministicTest(test.TestCase): def testInvalidTolRaises(self): loc = rng.rand(2, 3, 4).astype(np.float32) - deterministic = deterministic_lib.Deterministic( - loc, atol=-1, validate_args=True) - with self.cached_session(): - with self.assertRaisesOpError("Condition x >= 0"): - deterministic.prob(0.).eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Condition x >= 0"): + _ = deterministic_lib.Deterministic(loc, atol=-1, validate_args=True) + # Error detected statically; no need for _.prob(0.).eval() def testProbWithNoBatchDimsIntegerType(self): deterministic = deterministic_lib.Deterministic(0) @@ -195,16 +195,16 @@ class VectorDeterministicTest(test.TestCase): def testInvalidTolRaises(self): loc = rng.rand(2, 3, 4).astype(np.float32) - deterministic = deterministic_lib.VectorDeterministic( - loc, atol=-1, validate_args=True) - with self.cached_session(): - with self.assertRaisesOpError("Condition x >= 0"): - deterministic.prob(loc).eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Condition x >= 0"): + _ = deterministic_lib.VectorDeterministic( + loc, atol=-1, validate_args=True) + # Error detected statically; no need for _.prob(loc).eval() def testInvalidXRaises(self): loc = rng.rand(2, 3, 4).astype(np.float32) deterministic = deterministic_lib.VectorDeterministic( - loc, atol=-1, validate_args=True) + loc, atol=None, validate_args=True) with self.cached_session(): with self.assertRaisesRegexp(ValueError, "must have rank at least 1"): deterministic.prob(0.).eval() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py index 686de9d2465..3ed96e6fdb8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -41,6 +42,7 @@ def try_import(name): # pylint: disable=invalid-name tf_logging.warning("Could not import %s: %s" % (name, str(e))) return module + stats = try_import("scipy.stats") @@ -288,9 +290,10 @@ class HalfNormalTest(test.TestCase): def testNegativeSigmaFails(self): with self.cached_session(): - halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") - with self.assertRaisesOpError("Condition x > 0 did not hold"): - halfnorm.mean().eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Condition x > 0 did not hold"): + _ = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") + # Error detected statically; no need for _.mean().eval() def testHalfNormalShape(self): with self.cached_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py index 70551d89d9c..7c46674cc04 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py @@ -22,6 +22,7 @@ from scipy import stats from tensorflow.contrib.distributions.python.ops import inverse_gamma from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test @@ -249,7 +250,8 @@ class InverseGammaTest(test.TestCase): fails += 0 if self._kstest(a, b, s) else 1 self.assertLess(fails, trials * 0.03) - def _kstest(self, alpha, beta, samples): + @staticmethod + def _kstest(alpha, beta, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. ks, _ = stats.kstest(samples, stats.invgamma(alpha, scale=beta).cdf) # Return True when the test passes. @@ -295,16 +297,18 @@ class InverseGammaTest(test.TestCase): with self.cached_session(): alpha_v = constant_op.constant(0.0, name="alpha") beta_v = constant_op.constant(1.0, name="beta") - inv_gamma = inverse_gamma.InverseGamma( - concentration=alpha_v, rate=beta_v, validate_args=True) - with self.assertRaisesOpError("alpha"): - inv_gamma.mean().eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "alpha"): + _ = inverse_gamma.InverseGamma( + concentration=alpha_v, rate=beta_v, validate_args=True) + # Error detected statically; no need for _.mean().eval() alpha_v = constant_op.constant(1.0, name="alpha") beta_v = constant_op.constant(0.0, name="beta") - inv_gamma = inverse_gamma.InverseGamma( - concentration=alpha_v, rate=beta_v, validate_args=True) - with self.assertRaisesOpError("beta"): - inv_gamma.mean().eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "beta"): + _ = inverse_gamma.InverseGamma( + concentration=alpha_v, rate=beta_v, validate_args=True) + # Error detected statically; no need for _.mean().eval() def testInverseGammaWithSoftplusConcentrationRate(self): with self.cached_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py index 07528cafaf1..82257e136ba 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py @@ -21,6 +21,7 @@ import numpy as np from scipy import stats from tensorflow.contrib import distributions as distributions_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl @@ -361,15 +362,14 @@ class QuantizedDistributionTest(test.TestCase): def testLowerCutoffMustBeBelowUpperCutoffOrWeRaise(self): with self.cached_session(): - qdist = distributions.QuantizedDistribution( - distribution=distributions.Normal(loc=0., scale=1.), - low=1., # not strictly less than high. - high=1., - validate_args=True) - - self.assertTrue(qdist.validate_args) # Default is True. - with self.assertRaisesOpError("must be strictly less"): - qdist.sample().eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "must be strictly less"): + _ = distributions.QuantizedDistribution( + distribution=distributions.Normal(loc=0., scale=1.), + low=1., # not strictly less than high. + high=1., + validate_args=True) + # Error detected statically; no need for _.sample().eval() def testCutoffsMustBeIntegerValuedIfValidateArgsTrue(self): with self.cached_session(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py index fec23749286..aa90dae88bb 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py @@ -94,12 +94,11 @@ class RelaxedBernoulliTest(test.TestCase): """If validate_args, raises InvalidArgumentError when temperature is 0.""" temperature = constant_op.constant(0.0) p = constant_op.constant([0.1, 0.4]) - dist = relaxed_bernoulli.RelaxedBernoulli(temperature, probs=p, - validate_args=True) - with self.cached_session(): - sample = dist.sample() - with self.assertRaises(errors_impl.InvalidArgumentError): - sample.eval() + with self.assertRaisesWithPredicateMatch(errors_impl.InvalidArgumentError, + "x > 0 did not hold"): + _ = relaxed_bernoulli.RelaxedBernoulli( + temperature, probs=p, validate_args=True) + # Error detected statically; no need to run the op. def testDtype(self): temperature = constant_op.constant(1.0, dtype=dtypes.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py index cdee30bbc42..c924a22c290 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py @@ -382,7 +382,7 @@ class WishartCholeskyTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "cannot be less than"): distributions.WishartCholesky( df=2, scale=chol_scale, validate_args=False) - with self.assertRaisesRegexp(TypeError, "Argument tril must have dtype"): + with self.assertRaisesRegexp(TypeError, "."): distributions.WishartCholesky( df=4., scale=np.asarray( diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index d4503790888..e174596defd 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -191,10 +191,8 @@ class BatchReshape(distribution_lib.Distribution): self.distribution.survival_function, x) def _entropy(self): - return self._call_and_reshape_output( - self.distribution.entropy, - [], - [tensor_shape.scalar()]) + return self._call_and_reshape_output(self.distribution.entropy, [], + [tensor_shape.TensorShape([])]) def _mean(self): return self._call_and_reshape_output(self.distribution.mean) @@ -381,7 +379,7 @@ def calculate_reshape(original_shape, new_shape, validate=False, name=None): size_implicit_dim = ( original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) new_ndims = array_ops.shape(new_shape) - expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + expanded_new_shape = array_ops.where_v2( # Assumes exactly one `-1`. implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) validations = [] if not validate else [ check_ops.assert_rank( diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index fcc8898f6eb..2e0fd592c6c 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -22,7 +22,6 @@ from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops @@ -254,8 +253,6 @@ class Affine(bijector.Bijector): super(Affine, self).__init__( forward_min_event_ndims=1, graph_parents=( - [self._scale] if tensor_util.is_tensor(self._scale) - else self._scale.graph_parents + [self._shift] if self._shift is not None else []), is_constant_jacobian=True, dtype=dtype, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 91301f15ad8..722d843f7f4 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -141,7 +141,6 @@ class AffineLinearOperator(bijector.Bijector): raise TypeError("scale is not an instance of tf.LinearOperator") if validate_args and not scale.is_non_singular: raise ValueError("Scale matrix must be non-singular.") - graph_parents += scale.graph_parents if scale.tensor_rank is not None: batch_ndims = scale.tensor_rank - 2 else: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 241fba2cb7e..aee3a603d2b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -43,7 +43,7 @@ __all__ = [ warn_once=True) def _sqrtx2p1(x): """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" - return array_ops.where( + return array_ops.where_v2( math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., math_ops.sqrt(x**2. + 1.), # For large x, calculating x**2 can overflow. This can be alleviated by diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index b349e5966dd..38505c172f6 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -68,9 +68,9 @@ def _bdtr(k, n, p): # where(unsafe, safe_output, betainc(where(unsafe, safe_input, input))) ones = array_ops.ones_like(n - k) k_eq_n = math_ops.equal(k, n) - safe_dn = array_ops.where(k_eq_n, ones, n - k) + safe_dn = array_ops.where_v2(k_eq_n, ones, n - k) dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p) - return array_ops.where(k_eq_n, ones, dk) + return array_ops.where_v2(k_eq_n, ones, dk) class Binomial(distribution.Distribution): @@ -230,7 +230,7 @@ class Binomial(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) @distribution_util.AppendDocstring(_binomial_sample_note) def _log_prob(self, counts): diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index c461833b9ae..6b1a022a312 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -173,7 +173,7 @@ class Cauchy(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 507c5d36794..0d57a2ddc60 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -281,7 +281,7 @@ class Deterministic(_BaseDeterministic): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _prob(self, x): return math_ops.cast( diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 85692d271b6..e6acae57a40 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -305,7 +305,7 @@ def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): ValueError: If the last dimension of `loc` is determined statically to be different than the range of `scale`. """ - with ops.name_scope(name, values=[loc] + scale.graph_parents): + with ops.name_scope(name, values=[loc]): # Get event shape. event_size = scale.range_dimension_tensor() event_size_const = tensor_util.constant_value(event_size) @@ -475,10 +475,9 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(mixture_distribution) cat_batch_ndims = _get_ndims(categorical_distribution) - pad_ndims = array_ops.where( - categorical_distribution.is_scalar_batch(), - dist_batch_ndims, - dist_batch_ndims - cat_batch_ndims) + pad_ndims = array_ops.where_v2(categorical_distribution.is_scalar_batch(), + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index d62f024aa2a..0b5c47056f3 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -132,7 +132,7 @@ class Geometric(distribution.Distribution): return array_ops.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): # Uniform variates must be sampled from the open-interval `(0, 1)` rather diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index 4b50df5b481..341d63f573b 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -178,7 +178,7 @@ class _Gumbel(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): # Uniform variates must be sampled from the open-interval `(0, 1)` rather diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index f1216370869..1f04090b3ac 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -150,7 +150,7 @@ class HalfNormal(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 9f1e9d5cd1b..e55b4a1457a 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -187,7 +187,7 @@ class InverseGamma(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) @distribution_util.AppendDocstring( """Note: See `tf.random.gamma` docstring for sampling details and @@ -236,7 +236,7 @@ class InverseGamma(distribution.Distribution): self.batch_shape_tensor(), np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), name="nan") - return array_ops.where(self.concentration > 1., mean, nan) + return array_ops.where_v2(self.concentration > 1., mean, nan) else: return control_flow_ops.with_dependencies([ check_ops.assert_less( @@ -257,7 +257,7 @@ class InverseGamma(distribution.Distribution): self.batch_shape_tensor(), np.array(np.nan, dtype=self.dtype.as_numpy_dtype()), name="nan") - return array_ops.where(self.concentration > 2., var, nan) + return array_ops.where_v2(self.concentration > 2., var, nan) else: return control_flow_ops.with_dependencies([ check_ops.assert_less( diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index e3712dd84e3..56f35c28b1b 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -235,7 +235,7 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution): np.array(np.nan, dtype=self.dtype.as_numpy_dtype), name="nan") is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) - return array_ops.where(is_defined, mode, nan) + return array_ops.where_v2(is_defined, mode, nan) return control_flow_ops.with_dependencies([ check_ops.assert_less( diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 21c9b5a3544..03c5ba2997a 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -173,7 +173,7 @@ class Logistic(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): # Uniform variates must be sampled from the open-interval `(0, 1)` rather diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 8fdc99824b6..f9b51cc5a62 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -186,7 +186,7 @@ class MultivariateNormalLinearOperator( if not scale.dtype.is_floating: raise TypeError("`scale` parameter must have floating-point dtype.") - with ops.name_scope(name, values=[loc] + scale.graph_parents) as name: + with ops.name_scope(name, values=[loc]) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc @@ -329,8 +329,7 @@ def _kl_brute_force(a, b, name=None): isinstance(x, linalg.LinearOperatorScaledIdentity) or isinstance(x, linalg.LinearOperatorDiag)) - with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc] + - a.scale.graph_parents + b.scale.graph_parents): + with ops.name_scope(name, "kl_mvn", values=[a.loc, b.loc]): # Calculation is based on: # http://stats.stackexchange.com/questions/60680/kl-divergence-between-two-multivariate-gaussians # and, diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index 6acfc5746a0..faf9827c8bf 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -145,7 +145,7 @@ class NegativeBinomial(distribution.Distribution): return array_ops.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): # Here we use the fact that if: @@ -190,10 +190,9 @@ class NegativeBinomial(distribution.Distribution): return self.total_count * math_ops.exp(self.logits) def _mode(self): - adjusted_count = array_ops.where( - 1. < self.total_count, - self.total_count - 1., - array_ops.zeros_like(self.total_count)) + adjusted_count = array_ops.where_v2(1. < self.total_count, + self.total_count - 1., + array_ops.zeros_like(self.total_count)) return math_ops.floor(adjusted_count * math_ops.exp(self.logits)) def _variance(self): diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index 3d055085cc7..64c41c57d79 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -151,7 +151,7 @@ class Poisson(distribution.Distribution): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) @distribution_util.AppendDocstring(_poisson_sample_note) def _log_prob(self, x): diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 85683e3233d..b23a3231d27 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -355,7 +355,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): self.mixture_distribution.logits.shape)[:-1] def _event_shape(self): - return tensor_shape.scalar() + return tensor_shape.TensorShape([]) def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 19d88d5ab5d..1be2dd1c719 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -457,9 +457,9 @@ class _DistributionShape(object): batch_shape = s[1:1+self.batch_ndims] # Since sample_dims=1 and is left-most, we add 1 to the number of # batch_ndims to get the event start dim. - event_start = array_ops.where( - math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0), - 2, 1 + self.batch_ndims) + event_start = array_ops.where_v2( + math_ops.logical_and(expand_batch_dim, self._batch_ndims_is_0), 2, + 1 + self.batch_ndims) event_shape = s[event_start:event_start+self.event_ndims] new_shape = array_ops.concat([sample_shape, batch_shape, event_shape], 0) x = array_ops.reshape(x, shape=new_shape) diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index f9748466c2e..f17ac136406 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -524,8 +524,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): parameters=parameters, graph_parents=( distribution._graph_parents # pylint: disable=protected-access - + [loc_ for loc_ in loc if loc_ is not None] - + [p for scale_ in scale for p in scale_.graph_parents]), + + [loc_ for loc_ in loc if loc_ is not None]), name=name) @property @@ -1060,5 +1059,5 @@ def softmax(x, axis, name=None): if axis_ is not None: axis = np.int(ndims + axis_ if axis_ < 0 else axis_) else: - axis = array_ops.where(axis < 0, ndims + axis, axis) + axis = array_ops.where_v2(axis < 0, ndims + axis, axis) return nn_ops.softmax(x, axis=axis) diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index fd5bf9ecc72..9dcd60dab5a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -191,7 +191,7 @@ class VectorExponentialLinearOperator( if not scale.dtype.is_floating: raise TypeError("`scale` parameter must have floating-point dtype.") - with ops.name_scope(name, values=[loc] + scale.graph_parents) as name: + with ops.name_scope(name, values=[loc]) as name: # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 67d2ccd28d6..313046db9ba 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -207,7 +207,7 @@ class VectorLaplaceLinearOperator( if not scale.dtype.is_floating: raise TypeError("`scale` parameter must have floating-point dtype.") - with ops.name_scope(name, values=[loc] + scale.graph_parents): + with ops.name_scope(name, values=[loc]): # Since expand_dims doesn't preserve constant-ness, we obtain the # non-dynamic value if possible. loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index a5bb880bed9..8b819053f92 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -170,8 +170,7 @@ class _WishartLinearOperator(distribution.Distribution): allow_nan_stats=allow_nan_stats, reparameterization_type=distribution.FULLY_REPARAMETERIZED, parameters=parameters, - graph_parents=([self._df, self._dimension] + - self._scale_operator.graph_parents), + graph_parents=[self._df, self._dimension], name=name) @property @@ -400,10 +399,9 @@ class _WishartLinearOperator(distribution.Distribution): def _mode(self): s = self.df - self.dimension - 1. - s = array_ops.where( + s = array_ops.where_v2( math_ops.less(s, 0.), - constant_op.constant(float("NaN"), dtype=self.dtype, name="nan"), - s) + constant_op.constant(float("NaN"), dtype=self.dtype, name="nan"), s) if self.cholesky_input_output_matrices: return math_ops.sqrt(s) * self.scale_operator.to_dense() return s * self._square_scale_operator() diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index 48925b1bfac..0bbece7d6c3 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -25,9 +25,9 @@ import numpy as np from tensorflow.contrib import lookup from tensorflow.contrib.eager.python import datasets -from tensorflow.python.data import Dataset from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.experimental.ops import unique +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -44,24 +44,24 @@ class IteratorTest(test.TestCase): def testBasic(self): got = [] - for t in datasets.Iterator(Dataset.range(4)): + for t in datasets.Iterator(dataset_ops.Dataset.range(4)): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) def testBasicOneShotIterator(self): got = [] - for t in Dataset.range(4).make_one_shot_iterator(): + for t in dataset_ops.Dataset.range(4).make_one_shot_iterator(): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) def testBasicImplicitIterator(self): got = [] - for t in Dataset.range(4): + for t in dataset_ops.Dataset.range(4): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) def testGetNext(self): - iterator = datasets.Iterator(Dataset.range(4)) + iterator = datasets.Iterator(dataset_ops.Dataset.range(4)) self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) self.assertEqual(2, iterator.get_next().numpy()) @@ -70,7 +70,7 @@ class IteratorTest(test.TestCase): iterator.get_next() def testGetNextOneShotIterator(self): - iterator = Dataset.range(4).make_one_shot_iterator() + iterator = dataset_ops.Dataset.range(4).make_one_shot_iterator() self.assertEqual(0, iterator.get_next().numpy()) self.assertEqual(1, iterator.get_next().numpy()) self.assertEqual(2, iterator.get_next().numpy()) @@ -79,7 +79,7 @@ class IteratorTest(test.TestCase): iterator.get_next() def testMultipleIteratorsOnTheSameDataset(self): - ds = Dataset.range(4) + ds = dataset_ops.Dataset.range(4) it1 = datasets.Iterator(ds) it2 = datasets.Iterator(ds) got = [x.numpy() for x in it1] @@ -89,8 +89,10 @@ class IteratorTest(test.TestCase): self.assertAllEqual([0, 1, 2, 3], got) def testNestedOutputs(self): - ds = Dataset.zip((Dataset.range(4), Dataset.zip((Dataset.range(4), - Dataset.range(4))))) + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(4), + dataset_ops.Dataset.zip( + (dataset_ops.Dataset.range(4), dataset_ops.Dataset.range(4))))) total = 0 # The Iterator will return a nested structure of Tensor objects. # Some funkiness to compare against simple integers. @@ -102,10 +104,12 @@ class IteratorTest(test.TestCase): self.assertEqual(4, total) def testMapAndFilter(self): + def even(x): return math_ops.equal(math_ops.mod(x, 2), 0) - it = datasets.Iterator(Dataset.range(8).map(math_ops.square).filter(even)) + it = datasets.Iterator( + dataset_ops.Dataset.range(8).map(math_ops.square).filter(even)) got = [x.numpy() for x in it] self.assertAllEqual([0, 4, 16, 36], got) @@ -115,14 +119,16 @@ class IteratorTest(test.TestCase): values = constant_op.constant([0, 1, 2], dtypes.int64) table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), default_val) - dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery']) + dataset = dataset_ops.Dataset.from_tensor_slices( + ['brain', 'salad', 'surgery']) dataset = dataset.map(table.lookup) it = datasets.Iterator(dataset) got = [x.numpy() for x in it] self.assertAllEqual([0, 1, 2], got) def testMultipleIteratorsOnADatasetThatUsesFunctions(self): - ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square) + ds = dataset_ops.Dataset.from_tensor_slices([1, 2, 3, 4, 5, + 6]).map(math_ops.square) got1 = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual([1, 4, 9, 16, 25, 36], got1) @@ -172,7 +178,7 @@ class IteratorTest(test.TestCase): ] for i, result in enumerate( - datasets.Iterator(Dataset.from_tensor_slices(components))): + datasets.Iterator(dataset_ops.Dataset.from_tensor_slices(components))): self.assertSparseValuesEqual(expected[i][0], result[0]) self.assertSparseValuesEqual(expected[i][1], result[1]) @@ -181,20 +187,20 @@ class IteratorTest(test.TestCase): def my_map(inp): return [[x + 1 for x in inp]] - ds = Dataset.range(4).map( + ds = dataset_ops.Dataset.range(4).map( lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64)) got = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual([[1], [2], [3], [4]], got) def testTensorsPlacedOnDevice(self): - ds = Dataset.from_tensors([0., 1.]) + ds = dataset_ops.Dataset.from_tensors([0., 1.]) with ops.device(test.gpu_device_name()): x = datasets.Iterator(ds).next() x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) def testGpuTensor(self): - ds = Dataset.from_tensors([0., 1.]) + ds = dataset_ops.Dataset.from_tensors([0., 1.]) with ops.device(test.gpu_device_name()): for x in ds: y = math_ops.add(x, x) @@ -213,7 +219,7 @@ class IteratorTest(test.TestCase): for num_threads in [1, 2, 4, 8, 16]: dataset = ( - Dataset.range(1000).map( + dataset_ops.Dataset.range(1000).map( lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), num_parallel_calls=32).apply(unique.unique())) @@ -235,8 +241,13 @@ class IteratorTest(test.TestCase): def testSaveRestore(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') - dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset_ops.Dataset.from_tensor_slices( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) + # TODO(b/138399725): Re-enable default optimizations. + options = dataset_ops.Options() + options.experimental_optimization.apply_default_optimizations = False + dataset = dataset.with_options(options) iterator = datasets.Iterator(dataset) checkpoint = trackable_utils.Checkpoint(iterator=iterator) self.assertAllEqual([1, 4], iterator.get_next().numpy()) @@ -250,11 +261,16 @@ class IteratorTest(test.TestCase): def testSaveRestoreMultipleIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') - dataset = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + dataset = dataset_ops.Dataset.from_tensor_slices( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) dataset = dataset.map(math_ops.square).batch(2) + # TODO(b/138399725): Re-enable default optimizations. + options = dataset_ops.Options() + options.experimental_optimization.apply_default_optimizations = False + dataset = dataset.with_options(options) iterator_1 = datasets.Iterator(dataset) iterator_2 = datasets.Iterator(dataset) - dataset_2 = Dataset.range(10) + dataset_2 = dataset_ops.Dataset.range(10) iterator_3 = datasets.Iterator(dataset_2) checkpoint = trackable_utils.Checkpoint( @@ -276,7 +292,7 @@ class IteratorTest(test.TestCase): def testRestoreExhaustedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') - dataset = Dataset.range(3) + dataset = dataset_ops.Dataset.range(3) iterator = datasets.Iterator(dataset) checkpoint = trackable_utils.Checkpoint(iterator=iterator) @@ -290,12 +306,12 @@ class IteratorTest(test.TestCase): def testRestoreInReconstructedIterator(self): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') - dataset = Dataset.range(10) + dataset = dataset_ops.Dataset.range(10) for i in range(5): iterator = datasets.Iterator(dataset) checkpoint = trackable_utils.Checkpoint(iterator=iterator) - checkpoint.restore(checkpoint_management.latest_checkpoint( - checkpoint_directory)) + checkpoint.restore( + checkpoint_management.latest_checkpoint(checkpoint_directory)) for j in range(2): self.assertEqual(i * 2 + j, iterator.get_next().numpy()) checkpoint.save(file_prefix=checkpoint_prefix) @@ -311,8 +327,8 @@ class DatasetConstructorBenchmark(test.Benchmark): input_data = np.random.randn(input_size) dataset = ( - Dataset.from_tensor_slices(input_data).repeat(num_epochs) - .batch(batch_size)) + dataset_ops.Dataset.from_tensor_slices(input_data).repeat( + num_epochs).batch(batch_size)) iterator = datasets.Iterator(dataset) ends = [time.time()] @@ -321,10 +337,8 @@ class DatasetConstructorBenchmark(test.Benchmark): deltas = np.ediff1d(ends) median_wall_time = np.median(deltas) - print( - 'Slice/repeat/batch eager input size: %d batch size: %d Median wall ' - 'time per element: %f' - % (input_size, batch_size, median_wall_time)) + print('Slice/repeat/batch eager input size: %d batch size: %d Median wall ' + 'time per element: %f' % (input_size, batch_size, median_wall_time)) self.report_benchmark( iters=len(deltas), wall_time=median_wall_time, @@ -339,8 +353,8 @@ class DatasetConstructorBenchmark(test.Benchmark): input_data = np.random.randn(input_size) dataset = ( - Dataset.from_tensor_slices(input_data).batch(batch_size).cache() - .repeat(num_epochs)) + dataset_ops.Dataset.from_tensor_slices(input_data).batch( + batch_size).cache().repeat(num_epochs)) iterator = datasets.Iterator(dataset) ends = [time.time()] @@ -349,10 +363,9 @@ class DatasetConstructorBenchmark(test.Benchmark): deltas = np.ediff1d(ends) median_wall_time = np.median(deltas) - print( - 'Slice/batch/cache/repeat eager input size: %d batch size: %d Median ' - 'wall time per element: %f' - % (input_size, batch_size, median_wall_time)) + print('Slice/batch/cache/repeat eager input size: %d batch size: %d Median ' + 'wall time per element: %f' % + (input_size, batch_size, median_wall_time)) self.report_benchmark( iters=len(deltas), wall_time=median_wall_time, diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 512605a17eb..cabc71c98e1 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -117,7 +117,7 @@ "source": [ "# Download the file\n", "path_to_zip = tf.keras.utils.get_file(\n", - " 'spa-eng.zip', origin='http://download.tensorflow.org/data/spa-eng.zip', \n", + " 'spa-eng.zip', origin='https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip', \n", " extract=True)\n", "\n", "path_to_file = os.path.dirname(path_to_zip)+\"/spa-eng/spa.txt\"" diff --git a/tensorflow/contrib/eager/python/examples/revnet/blocks.py b/tensorflow/contrib/eager/python/examples/revnet/blocks.py index f61354bc38a..221b0766225 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/blocks.py +++ b/tensorflow/contrib/eager/python/examples/revnet/blocks.py @@ -61,7 +61,7 @@ class RevBlock(tf.keras.Model): fused: use fused batch normalization if True dtype: float16, float32, or float64 """ - super(RevBlock, self).__init__() + super(RevBlock, self).__init__(dtype=dtype) self.blocks = tf.contrib.checkpoint.List() for i in range(n_res): curr_batch_norm_first = batch_norm_first and i == 0 @@ -135,7 +135,7 @@ class _Residual(tf.keras.Model): fused: use fused batch normalization if True dtype: float16, float32, or float64 """ - super(_Residual, self).__init__() + super(_Residual, self).__init__(dtype=dtype) self.filters = filters self.strides = strides @@ -283,7 +283,7 @@ class _BottleneckResidualInner(tf.keras.Model): fused: use fused batch normalization if True dtype: float16, float32, or float64 """ - super(_BottleneckResidualInner, self).__init__() + super(_BottleneckResidualInner, self).__init__(dtype=dtype) axis = 1 if data_format == "channels_first" else 3 if batch_norm_first: self.batch_norm_0 = tf.keras.layers.BatchNormalization( @@ -365,7 +365,7 @@ class _ResidualInner(tf.keras.Model): fused: use fused batch normalization if True dtype: float16, float32, or float64 """ - super(_ResidualInner, self).__init__() + super(_ResidualInner, self).__init__(dtype=dtype) axis = 1 if data_format == "channels_first" else 3 if batch_norm_first: self.batch_norm_0 = tf.keras.layers.BatchNormalization( @@ -416,7 +416,7 @@ class InitBlock(tf.keras.Model): Args: config: tf.contrib.training.HParams object; specifies hyperparameters """ - super(InitBlock, self).__init__() + super(InitBlock, self).__init__(config.dtype) self.config = config self.axis = 1 if self.config.data_format == "channels_first" else 3 self.conv2d = tf.keras.layers.Conv2D( @@ -430,7 +430,8 @@ class InitBlock(tf.keras.Model): dtype=self.config.dtype) self.batch_norm = tf.keras.layers.BatchNormalization( axis=self.axis, fused=self.config.fused, dtype=self.config.dtype) - self.activation = tf.keras.layers.Activation("relu") + self.activation = tf.keras.layers.Activation("relu", + dtype=self.config.dtype) if self.config.init_max_pool: self.max_pool = tf.keras.layers.MaxPooling2D( @@ -464,7 +465,7 @@ class FinalBlock(tf.keras.Model): Raises: ValueError: Unsupported data format """ - super(FinalBlock, self).__init__() + super(FinalBlock, self).__init__(dtype=config.dtype) self.config = config self.axis = 1 if self.config.data_format == "channels_first" else 3 @@ -488,7 +489,8 @@ class FinalBlock(tf.keras.Model): input_shape=input_shape, fused=self.config.fused, dtype=self.config.dtype) - self.activation = tf.keras.layers.Activation("relu") + self.activation = tf.keras.layers.Activation("relu", + dtype=self.config.dtype) self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D( data_format=self.config.data_format, dtype=self.config.dtype) self.dense = tf.keras.layers.Dense( diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet.py b/tensorflow/contrib/eager/python/examples/revnet/revnet.py index 7406787ba43..08f2d8d6f17 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet.py @@ -37,7 +37,7 @@ class RevNet(tf.keras.Model): Args: config: tf.contrib.training.HParams object; specifies hyperparameters """ - super(RevNet, self).__init__() + super(RevNet, self).__init__(dtype=config.dtype) self.axis = 1 if config.data_format == "channels_first" else 3 self.config = config diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 5c55f7f597b..e04de0579b1 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import numbers from six.moves import xrange # pylint: disable=redefined-builtin @@ -42,6 +41,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import resource_loader +from tensorflow.python.util.compat import collections_abc _factorization_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_factorization_ops.so")) @@ -388,7 +388,7 @@ class WALSModel(object): return None init_mode = "list" - if isinstance(wt_init, collections.Iterable): + if isinstance(wt_init, collections_abc.Iterable): if num_shards == 1 and len(wt_init) == num_wts: wt_init = [wt_init] assert len(wt_init) == num_shards @@ -641,9 +641,9 @@ class WALSModel(object): extras = size % num_shards assignments = math_ops.maximum(ids // (ids_per_shard + 1), (ids - extras) // ids_per_shard) - new_ids = array_ops.where(assignments < extras, - ids % (ids_per_shard + 1), - (ids - extras) % ids_per_shard) + new_ids = array_ops.where_v2(assignments < extras, + ids % (ids_per_shard + 1), + (ids - extras) % ids_per_shard) return assignments, new_ids return func diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index ca65ad45326..32e62a6725f 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -135,9 +135,10 @@ class DecodeAudioOpV2 : public OpKernel { "channel_count must be a rank-0 tensor but got shape ", channel_count_tensor.shape().DebugString())); - const tensorflow::StringPiece contents = contents_tensor.scalar()(); + const tensorflow::StringPiece contents = + contents_tensor.scalar()(); const string file_format = - absl::AsciiStrToLower(file_format_tensor.scalar()()); + absl::AsciiStrToLower(file_format_tensor.scalar()()); const int32 samples_per_second = samples_per_second_tensor.scalar()(); const int32 channel_count = channel_count_tensor.scalar()(); @@ -243,7 +244,7 @@ class DecodeAudioOp : public OpKernel { errors::InvalidArgument("contents must be scalar but got shape ", contents.shape().DebugString())); - const tensorflow::StringPiece file_contents = contents.scalar()(); + const tensorflow::StringPiece file_contents = contents.scalar()(); Decode(context, file_contents, file_format_, samples_per_second_, channel_count_, ""); } diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc index 6f8ad486d10..0bfdc2781aa 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -45,7 +45,8 @@ class DecodeVideoOp : public OpKernel { errors::InvalidArgument( "contents must be a rank-0 tensor but got shape ", contents_tensor.shape().DebugString())); - const tensorflow::StringPiece contents = contents_tensor.scalar()(); + const tensorflow::StringPiece contents = + contents_tensor.scalar()(); // Write the input data to a temp file. string extension; diff --git a/tensorflow/contrib/ffmpeg/encode_audio_op.cc b/tensorflow/contrib/ffmpeg/encode_audio_op.cc index 7de09e062ec..ee418fb9020 100644 --- a/tensorflow/contrib/ffmpeg/encode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/encode_audio_op.cc @@ -45,7 +45,7 @@ void Encode(OpKernelContext* context, const Tensor& contents, // Copy the encoded audio file to the output tensor. Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output)); - output->scalar()() = encoded_audio; + output->scalar()() = encoded_audio; } } // namespace @@ -95,7 +95,7 @@ class EncodeAudioOpV2 : public OpKernel { bits_per_second_tensor.shape().DebugString())); const string file_format = - absl::AsciiStrToLower(file_format_tensor.scalar()()); + absl::AsciiStrToLower(file_format_tensor.scalar()()); const int32 samples_per_second = samples_per_second_tensor.scalar()(); const int32 bits_per_second = bits_per_second_tensor.scalar()(); diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py index 6dd887edf59..811df7a55ae 100644 --- a/tensorflow/contrib/framework/python/framework/checkpoint_utils.py +++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils.py @@ -21,6 +21,7 @@ from __future__ import print_function import six +from tensorflow.python.framework import ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -116,9 +117,10 @@ def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, name: Name of the operation. """ base_type = variable.dtype.base_dtype - restore_op = io_ops.restore_v2( - file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0] - variable._initializer_op = state_ops.assign(variable, restore_op) + with ops.device(variable.device), ops.device("/cpu:0"): + restore_op = io_ops.restore_v2( + file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0] + variable._initializer_op = state_ops.assign(variable, restore_op) def _set_variable_or_list_initializer(variable_or_list, file_pattern, diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index c26fdb1f0a2..8ef11109da9 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -229,13 +229,17 @@ class LaunchFusedConv2DBiasActivationOp { // (1) Scale and add bias. // NOTE(ezhulenev): We do not use Eigen expressions for this loop, // because it seems that packet FMA produces slightly different results, - // and we are targeting bit-by-bit equality with Nvidia implementation. + // and we are targeting close equality with Nvidia implementation. + // We could use std::fmaf, but it can be ~50x slower, on machines + // without fma instruction. for (int idx = 0; idx < num_rows; ++idx) { - conv_output_ptr[idx] = - std::fmaf(conv_output_ptr[idx], conv_input_scale, bias_ptr[idx]); + conv_output_ptr[idx] = static_cast(conv_output_ptr[idx]) * + static_cast(conv_input_scale) + + static_cast(bias_ptr[idx]); if (side_input_scale != 0.0f) { - conv_output_ptr[idx] = std::fmaf( - side_input_ptr[idx], side_input_scale, conv_output_ptr[idx]); + conv_output_ptr[idx] = static_cast(side_input_ptr[idx]) * + static_cast(side_input_scale) + + static_cast(conv_output_ptr[idx]); } } @@ -561,6 +565,14 @@ void LogFusedConvForwardAutotuneResults( *log.mutable_cudnn_version() = GetCudnnVersion(stream_exec); *log.mutable_compute_capability() = GetComputeCapability(stream_exec); log.set_device_pci_bus_id(stream_exec->GetDeviceDescription().pci_bus_id()); + { + string blas_version; + if (auto* blas = stream_exec->AsBlas()) { + if (blas->GetVersion(&blas_version).ok()) { + log.set_blas_version(blas_version); + } + } + } for (const auto& result : results) { *log.add_results() = result; } diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD deleted file mode 100644 index ddd04947e9b..00000000000 --- a/tensorflow/contrib/gan/BUILD +++ /dev/null @@ -1,778 +0,0 @@ -# Files for using TF-GAN framework. - -load("//tensorflow:tensorflow.bzl", "py_test") - -package( - default_visibility = [ - "//tensorflow:__subpackages__", - ], - licenses = ["notice"], # Apache 2.0 -) - -exports_files(["LICENSE"]) - -py_library( - name = "gan", - srcs = [ - "__init__.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":estimator", - ":eval", - ":features", - ":losses", - ":namedtuples", - ":train", - "//tensorflow/python:util", - ], -) - -py_library( - name = "namedtuples", - srcs = ["python/namedtuples.py"], - srcs_version = "PY2AND3", -) - -py_library( - name = "train", - srcs = ["python/train.py"], - srcs_version = "PY2AND3", - deps = [ - ":losses", - ":namedtuples", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/slim:learning", - "//tensorflow/contrib/training:training_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/losses", - ], -) - -py_test( - name = "train_test", - srcs = ["python/train_test.py"], - python_version = "PY2", - shard_count = 50, - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":namedtuples", - ":random_tensor_pool", - ":train", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/slim:learning", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/ops/distributions", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_library( - name = "eval", - srcs = ["python/eval/__init__.py"], - srcs_version = "PY2AND3", - deps = [ - ":classifier_metrics", - ":eval_utils", - ":sliced_wasserstein", - ":summaries", - "//tensorflow/python:util", - ], -) - -py_library( - name = "estimator", - srcs = ["python/estimator/__init__.py"], - srcs_version = "PY2AND3", - deps = [ - ":gan_estimator", - ":head", - ":latent_gan_estimator", - ":stargan_estimator", - ":tpu_gan_estimator", - "//tensorflow/python:util", - ], -) - -py_library( - name = "losses", - srcs = ["python/losses/__init__.py"], - srcs_version = "PY2AND3", - deps = [ - ":losses_impl", - ":tuple_losses", - "//tensorflow/python:util", - ], -) - -py_library( - name = "features", - srcs = ["python/features/__init__.py"], - srcs_version = "PY2AND3", - deps = [ - ":clip_weights", - ":conditioning_utils", - ":random_tensor_pool", - ":spectral_normalization", - ":virtual_batchnorm", - "//tensorflow/python:util", - ], -) - -py_library( - name = "losses_impl", - srcs = ["python/losses/python/losses_impl.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:clip_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients_impl", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/losses", - ], -) - -py_test( - name = "losses_impl_test", - srcs = ["python/losses/python/losses_impl_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":losses_impl", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:clip_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/ops/distributions", - "//tensorflow/python/ops/losses", - ], -) - -py_library( - name = "tuple_losses", - srcs = [ - "python/losses/python/losses_wargs.py", - "python/losses/python/tuple_losses.py", - "python/losses/python/tuple_losses_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":losses_impl", - ":namedtuples", - "//tensorflow/python:util", - ], -) - -py_test( - name = "tuple_losses_test", - srcs = ["python/losses/python/tuple_losses_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":losses_impl", - ":namedtuples", - ":tuple_losses", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:math_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_library( - name = "conditioning_utils", - srcs = [ - "python/features/python/conditioning_utils.py", - "python/features/python/conditioning_utils_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:embedding_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - ], -) - -py_test( - name = "conditioning_utils_test", - srcs = ["python/features/python/conditioning_utils_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":conditioning_utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - ], -) - -py_library( - name = "random_tensor_pool", - srcs = [ - "python/features/python/random_tensor_pool.py", - "python/features/python/random_tensor_pool_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:data_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:util", - ], -) - -py_test( - name = "random_tensor_pool_test", - srcs = ["python/features/python/random_tensor_pool_test.py"], - python_version = "PY2", - shard_count = 6, - srcs_version = "PY2AND3", - deps = [ - ":random_tensor_pool", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//third_party/py/numpy", - ], -) - -py_library( - name = "virtual_batchnorm", - srcs = [ - "python/features/python/virtual_batchnorm.py", - "python/features/python/virtual_batchnorm_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - ], -) - -py_test( - name = "virtual_batchnorm_test", - srcs = ["python/features/python/virtual_batchnorm_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":virtual_batchnorm", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:layers", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - -py_library( - name = "clip_weights", - srcs = [ - "python/features/python/clip_weights.py", - "python/features/python/clip_weights_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/opt:opt_py", - "//tensorflow/python:util", - ], -) - -py_test( - name = "clip_weights_test", - srcs = ["python/features/python/clip_weights_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":clip_weights", - "//tensorflow/python:client_testlib", - "//tensorflow/python:training", - "//tensorflow/python:variables", - ], -) - -py_library( - name = "classifier_metrics", - srcs = [ - "python/eval/python/classifier_metrics.py", - "python/eval/python/classifier_metrics_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework", - "//tensorflow/python:framework_ops", - "//tensorflow/python:functional_ops", - "//tensorflow/python:image_ops", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:platform", - "//tensorflow/python:util", - "@six_archive//:six", - ], -) - -py_test( - name = "classifier_metrics_test", - srcs = ["python/eval/python/classifier_metrics_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - tags = [ - "no_pip", - "no_windows", - ], - deps = [ - ":classifier_metrics", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:variables", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - ], -) - -py_library( - name = "eval_utils", - srcs = [ - "python/eval/python/eval_utils.py", - "python/eval/python/eval_utils_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:util", - ], -) - -py_test( - name = "eval_utils_test", - srcs = ["python/eval/python/eval_utils_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":eval_utils", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - ], -) - -py_library( - name = "summaries", - srcs = [ - "python/eval/python/summaries.py", - "python/eval/python/summaries_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":eval_utils", - ":namedtuples", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:functional_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:summary", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/losses", - ], -) - -py_test( - name = "summaries_test", - srcs = ["python/eval/python/summaries_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":namedtuples", - ":summaries", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:summary", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - ], -) - -py_library( - name = "head", - srcs = [ - "python/estimator/python/head.py", - "python/estimator/python/head_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":namedtuples", - ":train", - "//tensorflow/python:framework_ops", - "//tensorflow/python:util", - "//tensorflow/python/estimator:estimator_py", - ], -) - -py_test( - name = "head_test", - srcs = ["python/estimator/python/head_test.py"], - python_version = "PY2", - shard_count = 1, - srcs_version = "PY2AND3", - deps = [ - ":head", - ":namedtuples", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:math_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", - ], -) - -py_library( - name = "gan_estimator", - srcs = [ - "python/estimator/python/gan_estimator.py", - "python/estimator/python/gan_estimator_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":namedtuples", - ":summaries", - ":train", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", - ], -) - -py_test( - name = "gan_estimator_test", - srcs = ["python/estimator/python/gan_estimator_test.py"], - python_version = "PY2", - shard_count = 1, - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":gan_estimator", - ":namedtuples", - ":tuple_losses", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/learn", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:numpy_io", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - "@six_archive//:six", - ], -) - -py_library( - name = "stargan_estimator", - srcs = [ - "python/estimator/python/stargan_estimator.py", - "python/estimator/python/stargan_estimator_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":namedtuples", - ":summaries", - ":train", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", - ], -) - -py_test( - name = "stargan_estimator_test", - srcs = ["python/estimator/python/stargan_estimator_test.py"], - python_version = "PY2", - shard_count = 1, - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":namedtuples", - ":stargan_estimator", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:summary", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:numpy_io", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - "@six_archive//:six", - ], -) - -py_library( - name = "tpu_gan_estimator", - srcs = [ - "python/estimator/python/tpu_gan_estimator.py", - "python/estimator/python/tpu_gan_estimator_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":gan_estimator", - ":namedtuples", - ":train", - "//tensorflow/contrib/tpu:tpu_estimator", - "//tensorflow/contrib/tpu:tpu_lib", - "//tensorflow/contrib/training:training_py", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:util", - "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/ops/losses", - ], -) - -py_test( - name = "tpu_gan_estimator_test", - srcs = ["python/estimator/python/tpu_gan_estimator_test.py"], - python_version = "PY2", - shard_count = 11, - srcs_version = "PY2AND3", - tags = ["notsan"], - deps = [ - ":namedtuples", - ":tpu_gan_estimator", - ":tuple_losses", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/tpu:tpu_estimator", - "//tensorflow/contrib/tpu:tpu_lib", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", - "//tensorflow/python:summary", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", - "//tensorflow/python:training_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/estimator", - "//tensorflow/python/estimator:model_fn", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - "@six_archive//:six", - ], -) - -py_library( - name = "latent_gan_estimator", - srcs = [ - "python/estimator/python/latent_gan_estimator.py", - "python/estimator/python/latent_gan_estimator_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":train", - "//tensorflow/python:clip_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:random_ops", - "//tensorflow/python:summary", - "//tensorflow/python:training_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:estimator_py", - ], -) - -py_test( - name = "latent_gan_estimator_test", - srcs = [ - "python/estimator/python/latent_gan_estimator_test.py", - ], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":latent_gan_estimator", - "//tensorflow/python:array_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:run_config", - "//tensorflow/python/ops/losses", - ], -) - -py_library( - name = "sliced_wasserstein", - srcs = [ - "python/eval/python/sliced_wasserstein.py", - "python/eval/python/sliced_wasserstein_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn", - "//tensorflow/python:nn_ops", - "//tensorflow/python:random_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:util", - "//third_party/py/numpy", - ], -) - -py_test( - name = "sliced_wasserstein_test", - srcs = ["python/eval/python/sliced_wasserstein_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":sliced_wasserstein", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:random_ops", - "//third_party/py/numpy", - ], -) - -py_library( - name = "spectral_normalization", - srcs = [ - "python/features/python/spectral_normalization.py", - "python/features/python/spectral_normalization_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:standard_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python/keras:engine", - ], -) - -py_test( - name = "spectral_normalization_test", - srcs = ["python/features/python/spectral_normalization_test.py"], - python_version = "PY2", - srcs_version = "PY2AND3", - deps = [ - ":spectral_normalization", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/slim", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", - "//tensorflow/python/keras:layers", - "//third_party/py/numpy", - ], -) diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md deleted file mode 100644 index 3c1d814e70f..00000000000 --- a/tensorflow/contrib/gan/README.md +++ /dev/null @@ -1,281 +0,0 @@ - - -# TensorFlow-GAN (TF-GAN) - -TF-GAN is a lightweight library for training and evaluating Generative -Adversarial Networks (GANs). This technique allows you to train a network -(called the 'generator') to sample from a distribution, without having to -explicitly model the distribution and without writing an explicit loss. For -example, the generator could learn to draw samples from the distribution of -natural images. For more details on this technique, see -['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. See -[tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) -for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an introduction. - -#### Usage -```python -import tensorflow as tf -tfgan = tf.contrib.gan -``` - -## Why TF-GAN? - -* Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can -mix TF-GAN, native TF, and other custom frameworks -* Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) -* [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them -* Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training -* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) -* Use the TF-GAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model -* Improvements in TF-GAN infrastructure will automatically benefit your TF-GAN project -* Stay up-to-date with research as we add more algorithms - -## What are the TF-GAN components? - -TF-GAN is composed of several parts which were design to exist independently. -These include the following main pieces (explained in detail below). - -* [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py): - provides the main infrastructure needed to train a GAN. Training occurs in - four phases, and each phase can be completed by custom-code or by using a - TF-GAN library call. - -* [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/): - Many common GAN operations and normalization techniques are implemented for - you to use, such as instance normalization and conditioning. - -* [losses](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/): - Easily experiment with already-implemented and well-tested losses and - penalties, such as the Wasserstein loss, gradient penalty, mutual - information penalty, etc - -* [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/): - Use `Inception Score`, `Frechet Distance`, or `Kernel Distance` with a - pretrained Inception network to evaluate your unconditional generative - model. You can also use your own pretrained classifier for more specific - performance numbers, or use other methods for evaluating conditional - generative models. - -* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) - and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TF-GAN to make - GAN training easier, or use the more complicated examples to jump-start your - own project. These include unconditional and conditional GANs, InfoGANs, - adversarial losses on existing networks, and image-to-image translation. - -## Training a GAN model - -Training in TF-GAN typically consists of the following steps: - -1. Specify the input to your networks. -1. Set up your generator and discriminator using a `GANModel`. -1. Specify your loss using a `GANLoss`. -1. Create your train ops using a `GANTrainOps`. -1. Run your train ops. - -At each stage, you can either use TF-GAN's convenience functions, or you can -perform the step manually for fine-grained control. We provide examples below. - -There are various types of GAN setups. For instance, you can train a generator -to sample unconditionally from a learned distribution, or you can condition on -extra information such as a class label. TF-GAN is compatible with many setups, -and we demonstrate a few below: - -### Examples - -#### Unconditional MNIST generation - -This example trains a generator to produce handwritten MNIST digits. The generator maps -random draws from a multivariate normal distribution to MNIST digit images. See -['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. - -```python -# Set up the input. -images = mnist_data_provider.provide_data(FLAGS.batch_size) -noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims]) - -# Build the generator and discriminator. -gan_model = tfgan.gan_model( - generator_fn=mnist.unconditional_generator, # you define - discriminator_fn=mnist.unconditional_discriminator, # you define - real_data=images, - generator_inputs=noise) - -# Build the GAN loss. -gan_loss = tfgan.gan_loss( - gan_model, - generator_loss_fn=tfgan.losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss) - -# Create the train ops, which calculate gradients and apply updates to weights. -train_ops = tfgan.gan_train_ops( - gan_model, - gan_loss, - generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), - discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5)) - -# Run the train ops in the alternating training scheme. -tfgan.gan_train( - train_ops, - hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)], - logdir=FLAGS.train_log_dir) -``` - -#### Conditional MNIST generation -This example trains a generator to generate MNIST images *of a given class*. -The generator maps random draws from a multivariate normal distribution and a -one-hot label of the desired digit class to an MNIST digit image. See -['Conditional Generative Adversarial Nets'](https://arxiv.org/abs/1411.1784) by -Mirza and Osindero. - -```python -# Set up the input. -images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size) -noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims]) - -# Build the generator and discriminator. -gan_model = tfgan.gan_model( - generator_fn=mnist.conditional_generator, # you define - discriminator_fn=mnist.conditional_discriminator, # you define - real_data=images, - generator_inputs=(noise, one_hot_labels)) - -# The rest is the same as in the unconditional case. -... -``` -#### Adversarial loss -This example combines an L1 pixel loss and an adversarial loss to learn to -autoencode images. The bottleneck layer can be used to transmit compressed -representations of the image. Neutral networks with pixel-wise loss only tend to -produce blurry results, so the GAN can be used to make the reconstructions more -plausible. See ['Full Resolution Image Compression with Recurrent Neural Networks'](https://arxiv.org/abs/1608.05148) by Toderici et al -for an example of neural networks used for image compression, and ['Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network'](https://arxiv.org/abs/1609.04802) by Ledig et al for a more detailed description of -how GANs can sharpen image output. - -```python -# Set up the input pipeline. -images = image_provider.provide_data(FLAGS.batch_size) - -# Build the generator and discriminator. -gan_model = tfgan.gan_model( - generator_fn=nets.autoencoder, # you define - discriminator_fn=nets.discriminator, # you define - real_data=images, - generator_inputs=images) - -# Build the GAN loss and standard pixel loss. -gan_loss = tfgan.gan_loss( - gan_model, - generator_loss_fn=tfgan.losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - gradient_penalty=1.0) -l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) - -# Modify the loss tuple to include the pixel loss. -gan_loss = tfgan.losses.combine_adversarial_loss( - gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor) - -# The rest is the same as in the unconditional case. -... -``` - -#### Image-to-image translation -This example maps images in one domain to images of the same size in a different -dimension. For example, it can map segmentation masks to street images, or -grayscale images to color. See ['Image-to-Image Translation with Conditional Adversarial Networks'](https://arxiv.org/abs/1611.07004) by Isola et al for more details. - -```python -# Set up the input pipeline. -input_image, target_image = data_provider.provide_data(FLAGS.batch_size) - -# Build the generator and discriminator. -gan_model = tfgan.gan_model( - generator_fn=nets.generator, # you define - discriminator_fn=nets.discriminator, # you define - real_data=target_image, - generator_inputs=input_image) - -# Build the GAN loss and standard pixel loss. -gan_loss = tfgan.gan_loss( - gan_model, - generator_loss_fn=tfgan.losses.least_squares_generator_loss, - discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss) -l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) - -# Modify the loss tuple to include the pixel loss. -gan_loss = tfgan.losses.combine_adversarial_loss( - gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor) - -# The rest is the same as in the unconditional case. -... -``` - -#### InfoGAN -Train a generator to generate specific MNIST digit images, and control for digit style *without using any labels*. See ['InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets'](https://arxiv.org/abs/1606.03657) for more details. - -```python -# Set up the input pipeline. -images = mnist_data_provider.provide_data(FLAGS.batch_size) - -# Build the generator and discriminator. -gan_model = tfgan.infogan_model( - generator_fn=mnist.infogan_generator, # you define - discriminator_fn=mnist.infogran_discriminator, # you define - real_data=images, - unstructured_generator_inputs=unstructured_inputs, # you define - structured_generator_inputs=structured_inputs) # you define - -# Build the GAN loss with mutual information penalty. -gan_loss = tfgan.gan_loss( - gan_model, - generator_loss_fn=tfgan.losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - gradient_penalty=1.0, - mutual_information_penalty_weight=1.0) - -# The rest is the same as in the unconditional case. -... -``` - -#### Custom model creation -Train an unconditional GAN to generate MNIST digits, but manually construct -the `GANModel` tuple for more fine-grained control. - -```python -# Set up the input pipeline. -images = mnist_data_provider.provide_data(FLAGS.batch_size) -noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims]) - -# Manually build the generator and discriminator. -with tf.variable_scope('Generator') as gen_scope: - generated_images = generator_fn(noise) -with tf.variable_scope('Discriminator') as dis_scope: - discriminator_gen_outputs = discriminator_fn(generated_images) -with variable_scope.variable_scope(dis_scope, reuse=True): - discriminator_real_outputs = discriminator_fn(images) -generator_variables = variables_lib.get_trainable_variables(gen_scope) -discriminator_variables = variables_lib.get_trainable_variables(dis_scope) -# Depending on what TF-GAN features you use, you don't always need to supply -# every `GANModel` field. At a minimum, you need to include the discriminator -# outputs and variables if you want to use TF-GAN to construct losses. -gan_model = tfgan.GANModel( - generator_inputs, - generated_data, - generator_variables, - gen_scope, - generator_fn, - real_data, - discriminator_real_outputs, - discriminator_gen_outputs, - discriminator_variables, - dis_scope, - discriminator_fn) - -# The rest is the same as the unconditional case. -... -``` - - -## Authors -Joel Shor (github: [joel-shor](https://github.com/joel-shor)) and Sergio Guadarrama (github: [sguada](https://github.com/sguada)) diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py deleted file mode 100644 index 1e6000898f7..00000000000 --- a/tensorflow/contrib/gan/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2017 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TF-GAN is a lightweight library for training and evaluating GANs. - -In addition to providing the infrastructure for easily training and evaluating -GANS, this library contains modules for a TFGAN-backed Estimator, -evaluation metrics, features (such as virtual batch normalization), and losses. -Please see README.md for details and usage. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Collapse TF-GAN into a tiered namespace. -from tensorflow.contrib.gan.python import estimator -from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin -from tensorflow.contrib.gan.python import features -from tensorflow.contrib.gan.python import losses -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python import train - -# pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.gan.python.namedtuples import * -from tensorflow.contrib.gan.python.train import * -# pylint: enable=unused-import,wildcard-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'estimator', - 'eval', - 'features', - 'losses', -] -_allowed_symbols += train.__all__ -_allowed_symbols += namedtuples.__all__ -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py deleted file mode 100644 index 430266555b7..00000000000 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TF-GAN estimator module. - -GANEstimator provides all the infrastructure support of a TensorFlow Estimator -with the feature support of TF-GAN. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Collapse `estimator` into a single namespace. -# pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.gan.python.estimator.python import gan_estimator -from tensorflow.contrib.gan.python.estimator.python import head -from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator -from tensorflow.contrib.gan.python.estimator.python import stargan_estimator -from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator - -from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * -from tensorflow.contrib.gan.python.estimator.python.head import * -from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator import * -from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * -from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator import * -# pylint: enable=unused-import,wildcard-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ([ - 'gan_estimator', - 'stargan_estimator', - 'tpu_gan_estimator', - 'latent_gan_estimator', - 'head', -] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ + - tpu_gan_estimator.__all__ + latent_gan_estimator.__all__) -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py deleted file mode 100644 index bc0e4854091..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""`tf.Learn` components for `GANEstimator`.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.estimator.python.gan_estimator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = gan_estimator_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py deleted file mode 100644 index d234558d4da..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A TF-GAN-backed GAN Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import enum - -from tensorflow.contrib.framework.python.ops import variables as variable_lib -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import tf_inspect as inspect - - -__all__ = [ - 'GANEstimator', - 'SummaryType' -] - - -class SummaryType(enum.IntEnum): - NONE = 0 - VARIABLES = 1 - IMAGES = 2 - IMAGE_COMPARISON = 3 - - -_summary_type_map = { - SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, - SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries, - SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long -} - - -class GANEstimator(estimator.Estimator): - """An estimator for Generative Adversarial Networks (GANs). - - This Estimator is backed by TF-GAN. The network functions follow the TF-GAN - API except for one exception: if either `generator_fn` or `discriminator_fn` - have an argument called `mode`, then the tf.Estimator mode is passed in for - that argument. This helps with operations like batch normalization, which have - different train and evaluation behavior. - - Example: - - ```python - import tensorflow as tf - tfgan = tf.contrib.gan - - # See TF-GAN's `train.py` for a description of the generator and - # discriminator API. - def generator_fn(generator_inputs): - ... - return generated_data - - def discriminator_fn(data, conditioning): - ... - return logits - - # Create GAN estimator. - gan_estimator = tfgan.estimator.GANEstimator( - model_dir, - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_loss_fn=tfgan.losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), - discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5)) - - # Train estimator. - gan_estimator.train(train_input_fn, steps) - - # Evaluate resulting estimator. - gan_estimator.evaluate(eval_input_fn) - - # Generate samples from generator. - predictions = np.array([ - x for x in gan_estimator.predict(predict_input_fn)]) - ``` - """ - - def __init__(self, - model_dir=None, - generator_fn=None, - discriminator_fn=None, - generator_loss_fn=None, - discriminator_loss_fn=None, - generator_optimizer=None, - discriminator_optimizer=None, - get_hooks_fn=None, - get_eval_metric_ops_fn=None, - add_summaries=None, - use_loss_summaries=True, - config=None, - warm_start_from=None, - is_chief=True): - """Initializes a GANEstimator instance. - - Args: - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator - to continue training a previously saved model. - generator_fn: A python function that takes a Tensor, Tensor list, or - Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TF-GAN` for more details and examples. Additionally, if - it has an argument called `mode`, the Estimator's `mode` will be passed - in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch - normalization. - discriminator_fn: A python function that takes the output of - `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details - and examples. - generator_loss_fn: The loss function on the generator. Takes a `GANModel` - tuple. - discriminator_loss_fn: The loss function on the discriminator. Takes a - `GANModel` tuple. - generator_optimizer: The optimizer for generator updates, or a function - that takes no arguments and returns an optimizer. This function will - be called when the default graph is the `GANEstimator`'s graph, so - utilities like `tf.contrib.framework.get_or_create_global_step` will - work. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a - list of hooks. These hooks are run on the generator and discriminator - train ops, and can be used to implement the GAN training scheme. - Defaults to `train.get_sequential_train_hooks()`. - get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a - dict of metric results keyed by name. The output of this function is - passed into `tf.estimator.EstimatorSpec` during evaluation. - add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. - use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - config: `RunConfig` object to configure the runtime settings. - warm_start_from: A filepath to a checkpoint or saved model, or a - WarmStartSettings object to configure initialization. - is_chief: Whether or not this Estimator is running on a chief or worker. - Needs to be set appropriately if using SyncReplicasOptimizers. - - Raises: - ValueError: If loss functions aren't callable. - ValueError: If `use_loss_summaries` isn't boolean or `None`. - ValueError: If `get_hooks_fn` isn't callable or `None`. - """ - if not callable(generator_loss_fn): - raise ValueError('generator_loss_fn must be callable.') - if not callable(discriminator_loss_fn): - raise ValueError('discriminator_loss_fn must be callable.') - if use_loss_summaries not in [True, False, None]: - raise ValueError('use_loss_summaries must be True, False or None.') - if get_hooks_fn is not None and not callable(get_hooks_fn): - raise TypeError('get_hooks_fn must be callable.') - - def _model_fn(features, labels, mode): - """GANEstimator model function.""" - if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, - model_fn_lib.ModeKeys.PREDICT]: - raise ValueError('Mode not recognized: %s' % mode) - real_data = labels # rename inputs for clarity - generator_inputs = features # rename inputs for clarity - - # Make GANModel, which encapsulates the GAN model architectures. - gan_model = _get_gan_model( - mode, generator_fn, discriminator_fn, real_data, generator_inputs, - add_summaries) - - # Make the EstimatorSpec, which incorporates the GANModel, losses, eval - # metrics, and optimizers (if required). - return _get_estimator_spec( - mode, gan_model, generator_loss_fn, discriminator_loss_fn, - get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn, use_loss_summaries, is_chief) - - super(GANEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config, - warm_start_from=warm_start_from) - - -def _get_gan_model( - mode, generator_fn, discriminator_fn, real_data, generator_inputs, - add_summaries, generator_scope='Generator'): - """Makes the GANModel tuple, which encapsulates the GAN model architecture.""" - if mode == model_fn_lib.ModeKeys.PREDICT: - if real_data is not None: - raise ValueError('`labels` must be `None` when mode is `predict`. ' - 'Instead, found %s' % real_data) - gan_model = _make_prediction_gan_model( - generator_inputs, generator_fn, generator_scope) - else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL - gan_model = _make_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope, add_summaries, mode) - - return gan_model - - -def _get_estimator_spec( - mode, gan_model, generator_loss_fn, discriminator_loss_fn, - get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn=None, use_loss_summaries=True, is_chief=True): - """Get the EstimatorSpec for the current mode.""" - if mode == model_fn_lib.ModeKeys.PREDICT: - estimator_spec = model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data) - else: - gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn( - gan_model, add_summaries=use_loss_summaries), - discriminator_loss=discriminator_loss_fn( - gan_model, add_summaries=use_loss_summaries)) - if mode == model_fn_lib.ModeKeys.EVAL: - estimator_spec = _get_eval_estimator_spec( - gan_model, gan_loss, get_eval_metric_ops_fn) - else: # model_fn_lib.ModeKeys.TRAIN: - if callable(generator_optimizer): - generator_optimizer = generator_optimizer() - if callable(discriminator_optimizer): - discriminator_optimizer = discriminator_optimizer() - get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() - estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, generator_optimizer, discriminator_optimizer, - get_hooks_fn, is_chief=is_chief) - - return estimator_spec - - -def _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, mode): - """Construct a `GANModel`, and optionally pass in `mode`.""" - # If network functions have an argument `mode`, pass mode to it. - if 'mode' in inspect.getargspec(generator_fn).args: - generator_fn = functools.partial(generator_fn, mode=mode) - if 'mode' in inspect.getargspec(discriminator_fn).args: - discriminator_fn = functools.partial(discriminator_fn, mode=mode) - gan_model = tfgan_train.gan_model( - generator_fn, - discriminator_fn, - real_data, - generator_inputs, - generator_scope=generator_scope, - check_shapes=False) - if add_summaries: - if not isinstance(add_summaries, (tuple, list)): - add_summaries = [add_summaries] - with ops.name_scope(None): - for summary_type in add_summaries: - _summary_type_map[summary_type](gan_model) - - return gan_model - - -def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): - """Make a `GANModel` from just the generator.""" - # If `generator_fn` has an argument `mode`, pass mode to it. - if 'mode' in inspect.getargspec(generator_fn).args: - generator_fn = functools.partial(generator_fn, - mode=model_fn_lib.ModeKeys.PREDICT) - with variable_scope.variable_scope(generator_scope) as gen_scope: - generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access - generated_data = generator_fn(generator_inputs) - generator_variables = variable_lib.get_trainable_variables(gen_scope) - - return tfgan_tuples.GANModel( - generator_inputs, - generated_data, - generator_variables, - gen_scope, - generator_fn, - real_data=None, - discriminator_real_outputs=None, - discriminator_gen_outputs=None, - discriminator_variables=None, - discriminator_scope=None, - discriminator_fn=None) - - -def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None, - name=None): - """Return an EstimatorSpec for the eval case.""" - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - with ops.name_scope(None, 'metrics', - [gan_loss.generator_loss, - gan_loss.discriminator_loss]): - def _summary_key(head_name, val): - return '%s/%s' % (val, head_name) if head_name else val - eval_metric_ops = { - _summary_key(name, 'generator_loss'): - metrics_lib.mean(gan_loss.generator_loss), - _summary_key(name, 'discriminator_loss'): - metrics_lib.mean(gan_loss.discriminator_loss) - } - if get_eval_metric_ops_fn is not None: - custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) - if not isinstance(custom_eval_metric_ops, dict): - raise TypeError('get_eval_metric_ops_fn must return a dict, ' - 'received: {}'.format(custom_eval_metric_ops)) - eval_metric_ops.update(custom_eval_metric_ops) - return model_fn_lib.EstimatorSpec( - mode=model_fn_lib.ModeKeys.EVAL, - predictions=gan_model.generated_data, - loss=scalar_loss, - eval_metric_ops=eval_metric_ops) - - -def _get_train_estimator_spec( - gan_model, gan_loss, generator_optimizer, discriminator_optimizer, - get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops, is_chief=True): - """Return an EstimatorSpec for the train case.""" - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, - discriminator_optimizer, is_chief=is_chief) - training_hooks = get_hooks_fn(train_ops) - return model_fn_lib.EstimatorSpec( - loss=scalar_loss, - mode=model_fn_lib.ModeKeys.TRAIN, - train_op=train_ops.global_step_inc_op, - training_hooks=training_hooks) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py deleted file mode 100644 index 66af79d1e81..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TF-GAN's estimator.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator -from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses -from tensorflow.contrib.learn.python.learn.learn_io import graph_io -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.estimator import WarmStartSettings -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework.errors_impl import NotFoundError -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import input as input_lib -from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import sync_replicas_optimizer -from tensorflow.python.training import training -from tensorflow.python.training import training_util - - -def generator_fn(noise_dict, mode): - del mode - noise = noise_dict['x'] - return layers.fully_connected(noise, tensor_shape.dimension_value( - noise.shape[1])) - - -def discriminator_fn(data, unused_conditioning, mode): - del unused_conditioning, mode - return layers.fully_connected(data, 1) - - -class GetGANModelTest(test.TestCase, parameterized.TestCase): - """Tests that `GetGANModel` produces the correct model.""" - - @parameterized.named_parameters( - ('train', model_fn_lib.ModeKeys.TRAIN), - ('eval', model_fn_lib.ModeKeys.EVAL), - ('predict', model_fn_lib.ModeKeys.PREDICT)) - def test_get_gan_model(self, mode): - with ops.Graph().as_default(): - generator_inputs = {'x': array_ops.ones([3, 4])} - is_predict = mode == model_fn_lib.ModeKeys.PREDICT - real_data = array_ops.zeros([3, 4]) if not is_predict else None - gan_model = estimator._get_gan_model( - mode, generator_fn, discriminator_fn, real_data, generator_inputs, - add_summaries=False) - - self.assertEqual(generator_inputs, gan_model.generator_inputs) - self.assertIsNotNone(gan_model.generated_data) - self.assertLen(gan_model.generator_variables, 2) # 1 FC layer - self.assertIsNotNone(gan_model.generator_fn) - if mode == model_fn_lib.ModeKeys.PREDICT: - self.assertIsNone(gan_model.real_data) - self.assertIsNone(gan_model.discriminator_real_outputs) - self.assertIsNone(gan_model.discriminator_gen_outputs) - self.assertIsNone(gan_model.discriminator_variables) - self.assertIsNone(gan_model.discriminator_scope) - self.assertIsNone(gan_model.discriminator_fn) - else: - self.assertIsNotNone(gan_model.real_data) - self.assertIsNotNone(gan_model.discriminator_real_outputs) - self.assertIsNotNone(gan_model.discriminator_gen_outputs) - self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer - self.assertIsNotNone(gan_model.discriminator_scope) - self.assertIsNotNone(gan_model.discriminator_fn) - - -def get_dummy_gan_model(): - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) - with variable_scope.variable_scope('discriminator') as dis_scope: - dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) - return tfgan_tuples.GANModel( - generator_inputs=None, - generated_data=array_ops.ones([3, 4]), - generator_variables=[gen_var], - generator_scope=gen_scope, - generator_fn=None, - real_data=array_ops.zeros([3, 4]), - discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, - discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, - discriminator_variables=[dis_var], - discriminator_scope=dis_scope, - discriminator_fn=None) - - -def dummy_loss_fn(gan_model, add_summaries=True): - del add_summaries - return math_ops.reduce_sum(gan_model.discriminator_real_outputs - - gan_model.discriminator_gen_outputs) - - -def get_metrics(gan_model): - return { - 'mse_custom_metric': metrics_lib.mean_squared_error( - gan_model.real_data, gan_model.generated_data) - } - - -class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): - """Tests that the EstimatorSpec is constructed appropriately.""" - - @classmethod - def setUpClass(cls): - super(GetEstimatorSpecTest, cls).setUpClass() - cls._generator_optimizer = training.GradientDescentOptimizer(1.0) - cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) - - @parameterized.named_parameters( - ('train', model_fn_lib.ModeKeys.TRAIN), - ('eval', model_fn_lib.ModeKeys.EVAL), - ('predict', model_fn_lib.ModeKeys.PREDICT)) - def test_get_estimator_spec(self, mode): - with ops.Graph().as_default(): - self._gan_model = get_dummy_gan_model() - spec = estimator._get_estimator_spec( - mode, - self._gan_model, - generator_loss_fn=dummy_loss_fn, - discriminator_loss_fn=dummy_loss_fn, - get_eval_metric_ops_fn=get_metrics, - generator_optimizer=self._generator_optimizer, - discriminator_optimizer=self._discriminator_optimizer) - - self.assertEqual(mode, spec.mode) - if mode == model_fn_lib.ModeKeys.PREDICT: - self.assertEqual(self._gan_model.generated_data, spec.predictions) - elif mode == model_fn_lib.ModeKeys.TRAIN: - self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar - self.assertIsNotNone(spec.train_op) - self.assertIsNotNone(spec.training_hooks) - elif mode == model_fn_lib.ModeKeys.EVAL: - self.assertEqual(self._gan_model.generated_data, spec.predictions) - self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar - self.assertIsNotNone(spec.eval_metric_ops) - - def test_get_sync_estimator_spec(self): - """Make sure spec is loaded with sync hooks for sync opts.""" - - def get_sync_optimizer(): - return sync_replicas_optimizer.SyncReplicasOptimizer( - training.GradientDescentOptimizer(learning_rate=1.0), - replicas_to_aggregate=1) - - with ops.Graph().as_default(): - self._gan_model = get_dummy_gan_model() - g_opt = get_sync_optimizer() - d_opt = get_sync_optimizer() - - spec = estimator._get_estimator_spec( - model_fn_lib.ModeKeys.TRAIN, - self._gan_model, - generator_loss_fn=dummy_loss_fn, - discriminator_loss_fn=dummy_loss_fn, - get_eval_metric_ops_fn=get_metrics, - generator_optimizer=g_opt, - discriminator_optimizer=d_opt) - - self.assertLen(spec.training_hooks, 4) - sync_opts = [ - hook._sync_optimizer for hook in spec.training_hooks if - isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] - self.assertLen(sync_opts, 2) - self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) - - -class GANEstimatorIntegrationTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, - lr_decay=False): - def make_opt(): - gstep = training_util.get_or_create_global_step() - lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) - return training.GradientDescentOptimizer(lr) - - gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) - dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) - est = estimator.GANEstimator( - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - generator_optimizer=gopt, - discriminator_optimizer=dopt, - get_eval_metric_ops_fn=get_metrics, - model_dir=self._model_dir) - - # Train. - num_steps = 10 - est.train(train_input_fn, steps=num_steps) - - # Evaluate. - scores = est.evaluate(eval_input_fn) - self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', scores) - self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], - scores['loss']) - self.assertIn('mse_custom_metric', scores) - - # Predict. - predictions = np.array([x for x in est.predict(predict_input_fn)]) - - self.assertAllEqual(prediction_size, predictions.shape) - - def test_numpy_input_fn(self): - """Tests complete flow with numpy_input_fn.""" - input_dim = 4 - batch_size = 5 - data = np.zeros([batch_size, input_dim]) - train_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - batch_size=batch_size, - shuffle=False) - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - prediction_size=[batch_size, input_dim]) - - def test_numpy_input_fn_lrdecay(self): - """Tests complete flow with numpy_input_fn.""" - input_dim = 4 - batch_size = 5 - data = np.zeros([batch_size, input_dim]) - train_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - y=data, - batch_size=batch_size, - shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': data}, - batch_size=batch_size, - shuffle=False) - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - prediction_size=[batch_size, input_dim], - lr_decay=True) - - def test_input_fn_from_parse_example(self): - """Tests complete flow with input_fn constructed from parse_example.""" - input_dim = 4 - batch_size = 6 - data = np.zeros([batch_size, input_dim]) - - serialized_examples = [] - for datum in data: - example = example_pb2.Example(features=feature_pb2.Features( - feature={ - 'x': feature_pb2.Feature( - float_list=feature_pb2.FloatList(value=datum)), - 'y': feature_pb2.Feature( - float_list=feature_pb2.FloatList(value=datum)), - })) - serialized_examples.append(example.SerializeToString()) - - feature_spec = { - 'x': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), - 'y': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), - } - def _train_input_fn(): - feature_map = parsing_ops.parse_example( - serialized_examples, feature_spec) - _, features = graph_io.queue_parsed_features(feature_map) - labels = features.pop('y') - return features, labels - def _eval_input_fn(): - feature_map = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - _, features = graph_io.queue_parsed_features(feature_map) - labels = features.pop('y') - return features, labels - def _predict_input_fn(): - feature_map = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - _, features = graph_io.queue_parsed_features(feature_map) - features.pop('y') - return features, None - - self._test_complete_flow( - train_input_fn=_train_input_fn, - eval_input_fn=_eval_input_fn, - predict_input_fn=_predict_input_fn, - prediction_size=[batch_size, input_dim]) - - -class GANEstimatorWarmStartTest(test.TestCase): - - def setUp(self): - self._model_dir = self.get_temp_dir() - self.new_variable_name = 'new_var' - self.new_variable_value = [1, 2, 3] - - def tearDown(self): - writer_cache.FileWriterCache.clear() - - def _test_warm_start(self, warm_start_from=None): - """Tests whether WarmStartSettings work as intended.""" - def generator_with_new_variable(noise_dict, mode): - variable_scope.get_variable(name=self.new_variable_name, - initializer=self.new_variable_value, - trainable=True) - return generator_fn(noise_dict, mode) - - def train_input_fn(): - data = np.zeros([3, 4]) - return {'x': data}, data - - est = estimator.GANEstimator( - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0), - model_dir=self._model_dir) - - est.train(train_input_fn, steps=1) - - est_warm = estimator.GANEstimator( - generator_fn=generator_with_new_variable, - discriminator_fn=discriminator_fn, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0), - model_dir=None if warm_start_from else self._model_dir, - warm_start_from=warm_start_from) - - est_warm.train(train_input_fn, steps=1) - - return est_warm - - def test_warm_start_error(self): - """Test if exception when reloading different estimators.""" - with self.assertRaises(NotFoundError): - self._test_warm_start() - - def test_warm_start_success(self): - """Test if GANEstimator allows explicit warm start variable assignment.""" - # Regex matches all variable names in ckpt except for new_var. - var_regex = '^(?!.*%s.*)' % self.new_variable_name - warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir, - vars_to_warm_start=var_regex) - est_warm = self._test_warm_start(warm_start_from=warmstart) - full_variable_name = 'Generator/%s' % self.new_variable_name - self.assertIn(full_variable_name, est_warm.get_variable_names()) - equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name), - self.new_variable_value) - self.assertTrue(equal_vals) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/head.py b/tensorflow/contrib/gan/python/estimator/python/head.py deleted file mode 100644 index 3225d6f41a1..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/head.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""`tf.Learn` components for `GANEstimator`'s loss.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.estimator.python import head_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.estimator.python.head_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = head_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py deleted file mode 100644 index cbe990b476c..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ /dev/null @@ -1,258 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A TF-GAN-backed GAN Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.canned import head -from tensorflow.python.estimator.export import export_output -from tensorflow.python.framework import ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.util import deprecation - -__all__ = [ - 'GANHead', - 'gan_head', -] - - -def _summary_key(head_name, val): - return '%s/%s' % (val, head_name) if head_name else val - - -@deprecation.deprecated( - None, 'Please use tf.contrib.gan.GANEstimator without explicitly making a ' - 'GANHead.') -def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, - discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - get_eval_metric_ops_fn=None, name=None): - """Creates a `GANHead`. - - Args: - generator_loss_fn: A TFGAN loss function for the generator. Takes a - `GANModel` and returns a scalar. - discriminator_loss_fn: Same as `generator_loss_fn`, but for the - discriminator. - generator_optimizer: The optimizer for generator updates. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a - list of hooks. - get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a - dict of metric results keyed by name. The output of this function is - passed into `tf.estimator.EstimatorSpec` during evaluation. - name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. - - Returns: - An instance of `GANHead`. - """ - return GANHead(generator_loss_fn=generator_loss_fn, - discriminator_loss_fn=discriminator_loss_fn, - generator_optimizer=generator_optimizer, - discriminator_optimizer=discriminator_optimizer, - use_loss_summaries=use_loss_summaries, - get_hooks_fn=get_hooks_fn, - get_eval_metric_ops_fn=get_eval_metric_ops_fn, - name=name) - - -class GANHead(head._Head): # pylint: disable=protected-access - """`Head` for a GAN.""" - - @deprecation.deprecated( - None, 'Please use tf.contrib.gan.GANEstimator without explicitly making ' - 'a GANHead.') - def __init__(self, generator_loss_fn, discriminator_loss_fn, - generator_optimizer, discriminator_optimizer, - use_loss_summaries=True, - get_hooks_fn=None, - get_eval_metric_ops_fn=None, - name=None): - """`Head` for GAN training. - - Args: - generator_loss_fn: A TFGAN loss function for the generator. Takes a - `GANModel` and returns a scalar. - discriminator_loss_fn: Same as `generator_loss_fn`, but for the - discriminator. - generator_optimizer: The optimizer for generator updates. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a - list of hooks. Defaults to `train.get_sequential_train_hooks()` - get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a - dict of metric results keyed by name. The output of this function is - passed into `tf.estimator.EstimatorSpec` during evaluation. - name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. - """ - - if not callable(generator_loss_fn): - raise TypeError('generator_loss_fn must be callable.') - if not callable(discriminator_loss_fn): - raise TypeError('discriminator_loss_fn must be callable.') - if use_loss_summaries not in [True, False, None]: - raise ValueError('use_loss_summaries must be True, False or None.') - if get_hooks_fn is not None and not callable(get_hooks_fn): - raise TypeError('get_hooks_fn must be callable.') - if name is not None and not isinstance(name, str): - raise TypeError('name must be string.') - - if get_hooks_fn is None: - get_hooks_fn = tfgan_train.get_sequential_train_hooks() - - if use_loss_summaries in [True, False]: - generator_loss_fn = functools.partial( - generator_loss_fn, add_summaries=use_loss_summaries) - discriminator_loss_fn = functools.partial( - discriminator_loss_fn, add_summaries=use_loss_summaries) - self._generator_loss_fn = generator_loss_fn - self._discriminator_loss_fn = discriminator_loss_fn - self._generator_optimizer = generator_optimizer - self._discriminator_optimizer = discriminator_optimizer - self._get_hooks_fn = get_hooks_fn - self._get_eval_metric_ops_fn = get_eval_metric_ops_fn - self._name = name - - @property - def name(self): - return self._name - - @property - def logits_dimension(self): - return None - - def create_loss(self, features, mode, logits, labels): - """Returns a GANLoss tuple from the provided GANModel. - - See `Head` for more details. - - Args: - features: Input `dict` of `Tensor` objects. Unused. - mode: Estimator's `ModeKeys`. - logits: A GANModel tuple. - labels: Must be `None`. - - Returns: - A GANLoss tuple. - - """ - _validate_logits_and_labels(logits, labels) - del mode, labels, features # unused for this head. - gan_model = logits # rename variable for clarity - return tfgan_tuples.GANLoss( - generator_loss=self._generator_loss_fn(gan_model), - discriminator_loss=self._discriminator_loss_fn(gan_model)) - - def create_estimator_spec( - self, features, mode, logits, labels=None, - train_op_fn=tfgan_train.gan_train_ops): - """Returns `EstimatorSpec` that a model_fn can return. - - See `Head` for more details. - - Args: - features: Must be `None`. - mode: Estimator's `ModeKeys`. - logits: A GANModel tuple. - labels: Must be `None`. - train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, - and discriminator optimizer, and returns a `GANTrainOps` tuple. For - example, this function can come from TFGAN's `train.py` library, or can - be custom. - - Returns: - `EstimatorSpec`. - - Raises: - ValueError: If `features` isn't `None`. - ValueError: If `train_op_fn` isn't provided in train mode. - """ - _validate_logits_and_labels(logits, labels) - if features is not None: - raise ValueError('`features` should be `None`. Instead, found: %s' % - features) - gan_model = logits # rename variable for clarity - with ops.name_scope('GANHead'): - if mode == model_fn_lib.ModeKeys.PREDICT: - return model_fn_lib.EstimatorSpec( - mode=model_fn_lib.ModeKeys.PREDICT, - predictions=gan_model.generated_data, - export_outputs={ - 'predict': export_output.PredictOutput(gan_model.generated_data) - }) - elif mode == model_fn_lib.ModeKeys.EVAL: - gan_loss = self.create_loss( - features=None, mode=mode, logits=gan_model, labels=None) - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - with ops.name_scope(None, 'metrics', - [gan_loss.generator_loss, - gan_loss.discriminator_loss]): - eval_metric_ops = { - _summary_key(self._name, 'generator_loss'): - metrics_lib.mean(gan_loss.generator_loss), - _summary_key(self._name, 'discriminator_loss'): - metrics_lib.mean(gan_loss.discriminator_loss) - } - if self._get_eval_metric_ops_fn is not None: - custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) - if not isinstance(custom_eval_metric_ops, dict): - raise TypeError('get_eval_metric_ops_fn must return a dict, ' - 'received: {}'.format(custom_eval_metric_ops)) - eval_metric_ops.update(custom_eval_metric_ops) - return model_fn_lib.EstimatorSpec( - mode=model_fn_lib.ModeKeys.EVAL, - predictions=gan_model.generated_data, - loss=scalar_loss, - eval_metric_ops=eval_metric_ops) - elif mode == model_fn_lib.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') - gan_loss = self.create_loss(None, mode, gan_model, None) - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer, - self._discriminator_optimizer) - training_hooks = self._get_hooks_fn(train_ops) - return model_fn_lib.EstimatorSpec( - loss=scalar_loss, - mode=model_fn_lib.ModeKeys.TRAIN, - train_op=train_ops.global_step_inc_op, - training_hooks=training_hooks) - else: - raise ValueError('Mode not recognized: %s' % mode) - - -def _validate_logits_and_labels(logits, labels): - if labels is not None: - raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must ' - 'be `None`. Instead, found: %s' % labels) - - if not isinstance(logits, tfgan_tuples.GANModel): - raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must ' - 'be an instnace of a `GANModel`. Instead, found: %s' % - logits) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py deleted file mode 100644 index 5b50234a0e3..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TF-GAN's head.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python.estimator.python import head - -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.training import training - -_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - - -def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument - return math_ops.reduce_sum(gan_model.discriminator_real_outputs - - gan_model.discriminator_gen_outputs) - - -def get_gan_model(): - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) - with variable_scope.variable_scope('discriminator') as dis_scope: - dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) - return tfgan_tuples.GANModel( - generator_inputs=None, - generated_data=array_ops.ones([3, 4]), - generator_variables=[gen_var], - generator_scope=gen_scope, - generator_fn=None, - real_data=None, - discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, - discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, - discriminator_variables=[dis_var], - discriminator_scope=dis_scope, - discriminator_fn=None) - - -class GANHeadTest(test.TestCase): - - def setUp(self): - super(GANHeadTest, self).setUp() - self.gan_head = head.gan_head( - generator_loss_fn=dummy_loss, - discriminator_loss_fn=dummy_loss, - generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0), - get_eval_metric_ops_fn=self.get_metrics) - self.assertIsInstance(self.gan_head, head.GANHead) - - def get_metrics(self, gan_model): - self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) - return {} - - def _test_modes_helper(self, mode): - return self.gan_head.create_estimator_spec( - features=None, - mode=mode, - logits=get_gan_model()) - - def test_modes_predict(self): - spec = self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) - self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict'), - spec.export_outputs.keys()) - - def test_modes_eval(self): - self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) - - def test_modes_train(self): - self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py deleted file mode 100644 index 4e164e24168..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""`tf.Learn` components for `Train Input Estimator`.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = latent_gan_estimator_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py deleted file mode 100644 index f5afc773193..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_impl.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Implements an estimator wrapper that allows training the input latent space. - -This file implements a latent gan estimator that wraps around a previously -trained GAN. The latent gan estimator trains a single variable z, representing -the hidden latent distribution that is the 'noise' input to the GAN. By training -z, the inpainting estimator can move around the latent z space towards -minimizing a specific loss function. - -The latent gan estimator has a few key differences from a normal estimator. - -First: the variables in the estimator should not be saved, as we are not -updating the original GAN and are only adding a new z variable that is meant -to be different for each run. In order to do distributed training using -train_and_evaluate, the Tensorflow RunConfig is expected to save checkpoints -by having either save_checkpoints_steps or save_checkpoints_secs saved. -To avoid this conflict, we purposely set the save_checkpoints_steps value in -the RunConfig to be one step more than the total number of steps that the -inpainter estimator will run. - -Second: we need to specify warm start settings, as we are reloading the -GAN model into a different graph (specifically, one with a new z variable). -The warm start settings defined below reload all GAN variables and ignore the -new z variable (and the optimizer). - -Usage: - - def _generator(net, mode): - ... - - def _discriminator(net, condition, mode): - ... - - def _loss(gan_model, features, labels, add_summaries): - ... - - def optimizer(): - ... - - params = {} - config = tf.estimator.RunConfig() - tmp_dir = path/to/output/storage - - estimator = latent_gan_estimator.get_latent_gan_estimator( - _generator, _discriminator, _loss, optimizer, params, config, tmp_dir) - - def input_fn(): - ... - - estimator.train(input_fn=input_fn) - -See latent_gan_estimator_test.py or tensorflow_models/gan/face_inpainting for -further examples. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.summary import summary -from tensorflow.python.training import training_util - - -INPUT_NAME = 'new_var_z_input' # The name for the new z space input variable. -OPTIMIZER_NAME = 'latent_gan_optimizer' # The name for the new optimizer vars. - -__all__ = [ - 'get_latent_gan_estimator', -] - - -def _get_latent_gan_model_fn(generator_fn, discriminator_fn, loss_fn, - optimizer): - """Sets up a model function that wraps around a given GAN.""" - def model_fn(features, labels, mode, params): - """Model function defining an inpainting estimator.""" - batch_size = params['batch_size'] - z_shape = [batch_size] + params['z_shape'] - add_summaries = params['add_summaries'] - input_clip = params['input_clip'] - - z = variable_scope.get_variable( - name=INPUT_NAME, initializer=random_ops.truncated_normal(z_shape), - constraint=lambda x: clip_ops.clip_by_value(x, -input_clip, input_clip)) - - generator = functools.partial(generator_fn, mode=mode) - discriminator = functools.partial(discriminator_fn, mode=mode) - gan_model = tfgan_train.gan_model(generator_fn=generator, - discriminator_fn=discriminator, - real_data=labels, - generator_inputs=z, - check_shapes=False) - - loss = loss_fn(gan_model, features, labels, add_summaries) - - # Use a variable scope to make sure that estimator variables dont cause - # save/load problems when restoring from ckpts. - with variable_scope.variable_scope(OPTIMIZER_NAME): - opt = optimizer(learning_rate=params['learning_rate'], - **params['opt_kwargs']) - train_op = opt.minimize( - loss=loss, global_step=training_util.get_or_create_global_step(), - var_list=[z]) - - if add_summaries: - z_grads = gradients_impl.gradients(loss, z) - summary.scalar('z_loss/z_grads', clip_ops.global_norm(z_grads)) - summary.scalar('z_loss/loss', loss) - - return model_fn_lib.EstimatorSpec(mode=mode, - predictions=gan_model.generated_data, - loss=loss, - train_op=train_op) - return model_fn - - -def get_latent_gan_estimator(generator_fn, discriminator_fn, loss_fn, - optimizer, params, config, ckpt_dir, - warmstart_options=True): - """Gets an estimator that passes gradients to the input. - - This function takes in a generator and adds a trainable z variable that is - used as input to this generator_fn. The generator itself is treated as a black - box through which gradients can pass through without updating any weights. The - result is a trainable way to traverse the GAN latent space. The loss_fn is - used to actually train the z variable. The generator_fn and discriminator_fn - should be previously trained by the tfgan library (on reload, the variables - are expected to follow the tfgan format. It may be possible to use the - latent gan estimator with entirely custom GANs that do not use the tfgan - library as long as the appropriate variables are wired properly). - - Args: - generator_fn: a function defining a Tensorflow graph for a GAN generator. - The weights defined in this graph should already be defined in the given - checkpoint location. Should have 'mode' as an argument. - discriminator_fn: a function defining a Tensorflow graph for a GAN - discriminator. Should have 'mode' as an argument. - loss_fn: a function defining a Tensorflow graph for a GAN loss. Takes in a - GANModel tuple, features, labels, and add_summaries as inputs. - optimizer: a tf.Optimizer or a function that returns a tf.Optimizer with no - inputs. - params: An object containing the following parameters: - - batch_size: an int indicating the size of the training batch. - - z_shape: the desired shape of the input z values (not counting batch). - - learning_rate: a scalar or function defining a learning rate applied to - optimizer. - - input_clip: the amount to clip the x training variable by. - - add_summaries: whether or not to add summaries. - - opt_kwargs: optimizer kwargs. - config: tf.RunConfig. Should point model to output dir and should indicate - whether to save checkpoints (to avoid saving checkpoints, set - save_checkpoints_steps to a number larger than the number of train steps). - The model_dir field in the RunConfig should point to a directory WITHOUT - any saved checkpoints. - ckpt_dir: the directory where the model checkpoints live. The checkpoint is - used to warm start the underlying GAN. This should NOT be the same as - config.model_dir. - warmstart_options: boolean, None, or a WarmStartSettings object. If set to - True, uses a default WarmStartSettings object. If set to False or None, - does not use warm start. If using a custom WarmStartSettings object, make - sure that new variables are properly accounted for when reloading the - underlying GAN. Defaults to True. - Returns: - An estimator spec defining a GAN input training estimator. - """ - model_fn = _get_latent_gan_model_fn(generator_fn, discriminator_fn, - loss_fn, optimizer) - - if isinstance(warmstart_options, estimator.WarmStartSettings): - ws = warmstart_options - elif warmstart_options: - # Default WarmStart loads all variable names except INPUT_NAME and - # OPTIMIZER_NAME. - var_regex = '^(?!.*(%s|%s).*)' % (INPUT_NAME, OPTIMIZER_NAME) - ws = estimator.WarmStartSettings(ckpt_to_initialize_from=ckpt_dir, - vars_to_warm_start=var_regex) - else: - ws = None - - if 'opt_kwargs' not in params: - params['opt_kwargs'] = {} - - return estimator.Estimator(model_fn=model_fn, config=config, params=params, - warm_start_from=ws) diff --git a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py deleted file mode 100644 index ac139e532e3..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/latent_gan_estimator_test.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for latent_gan_estimator. - -See g3.tp.tensorflow.contrib.gan.python.estimator.python.latent_gan_estimator. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tempfile -import numpy as np -from tensorflow.contrib.gan.python.estimator.python import latent_gan_estimator -from tensorflow.python.estimator import run_config as run_config -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import test -from tensorflow.python.training import training - - -class TrainInputEstimatorTest(test.TestCase): - - def test_get_input_training_estimator(self): - """Integration test to make sure the input_training_estimator works.""" - - # Create dummy test input tensors. - true_features = np.reshape(np.random.uniform(size=100), (10, 10)) - true_labels = np.reshape(np.random.uniform(size=100), (5, 20)) - expected_z_output = [[1, -1], [-1, 1]] - - # Fill out required parameters randomly, includes optimizer kwargs. - params = { - 'batch_size': 2, - 'z_shape': [2], - 'learning_rate': 1.0, - 'input_clip': 1.0, - 'add_summaries': False, - 'opt_kwargs': { - 'beta1': 0.1 - } - } - - input_z_shape = [params['batch_size']] + params['z_shape'] - - # Create dummy model functions that represent an underlying GANEstimator and - # the input training wrapper. Make sure that everything is wired up - # correctly in the internals of each dummy function. - def _generator(net, mode): - """The generator function will get the newly created z variable.""" - del mode - self.assertSequenceEqual(net.shape, input_z_shape) - gen_dummy_var = variable_scope.get_variable( - name='generator_dummy_variable', - initializer=array_ops.ones(input_z_shape)) - return net * gen_dummy_var - - def _discriminator(net, condition, mode): - """The discriminator function will get either the z variable or labels.""" - del condition, mode - try: - self.assertSequenceEqual(net.shape, true_labels.shape) - except AssertionError: - self.assertSequenceEqual(net.shape, input_z_shape) - return net - - def _loss(gan_model, features, labels, _): - """Make sure that features and labels are passed in from input.""" - self.assertTrue(np.array_equal(features, true_features)) - self.assertTrue(np.array_equal(labels, true_labels)) - return losses.absolute_difference(expected_z_output, - gan_model.generated_data) - - optimizer = training.AdamOptimizer - - # We are not loading checkpoints, so set the corresponding directory to a - # dummy directories. - tmp_dir = tempfile.mkdtemp() - config = run_config.RunConfig(model_dir=tmp_dir, - save_summary_steps=None, - save_checkpoints_steps=1, - save_checkpoints_secs=None) - - # Get the estimator. Disable warm start so that there is no attempted - # checkpoint reloading. - estimator = latent_gan_estimator.get_latent_gan_estimator( - _generator, _discriminator, _loss, optimizer, params, config, tmp_dir, - warmstart_options=None) - - # Train for a few steps. - def dummy_input(): - return true_features, true_labels - estimator.train(input_fn=dummy_input, steps=10) - - # Make sure the generator variables did not change, but the z variables did - # change. - self.assertTrue(np.array_equal( - estimator.get_variable_value('Generator/generator_dummy_variable'), - np.ones(input_z_shape))) - self.assertTrue(np.array_equal( - estimator.get_variable_value('new_var_z_input'), - expected_z_output)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py deleted file mode 100644 index 341bdf9fbbc..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""`tf.Learn` components for `GANEstimator`.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.estimator.python.stargan_estimator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = stargan_estimator_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py deleted file mode 100644 index 06a1480c072..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A TF-GAN-backed StarGAN Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import enum - -from tensorflow.contrib.framework.python.ops import variables as variable_lib -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries -from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops import variable_scope -from tensorflow.python.util import tf_inspect as inspect - -__all__ = ['StarGANEstimator', 'SummaryType'] - - -class SummaryType(enum.IntEnum): - NONE = 0 - VARIABLES = 1 - IMAGES = 2 - IMAGE_COMPARISON = 3 - - -_summary_type_map = { - SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, - SummaryType.IMAGES: tfgan_summaries.add_stargan_image_summaries, -} - - -class StarGANEstimator(estimator.Estimator): - """An estimator for Generative Adversarial Networks (GANs). - - This Estimator is backed by TFGAN. The network functions follow the TFGAN API - except for one exception: if either `generator_fn` or `discriminator_fn` have - an argument called `mode`, then the tf.Estimator mode is passed in for that - argument. This helps with operations like batch normalization, which have - different train and evaluation behavior. - - Example: - - ```python - import tensorflow as tf - tfgan = tf.contrib.gan - - # See TFGAN's `train.py` for a description of the generator and - # discriminator API. - def generator_fn(generator_inputs): - ... - return generated_data - - def discriminator_fn(data, conditioning): - ... - return logits - - # Create GAN estimator. - stargan_estimator = tfgan.estimator.StarGANEstimator( - model_dir, - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - loss_fn=loss_fn, - generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), - discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5)) - - # Train estimator. - stargan_estimator.train(train_input_fn, steps) - - # Evaluate resulting estimator. - stargan_estimator.evaluate(eval_input_fn) - - # Generate samples from generator. - stargan_estimator = np.array([ - x for x in stargan_estimator.predict(predict_input_fn)]) - ``` - """ - - def __init__(self, - model_dir=None, - generator_fn=None, - discriminator_fn=None, - loss_fn=None, - generator_optimizer=None, - discriminator_optimizer=None, - get_hooks_fn=None, - get_eval_metric_ops_fn=None, - add_summaries=None, - use_loss_summaries=True, - config=None): - """Initializes a StarGANEstimator instance. - - Args: - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. - generator_fn: A python function that takes a Tensor, Tensor list, or - Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if - it has an argument called `mode`, the Estimator's `mode` will be passed - in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch - normalization. - discriminator_fn: A python function that takes the output of - `generator_fn` or real data in the GAN setup, and `input_data`. Outputs - a Tensor in the range [-inf, inf]. See `TFGAN` for more details and - examples. - loss_fn: The loss function on the generator. Takes a `StarGANModel` - namedtuple and return a `GANLoss` namedtuple. - generator_optimizer: The optimizer for generator updates, or a function - that takes no arguments and returns an optimizer. This function will be - called when the default graph is the `StarGANEstimator`'s graph, so - utilities like `tf.contrib.framework.get_or_create_global_step` will - work. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a - list of hooks. These hooks are run on the generator and discriminator - train ops, and can be used to implement the GAN training scheme. - Defaults to `train.get_sequential_train_hooks()`. - get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a - dict of metric results keyed by name. The output of this function is - passed into `tf.estimator.EstimatorSpec` during evaluation. - add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. - use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - config: `RunConfig` object to configure the runtime settings. - - Raises: - ValueError: If loss functions aren't callable. - ValueError: If `use_loss_summaries` isn't boolean or `None`. - ValueError: If `get_hooks_fn` isn't callable or `None`. - """ - if not callable(loss_fn): - raise ValueError('loss_fn must be callable.') - if use_loss_summaries not in [True, False, None]: - raise ValueError('use_loss_summaries must be True, False or None.') - if get_hooks_fn is not None and not callable(get_hooks_fn): - raise TypeError('get_hooks_fn must be callable.') - - def _model_fn(features, labels, mode): - """StarGANEstimator model function.""" - if mode not in [ - model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, - model_fn_lib.ModeKeys.PREDICT - ]: - raise ValueError('Mode not recognized: %s' % mode) - - if mode == model_fn_lib.ModeKeys.PREDICT: - input_data = features[0] - input_data_domain_label = features[1] - else: - input_data = features # rename inputs for clarity - input_data_domain_label = labels # rename inputs for clarity - - # Make StarGANModel, which encapsulates the GAN model architectures. - gan_model = _get_gan_model(mode, generator_fn, discriminator_fn, - input_data, input_data_domain_label, - add_summaries) - - # Make the EstimatorSpec, which incorporates the StarGANModel, losses, - # eval, metrics, and optimizers (if required). - return _get_estimator_spec(mode, gan_model, loss_fn, - get_eval_metric_ops_fn, generator_optimizer, - discriminator_optimizer, get_hooks_fn) - - super(StarGANEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) - - -def _get_gan_model(mode, - generator_fn, - discriminator_fn, - input_data, - input_data_domain_label, - add_summaries, - generator_scope='Generator'): - """Makes the StarGANModel tuple.""" - if mode == model_fn_lib.ModeKeys.PREDICT: - gan_model = _make_prediction_gan_model(input_data, input_data_domain_label, - generator_fn, generator_scope) - else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL - gan_model = _make_gan_model(generator_fn, discriminator_fn, input_data, - input_data_domain_label, generator_scope, - add_summaries, mode) - - return gan_model - - -def _get_estimator_spec(mode, - gan_model, - loss_fn, - get_eval_metric_ops_fn, - generator_optimizer, - discriminator_optimizer, - get_hooks_fn=None): - """Get the EstimatorSpec for the current mode.""" - if mode == model_fn_lib.ModeKeys.PREDICT: - estimator_spec = model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data) - else: - gan_loss = loss_fn(gan_model) - if mode == model_fn_lib.ModeKeys.EVAL: - estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, - get_eval_metric_ops_fn) - else: # model_fn_lib.ModeKeys.TRAIN: - gopt = ( - generator_optimizer() - if callable(generator_optimizer) else generator_optimizer) - dopt = ( - discriminator_optimizer() - if callable(discriminator_optimizer) else discriminator_optimizer) - get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() - estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, - dopt, get_hooks_fn) - - return estimator_spec - - -def _make_gan_model(generator_fn, discriminator_fn, input_data, - input_data_domain_label, generator_scope, add_summaries, - mode): - """Construct a `StarGANModel`, and optionally pass in `mode`.""" - # If network functions have an argument `mode`, pass mode to it. - if 'mode' in inspect.getargspec(generator_fn).args: - generator_fn = functools.partial(generator_fn, mode=mode) - if 'mode' in inspect.getargspec(discriminator_fn).args: - discriminator_fn = functools.partial(discriminator_fn, mode=mode) - gan_model = tfgan_train.stargan_model( - generator_fn, - discriminator_fn, - input_data, - input_data_domain_label, - generator_scope=generator_scope) - if add_summaries: - if not isinstance(add_summaries, (tuple, list)): - add_summaries = [add_summaries] - with ops.name_scope(None): - for summary_type in add_summaries: - _summary_type_map[summary_type](gan_model) - - return gan_model - - -def _make_prediction_gan_model(input_data, input_data_domain_label, - generator_fn, generator_scope): - """Make a `StarGANModel` from just the generator.""" - # If `generator_fn` has an argument `mode`, pass mode to it. - if 'mode' in inspect.getargspec(generator_fn).args: - generator_fn = functools.partial( - generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) - with variable_scope.variable_scope(generator_scope) as gen_scope: - # pylint:disable=protected-access - input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) - input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( - input_data_domain_label) - # pylint:enable=protected-access - generated_data = generator_fn(input_data, input_data_domain_label) - generator_variables = variable_lib.get_trainable_variables(gen_scope) - - return tfgan_tuples.StarGANModel( - input_data=input_data, - input_data_domain_label=None, - generated_data=generated_data, - generated_data_domain_target=input_data_domain_label, - reconstructed_data=None, - discriminator_input_data_source_predication=None, - discriminator_generated_data_source_predication=None, - discriminator_input_data_domain_predication=None, - discriminator_generated_data_domain_predication=None, - generator_variables=generator_variables, - generator_scope=generator_scope, - generator_fn=generator_fn, - discriminator_variables=None, - discriminator_scope=None, - discriminator_fn=None) - - -def _get_eval_estimator_spec(gan_model, - gan_loss, - get_eval_metric_ops_fn=None, - name=None): - """Return an EstimatorSpec for the eval case.""" - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - with ops.name_scope(None, 'metrics', - [gan_loss.generator_loss, gan_loss.discriminator_loss]): - - def _summary_key(head_name, val): - return '%s/%s' % (val, head_name) if head_name else val - - eval_metric_ops = { - _summary_key(name, 'generator_loss'): - metrics_lib.mean(gan_loss.generator_loss), - _summary_key(name, 'discriminator_loss'): - metrics_lib.mean(gan_loss.discriminator_loss) - } - if get_eval_metric_ops_fn is not None: - custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) - if not isinstance(custom_eval_metric_ops, dict): - raise TypeError('get_eval_metric_ops_fn must return a dict, ' - 'received: {}'.format(custom_eval_metric_ops)) - eval_metric_ops.update(custom_eval_metric_ops) - return model_fn_lib.EstimatorSpec( - mode=model_fn_lib.ModeKeys.EVAL, - predictions=gan_model.generated_data, - loss=scalar_loss, - eval_metric_ops=eval_metric_ops) - - -def _get_train_estimator_spec(gan_model, - gan_loss, - generator_optimizer, - discriminator_optimizer, - get_hooks_fn, - train_op_fn=tfgan_train.gan_train_ops): - """Return an EstimatorSpec for the train case.""" - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, - discriminator_optimizer) - training_hooks = get_hooks_fn(train_ops) - return model_fn_lib.EstimatorSpec( - loss=scalar_loss, - mode=model_fn_lib.ModeKeys.TRAIN, - train_op=train_ops.global_step_inc_op, - training_hooks=training_hooks) - - -def stargan_prediction_input_fn_wrapper(fn): - """StarGAN Estimator prediction input_fn wrapper. - - Since estimator will disregard the "label" variable pass to the model, we will - use a wrapper to pack the (feature, label) tuple as feature passed to the - model. - - Args: - fn: input_fn for the prediction. - - Returns: - A tuple ((feature, label), None) where the second element is the dummy label - to be disregarded and the first element is the true input to the estimator. - """ - - def new_fn(): - return fn(), None - - return new_fn diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py deleted file mode 100644 index 0fcd1b7924e..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TF-GAN's stargan_estimator.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl as estimator -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.inputs import numpy_io -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import training -from tensorflow.python.training import training_util - - -def dummy_generator_fn(input_data, input_data_domain_label, mode): - del input_data_domain_label, mode - - return variable_scope.get_variable('dummy_g', initializer=0.5) * input_data - - -def dummy_discriminator_fn(input_data, num_domains, mode): - del mode - - hidden = layers.flatten(input_data) - output_src = math_ops.reduce_mean(hidden, axis=1) - output_cls = layers.fully_connected( - inputs=hidden, num_outputs=num_domains, scope='debug') - - return output_src, output_cls - - -class StarGetGANModelTest(test.TestCase, parameterized.TestCase): - """Tests that `StarGetGANModel` produces the correct model.""" - - @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN), - ('eval', model_fn_lib.ModeKeys.EVAL), - ('predict', model_fn_lib.ModeKeys.PREDICT)) - def test_get_gan_model(self, mode): - with ops.Graph().as_default(): - input_data = array_ops.ones([6, 4, 4, 3]) - input_data_domain_label = array_ops.one_hot([0] * 6, 5) - gan_model = estimator._get_gan_model( - mode, - dummy_generator_fn, - dummy_discriminator_fn, - input_data, - input_data_domain_label, - add_summaries=False) - - self.assertEqual(input_data, gan_model.input_data) - self.assertIsNotNone(gan_model.generated_data) - self.assertIsNotNone(gan_model.generated_data_domain_target) - self.assertLen(gan_model.generator_variables, 1) - self.assertIsNotNone(gan_model.generator_scope) - self.assertIsNotNone(gan_model.generator_fn) - if mode == model_fn_lib.ModeKeys.PREDICT: - self.assertIsNone(gan_model.input_data_domain_label) - self.assertEqual(input_data_domain_label, - gan_model.generated_data_domain_target) - self.assertIsNone(gan_model.reconstructed_data) - self.assertIsNone(gan_model.discriminator_input_data_source_predication) - self.assertIsNone( - gan_model.discriminator_generated_data_source_predication) - self.assertIsNone(gan_model.discriminator_input_data_domain_predication) - self.assertIsNone( - gan_model.discriminator_generated_data_domain_predication) - self.assertIsNone(gan_model.discriminator_variables) - self.assertIsNone(gan_model.discriminator_scope) - self.assertIsNone(gan_model.discriminator_fn) - else: - self.assertEqual(input_data_domain_label, - gan_model.input_data_domain_label) - self.assertIsNotNone(gan_model.reconstructed_data.shape) - self.assertIsNotNone( - gan_model.discriminator_input_data_source_predication) - self.assertIsNotNone( - gan_model.discriminator_generated_data_source_predication) - self.assertIsNotNone( - gan_model.discriminator_input_data_domain_predication) - self.assertIsNotNone( - gan_model.discriminator_generated_data_domain_predication) - self.assertLen(gan_model.discriminator_variables, 2) # 1 FC layer - self.assertIsNotNone(gan_model.discriminator_scope) - self.assertIsNotNone(gan_model.discriminator_fn) - - -def get_dummy_gan_model(): - """Similar to get_gan_model().""" - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) - with variable_scope.variable_scope('discriminator') as dis_scope: - dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) - return tfgan_tuples.StarGANModel( - input_data=array_ops.ones([1, 2, 2, 3]), - input_data_domain_label=array_ops.ones([1, 2]), - generated_data=array_ops.ones([1, 2, 2, 3]), - generated_data_domain_target=array_ops.ones([1, 2]), - reconstructed_data=array_ops.ones([1, 2, 2, 3]), - discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var, - discriminator_generated_data_source_predication=array_ops.ones( - [1]) * gen_var * dis_var, - discriminator_input_data_domain_predication=array_ops.ones([1, 2 - ]) * dis_var, - discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) * - gen_var * dis_var, - generator_variables=[gen_var], - generator_scope=gen_scope, - generator_fn=None, - discriminator_variables=[dis_var], - discriminator_scope=dis_scope, - discriminator_fn=None) - - -def dummy_loss_fn(gan_model): - loss = math_ops.reduce_sum( - gan_model.discriminator_input_data_domain_predication - - gan_model.discriminator_generated_data_domain_predication) - loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data) - return tfgan_tuples.GANLoss(loss, loss) - - -def get_metrics(gan_model): - return { - 'mse_custom_metric': - metrics_lib.mean_squared_error(gan_model.input_data, - gan_model.generated_data) - } - - -class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): - """Tests that the EstimatorSpec is constructed appropriately.""" - - @classmethod - def setUpClass(cls): - super(GetEstimatorSpecTest, cls).setUpClass() - cls._generator_optimizer = training.GradientDescentOptimizer(1.0) - cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) - - @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN), - ('eval', model_fn_lib.ModeKeys.EVAL), - ('predict', model_fn_lib.ModeKeys.PREDICT)) - def test_get_estimator_spec(self, mode): - with ops.Graph().as_default(): - self._gan_model = get_dummy_gan_model() - spec = estimator._get_estimator_spec( - mode, - self._gan_model, - loss_fn=dummy_loss_fn, - get_eval_metric_ops_fn=get_metrics, - generator_optimizer=self._generator_optimizer, - discriminator_optimizer=self._discriminator_optimizer) - - self.assertEqual(mode, spec.mode) - if mode == model_fn_lib.ModeKeys.PREDICT: - self.assertEqual(self._gan_model.generated_data, spec.predictions) - elif mode == model_fn_lib.ModeKeys.TRAIN: - self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar - self.assertIsNotNone(spec.train_op) - self.assertIsNotNone(spec.training_hooks) - elif mode == model_fn_lib.ModeKeys.EVAL: - self.assertEqual(self._gan_model.generated_data, spec.predictions) - self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar - self.assertIsNotNone(spec.eval_metric_ops) - - -# TODO(joelshor): Add pandas test. -class StarGANEstimatorIntegrationTest(test.TestCase): - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _test_complete_flow(self, - train_input_fn, - eval_input_fn, - predict_input_fn, - prediction_size, - lr_decay=False): - - def make_opt(): - gstep = training_util.get_or_create_global_step() - lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) - return training.GradientDescentOptimizer(lr) - - gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) - dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) - est = estimator.StarGANEstimator( - generator_fn=dummy_generator_fn, - discriminator_fn=dummy_discriminator_fn, - loss_fn=dummy_loss_fn, - generator_optimizer=gopt, - discriminator_optimizer=dopt, - get_eval_metric_ops_fn=get_metrics, - model_dir=self._model_dir) - - # TRAIN - num_steps = 10 - est.train(train_input_fn, steps=num_steps) - - # EVALUTE - scores = est.evaluate(eval_input_fn) - self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) - self.assertIn('loss', scores) - self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], - scores['loss']) - self.assertIn('mse_custom_metric', scores) - - # PREDICT - predictions = np.array([x for x in est.predict(predict_input_fn)]) - - self.assertAllEqual(prediction_size, predictions.shape) - - @staticmethod - def _numpy_input_fn_wrapper(numpy_input_fn, batch_size, label_size): - """Wrapper to remove the dictionary in numpy_input_fn. - - NOTE: - We create the domain_label here because the model expect a fully define - batch_size from the input. - - Args: - numpy_input_fn: input_fn created from numpy_io - batch_size: (int) number of items for each batch - label_size: (int) number of domains - - Returns: - a new input_fn - """ - - def new_input_fn(): - features = numpy_input_fn() - return features['x'], array_ops.one_hot([0] * batch_size, label_size) - - return new_input_fn - - def test_numpy_input_fn(self): - """Tests complete flow with numpy_input_fn.""" - batch_size = 5 - img_size = 8 - channel_size = 3 - label_size = 3 - image_data = np.zeros( - [batch_size, img_size, img_size, channel_size], dtype=np.float32) - train_input_fn = numpy_io.numpy_input_fn( - x={'x': image_data}, - batch_size=batch_size, - num_epochs=None, - shuffle=True) - eval_input_fn = numpy_io.numpy_input_fn( - x={'x': image_data}, batch_size=batch_size, shuffle=False) - predict_input_fn = numpy_io.numpy_input_fn( - x={'x': image_data}, shuffle=False) - - train_input_fn = self._numpy_input_fn_wrapper(train_input_fn, batch_size, - label_size) - eval_input_fn = self._numpy_input_fn_wrapper(eval_input_fn, batch_size, - label_size) - predict_input_fn = self._numpy_input_fn_wrapper(predict_input_fn, - batch_size, label_size) - - predict_input_fn = estimator.stargan_prediction_input_fn_wrapper( - predict_input_fn) - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - prediction_size=[batch_size, img_size, img_size, channel_size]) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py deleted file mode 100644 index deb381f7be3..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""`tf.Learn` components for `TPUGANEstimator`.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.estimator.python.tpu_gan_estimator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = tpu_gan_estimator_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py deleted file mode 100644 index 8ed64e869a0..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_impl.py +++ /dev/null @@ -1,423 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A TF-GAN-backed GAN Estimator that works on TPU.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as gan_estimator_lib -from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.contrib.tpu.python.tpu import tpu_optimizer -from tensorflow.contrib.training.python.training import training -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops.losses import losses - -__all__ = [ - 'TPUGANEstimator', -] - - -class TPUGANEstimator(tpu_estimator.TPUEstimator): - """An estimator for Generative Adversarial Networks (GANs) on TPU. - - This Estimator is backed by TFGAN. It is similar to `tfgan.GANEstimator`, - but works on TPU. - - Example: - - ```python - import tensorflow as tf - tfgan = tf.contrib.gan - - # See TFGAN's `train.py` for a description of the generator and - # discriminator API. - def generator_fn(generator_inputs): - ... - return generated_data - - def discriminator_fn(data, conditioning): - ... - return logits - - # Create GAN estimator. - config = tpu_config.RunConfig(model_dir='/my/dir') - gan_estimator = tfgan.estimator.TPUGANEstimator( - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_loss_fn=tfgan.losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, - generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), - discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.1, 0.5), - train_batch_size=4, - config=config) - - # Train estimator. - gan_estimator.train(train_input_fn, train_steps) - - # Evaluate resulting estimator. - gan_estimator.evaluate(eval_input_fn, eval_steps) - - # Generate samples from generator. - predictions = np.array([ - x['generated_data'] for x in gan_estimator.predict(predict_input_fn)]) - ``` - """ - - def __init__(self, - # Arguments to construct the `model_fn`. - generator_fn=None, - discriminator_fn=None, - generator_loss_fn=None, - discriminator_loss_fn=None, - generator_optimizer=None, - discriminator_optimizer=None, - get_eval_metric_ops_fn=None, - add_summaries=None, - joint_train=False, - gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1), - # TPUEstimator options. - model_dir=None, - config=None, - params=None, - use_tpu=True, - train_batch_size=None, - eval_batch_size=None, - predict_batch_size=None, - batch_axis=None, - eval_on_tpu=True, - export_to_tpu=True, - warm_start_from=None): - """Initializes a TPUGANEstimator instance. - - Args: - generator_fn: A python function that takes a Tensor, Tensor list, or - Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if - it has an argument called `mode`, the Estimator's `mode` will be passed - in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch - normalization. - discriminator_fn: A python function that takes the output of - `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details - and examples. - generator_loss_fn: The loss function on the generator. Takes a `GANModel` - tuple. - discriminator_loss_fn: The loss function on the discriminator. Takes a - `GANModel` tuple. - generator_optimizer: The optimizer for generator updates, or a function - that takes no arguments and returns an optimizer. This function will - be called when the default graph is the `GANEstimator`'s graph, so - utilities like `tf.contrib.framework.get_or_create_global_step` will - work. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - get_eval_metric_ops_fn: A function that takes a list of arguments and - returns a dict of metric results keyed by name. The output of this - function is passed into `tf.estimator.EstimatorSpec` during evaluation. - The arguments must be: - * generator_inputs - * generated_data - * real_data - * discriminator_real_outputs - * discriminator_gen_outputs - add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. - This is ignored for jobs that run on TPU, such as the train job if - `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`. - joint_train: A Python boolean. If `True`, jointly train the generator and - the discriminator. If `False`, sequentially train them. See `train.py` - in TFGAN for more details on the differences between the two GAN - training methods. - gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio - of generator to discriminator steps. For now, only supports 1:1 - training. - model_dir: Same as `TPUEstimator`: Directory to save model parameters, - graph and etc. This can also be used to load checkpoints from the - directory into a estimator to continue training a previously saved - model. If `None`, the model_dir in `config` will be used if set. If both - are set, they must be same. If both are `None`, a temporary directory - will be used. - config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration - object. Cannot be `None`. - params: Same as `TPUEstimator`: An optional `dict` of hyper parameters - that will be passed into `input_fn` and `model_fn`. Keys are names of - parameters, values are basic python types. There are reserved keys for - `TPUEstimator`, including 'batch_size'. - use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is - enabled. Currently, TPU training and evaluation respect this bit, but - eval_on_tpu can override execution of eval. See below. Predict still - happens on CPU. - train_batch_size: Same as `TPUEstimator`: An int representing the global - training batch size. TPUEstimator transforms this global batch size to a - per-shard batch size, as params['batch_size'], when calling `input_fn` - and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be - divisible by total number of replicas. - eval_batch_size: Same as `TPUEstimator`: An int representing evaluation - batch size. Must be divisible by total number of replicas. - predict_batch_size: Same as `TPUEstimator`: An int representing the - prediction batch size. Must be divisible by total number of replicas. - batch_axis: Same as `TPUEstimator`: A python tuple of int values - describing how each tensor produced by the Estimator `input_fn` should - be split across the TPU compute shards. For example, if your input_fn - produced (images, labels) where the images tensor is in `HWCN` format, - your shard dimensions would be [3, 0], where 3 corresponds to the `N` - dimension of your images Tensor, and 0 corresponds to the dimension - along which to split the labels to match up with the corresponding - images. If None is supplied, and per_host_input_for_training is True, - batches will be sharded based on the major dimension. If - tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, - batch_axis is ignored. - eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or - GPU. In this case, the model_fn must return `EstimatorSpec` when called - with `mode` as `EVAL`. - export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()` - exports a metagraph for serving on TPU besides the one on CPU. - warm_start_from: Same as `TPUEstimator`: Optional string filepath to a - checkpoint or SavedModel to warm-start from, or a - `tf.estimator.WarmStartSettings` object to fully configure - warm-starting. If the string filepath is provided instead of a - `WarmStartSettings`, then all variables are warm-started, and it is - assumed that vocabularies and Tensor names are unchanged. - - Raises: - ValueError: If loss functions aren't callable. - ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps` - tuple. - ValueError: If `gan_train_steps` isn't 1:1 training. - """ - if not callable(generator_loss_fn): - raise ValueError('generator_loss_fn must be callable.') - if not callable(discriminator_loss_fn): - raise ValueError('discriminator_loss_fn must be callable.') - if not isinstance(gan_train_steps, tfgan_tuples.GANTrainSteps): - raise ValueError( - '`gan_train_steps` must be `tfgan_tuples.GANTrainSteps`. Instead, ' - 'was type: %s' % type(gan_train_steps)) - if (gan_train_steps.generator_train_steps != 1 or - gan_train_steps.discriminator_train_steps != 1): - raise ValueError('Estimator currently only supports 1:1 training.') - - if use_tpu: - generator_optimizer = _maybe_make_cross_shard_optimizer( - generator_optimizer) - discriminator_optimizer = _maybe_make_cross_shard_optimizer( - discriminator_optimizer) - - def _model_fn(features, labels, mode, params): - """GANEstimator model function.""" - del params # unused - if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, - model_fn_lib.ModeKeys.PREDICT]: - raise ValueError('Mode not recognized: %s' % mode) - real_data = labels # rename inputs for clarity - generator_inputs = features # rename inputs for clarity - - # Make GANModel, which encapsulates the GAN model architectures. - # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then - # remove `add_summaries` logic below. - is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) - gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access - mode, generator_fn, discriminator_fn, real_data, generator_inputs, - add_summaries=None if is_on_tpu else add_summaries) - - # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval - # metrics, and optimizers (if required). - estimator_spec = _get_estimator_spec( - mode, gan_model, generator_loss_fn, discriminator_loss_fn, - get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - joint_train, is_on_tpu, gan_train_steps) - assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) - return estimator_spec - - super(TPUGANEstimator, self).__init__( - model_fn=_model_fn, - model_dir=model_dir, - config=config, - params=params, - use_tpu=use_tpu, - train_batch_size=train_batch_size, - eval_batch_size=eval_batch_size, - predict_batch_size=predict_batch_size, - batch_axis=batch_axis, - eval_on_tpu=eval_on_tpu, - export_to_tpu=export_to_tpu, - warm_start_from=warm_start_from) - - -def _is_on_tpu(mode, use_tpu, eval_on_tpu): - if mode == model_fn_lib.ModeKeys.TRAIN: - return use_tpu - elif mode == model_fn_lib.ModeKeys.EVAL: - return eval_on_tpu - else: - return False - - -def _get_estimator_spec( - mode, gan_model, generator_loss_fn, discriminator_loss_fn, - get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - joint_train, is_on_tpu, gan_train_steps): - """Get the TPUEstimatorSpec for the current mode.""" - if mode == model_fn_lib.ModeKeys.PREDICT: - estimator_spec = tpu_estimator.TPUEstimatorSpec( - mode=mode, predictions={'generated_data': gan_model.generated_data}) - elif mode == model_fn_lib.ModeKeys.EVAL: - gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn( - gan_model, add_summaries=not is_on_tpu), - discriminator_loss=discriminator_loss_fn( - gan_model, add_summaries=not is_on_tpu)) - # Eval losses for metrics must preserve batch dimension. - gan_loss_no_reduction = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn( - gan_model, add_summaries=False, reduction=losses.Reduction.NONE), - discriminator_loss=discriminator_loss_fn( - gan_model, add_summaries=False, reduction=losses.Reduction.NONE)) - estimator_spec = _get_eval_estimator_spec( - gan_model, gan_loss, gan_loss_no_reduction, get_eval_metric_ops_fn) - else: # model_fn_lib.ModeKeys.TRAIN: - gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn( - gan_model, add_summaries=not is_on_tpu), - discriminator_loss=discriminator_loss_fn( - gan_model, add_summaries=not is_on_tpu)) - - # Construct optimizers if arguments were callable. For TPUs, they must be - # `CrossShardOptimizer`. - g_callable = callable(generator_optimizer) - gopt = generator_optimizer() if g_callable else generator_optimizer - d_callable = callable(discriminator_optimizer) - dopt = discriminator_optimizer() if d_callable else discriminator_optimizer - - estimator_spec = _get_train_estimator_spec( - gan_model, gan_loss, gopt, dopt, joint_train, gan_train_steps) - - return estimator_spec - - -def _get_eval_estimator_spec(gan_model, gan_loss, gan_loss_no_reduction, - get_eval_metric_ops_fn): - """Return an TPUEstimatorSpec for the eval case.""" - # Make the metric function and tensor names. - if get_eval_metric_ops_fn is not None: - def metric_fn( - generator_inputs, generated_data, real_data, discriminator_real_outputs, - discriminator_gen_outputs, generator_loss, discriminator_loss): - """`metric_fn` used in TPUEstimator to calculate metrics.""" - eval_metric_ops = { - 'generator_loss': metrics_lib.mean(generator_loss), - 'discriminator_loss': metrics_lib.mean(discriminator_loss), - } - custom_eval_metric_ops = get_eval_metric_ops_fn( - generator_inputs, generated_data, real_data, - discriminator_real_outputs, discriminator_gen_outputs) - if not isinstance(custom_eval_metric_ops, dict): - raise TypeError('`get_eval_metric_ops_fn` must return a dict, ' - 'received: {}'.format(custom_eval_metric_ops)) - eval_metric_ops.update(custom_eval_metric_ops) - return eval_metric_ops - tensors = { - 'generator_loss': gan_loss_no_reduction.generator_loss, - 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, - 'generator_inputs': gan_model.generator_inputs, - 'generated_data': gan_model.generated_data, - 'real_data': gan_model.real_data, - 'discriminator_real_outputs': gan_model.discriminator_real_outputs, - 'discriminator_gen_outputs': gan_model.discriminator_gen_outputs, - } - else: - def metric_fn(generator_loss, discriminator_loss): - return { - 'generator_loss': metrics_lib.mean(generator_loss), - 'discriminator_loss': metrics_lib.mean(discriminator_loss), - } - tensors = { - 'generator_loss': gan_loss_no_reduction.generator_loss, - 'discriminator_loss': gan_loss_no_reduction.discriminator_loss, - } - - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - return tpu_estimator.TPUEstimatorSpec( - mode=model_fn_lib.ModeKeys.EVAL, - predictions=gan_model.generated_data, - loss=scalar_loss, - eval_metrics=(metric_fn, tensors)) - - -def _get_train_estimator_spec( - gan_model, gan_loss, generator_optimizer, discriminator_optimizer, - joint_train, gan_train_steps): - """Return a TPUEstimatorSpec for the train case.""" - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - - # Get generator and discriminator update ops. We split them so that update - # ops aren't accidentally run multiple times. For now, throw an error if - # there are update ops that aren't associated with either the generator or - # the discriminator. Might modify the `kwargs` dictionary. - gen_update_ops, dis_update_ops = tfgan_train._get_update_ops( # pylint:disable=protected-access - {}, gan_model.generator_scope.name, gan_model.discriminator_scope.name) - - def gen_train_op(): - with ops.name_scope('generator_train'): - return training.create_train_op( - total_loss=gan_loss.generator_loss, - optimizer=generator_optimizer, - variables_to_train=gan_model.generator_variables, - update_ops=gen_update_ops) - def dis_train_op(): - with ops.name_scope('discriminator_train'): - return training.create_train_op( - total_loss=gan_loss.discriminator_loss, - optimizer=discriminator_optimizer, - variables_to_train=gan_model.discriminator_variables, - update_ops=dis_update_ops) - - # Either optimize the generator and discriminator sequentially or jointly. - tpu_train_op = _combine_train_ops(gen_train_op, dis_train_op, joint_train, - gan_train_steps) - - return tpu_estimator.TPUEstimatorSpec( - loss=scalar_loss, - mode=model_fn_lib.ModeKeys.TRAIN, - train_op=tpu_train_op) - - -# TODO(joelshor): Add support for multiple D / G steps. -def _combine_train_ops(gen_train_op, dis_train_op, joint_train, - gan_train_steps): - """Combine generator and discriminator train ops into a single op.""" - del gan_train_steps - if joint_train: - tpu_train_op = control_flow_ops.group(gen_train_op(), dis_train_op(), - name='joint_train') - else: - with ops.control_dependencies([dis_train_op()]): - tpu_train_op = gen_train_op() - - return tpu_train_op - - -def _maybe_make_cross_shard_optimizer(opt): - if callable(opt): - if not isinstance(opt(), tpu_optimizer.CrossShardOptimizer): - return lambda: tpu_optimizer.CrossShardOptimizer(opt()) - elif not isinstance(opt, tpu_optimizer.CrossShardOptimizer): - return tpu_optimizer.CrossShardOptimizer(opt) - return opt diff --git a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py deleted file mode 100644 index baf2c28df4b..00000000000 --- a/tensorflow/contrib/gan/python/estimator/python/tpu_gan_estimator_test.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TF-GAN's TPU Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import shutil -import tempfile - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python.estimator.python import tpu_gan_estimator_impl as estimator -from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses -from tensorflow.contrib.tpu.python.tpu import tpu_config -from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.contrib.tpu.python.tpu import tpu_optimizer -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.estimator import WarmStartSettings -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework.errors_impl import NotFoundError -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import metrics as metrics_lib -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import flags -from tensorflow.python.platform import test -from tensorflow.python.summary.writer import writer_cache -from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import training -from tensorflow.python.training import training_util - -FLAGS = flags.FLAGS - -flags.DEFINE_bool('use_tpu', False, 'Whether to run test on TPU or not.') - - -def generator_fn(noise, mode): - del mode - return layers.fully_connected(noise, tensor_shape.dimension_value( - noise.shape[1])) - - -def discriminator_fn(data, unused_conditioning, mode): - del unused_conditioning, mode - return layers.fully_connected(data, 1) - - -def get_dummy_gan_model(): - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) - with variable_scope.variable_scope('discriminator') as dis_scope: - dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) - return tfgan_tuples.GANModel( - generator_inputs=None, - generated_data=array_ops.ones([3, 4]), - generator_variables=[gen_var], - generator_scope=gen_scope, - generator_fn=None, - real_data=array_ops.zeros([3, 4]), - discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, - discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, - discriminator_variables=[dis_var], - discriminator_scope=dis_scope, - discriminator_fn=None) - - -def get_metrics(generator_inputs, generated_data, real_data, - discriminator_real_outputs, discriminator_gen_outputs): - del generator_inputs, discriminator_real_outputs, discriminator_gen_outputs - return { - 'mse_custom_metric': metrics_lib.mean_squared_error( - real_data, generated_data) - } - - -class GetTPUEstimatorSpecTest(test.TestCase, parameterized.TestCase): - """Tests that the EstimatorSpec is constructed appropriately.""" - - @classmethod - def setUpClass(cls): - super(GetTPUEstimatorSpecTest, cls).setUpClass() - cls._generator_optimizer = tpu_optimizer.CrossShardOptimizer( - training.GradientDescentOptimizer(1.0)) - cls._discriminator_optimizer = tpu_optimizer.CrossShardOptimizer( - training.GradientDescentOptimizer(1.0)) - - @parameterized.named_parameters( - ('joint_train', model_fn_lib.ModeKeys.TRAIN, True), - ('train_sequential', model_fn_lib.ModeKeys.TRAIN, False), - ('eval', model_fn_lib.ModeKeys.EVAL, None), - ('predict', model_fn_lib.ModeKeys.PREDICT, None)) - def test_get_estimator_spec(self, mode, joint_train): - with ops.Graph().as_default(): - self._gan_model = get_dummy_gan_model() - spec = estimator._get_estimator_spec( - mode, - self._gan_model, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - get_eval_metric_ops_fn=get_metrics, - generator_optimizer=self._generator_optimizer, - discriminator_optimizer=self._discriminator_optimizer, - joint_train=joint_train, - is_on_tpu=FLAGS.use_tpu, - gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1)) - - self.assertIsInstance(spec, tpu_estimator.TPUEstimatorSpec) - self.assertEqual(mode, spec.mode) - if mode == model_fn_lib.ModeKeys.PREDICT: - self.assertEqual({'generated_data': self._gan_model.generated_data}, - spec.predictions) - elif mode == model_fn_lib.ModeKeys.TRAIN: - self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar - self.assertIsNotNone(spec.train_op) - self.assertIsNotNone(spec.training_hooks) - elif mode == model_fn_lib.ModeKeys.EVAL: - self.assertEqual(self._gan_model.generated_data, spec.predictions) - self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar - self.assertIsNotNone(spec.eval_metrics) - - -class TPUGANEstimatorIntegrationTest(test.TestCase, parameterized.TestCase): - - def setUp(self): - super(TPUGANEstimatorIntegrationTest, self).setUp() - self._model_dir = tempfile.mkdtemp() - self._config = tpu_config.RunConfig(model_dir=self._model_dir) - - def tearDown(self): - super(TPUGANEstimatorIntegrationTest, self).tearDown() - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) - - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, - lr_decay=False, joint_train=True): - def make_opt(): - gstep = training_util.get_or_create_global_step() - lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) - return training.GradientDescentOptimizer(lr) - - gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) - dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) - est = estimator.TPUGANEstimator( - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - generator_optimizer=gopt, - discriminator_optimizer=dopt, - joint_train=joint_train, - get_eval_metric_ops_fn=get_metrics, - train_batch_size=4, - eval_batch_size=10, - predict_batch_size=8, - use_tpu=FLAGS.use_tpu, - config=self._config) - - # Train. - num_steps_train = 10 - est.train(train_input_fn, steps=num_steps_train) - - # Evaluate. - num_steps_eval = 2 - scores = est.evaluate(eval_input_fn, steps=num_steps_eval) - self.assertIn(ops.GraphKeys.GLOBAL_STEP, scores) - self.assertIn('loss', scores) - self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], - scores['loss']) - self.assertIn('mse_custom_metric', scores) - - # Predict. - predictions = np.array([x['generated_data'] for x in - est.predict(predict_input_fn)]) - self.assertAllEqual(prediction_size, predictions.shape) - - @parameterized.named_parameters( - ('joint_train', True, False, False), - ('train_sequential', False, False, False), - ('lr_decay', False, True, False), - ('train_sequential_ds', False, False, True)) - def test_numpy_input_fn(self, joint_train, lr_decay, return_ds): - """Tests complete flow with numpy_input_fn.""" - input_dim = 4 - def train_input_fn(params): - data = np.zeros([input_dim], dtype=np.float32) - ds = (dataset_ops.Dataset - .from_tensors((data, data)) - .repeat() - .batch(params['batch_size'], drop_remainder=True)) - if return_ds: - return ds - else: - x, y = ds.make_one_shot_iterator().get_next() - return x, y - def eval_input_fn(params): - data = np.zeros([input_dim], dtype=np.float32) - ds = (dataset_ops.Dataset - .from_tensors((data, data)) - .repeat() - .batch(params['batch_size'], drop_remainder=True)) - if return_ds: - return ds - else: - x, y = ds.make_one_shot_iterator().get_next() - return x, y - predict_size = 10 - def predict_input_fn(params): - del params # unused - data = np.zeros([input_dim], dtype=np.float32) - ds = (dataset_ops.Dataset - .from_tensors(data) - .repeat(predict_size) - .batch(1, drop_remainder=True)) - return ds - - self._test_complete_flow( - train_input_fn=train_input_fn, - eval_input_fn=eval_input_fn, - predict_input_fn=predict_input_fn, - prediction_size=[predict_size, input_dim], - lr_decay=lr_decay, - joint_train=joint_train) - - -class TPUGANEstimatorWarmStartTest(test.TestCase): - - def setUp(self): - self._model_dir = self.get_temp_dir() - self._config = tpu_config.RunConfig(model_dir=self._model_dir) - self.new_variable_name = 'new_var' - self.new_variable_value = [1.0, 2.0, 3.0] - - def tearDown(self): - writer_cache.FileWriterCache.clear() - - def _test_warm_start(self, warm_start_from=None): - """Tests whether WarmStartSettings work as intended.""" - def generator_with_new_variable(noise_dict, mode): - variable_scope.get_variable(name=self.new_variable_name, - initializer=self.new_variable_value, - trainable=True) - return generator_fn(noise_dict, mode) - - est = estimator.TPUGANEstimator( - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0), - train_batch_size=4, - use_tpu=FLAGS.use_tpu, - config=self._config) - - def train_input_fn(params): - data = np.zeros([params['batch_size'], 4], dtype=np.float32) - return data, data - - est.train(train_input_fn, steps=1) - - est_warm = estimator.TPUGANEstimator( - generator_fn=generator_with_new_variable, - discriminator_fn=discriminator_fn, - generator_loss_fn=losses.wasserstein_generator_loss, - discriminator_loss_fn=losses.wasserstein_discriminator_loss, - generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0), - config=tpu_config.RunConfig( - model_dir=None if warm_start_from else self._model_dir), - train_batch_size=4, - use_tpu=FLAGS.use_tpu, - warm_start_from=warm_start_from) - - est_warm.train(train_input_fn, steps=1) - - return est_warm - - def test_warm_start_error(self): - """Test if exception when reloading different estimators.""" - with self.assertRaises(NotFoundError): - self._test_warm_start() - - def test_warm_start_success(self): - """Test if GANEstimator allows explicit warm start variable assignment.""" - # Regex matches all variable names in ckpt except for new_var. - var_regex = '^(?!.*%s.*)' % self.new_variable_name - warmstart = WarmStartSettings(ckpt_to_initialize_from=self._model_dir, - vars_to_warm_start=var_regex) - est_warm = self._test_warm_start(warm_start_from=warmstart) - full_variable_name = 'Generator/%s' % self.new_variable_name - self.assertIn(full_variable_name, est_warm.get_variable_names()) - equal_vals = np.array_equal(est_warm.get_variable_value(full_variable_name), - self.new_variable_value) - self.assertTrue(equal_vals) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py deleted file mode 100644 index 92e9abf8a35..00000000000 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TF-GAN evaluation module. - -This module supports techniques such as Inception Score, Frechet Inception -distance, and Sliced Wasserstein distance. -""" -# pylint: disable=,wildcard-import,unused-import - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Collapse eval into a single namespace. -from tensorflow.contrib.gan.python.eval.python import classifier_metrics -from tensorflow.contrib.gan.python.eval.python import eval_utils -from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein -from tensorflow.contrib.gan.python.eval.python import summaries - -from tensorflow.contrib.gan.python.eval.python.classifier_metrics import * -from tensorflow.contrib.gan.python.eval.python.eval_utils import * -from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein import * -from tensorflow.contrib.gan.python.eval.python.summaries import * -# pylint: enable=wildcard-import,unused-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = [ - 'classifier_metrics', - 'sliced_wasserstein_distance', - 'summaries', - 'eval_utils', -] + ( - classifier_metrics.__all__ + sliced_wasserstein.__all__ + - summaries.__all__ + eval_utils.__all__) -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py deleted file mode 100644 index a52e899114b..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Model evaluation tools for TF-GAN.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.eval.python.classifier_metrics_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = classifier_metrics_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py deleted file mode 100644 index 2c301267900..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ /dev/null @@ -1,1115 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Model evaluation tools for TF-GAN. - -These methods come from https://arxiv.org/abs/1606.03498, -https://arxiv.org/abs/1706.08500, and https://arxiv.org/abs/1801.01401. - -NOTE: This implementation uses the same weights as in -https://github.com/openai/improved-gan/blob/master/inception_score/model.py, -but is more numerically stable and is an unbiased estimator of the true -Inception score even when splitting the inputs into batches. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import os -import sys -import tarfile - -from six.moves import urllib - -from tensorflow.contrib.layers.python.layers import layers -from tensorflow.core.framework import graph_pb2 -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import importer -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import image_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import map_fn -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_impl -from tensorflow.python.ops import nn_ops -from tensorflow.python.platform import gfile -from tensorflow.python.platform import resource_loader - -__all__ = [ - 'get_graph_def_from_disk', - 'get_graph_def_from_resource', - 'get_graph_def_from_url_tarball', - 'preprocess_image', - 'run_image_classifier', - 'run_inception', - 'inception_score', - 'classifier_score', - 'classifier_score_from_logits', - 'frechet_inception_distance', - 'frechet_classifier_distance', - 'frechet_classifier_distance_from_activations', - 'mean_only_frechet_classifier_distance_from_activations', - 'diagonal_only_frechet_classifier_distance_from_activations', - 'kernel_inception_distance', - 'kernel_inception_distance_and_std', - 'kernel_classifier_distance', - 'kernel_classifier_distance_and_std', - 'kernel_classifier_distance_from_activations', - 'kernel_classifier_distance_and_std_from_activations', - 'INCEPTION_DEFAULT_IMAGE_SIZE', -] - -INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' -INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' -INCEPTION_INPUT = 'Mul:0' -INCEPTION_OUTPUT = 'logits:0' -INCEPTION_FINAL_POOL = 'pool_3:0' -INCEPTION_DEFAULT_IMAGE_SIZE = 299 - - -def _validate_images(images, image_size): - images = ops.convert_to_tensor(images) - images.shape.with_rank(4) - images.shape.assert_is_compatible_with([None, image_size, image_size, None]) - return images - - -def _symmetric_matrix_square_root(mat, eps=1e-10): - """Compute square root of a symmetric matrix. - - Note that this is different from an elementwise square root. We want to - compute M' where M' = sqrt(mat) such that M' * M' = mat. - - Also note that this method **only** works for symmetric matrices. - - Args: - mat: Matrix to take the square root of. - eps: Small epsilon such that any element less than eps will not be square - rooted to guard against numerical instability. - - Returns: - Matrix square root of mat. - """ - # Unlike numpy, tensorflow's return order is (s, u, v) - s, u, v = linalg_ops.svd(mat) - # sqrt is unstable around 0, just use 0 in such case - si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s)) - # Note that the v returned by Tensorflow is v = V - # (when referencing the equation A = U S V^T) - # This is unlike Numpy which returns v = V^T - return math_ops.matmul( - math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) - - -def preprocess_image(images, - height=INCEPTION_DEFAULT_IMAGE_SIZE, - width=INCEPTION_DEFAULT_IMAGE_SIZE, - scope=None): - """Prepare a batch of images for evaluation. - - This is the preprocessing portion of the graph from - http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz. - - Note that it expects Tensors in [0, 255]. This function maps pixel values to - [-1, 1] and resizes to match the InceptionV1 network. - - Args: - images: 3-D or 4-D Tensor of images. Values are in [0, 255]. - height: Integer. Height of resized output image. - width: Integer. Width of resized output image. - scope: Optional scope for name_scope. - - Returns: - 3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1]. - """ - is_single = images.shape.ndims == 3 - with ops.name_scope(scope, 'preprocess', [images, height, width]): - if not images.dtype.is_floating: - images = math_ops.cast(images, dtypes.float32) - if is_single: - images = array_ops.expand_dims(images, axis=0) - resized = image_ops.resize_bilinear(images, [height, width]) - resized = (resized - 128.0) / 128.0 - if is_single: - resized = array_ops.squeeze(resized, axis=0) - return resized - - -def _kl_divergence(p, p_logits, q): - """Computes the Kullback-Liebler divergence between p and q. - - This function uses p's logits in some places to improve numerical stability. - - Specifically: - - KL(p || q) = sum[ p * log(p / q) ] - = sum[ p * ( log(p) - log(q) ) ] - = sum[ p * ( log_softmax(p_logits) - log(q) ) ] - - Args: - p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch - example and `j` corresponds to the probability of being in class `j`. - p_logits: A 2-D floating-point Tensor corresponding to logits for `p`. - q: A 1-D floating-point Tensor, where q_j corresponds to the probability - of class `j`. - - Returns: - KL divergence between two distributions. Output dimension is 1D, one entry - per distribution in `p`. - - Raises: - ValueError: If any of the inputs aren't floating-point. - ValueError: If p or p_logits aren't 2D. - ValueError: If q isn't 1D. - """ - for tensor in [p, p_logits, q]: - if not tensor.dtype.is_floating: - raise ValueError('Input %s must be floating type.', tensor.name) - p.shape.assert_has_rank(2) - p_logits.shape.assert_has_rank(2) - q.shape.assert_has_rank(1) - return math_ops.reduce_sum( - p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1) - - -def get_graph_def_from_disk(filename): - """Get a GraphDef proto from a disk location.""" - with gfile.GFile(filename, 'rb') as f: - return graph_pb2.GraphDef.FromString(f.read()) - - -def get_graph_def_from_resource(filename): - """Get a GraphDef proto from within a .par file.""" - return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename)) - - -def get_graph_def_from_url_tarball(url, filename, tar_filename=None): - """Get a GraphDef proto from a tarball on the web. - - Args: - url: Web address of tarball - filename: Filename of graph definition within tarball - tar_filename: Temporary download filename (None = always download) - - Returns: - A GraphDef loaded from a file in the downloaded tarball. - """ - if not (tar_filename and os.path.exists(tar_filename)): - - def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % - (url, - float(count * block_size) / float(total_size) * 100.0)) - sys.stdout.flush() - - tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress) - with tarfile.open(tar_filename, 'r:gz') as tar: - proto_str = tar.extractfile(filename).read() - return graph_pb2.GraphDef.FromString(proto_str) - - -def _default_graph_def_fn(): - return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH, - os.path.basename(INCEPTION_URL)) - - -def run_inception(images, - graph_def=None, - default_graph_def_fn=_default_graph_def_fn, - image_size=INCEPTION_DEFAULT_IMAGE_SIZE, - input_tensor=INCEPTION_INPUT, - output_tensor=INCEPTION_OUTPUT): - """Run images through a pretrained Inception classifier. - - Args: - images: Input tensors. Must be [batch, height, width, channels]. Input shape - and values must be in [-1, 1], which can be achieved using - `preprocess_image`. - graph_def: A GraphDef proto of a pretrained Inception graph. If `None`, - call `default_graph_def_fn` to get GraphDef. - default_graph_def_fn: A function that returns a GraphDef. Used if - `graph_def` is `None. By default, returns a pretrained InceptionV3 graph. - image_size: Required image width and height. See unit tests for the default - values. - input_tensor: Name of input Tensor. - output_tensor: Name or list of output Tensors. This function will compute - activations at the specified layer. Examples include INCEPTION_V3_OUTPUT - and INCEPTION_V3_FINAL_POOL which would result in this function computing - the final logits or the penultimate pooling layer. - - Returns: - Tensor or Tensors corresponding to computed `output_tensor`. - - Raises: - ValueError: If images are not the correct size. - ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided. - """ - images = _validate_images(images, image_size) - - if graph_def is None: - if default_graph_def_fn is None: - raise ValueError('If `graph_def` is `None`, must provide ' - '`default_graph_def_fn`.') - graph_def = default_graph_def_fn() - - activations = run_image_classifier(images, graph_def, input_tensor, - output_tensor) - if isinstance(activations, list): - for i, activation in enumerate(activations): - if array_ops.rank(activation) != 2: - activations[i] = layers.flatten(activation) - else: - if array_ops.rank(activations) != 2: - activations = layers.flatten(activations) - - return activations - - -def run_image_classifier(tensor, - graph_def, - input_tensor, - output_tensor, - scope='RunClassifier'): - """Runs a network from a frozen graph. - - Args: - tensor: An Input tensor. - graph_def: A GraphDef proto. - input_tensor: Name of input tensor in graph def. - output_tensor: A tensor name or list of tensor names in graph def. - scope: Name scope for classifier. - - Returns: - Classifier output if `output_tensor` is a string, or a list of outputs if - `output_tensor` is a list. - - Raises: - ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. - """ - input_map = {input_tensor: tensor} - is_singleton = isinstance(output_tensor, str) - if is_singleton: - output_tensor = [output_tensor] - classifier_outputs = importer.import_graph_def( - graph_def, input_map, output_tensor, name=scope) - if is_singleton: - classifier_outputs = classifier_outputs[0] - - return classifier_outputs - - -def classifier_score(images, classifier_fn, num_batches=1): - """Classifier score for evaluating a conditional generative model. - - This is based on the Inception Score, but for an arbitrary classifier. - - This technique is described in detail in https://arxiv.org/abs/1606.03498. In - summary, this function calculates - - exp( E[ KL(p(y|x) || p(y)) ] ) - - which captures how different the network's classification prediction is from - the prior distribution over classes. - - NOTE: This function consumes images, computes their logits, and then - computes the classifier score. If you would like to precompute many logits for - large batches, use classifier_score_from_logits(), which this method also - uses. - - Args: - images: Images to calculate the classifier score for. - classifier_fn: A function that takes images and produces logits based on a - classifier. - num_batches: Number of batches to split `generated_images` in to in order to - efficiently run them through the classifier network. - - Returns: - The classifier score. A floating-point scalar of the same type as the output - of `classifier_fn`. - """ - generated_images_list = array_ops.split( - images, num_or_size_splits=num_batches) - - # Compute the classifier splits using the memory-efficient `map_fn`. - logits = map_fn.map_fn( - fn=classifier_fn, - elems=array_ops.stack(generated_images_list), - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') - logits = array_ops.concat(array_ops.unstack(logits), 0) - - return classifier_score_from_logits(logits) - - -def classifier_score_from_logits(logits): - """Classifier score for evaluating a generative model from logits. - - This method computes the classifier score for a set of logits. This can be - used independently of the classifier_score() method, especially in the case - of using large batches during evaluation where we would like precompute all - of the logits before computing the classifier score. - - This technique is described in detail in https://arxiv.org/abs/1606.03498. In - summary, this function calculates: - - exp( E[ KL(p(y|x) || p(y)) ] ) - - which captures how different the network's classification prediction is from - the prior distribution over classes. - - Args: - logits: Precomputed 2D tensor of logits that will be used to - compute the classifier score. - - Returns: - The classifier score. A floating-point scalar of the same type as the output - of `logits`. - """ - logits.shape.assert_has_rank(2) - - # Use maximum precision for best results. - logits_dtype = logits.dtype - if logits_dtype != dtypes.float64: - logits = math_ops.cast(logits, dtypes.float64) - - p = nn_ops.softmax(logits) - q = math_ops.reduce_mean(p, axis=0) - kl = _kl_divergence(p, logits, q) - kl.shape.assert_has_rank(1) - log_score = math_ops.reduce_mean(kl) - final_score = math_ops.exp(log_score) - - if logits_dtype != dtypes.float64: - final_score = math_ops.cast(final_score, logits_dtype) - - return final_score - - -inception_score = functools.partial( - classifier_score, - classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_OUTPUT)) - - -def trace_sqrt_product(sigma, sigma_v): - """Find the trace of the positive sqrt of product of covariance matrices. - - '_symmetric_matrix_square_root' only works for symmetric matrices, so we - cannot just take _symmetric_matrix_square_root(sigma * sigma_v). - ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). - - Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. - We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) - Note the following properties: - (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) - => eigenvalues(A A B B) = eigenvalues (A B B A) - (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) - => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) - (iii) forall M: trace(M) = sum(eigenvalues(M)) - => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) - = sum(sqrt(eigenvalues(A B B A))) - = sum(eigenvalues(sqrt(A B B A))) - = trace(sqrt(A B B A)) - = trace(sqrt(A sigma_v A)) - A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** - use the _symmetric_matrix_square_root function to find the roots of these - matrices. - - Args: - sigma: a square, symmetric, real, positive semi-definite covariance matrix - sigma_v: same as sigma - - Returns: - The trace of the positive square root of sigma*sigma_v - """ - - # Note sqrt_sigma is called "A" in the proof above - sqrt_sigma = _symmetric_matrix_square_root(sigma) - - # This is sqrt(A sigma_v A) above - sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, - math_ops.matmul(sigma_v, sqrt_sigma)) - - return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) - - -def frechet_classifier_distance(real_images, - generated_images, - classifier_fn, - num_batches=1): - """Classifier distance for evaluating a generative model. - - This is based on the Frechet Inception distance, but for an arbitrary - classifier. - - This technique is described in detail in https://arxiv.org/abs/1706.08500. - Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calculates - - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) - - which captures how different the distributions of real images and generated - images (or more accurately, their visual features) are. Note that unlike the - Inception score, this is a true distance and utilizes information about real - world images. - - Note that when computed using sample means and sample covariance matrices, - Frechet distance is biased. It is more biased for small sample sizes. (e.g. - even if the two distributions are the same, for a small sample size, the - expected Frechet distance is large). It is important to use the same - sample size to compute Frechet classifier distance when comparing two - generative models. - - NOTE: This function consumes images, computes their activations, and then - computes the classifier score. If you would like to precompute many - activations for real and generated images for large batches, please use - frechet_clasifier_distance_from_activations(), which this method also uses. - - Args: - real_images: Real images to use to compute Frechet Inception distance. - generated_images: Generated images to use to compute Frechet Inception - distance. - classifier_fn: A function that takes images and produces activations - based on a classifier. - num_batches: Number of batches to split images in to in order to - efficiently run them through the classifier network. - - Returns: - The Frechet Inception distance. A floating-point scalar of the same type - as the output of `classifier_fn`. - """ - real_images_list = array_ops.split( - real_images, num_or_size_splits=num_batches) - generated_images_list = array_ops.split( - generated_images, num_or_size_splits=num_batches) - - real_imgs = array_ops.stack(real_images_list) - generated_imgs = array_ops.stack(generated_images_list) - - # Compute the activations using the memory-efficient `map_fn`. - def compute_activations(elems): - return map_fn.map_fn(fn=classifier_fn, - elems=elems, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') - - real_a = compute_activations(real_imgs) - gen_a = compute_activations(generated_imgs) - - # Ensure the activations have the right shapes. - real_a = array_ops.concat(array_ops.unstack(real_a), 0) - gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - - return frechet_classifier_distance_from_activations(real_a, gen_a) - - -def mean_only_frechet_classifier_distance_from_activations( - real_activations, generated_activations): - """Classifier distance for evaluating a generative model from activations. - - Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates - - |m - m_w|^2 - - which captures how different the distributions of real images and generated - images (or more accurately, their visual features) are. Note that unlike the - Inception score, this is a true distance and utilizes information about real - world images. - - Note that when computed using sample means and sample covariance matrices, - Frechet distance is biased. It is more biased for small sample sizes. (e.g. - even if the two distributions are the same, for a small sample size, the - expected Frechet distance is large). It is important to use the same - sample size to compute frechet classifier distance when comparing two - generative models. - - In this variant, we only compute the difference between the means of the - fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet - still retains much of the same information as FID. - - Args: - real_activations: 2D array of activations of real images of size - [num_images, num_dims] to use to compute Frechet Inception distance. - generated_activations: 2D array of activations of generated images of size - [num_images, num_dims] to use to compute Frechet Inception distance. - - Returns: - The mean-only Frechet Inception distance. A floating-point scalar of the - same type as the output of the activations. - """ - real_activations.shape.assert_has_rank(2) - generated_activations.shape.assert_has_rank(2) - - activations_dtype = real_activations.dtype - if activations_dtype != dtypes.float64: - real_activations = math_ops.cast(real_activations, dtypes.float64) - generated_activations = math_ops.cast(generated_activations, dtypes.float64) - - # Compute means of activations. - m = math_ops.reduce_mean(real_activations, 0) - m_w = math_ops.reduce_mean(generated_activations, 0) - - # Next the distance between means. - mean = math_ops.reduce_sum( - math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. - mofid = mean - if activations_dtype != dtypes.float64: - mofid = math_ops.cast(mofid, activations_dtype) - - return mofid - - -def diagonal_only_frechet_classifier_distance_from_activations( - real_activations, generated_activations): - """Classifier distance for evaluating a generative model. - - This is based on the Frechet Inception distance, but for an arbitrary - classifier. - - This technique is described in detail in https://arxiv.org/abs/1706.08500. - Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calcuates - - |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2)) - - which captures how different the distributions of real images and generated - images (or more accurately, their visual features) are. Note that unlike the - Inception score, this is a true distance and utilizes information about real - world images. In this variant, we compute diagonal-only covariance matrices. - As a result, instead of computing an expensive matrix square root, we can do - something much simpler, and has O(n) vs O(n^2) space complexity. - - Note that when computed using sample means and sample covariance matrices, - Frechet distance is biased. It is more biased for small sample sizes. (e.g. - even if the two distributions are the same, for a small sample size, the - expected Frechet distance is large). It is important to use the same - sample size to compute frechet classifier distance when comparing two - generative models. - - Args: - real_activations: Real images to use to compute Frechet Inception distance. - generated_activations: Generated images to use to compute Frechet Inception - distance. - - Returns: - The diagonal-only Frechet Inception distance. A floating-point scalar of - the same type as the output of the activations. - - Raises: - ValueError: If the shape of the variance and mean vectors are not equal. - """ - real_activations.shape.assert_has_rank(2) - generated_activations.shape.assert_has_rank(2) - - activations_dtype = real_activations.dtype - if activations_dtype != dtypes.float64: - real_activations = math_ops.cast(real_activations, dtypes.float64) - generated_activations = math_ops.cast(generated_activations, dtypes.float64) - - # Compute mean and covariance matrices of activations. - m, var = nn_impl.moments(real_activations, axes=[0]) - m_w, var_w = nn_impl.moments(generated_activations, axes=[0]) - - actual_shape = var.get_shape() - expected_shape = m.get_shape() - - if actual_shape != expected_shape: - raise ValueError('shape: {} must match expected shape: {}'.format( - actual_shape, expected_shape)) - - # Compute the two components of FID. - - # First the covariance component. - # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.reduce_sum( - (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) - - # Next the distance between means. - mean = math_ops.reduce_sum( - math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. - dofid = trace + mean - if activations_dtype != dtypes.float64: - dofid = math_ops.cast(dofid, activations_dtype) - - return dofid - - -def frechet_classifier_distance_from_activations(real_activations, - generated_activations): - """Classifier distance for evaluating a generative model. - - This methods computes the Frechet classifier distance from activations of - real images and generated images. This can be used independently of the - frechet_classifier_distance() method, especially in the case of using large - batches during evaluation where we would like precompute all of the - activations before computing the classifier distance. - - This technique is described in detail in https://arxiv.org/abs/1706.08500. - Given two Gaussian distribution with means m and m_w and covariance matrices - C and C_w, this function calculates - - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) - - which captures how different the distributions of real images and generated - images (or more accurately, their visual features) are. Note that unlike the - Inception score, this is a true distance and utilizes information about real - world images. - - Note that when computed using sample means and sample covariance matrices, - Frechet distance is biased. It is more biased for small sample sizes. (e.g. - even if the two distributions are the same, for a small sample size, the - expected Frechet distance is large). It is important to use the same - sample size to compute frechet classifier distance when comparing two - generative models. - - Args: - real_activations: 2D Tensor containing activations of real data. Shape is - [batch_size, activation_size]. - generated_activations: 2D Tensor containing activations of generated data. - Shape is [batch_size, activation_size]. - - Returns: - The Frechet Inception distance. A floating-point scalar of the same type - as the output of the activations. - - """ - real_activations.shape.assert_has_rank(2) - generated_activations.shape.assert_has_rank(2) - - activations_dtype = real_activations.dtype - if activations_dtype != dtypes.float64: - real_activations = math_ops.cast(real_activations, dtypes.float64) - generated_activations = math_ops.cast(generated_activations, dtypes.float64) - - # Compute mean and covariance matrices of activations. - m = math_ops.reduce_mean(real_activations, 0) - m_w = math_ops.reduce_mean(generated_activations, 0) - num_examples_real = math_ops.cast( - array_ops.shape(real_activations)[0], dtypes.float64) - num_examples_generated = math_ops.cast( - array_ops.shape(generated_activations)[0], dtypes.float64) - - # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T - real_centered = real_activations - m - sigma = math_ops.matmul( - real_centered, real_centered, transpose_a=True) / ( - num_examples_real - 1) - - gen_centered = generated_activations - m_w - sigma_w = math_ops.matmul( - gen_centered, gen_centered, transpose_a=True) / ( - num_examples_generated - 1) - - # Find the Tr(sqrt(sigma sigma_w)) component of FID - sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) - - # Compute the two components of FID. - - # First the covariance component. - # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component - - # Next the distance between means. - mean = math_ops.reduce_sum( - math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. - fid = trace + mean - if activations_dtype != dtypes.float64: - fid = math_ops.cast(fid, activations_dtype) - - return fid - -frechet_inception_distance = functools.partial( - frechet_classifier_distance, - classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_FINAL_POOL)) - - -def kernel_classifier_distance(real_images, - generated_images, - classifier_fn, - num_classifier_batches=1, - max_block_size=1024, - dtype=None): - """Kernel "classifier" distance for evaluating a generative model. - - This is based on the Kernel Inception distance, but for an arbitrary - embedding. - - This technique is described in detail in https://arxiv.org/abs/1801.01401. - Given two distributions P and Q of activations, this function calculates - - E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] - - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] - - where k is the polynomial kernel - - k(x, y) = ( x^T y / dimension + 1 )^3. - - This captures how different the distributions of real and generated images' - visual features are. Like the Frechet distance (and unlike the Inception - score), this is a true distance and incorporates information about the - target images. Unlike the Frechet score, this function computes an - *unbiased* and asymptotically normal estimator, which makes comparing - estimates across models much more intuitive. - - The estimator used takes time quadratic in max_block_size. Larger values of - max_block_size will decrease the variance of the estimator but increase the - computational cost. This differs slightly from the estimator used by the - original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. - - NOTE: the blocking code assumes that real_activations and - generated_activations are both in random order. If either is sorted in a - meaningful order, the estimator will behave poorly. - - NOTE: This function consumes images, computes their activations, and then - computes the classifier score. If you would like to precompute many - activations for real and generated images for large batches, or to compute - multiple scores based on the same images, please use - kernel_clasifier_distance_from_activations(), which this method also uses. - - Args: - real_images: Real images to use to compute Kernel Inception distance. - generated_images: Generated images to use to compute Kernel Inception - distance. - classifier_fn: A function that takes images and produces activations based - on a classifier. - num_classifier_batches: Number of batches to split images in to in order to - efficiently run them through the classifier network. - max_block_size: integer, default 1024. The distance estimator splits samples - into blocks for computational efficiency. Larger values are more - computationally expensive but decrease the variance of the distance - estimate. - dtype: if not None, coerce activations to this dtype before computations. - - Returns: - The Kernel Inception Distance. A floating-point scalar of the same type - as the output of the activations. - """ - return kernel_classifier_distance_and_std( - real_images, - generated_images, - classifier_fn, - num_classifier_batches=num_classifier_batches, - max_block_size=max_block_size, - dtype=dtype)[0] - - -kernel_inception_distance = functools.partial( - kernel_classifier_distance, - classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_FINAL_POOL)) - - -def kernel_classifier_distance_and_std(real_images, - generated_images, - classifier_fn, - num_classifier_batches=1, - max_block_size=1024, - dtype=None): - """Kernel "classifier" distance for evaluating a generative model. - - This is based on the Kernel Inception distance, but for an arbitrary - embedding. Also returns an estimate of the standard error of the distance - estimator. - - This technique is described in detail in https://arxiv.org/abs/1801.01401. - Given two distributions P and Q of activations, this function calculates - - E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] - - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] - - where k is the polynomial kernel - - k(x, y) = ( x^T y / dimension + 1 )^3. - - This captures how different the distributions of real and generated images' - visual features are. Like the Frechet distance (and unlike the Inception - score), this is a true distance and incorporates information about the - target images. Unlike the Frechet score, this function computes an - *unbiased* and asymptotically normal estimator, which makes comparing - estimates across models much more intuitive. - - The estimator used takes time quadratic in max_block_size. Larger values of - max_block_size will decrease the variance of the estimator but increase the - computational cost. This differs slightly from the estimator used by the - original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. - - NOTE: the blocking code assumes that real_activations and - generated_activations are both in random order. If either is sorted in a - meaningful order, the estimator will behave poorly. - - NOTE: This function consumes images, computes their activations, and then - computes the classifier score. If you would like to precompute many - activations for real and generated images for large batches, or to compute - multiple scores based on the same images, please use - kernel_clasifier_distance_from_activations(), which this method also uses. - - Args: - real_images: Real images to use to compute Kernel Inception distance. - generated_images: Generated images to use to compute Kernel Inception - distance. - classifier_fn: A function that takes images and produces activations based - on a classifier. - num_classifier_batches: Number of batches to split images in to in order to - efficiently run them through the classifier network. - max_block_size: integer, default 1024. The distance estimator splits samples - into blocks for computational efficiency. Larger values are more - computationally expensive but decrease the variance of the distance - estimate. Having a smaller block size also gives a better estimate of the - standard error. - dtype: if not None, coerce activations to this dtype before computations. - - Returns: - The Kernel Inception Distance. A floating-point scalar of the same type - as the output of the activations. - An estimate of the standard error of the distance estimator (a scalar of - the same type). - """ - real_images_list = array_ops.split( - real_images, num_or_size_splits=num_classifier_batches) - generated_images_list = array_ops.split( - generated_images, num_or_size_splits=num_classifier_batches) - - real_imgs = array_ops.stack(real_images_list) - generated_imgs = array_ops.stack(generated_images_list) - - # Compute the activations using the memory-efficient `map_fn`. - def compute_activations(elems): - return map_fn.map_fn( - fn=classifier_fn, - elems=elems, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') - - real_a = compute_activations(real_imgs) - gen_a = compute_activations(generated_imgs) - - # Ensure the activations have the right shapes. - real_a = array_ops.concat(array_ops.unstack(real_a), 0) - gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - - return kernel_classifier_distance_and_std_from_activations( - real_a, gen_a, max_block_size, dtype) - - -kernel_inception_distance_and_std = functools.partial( - kernel_classifier_distance_and_std, - classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_FINAL_POOL)) - - -def kernel_classifier_distance_from_activations(real_activations, - generated_activations, - max_block_size=1024, - dtype=None): - """Kernel "classifier" distance for evaluating a generative model. - - This methods computes the kernel classifier distance from activations of - real images and generated images. This can be used independently of the - kernel_classifier_distance() method, especially in the case of using large - batches during evaluation where we would like to precompute all of the - activations before computing the classifier distance, or if we want to - compute multiple metrics based on the same images. - - This technique is described in detail in https://arxiv.org/abs/1801.01401. - Given two distributions P and Q of activations, this function calculates - - E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] - - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] - - where k is the polynomial kernel - - k(x, y) = ( x^T y / dimension + 1 )^3. - - This captures how different the distributions of real and generated images' - visual features are. Like the Frechet distance (and unlike the Inception - score), this is a true distance and incorporates information about the - target images. Unlike the Frechet score, this function computes an - *unbiased* and asymptotically normal estimator, which makes comparing - estimates across models much more intuitive. - - The estimator used takes time quadratic in max_block_size. Larger values of - max_block_size will decrease the variance of the estimator but increase the - computational cost. This differs slightly from the estimator used by the - original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. - - NOTE: the blocking code assumes that real_activations and - generated_activations are both in random order. If either is sorted in a - meaningful order, the estimator will behave poorly. - - Args: - real_activations: 2D Tensor containing activations of real data. Shape is - [batch_size, activation_size]. - generated_activations: 2D Tensor containing activations of generated data. - Shape is [batch_size, activation_size]. - max_block_size: integer, default 1024. The distance estimator splits samples - into blocks for computational efficiency. Larger values are more - computationally expensive but decrease the variance of the distance - estimate. - dtype: If not None, coerce activations to this dtype before computations. - - Returns: - The Kernel Inception Distance. A floating-point scalar of the same type - as the output of the activations. - """ - return kernel_classifier_distance_and_std_from_activations( - real_activations, generated_activations, max_block_size, dtype)[0] - - -def kernel_classifier_distance_and_std_from_activations(real_activations, - generated_activations, - max_block_size=1024, - dtype=None): - """Kernel "classifier" distance for evaluating a generative model. - - This methods computes the kernel classifier distance from activations of - real images and generated images. This can be used independently of the - kernel_classifier_distance() method, especially in the case of using large - batches during evaluation where we would like to precompute all of the - activations before computing the classifier distance, or if we want to - compute multiple metrics based on the same images. It also returns a rough - estimate of the standard error of the estimator. - - This technique is described in detail in https://arxiv.org/abs/1801.01401. - Given two distributions P and Q of activations, this function calculates - - E_{X, X' ~ P}[k(X, X')] + E_{Y, Y' ~ Q}[k(Y, Y')] - - 2 E_{X ~ P, Y ~ Q}[k(X, Y)] - - where k is the polynomial kernel - - k(x, y) = ( x^T y / dimension + 1 )^3. - - This captures how different the distributions of real and generated images' - visual features are. Like the Frechet distance (and unlike the Inception - score), this is a true distance and incorporates information about the - target images. Unlike the Frechet score, this function computes an - *unbiased* and asymptotically normal estimator, which makes comparing - estimates across models much more intuitive. - - The estimator used takes time quadratic in max_block_size. Larger values of - max_block_size will decrease the variance of the estimator but increase the - computational cost. This differs slightly from the estimator used by the - original paper; it is the block estimator of https://arxiv.org/abs/1307.1954. - The estimate of the standard error will also be more reliable when there are - more blocks, i.e. when max_block_size is smaller. - - NOTE: the blocking code assumes that real_activations and - generated_activations are both in random order. If either is sorted in a - meaningful order, the estimator will behave poorly. - - Args: - real_activations: 2D Tensor containing activations of real data. Shape is - [batch_size, activation_size]. - generated_activations: 2D Tensor containing activations of generated data. - Shape is [batch_size, activation_size]. - max_block_size: integer, default 1024. The distance estimator splits samples - into blocks for computational efficiency. Larger values are more - computationally expensive but decrease the variance of the distance - estimate. Having a smaller block size also gives a better estimate of the - standard error. - dtype: If not None, coerce activations to this dtype before computations. - - Returns: - The Kernel Inception Distance. A floating-point scalar of the same type - as the output of the activations. - An estimate of the standard error of the distance estimator (a scalar of - the same type). - """ - - real_activations.shape.assert_has_rank(2) - generated_activations.shape.assert_has_rank(2) - real_activations.shape[1].assert_is_compatible_with( - generated_activations.shape[1]) - - if dtype is None: - dtype = real_activations.dtype - assert generated_activations.dtype == dtype - else: - real_activations = math_ops.cast(real_activations, dtype) - generated_activations = math_ops.cast(generated_activations, dtype) - - # Figure out how to split the activations into blocks of approximately - # equal size, with none larger than max_block_size. - n_r = array_ops.shape(real_activations)[0] - n_g = array_ops.shape(generated_activations)[0] - - n_bigger = math_ops.maximum(n_r, n_g) - n_blocks = math_ops.cast(math_ops.ceil(n_bigger / max_block_size), - dtypes.int32) - - v_r = n_r // n_blocks - v_g = n_g // n_blocks - - n_plusone_r = n_r - v_r * n_blocks - n_plusone_g = n_g - v_g * n_blocks - - sizes_r = array_ops.concat([ - array_ops.fill([n_blocks - n_plusone_r], v_r), - array_ops.fill([n_plusone_r], v_r + 1), - ], 0) - sizes_g = array_ops.concat([ - array_ops.fill([n_blocks - n_plusone_g], v_g), - array_ops.fill([n_plusone_g], v_g + 1), - ], 0) - - zero = array_ops.zeros([1], dtype=dtypes.int32) - inds_r = array_ops.concat([zero, math_ops.cumsum(sizes_r)], 0) - inds_g = array_ops.concat([zero, math_ops.cumsum(sizes_g)], 0) - - dim = math_ops.cast(real_activations.shape[1], dtype) - - def compute_kid_block(i): - """Computes the ith block of the KID estimate.""" - r_s = inds_r[i] - r_e = inds_r[i + 1] - r = real_activations[r_s:r_e] - m = math_ops.cast(r_e - r_s, dtype) - - g_s = inds_g[i] - g_e = inds_g[i + 1] - g = generated_activations[g_s:g_e] - n = math_ops.cast(g_e - g_s, dtype) - - k_rr = (math_ops.matmul(r, r, transpose_b=True) / dim + 1)**3 - k_rg = (math_ops.matmul(r, g, transpose_b=True) / dim + 1)**3 - k_gg = (math_ops.matmul(g, g, transpose_b=True) / dim + 1)**3 - return (-2 * math_ops.reduce_mean(k_rg) + - (math_ops.reduce_sum(k_rr) - math_ops.trace(k_rr)) / (m * (m - 1)) + - (math_ops.reduce_sum(k_gg) - math_ops.trace(k_gg)) / (n * (n - 1))) - - ests = map_fn.map_fn( - compute_kid_block, math_ops.range(n_blocks), dtype=dtype, back_prop=False) - - mn = math_ops.reduce_mean(ests) - - # nn_impl.moments doesn't use the Bessel correction, which we want here - n_blocks_ = math_ops.cast(n_blocks, dtype) - var = control_flow_ops.cond( - math_ops.less_equal(n_blocks, 1), - lambda: array_ops.constant(float('nan'), dtype=dtype), - lambda: math_ops.reduce_sum(math_ops.square(ests - mn)) / (n_blocks_ - 1)) - - return mn, math_ops.sqrt(var / n_blocks_) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py deleted file mode 100644 index bc7c1057b47..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ /dev/null @@ -1,566 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TF-GAN classifier_metrics.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import tarfile -import tempfile - -from absl.testing import parameterized -import numpy as np -from scipy import linalg as scp_linalg - -from google.protobuf import text_format - -from tensorflow.contrib.gan.python.eval.python import classifier_metrics_impl as classifier_metrics -from tensorflow.core.framework import graph_pb2 -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -mock = test.mock - - -def _numpy_softmax(x): - e_x = np.exp(x - np.max(x, axis=1)[:, None]) - return e_x / np.sum(e_x, axis=1)[:, None] - - -def _expected_inception_score(logits): - p = _numpy_softmax(logits) - q = np.expand_dims(np.mean(p, 0), 0) - per_example_logincscore = np.sum(p * (np.log(p) - np.log(q)), 1) - return np.exp(np.mean(per_example_logincscore)) - - -def _expected_mean_only_fid(real_imgs, gen_imgs): - m = np.mean(real_imgs, axis=0) - m_v = np.mean(gen_imgs, axis=0) - mean = np.square(m - m_v).sum() - mofid = mean - return mofid - - -def _expected_diagonal_only_fid(real_imgs, gen_imgs): - m = np.mean(real_imgs, axis=0) - m_v = np.mean(gen_imgs, axis=0) - var = np.var(real_imgs, axis=0) - var_v = np.var(gen_imgs, axis=0) - sqcc = np.sqrt(var * var_v) - mean = (np.square(m - m_v)).sum() - trace = (var + var_v - 2 * sqcc).sum() - dofid = mean + trace - return dofid - - -def _expected_fid(real_imgs, gen_imgs): - m = np.mean(real_imgs, axis=0) - m_v = np.mean(gen_imgs, axis=0) - sigma = np.cov(real_imgs, rowvar=False) - sigma_v = np.cov(gen_imgs, rowvar=False) - sqcc = scp_linalg.sqrtm(np.dot(sigma, sigma_v)) - mean = np.square(m - m_v).sum() - trace = np.trace(sigma + sigma_v - 2 * sqcc) - fid = mean + trace - return fid - - -def _expected_trace_sqrt_product(sigma, sigma_v): - return np.trace(scp_linalg.sqrtm(np.dot(sigma, sigma_v))) - - -def _expected_kid_and_std(real_imgs, gen_imgs, max_block_size=1024): - n_r, dim = real_imgs.shape - n_g = gen_imgs.shape[0] - - n_blocks = int(np.ceil(max(n_r, n_g) / max_block_size)) - - sizes_r = np.full(n_blocks, n_r // n_blocks) - to_patch = n_r - n_blocks * (n_r // n_blocks) - if to_patch > 0: - sizes_r[-to_patch:] += 1 - inds_r = np.r_[0, np.cumsum(sizes_r)] - assert inds_r[-1] == n_r - - sizes_g = np.full(n_blocks, n_g // n_blocks) - to_patch = n_g - n_blocks * (n_g // n_blocks) - if to_patch > 0: - sizes_g[-to_patch:] += 1 - inds_g = np.r_[0, np.cumsum(sizes_g)] - assert inds_g[-1] == n_g - - ests = [] - for i in range(n_blocks): - r = real_imgs[inds_r[i]:inds_r[i + 1]] - g = gen_imgs[inds_g[i]:inds_g[i + 1]] - - k_rr = (np.dot(r, r.T) / dim + 1)**3 - k_rg = (np.dot(r, g.T) / dim + 1)**3 - k_gg = (np.dot(g, g.T) / dim + 1)**3 - ests.append(-2 * k_rg.mean() + - k_rr[np.triu_indices_from(k_rr, k=1)].mean() + - k_gg[np.triu_indices_from(k_gg, k=1)].mean()) - - var = np.var(ests, ddof=1) if len(ests) > 1 else np.nan - return np.mean(ests), np.sqrt(var / len(ests)) - -# A dummy GraphDef string with the minimum number of Ops. -graphdef_string = """ -node { - name: "Mul" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 299 - } - dim { - size: 299 - } - dim { - size: 3 - } - } - } - } -} -node { - name: "logits" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 1001 - } - } - } - } -} -node { - name: "pool_3" - op: "Placeholder" - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: -1 - } - dim { - size: 2048 - } - } - } - } -} -versions { - producer: 24 -} -""" - - -def _get_dummy_graphdef(): - dummy_graphdef = graph_pb2.GraphDef() - text_format.Merge(graphdef_string, dummy_graphdef) - return dummy_graphdef - - -def _run_with_mock(function, *args, **kwargs): - with mock.patch.object( - classifier_metrics, - 'get_graph_def_from_url_tarball') as mock_tarball_getter: - mock_tarball_getter.return_value = _get_dummy_graphdef() - return function(*args, **kwargs) - - -class ClassifierMetricsTest(test.TestCase, parameterized.TestCase): - - @parameterized.named_parameters( - ('GraphDef', False), - ('DefaultGraphDefFn', True)) - def test_run_inception_graph(self, use_default_graph_def): - """Test `run_inception` graph construction.""" - batch_size = 7 - img = array_ops.ones([batch_size, 299, 299, 3]) - - if use_default_graph_def: - logits = _run_with_mock(classifier_metrics.run_inception, img) - else: - logits = classifier_metrics.run_inception(img, _get_dummy_graphdef()) - - self.assertIsInstance(logits, ops.Tensor) - logits.shape.assert_is_compatible_with([batch_size, 1001]) - - # Check that none of the model variables are trainable. - self.assertListEqual([], variables.trainable_variables()) - - @parameterized.named_parameters( - ('GraphDef', False), - ('DefaultGraphDefFn', True)) - def test_run_inception_graph_pool_output(self, use_default_graph_def): - """Test `run_inception` graph construction with pool output.""" - batch_size = 3 - img = array_ops.ones([batch_size, 299, 299, 3]) - - if use_default_graph_def: - pool = _run_with_mock( - classifier_metrics.run_inception, - img, - output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) - else: - pool = classifier_metrics.run_inception( - img, _get_dummy_graphdef(), - output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) - - self.assertIsInstance(pool, ops.Tensor) - pool.shape.assert_is_compatible_with([batch_size, 2048]) - - # Check that none of the model variables are trainable. - self.assertListEqual([], variables.trainable_variables()) - - def test_run_inception_multiple_outputs(self): - """Test `run_inception` graph construction with multiple outputs.""" - batch_size = 3 - img = array_ops.ones([batch_size, 299, 299, 3]) - logits, pool = _run_with_mock( - classifier_metrics.run_inception, - img, - output_tensor=[ - classifier_metrics.INCEPTION_OUTPUT, - classifier_metrics.INCEPTION_FINAL_POOL - ]) - - self.assertIsInstance(logits, ops.Tensor) - self.assertIsInstance(pool, ops.Tensor) - logits.shape.assert_is_compatible_with([batch_size, 1001]) - pool.shape.assert_is_compatible_with([batch_size, 2048]) - - # Check that none of the model variables are trainable. - self.assertListEqual([], variables.trainable_variables()) - - def test_inception_score_graph(self): - """Test `inception_score` graph construction.""" - score = _run_with_mock( - classifier_metrics.inception_score, - array_ops.zeros([6, 299, 299, 3]), - num_batches=3) - self.assertIsInstance(score, ops.Tensor) - score.shape.assert_has_rank(0) - - # Check that none of the model variables are trainable. - self.assertListEqual([], variables.trainable_variables()) - - def test_frechet_inception_distance_graph(self): - """Test `frechet_inception_distance` graph construction.""" - img = array_ops.ones([7, 299, 299, 3]) - distance = _run_with_mock( - classifier_metrics.frechet_inception_distance, img, img) - - self.assertIsInstance(distance, ops.Tensor) - distance.shape.assert_has_rank(0) - - # Check that none of the model variables are trainable. - self.assertListEqual([], variables.trainable_variables()) - - def test_kernel_inception_distance_graph(self): - """Test `frechet_inception_distance` graph construction.""" - img = array_ops.ones([7, 299, 299, 3]) - distance = _run_with_mock(classifier_metrics.kernel_inception_distance, img, - img) - - self.assertIsInstance(distance, ops.Tensor) - distance.shape.assert_has_rank(0) - - # Check that none of the model variables are trainable. - self.assertListEqual([], variables.trainable_variables()) - - def test_run_inception_multicall(self): - """Test that `run_inception` can be called multiple times.""" - for batch_size in (7, 3, 2): - img = array_ops.ones([batch_size, 299, 299, 3]) - _run_with_mock(classifier_metrics.run_inception, img) - - def test_invalid_input(self): - """Test that functions properly fail on invalid input.""" - with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): - classifier_metrics.run_inception(array_ops.ones([7, 50, 50, 3])) - - p = array_ops.zeros([8, 10]) - p_logits = array_ops.zeros([8, 10]) - q = array_ops.zeros([10]) - with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q) - - with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence(p, - array_ops.zeros( - [8, 10], dtype=dtypes.int32), q) - - with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence(p, p_logits, - array_ops.zeros( - [10], dtype=dtypes.int32)) - - with self.assertRaisesRegexp(ValueError, 'must have rank 2'): - classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q) - - with self.assertRaisesRegexp(ValueError, 'must have rank 2'): - classifier_metrics._kl_divergence(p, array_ops.zeros([8]), q) - - with self.assertRaisesRegexp(ValueError, 'must have rank 1'): - classifier_metrics._kl_divergence(p, p_logits, array_ops.zeros([10, 8])) - - def test_inception_score_value(self): - """Test that `inception_score` gives the correct value.""" - logits = np.array( - [np.array([1, 2] * 500 + [4]), - np.array([4, 5] * 500 + [6])]) - unused_image = array_ops.zeros([2, 299, 299, 3]) - incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) - - with self.cached_session(use_gpu=True) as sess: - incscore_np = sess.run(incscore, {'concat:0': logits}) - - self.assertAllClose(_expected_inception_score(logits), incscore_np) - - def test_mean_only_frechet_classifier_distance_value(self): - """Test that `frechet_classifier_distance` gives the correct value.""" - np.random.seed(0) - - pool_real_a = np.float32(np.random.randn(256, 2048)) - pool_gen_a = np.float32(np.random.randn(256, 2048)) - - tf_pool_real_a = array_ops.constant(pool_real_a) - tf_pool_gen_a = array_ops.constant(pool_gen_a) - - mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long - tf_pool_real_a, tf_pool_gen_a) - - with self.cached_session() as sess: - actual_mofid = sess.run(mofid_op) - - expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a) - - self.assertAllClose(expected_mofid, actual_mofid, 0.0001) - - def test_diagonal_only_frechet_classifier_distance_value(self): - """Test that `frechet_classifier_distance` gives the correct value.""" - np.random.seed(0) - - pool_real_a = np.float32(np.random.randn(256, 2048)) - pool_gen_a = np.float32(np.random.randn(256, 2048)) - - tf_pool_real_a = array_ops.constant(pool_real_a) - tf_pool_gen_a = array_ops.constant(pool_gen_a) - - dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long - tf_pool_real_a, tf_pool_gen_a) - - with self.cached_session() as sess: - actual_dofid = sess.run(dofid_op) - - expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a) - - self.assertAllClose(expected_dofid, actual_dofid, 0.0001) - - def test_frechet_classifier_distance_value(self): - """Test that `frechet_classifier_distance` gives the correct value.""" - np.random.seed(0) - - # Make num_examples > num_features to ensure scipy's sqrtm function - # doesn't return a complex matrix. - test_pool_real_a = np.float32(np.random.randn(512, 256)) - test_pool_gen_a = np.float32(np.random.randn(512, 256)) - - fid_op = _run_with_mock( - classifier_metrics.frechet_classifier_distance, - test_pool_real_a, - test_pool_gen_a, - classifier_fn=lambda x: x) - - with self.cached_session() as sess: - actual_fid = sess.run(fid_op) - - expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a) - - self.assertAllClose(expected_fid, actual_fid, 0.0001) - - def test_frechet_classifier_distance_covariance(self): - """Test that `frechet_classifier_distance` takes covariance into account.""" - np.random.seed(0) - - # Make num_examples > num_features to ensure scipy's sqrtm function - # doesn't return a complex matrix. - test_pool_reals, test_pool_gens = [], [] - for i in range(1, 11, 2): - test_pool_reals.append(np.float32(np.random.randn(2048, 256) * i)) - test_pool_gens.append(np.float32(np.random.randn(2048, 256) * i)) - - fid_ops = [] - for i in range(len(test_pool_reals)): - fid_ops.append(_run_with_mock( - classifier_metrics.frechet_classifier_distance, - test_pool_reals[i], - test_pool_gens[i], - classifier_fn=lambda x: x)) - - fids = [] - with self.cached_session() as sess: - for fid_op in fid_ops: - fids.append(sess.run(fid_op)) - - # Check that the FIDs increase monotonically. - self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:]))) - - def test_kernel_classifier_distance_value(self): - """Test that `kernel_classifier_distance` gives the correct value.""" - np.random.seed(0) - - test_pool_real_a = np.float32(np.random.randn(512, 256)) - test_pool_gen_a = np.float32(np.random.randn(512, 256) * 1.1 + .05) - - kid_op = _run_with_mock( - classifier_metrics.kernel_classifier_distance_and_std, - test_pool_real_a, - test_pool_gen_a, - classifier_fn=lambda x: x, - max_block_size=600) - - with self.cached_session() as sess: - actual_kid, actual_std = sess.run(kid_op) - - expected_kid, expected_std = _expected_kid_and_std(test_pool_real_a, - test_pool_gen_a) - - self.assertAllClose(expected_kid, actual_kid, 0.001) - self.assertAllClose(expected_std, actual_std, 0.001) - - def test_kernel_classifier_distance_block_sizes(self): - """Test that `kernel_classifier_distance` works with unusual max_block_size - - values.. - """ - np.random.seed(0) - - test_pool_real_a = np.float32(np.random.randn(512, 256)) - test_pool_gen_a = np.float32(np.random.randn(768, 256) * 1.1 + .05) - - max_block_size = array_ops.placeholder(dtypes.int32, shape=()) - kid_op = _run_with_mock( - classifier_metrics.kernel_classifier_distance_and_std_from_activations, - array_ops.constant(test_pool_real_a), - array_ops.constant(test_pool_gen_a), - max_block_size=max_block_size) - - for block_size in [50, 512, 1000]: - with self.cached_session() as sess: - actual_kid, actual_std = sess.run(kid_op, {max_block_size: block_size}) - - expected_kid, expected_std = _expected_kid_and_std( - test_pool_real_a, test_pool_gen_a, max_block_size=block_size) - - self.assertAllClose(expected_kid, actual_kid, 0.001) - self.assertAllClose(expected_std, actual_std, 0.001) - - def test_trace_sqrt_product_value(self): - """Test that `trace_sqrt_product` gives the correct value.""" - np.random.seed(0) - - # Make num_examples > num_features to ensure scipy's sqrtm function - # doesn't return a complex matrix. - test_pool_real_a = np.float32(np.random.randn(512, 256)) - test_pool_gen_a = np.float32(np.random.randn(512, 256)) - - cov_real = np.cov(test_pool_real_a, rowvar=False) - cov_gen = np.cov(test_pool_gen_a, rowvar=False) - - trace_sqrt_prod_op = _run_with_mock(classifier_metrics.trace_sqrt_product, - cov_real, cov_gen) - - with self.cached_session() as sess: - # trace_sqrt_product: tsp - actual_tsp = sess.run(trace_sqrt_prod_op) - - expected_tsp = _expected_trace_sqrt_product(cov_real, cov_gen) - - self.assertAllClose(actual_tsp, expected_tsp, 0.01) - - def test_preprocess_image_graph(self): - """Test `preprocess_image` graph construction.""" - incorrectly_sized_image = array_ops.zeros([520, 240, 3]) - correct_image = classifier_metrics.preprocess_image( - images=incorrectly_sized_image) - _run_with_mock(classifier_metrics.run_inception, - array_ops.expand_dims(correct_image, 0)) - - def test_get_graph_def_from_url_tarball(self): - """Test `get_graph_def_from_url_tarball`.""" - # Write dummy binary GraphDef to tempfile. - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - tmp_file.write(_get_dummy_graphdef().SerializeToString()) - relative_path = os.path.relpath(tmp_file.name) - - # Create gzip tarball. - tar_dir = tempfile.mkdtemp() - tar_filename = os.path.join(tar_dir, 'tmp.tar.gz') - with tarfile.open(tar_filename, 'w:gz') as tar: - tar.add(relative_path) - - with mock.patch.object(classifier_metrics, 'urllib') as mock_urllib: - mock_urllib.request.urlretrieve.return_value = tar_filename, None - graph_def = classifier_metrics.get_graph_def_from_url_tarball( - 'unused_url', relative_path) - - self.assertIsInstance(graph_def, graph_pb2.GraphDef) - self.assertEqual(_get_dummy_graphdef(), graph_def) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/eval_utils.py b/tensorflow/contrib/gan/python/eval/python/eval_utils.py deleted file mode 100644 index bb7327040c9..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/eval_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility file for visualizing generated images.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.eval.python import eval_utils_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.eval.python.eval_utils_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = eval_utils_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py b/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py deleted file mode 100644 index 6623b56c706..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/eval_utils_impl.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utility file for visualizing generated images.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops - - -__all__ = [ - "image_grid", - "image_reshaper", -] - - -# TODO(joelshor): Make this a special case of `image_reshaper`. -def image_grid(input_tensor, grid_shape, image_shape=(32, 32), num_channels=3): - """Arrange a minibatch of images into a grid to form a single image. - - Args: - input_tensor: Tensor. Minibatch of images to format, either 4D - ([batch size, height, width, num_channels]) or flattened - ([batch size, height * width * num_channels]). - grid_shape: Sequence of int. The shape of the image grid, - formatted as [grid_height, grid_width]. - image_shape: Sequence of int. The shape of a single image, - formatted as [image_height, image_width]. - num_channels: int. The number of channels in an image. - - Returns: - Tensor representing a single image in which the input images have been - arranged into a grid. - - Raises: - ValueError: The grid shape and minibatch size don't match, or the image - shape and number of channels are incompatible with the input tensor. - """ - if grid_shape[0] * grid_shape[1] != int(input_tensor.shape[0]): - raise ValueError("Grid shape %s incompatible with minibatch size %i." % - (grid_shape, int(input_tensor.shape[0]))) - if len(input_tensor.shape) == 2: - num_features = image_shape[0] * image_shape[1] * num_channels - if int(input_tensor.shape[1]) != num_features: - raise ValueError("Image shape and number of channels incompatible with " - "input tensor.") - elif len(input_tensor.shape) == 4: - if (int(input_tensor.shape[1]) != image_shape[0] or - int(input_tensor.shape[2]) != image_shape[1] or - int(input_tensor.shape[3]) != num_channels): - raise ValueError("Image shape and number of channels incompatible with " - "input tensor.") - else: - raise ValueError("Unrecognized input tensor format.") - height, width = grid_shape[0] * image_shape[0], grid_shape[1] * image_shape[1] - input_tensor = array_ops.reshape( - input_tensor, tuple(grid_shape) + tuple(image_shape) + (num_channels,)) - input_tensor = array_ops.transpose(input_tensor, [0, 1, 3, 2, 4]) - input_tensor = array_ops.reshape( - input_tensor, [grid_shape[0], width, image_shape[0], num_channels]) - input_tensor = array_ops.transpose(input_tensor, [0, 2, 1, 3]) - input_tensor = array_ops.reshape( - input_tensor, [1, height, width, num_channels]) - return input_tensor - - -def _validate_images(images): - for img in images: - img.shape.assert_has_rank(3) - img.shape.assert_is_fully_defined() - if img.shape[-1] not in (1, 3): - raise ValueError("image_reshaper only supports 1 or 3 channel images.") - - -# TODO(joelshor): Move the dimension logic from Python to Tensorflow. -def image_reshaper(images, num_cols=None): - """A reshaped summary image. - - Returns an image that will contain all elements in the list and will be - laid out in a nearly-square tiling pattern (e.g. 11 images will lead to a - 3x4 tiled image). - - Args: - images: Image data to summarize. Can be an RGB or grayscale image, a list of - such images, or a set of RGB images concatenated along the depth - dimension. The shape of each image is assumed to be [batch_size, - height, width, depth]. - num_cols: (Optional) If provided, this is the number of columns in the final - output image grid. Otherwise, the number of columns is determined by - the number of images. - - Returns: - A summary image matching the input with automatic tiling if needed. - Output shape is [1, height, width, channels]. - """ - if isinstance(images, ops.Tensor): - images = array_ops.unstack(images) - _validate_images(images) - - num_images = len(images) - num_columns = (num_cols if num_cols else - int(math.ceil(math.sqrt(num_images)))) - num_rows = int(math.ceil(float(num_images) / num_columns)) - rows = [images[x:x+num_columns] for x in range(0, num_images, num_columns)] - - # Add empty image tiles if the last row is incomplete. - num_short = num_rows * num_columns - num_images - assert num_short >= 0 and num_short < num_columns - if num_short > 0: - rows[-1].extend([array_ops.zeros_like(images[-1])] * num_short) - - # Convert each row from a list of tensors to a single tensor. - rows = [array_ops.concat(row, 1) for row in rows] - - # Stack rows vertically. - img = array_ops.concat(rows, 0) - - return array_ops.expand_dims(img, 0) diff --git a/tensorflow/contrib/gan/python/eval/python/eval_utils_test.py b/tensorflow/contrib/gan/python/eval/python/eval_utils_test.py deleted file mode 100644 index cfed4dc513e..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/eval_utils_test.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for eval_utils_test.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.eval.python import eval_utils_impl as eval_utils -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class UtilsTest(test.TestCase): - - def test_image_grid(self): - eval_utils.image_grid( - input_tensor=array_ops.zeros([25, 32, 32, 3]), - grid_shape=(5, 5)) - - # TODO(joelshor): Add more `image_reshaper` tests. - def test_image_reshaper_image_list(self): - images = eval_utils.image_reshaper( - images=array_ops.unstack(array_ops.zeros([25, 32, 32, 3])), - num_cols=2) - images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3]) - - def test_image_reshaper_image(self): - images = eval_utils.image_reshaper( - images=array_ops.zeros([25, 32, 32, 3]), - num_cols=2) - images.shape.assert_is_compatible_with([1, 13 * 32, 2 * 32, 3]) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py deleted file mode 100644 index 326fcb3cdbf..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Model evaluation tools for TF-GAN.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = sliced_wasserstein_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py deleted file mode 100644 index 9657d4e3d0c..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Implementation of Sliced Wasserstein Distance. - -Proposed in https://arxiv.org/abs/1710.10196 and the official Theano -implementation that we used as reference can be found here: -https://github.com/tkarras/progressive_growing_of_gans - -Note: this is not an exact distance but an approximation through random -projections. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import script_ops - -__all__ = ['sliced_wasserstein_distance'] -_GAUSSIAN_FILTER = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ - 6, 24, 36, 24, 6 -], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]).reshape([5, 5, 1, 1]) / 256.0 - - -def _laplacian_pyramid(batch, num_levels): - """Compute a Laplacian pyramid. - - Args: - batch: (tensor) The batch of images (batch, height, width, channels). - num_levels: (int) Desired number of hierarchical levels. - Returns: - List of tensors from the highest to lowest resolution. - """ - gaussian_filter = constant_op.constant(_GAUSSIAN_FILTER) - - def spatial_conv(batch, gain): - s = array_ops.shape(batch) - padded = array_ops.pad(batch, [[0, 0], [2, 2], [2, 2], [0, 0]], 'REFLECT') - xt = array_ops.transpose(padded, [0, 3, 1, 2]) - xt = array_ops.reshape(xt, [s[0] * s[3], s[1] + 4, s[2] + 4, 1]) - conv_out = nn_ops.conv2d(xt, gaussian_filter * gain, [1] * 4, 'VALID') - conv_xt = array_ops.reshape(conv_out, [s[0], s[3], s[1], s[2]]) - conv_xt = array_ops.transpose(conv_xt, [0, 2, 3, 1]) - return conv_xt - - def pyr_down(batch): # matches cv2.pyrDown() - return spatial_conv(batch, 1)[:, ::2, ::2] - - def pyr_up(batch): # matches cv2.pyrUp() - s = array_ops.shape(batch) - zeros = array_ops.zeros([3 * s[0], s[1], s[2], s[3]]) - res = array_ops.concat([batch, zeros], 0) - res = array_ops.batch_to_space(res, crops=[[0, 0], [0, 0]], block_size=2) - res = spatial_conv(res, 4) - return res - - pyramid = [math_ops.cast(batch, dtypes.float32)] - for _ in range(1, num_levels): - pyramid.append(pyr_down(pyramid[-1])) - pyramid[-2] -= pyr_up(pyramid[-1]) - return pyramid - - -def _batch_to_patches(batch, patches_per_image, patch_size): - """Extract patches from a batch. - - Args: - batch: (tensor) The batch of images (batch, height, width, channels). - patches_per_image: (int) Number of patches to extract per image. - patch_size: (int) Size of the patches (size, size, channels) to extract. - Returns: - Tensor (batch*patches_per_image, patch_size, patch_size, channels) of - patches. - """ - - def py_func_random_patches(batch): - """Numpy wrapper.""" - batch_size, height, width, channels = batch.shape - patch_count = patches_per_image * batch_size - hs = patch_size // 2 - # Randomly pick patches. - patch_id, y, x, chan = np.ogrid[0:patch_count, -hs:hs + 1, -hs:hs + 1, 0:3] - img_id = patch_id // patches_per_image - # pylint: disable=g-no-augmented-assignment - # Need explicit addition for broadcast to work properly. - y = y + np.random.randint(hs, height - hs, size=(patch_count, 1, 1, 1)) - x = x + np.random.randint(hs, width - hs, size=(patch_count, 1, 1, 1)) - # pylint: enable=g-no-augmented-assignment - idx = ((img_id * height + y) * width + x) * channels + chan - patches = batch.flat[idx] - return patches - - patches = script_ops.py_func( - py_func_random_patches, [batch], batch.dtype, stateful=False) - return patches - - -def _normalize_patches(patches): - """Normalize patches by their mean and standard deviation. - - Args: - patches: (tensor) The batch of patches (batch, size, size, channels). - Returns: - Tensor (batch, size, size, channels) of the normalized patches. - """ - patches = array_ops.concat(patches, 0) - mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True) - patches = (patches - mean) / math_ops.sqrt(variance) - return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1]) - - -def _sort_rows(matrix, num_rows): - """Sort matrix rows by the last column. - - Args: - matrix: a matrix of values (row,col). - num_rows: (int) number of sorted rows to return from the matrix. - Returns: - Tensor (num_rows, col) of the sorted matrix top K rows. - """ - tmatrix = array_ops.transpose(matrix, [1, 0]) - sorted_tmatrix = nn_ops.top_k(tmatrix, num_rows)[0] - return array_ops.transpose(sorted_tmatrix, [1, 0]) - - -def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): - """Compute the approximate sliced Wasserstein distance. - - Args: - a: (matrix) Distribution "a" of samples (row, col). - b: (matrix) Distribution "b" of samples (row, col). - random_sampling_count: (int) Number of random projections to average. - random_projection_dim: (int) Dimension of the random projection space. - Returns: - Float containing the approximate distance between "a" and "b". - """ - s = array_ops.shape(a) - means = [] - for _ in range(random_sampling_count): - # Random projection matrix. - proj = random_ops.random_normal( - [array_ops.shape(a)[1], random_projection_dim]) - proj *= math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True)) - # Project both distributions and sort them. - proj_a = math_ops.matmul(a, proj) - proj_b = math_ops.matmul(b, proj) - proj_a = _sort_rows(proj_a, s[0]) - proj_b = _sort_rows(proj_b, s[0]) - # Pairwise Wasserstein distance. - wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) - means.append(wdist) - return math_ops.reduce_mean(means) - - -def _sliced_wasserstein_svd(a, b): - """Compute the approximate sliced Wasserstein distance using an SVD. - - This is not part of the paper, it's a variant with possibly more accurate - measure. - - Args: - a: (matrix) Distribution "a" of samples (row, col). - b: (matrix) Distribution "b" of samples (row, col). - Returns: - Float containing the approximate distance between "a" and "b". - """ - s = array_ops.shape(a) - # Random projection matrix. - sig, u = linalg_ops.svd(array_ops.concat([a, b], 0))[:2] - proj_a, proj_b = array_ops.split(u * sig, 2, axis=0) - proj_a = _sort_rows(proj_a[:, ::-1], s[0]) - proj_b = _sort_rows(proj_b[:, ::-1], s[0]) - # Pairwise Wasserstein distance. - wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) - return wdist - - -def sliced_wasserstein_distance(real_images, - fake_images, - resolution_min=16, - patches_per_image=64, - patch_size=7, - random_sampling_count=1, - random_projection_dim=7 * 7 * 3, - use_svd=False): - """Compute the Wasserstein distance between two distributions of images. - - Note that measure vary with the number of images. Use 8192 images to get - numbers comparable to the ones in the original paper. - - Args: - real_images: (tensor) Real images (batch, height, width, channels). - fake_images: (tensor) Fake images (batch, height, width, channels). - resolution_min: (int) Minimum resolution for the Laplacian pyramid. - patches_per_image: (int) Number of patches to extract per image per - Laplacian level. - patch_size: (int) Width of a square patch. - random_sampling_count: (int) Number of random projections to average. - random_projection_dim: (int) Dimension of the random projection space. - use_svd: experimental method to compute a more accurate distance. - Returns: - List of tuples (distance_real, distance_fake) for each level of the - Laplacian pyramid from the highest resolution to the lowest. - distance_real is the Wasserstein distance between real images - distance_fake is the Wasserstein distance between real and fake images. - Raises: - ValueError: If the inputs shapes are incorrect. Input tensor dimensions - (batch, height, width, channels) are expected to be known at graph - construction time. In addition height and width must be the same and the - number of colors should be exactly 3. Real and fake images must have the - same size. - """ - height = real_images.shape[1] - real_images.shape.assert_is_compatible_with([None, None, height, 3]) - fake_images.shape.assert_is_compatible_with(real_images.shape) - - # Select resolutions. - resolution_full = int(height) - resolution_min = min(resolution_min, resolution_full) - resolution_max = resolution_full - # Base loss of detail. - resolutions = [ - 2**i - for i in range( - int(np.log2(resolution_max)), - int(np.log2(resolution_min)) - 1, -1) - ] - - # Gather patches for each level of the Laplacian pyramids. - patches_real, patches_fake, patches_test = ( - [[] for _ in resolutions] for _ in range(3)) - for lod, level in enumerate( - _laplacian_pyramid(real_images, len(resolutions))): - patches_real[lod].append( - _batch_to_patches(level, patches_per_image, patch_size)) - patches_test[lod].append( - _batch_to_patches(level, patches_per_image, patch_size)) - - for lod, level in enumerate( - _laplacian_pyramid(fake_images, len(resolutions))): - patches_fake[lod].append( - _batch_to_patches(level, patches_per_image, patch_size)) - - for lod in range(len(resolutions)): - for patches in [patches_real, patches_test, patches_fake]: - patches[lod] = _normalize_patches(patches[lod]) - - # Evaluate scores. - scores = [] - for lod in range(len(resolutions)): - if not use_svd: - scores.append( - (_sliced_wasserstein(patches_real[lod], patches_test[lod], - random_sampling_count, random_projection_dim), - _sliced_wasserstein(patches_real[lod], patches_fake[lod], - random_sampling_count, random_projection_dim))) - else: - scores.append( - (_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]), - _sliced_wasserstein_svd(patches_real[lod], patches_fake[lod]))) - return scores diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py deleted file mode 100644 index ab909feae37..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Sliced Wasserstein Distance.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from scipy import ndimage -from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl as swd -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.platform import test - - -class ClassifierMetricsTest(test.TestCase): - - def test_laplacian_pyramid(self): - # The numpy/scipy code for reference estimation comes from: - # https://github.com/tkarras/progressive_growing_of_gans - gaussian_filter = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ - 6, 24, 36, 24, 6 - ], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]) / 256.0 - - def np_pyr_down(minibatch): # matches cv2.pyrDown() - assert minibatch.ndim == 4 - return ndimage.convolve( - minibatch, - gaussian_filter[np.newaxis, np.newaxis, :, :], - mode='mirror')[:, :, ::2, ::2] - - def np_pyr_up(minibatch): # matches cv2.pyrUp() - assert minibatch.ndim == 4 - s = minibatch.shape - res = np.zeros((s[0], s[1], s[2] * 2, s[3] * 2), minibatch.dtype) - res[:, :, ::2, ::2] = minibatch - return ndimage.convolve( - res, - gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, - mode='mirror') - - def np_laplacian_pyramid(minibatch, num_levels): - # Note: there's a bug in the original SWD, fixed repeatability. - pyramid = [minibatch.astype('f').copy()] - for _ in range(1, num_levels): - pyramid.append(np_pyr_down(pyramid[-1])) - pyramid[-2] -= np_pyr_up(pyramid[-1]) - return pyramid - - data = np.random.normal(size=[256, 3, 32, 32]).astype('f') - pyramid = np_laplacian_pyramid(data, 3) - data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3]) - pyramid_tf = swd._laplacian_pyramid(data_tf, 3) - with self.cached_session() as sess: - pyramid_tf = sess.run( - pyramid_tf, feed_dict={ - data_tf: data.transpose(0, 2, 3, 1) - }) - for x in range(3): - self.assertAllClose( - pyramid[x].transpose(0, 2, 3, 1), pyramid_tf[x], atol=1e-6) - - def test_sliced_wasserstein_distance(self): - """Test the distance.""" - d1 = random_ops.random_uniform([256, 32, 32, 3]) - d2 = random_ops.random_normal([256, 32, 32, 3]) - wfunc = swd.sliced_wasserstein_distance(d1, d2) - with self.cached_session() as sess: - wscores = [sess.run(x) for x in wfunc] - self.assertAllClose( - np.array([0.014, 0.014], 'f'), - np.array([x[0] for x in wscores], 'f'), - rtol=0.15) - self.assertAllClose( - np.array([0.014, 0.020], 'f'), - np.array([x[1] for x in wscores], 'f'), - rtol=0.15) - - def test_sliced_wasserstein_distance_svd(self): - """Test the distance.""" - d1 = random_ops.random_uniform([256, 32, 32, 3]) - d2 = random_ops.random_normal([256, 32, 32, 3]) - wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True) - with self.cached_session() as sess: - wscores = [sess.run(x) for x in wfunc] - self.assertAllClose( - np.array([0.013, 0.013], 'f'), - np.array([x[0] for x in wscores], 'f'), - rtol=0.15) - self.assertAllClose( - np.array([0.014, 0.019], 'f'), - np.array([x[1] for x in wscores], 'f'), - rtol=0.15) - - def test_swd_mismatched(self): - """Test the inputs mismatched shapes are detected.""" - d1 = random_ops.random_uniform([256, 32, 32, 3]) - d2 = random_ops.random_normal([256, 32, 31, 3]) - d3 = random_ops.random_normal([256, 31, 32, 3]) - d4 = random_ops.random_normal([255, 32, 32, 3]) - with self.assertRaises(ValueError): - swd.sliced_wasserstein_distance(d1, d2) - with self.assertRaises(ValueError): - swd.sliced_wasserstein_distance(d1, d3) - with self.assertRaises(ValueError): - swd.sliced_wasserstein_distance(d1, d4) - - def test_swd_not_rgb(self): - """Test that only RGB is supported.""" - d1 = random_ops.random_uniform([256, 32, 32, 1]) - d2 = random_ops.random_normal([256, 32, 32, 1]) - with self.assertRaises(ValueError): - swd.sliced_wasserstein_distance(d1, d2) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py deleted file mode 100644 index 3eb4f5db0c8..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Common TF-GAN summaries.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python.eval.python import eval_utils -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import map_fn -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.losses import util as loss_util -from tensorflow.python.summary import summary - -__all__ = [ - 'add_gan_model_image_summaries', - 'add_image_comparison_summaries', - 'add_gan_model_summaries', - 'add_regularization_loss_summaries', - 'add_cyclegan_image_summaries', - 'add_stargan_image_summaries' -] - - -def _assert_is_image(data): - data.shape.assert_has_rank(4) - data.shape[1:].assert_is_fully_defined() - - -def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): - """Adds image summaries for real and fake images. - - Args: - gan_model: A GANModel tuple. - grid_size: The size of an image grid. - model_summaries: Also add summaries of the model. - - Raises: - ValueError: If real and generated data aren't images. - """ - if isinstance(gan_model, namedtuples.CycleGANModel): - raise ValueError( - '`add_gan_model_image_summaries` does not take CycleGANModels. Please ' - 'use `add_cyclegan_image_summaries` instead.') - _assert_is_image(gan_model.real_data) - _assert_is_image(gan_model.generated_data) - - num_images = grid_size ** 2 - real_image_shape = gan_model.real_data.shape.as_list()[1:3] - generated_image_shape = gan_model.generated_data.shape.as_list()[1:3] - real_channels = gan_model.real_data.shape.as_list()[3] - generated_channels = gan_model.generated_data.shape.as_list()[3] - - summary.image( - 'real_data', - eval_utils.image_grid( - gan_model.real_data[:num_images], - grid_shape=(grid_size, grid_size), - image_shape=real_image_shape, - num_channels=real_channels), - max_outputs=1) - summary.image( - 'generated_data', - eval_utils.image_grid( - gan_model.generated_data[:num_images], - grid_shape=(grid_size, grid_size), - image_shape=generated_image_shape, - num_channels=generated_channels), - max_outputs=1) - - if model_summaries: - add_gan_model_summaries(gan_model) - - -def add_cyclegan_image_summaries(cyclegan_model): - """Adds image summaries for CycleGAN. - - There are two summaries, one for each generator. The first image is the - generator input, the second is the generator output, and the third is G(F(x)). - - Args: - cyclegan_model: A CycleGANModel tuple. - - Raises: - ValueError: If `cyclegan_model` isn't a CycleGANModel. - ValueError: If generated data, generator inputs, and reconstructions aren't - images. - ValueError: If the generator input, generated data, and reconstructions - aren't all the same size. - """ - if not isinstance(cyclegan_model, namedtuples.CycleGANModel): - raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was ' - '%s' % type(cyclegan_model)) - - _assert_is_image(cyclegan_model.model_x2y.generator_inputs) - _assert_is_image(cyclegan_model.model_x2y.generated_data) - _assert_is_image(cyclegan_model.reconstructed_x) - _assert_is_image(cyclegan_model.model_y2x.generator_inputs) - _assert_is_image(cyclegan_model.model_y2x.generated_data) - _assert_is_image(cyclegan_model.reconstructed_y) - - def _add_comparison_summary(gan_model, reconstructions): - image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) + - array_ops.unstack(gan_model.generated_data[:1]) + - array_ops.unstack(reconstructions[:1])) - summary.image( - 'image_comparison', eval_utils.image_reshaper( - image_list, num_cols=len(image_list)), max_outputs=1) - - with ops.name_scope('x2y_image_comparison_summaries'): - _add_comparison_summary( - cyclegan_model.model_x2y, cyclegan_model.reconstructed_x) - with ops.name_scope('y2x_image_comparison_summaries'): - _add_comparison_summary( - cyclegan_model.model_y2x, cyclegan_model.reconstructed_y) - - -def add_image_comparison_summaries(gan_model, num_comparisons=2, - display_diffs=False): - """Adds image summaries to compare triplets of images. - - The first image is the generator input, the second is the generator output, - and the third is the real data. This style of comparison is useful for - image translation problems, where the generator input is a corrupted image, - the generator output is the reconstruction, and the real data is the target. - - Args: - gan_model: A GANModel tuple. - num_comparisons: The number of image triplets to display. - display_diffs: Also display the difference between generated and target. - - Raises: - ValueError: If real data, generated data, and generator inputs aren't - images. - ValueError: If the generator input, real, and generated data aren't all the - same size. - """ - _assert_is_image(gan_model.generator_inputs) - _assert_is_image(gan_model.generated_data) - _assert_is_image(gan_model.real_data) - - gan_model.generated_data.shape.assert_is_compatible_with( - gan_model.generator_inputs.shape) - gan_model.real_data.shape.assert_is_compatible_with( - gan_model.generated_data.shape) - - image_list = [] - image_list.extend( - array_ops.unstack(gan_model.generator_inputs[:num_comparisons])) - image_list.extend( - array_ops.unstack(gan_model.generated_data[:num_comparisons])) - image_list.extend(array_ops.unstack(gan_model.real_data[:num_comparisons])) - if display_diffs: - generated_list = array_ops.unstack( - gan_model.generated_data[:num_comparisons]) - real_list = array_ops.unstack(gan_model.real_data[:num_comparisons]) - diffs = [ - math_ops.abs(math_ops.cast(generated, dtypes.float32) - - math_ops.cast(real, dtypes.float32)) - for generated, real in zip(generated_list, real_list) - ] - image_list.extend(diffs) - - # Reshape image and display. - summary.image( - 'image_comparison', - eval_utils.image_reshaper(image_list, num_cols=num_comparisons), - max_outputs=1) - - -def add_stargan_image_summaries(stargan_model, - num_images=2, - display_diffs=False): - """Adds image summaries to see StarGAN image results. - - If display_diffs is True, each image result has `2` rows and `num_domains + 1` - columns. - The first row looks like: - [original_image, transformed_to_domain_0, transformed_to_domain_1, ...] - The second row looks like: - [no_modification_baseline, transformed_to_domain_0-original_image, ...] - If display_diffs is False, only the first row is shown. - - IMPORTANT: - Since the model originally does not transformed the image to every domains, - we will transform them on-the-fly within this function in parallel. - - Args: - stargan_model: A StarGANModel tuple. - num_images: The number of examples/images to be transformed and shown. - display_diffs: Also display the difference between generated and target. - - Raises: - ValueError: If input_data is not images. - ValueError: If input_data_domain_label is not rank 2. - ValueError: If dimension 2 of input_data_domain_label is not fully defined. - """ - - _assert_is_image(stargan_model.input_data) - stargan_model.input_data_domain_label.shape.assert_has_rank(2) - stargan_model.input_data_domain_label.shape[1:].assert_is_fully_defined() - - num_domains = stargan_model.input_data_domain_label.get_shape().as_list()[-1] - - def _build_image(image): - """Helper function to create a result for each image on the fly.""" - - # Expand the first dimension as batch_size = 1. - images = array_ops.expand_dims(image, axis=0) - - # Tile the image num_domains times, so we can get all transformed together. - images = array_ops.tile(images, [num_domains, 1, 1, 1]) - - # Create the targets to 0, 1, 2, ..., num_domains-1. - targets = array_ops.one_hot(list(range(num_domains)), num_domains) - - with variable_scope.variable_scope( - stargan_model.generator_scope, reuse=True): - - # Add the original image. - output_images_list = [image] - - # Generate the image and add to the list. - gen_images = stargan_model.generator_fn(images, targets) - gen_images_list = array_ops.split(gen_images, num_domains) - gen_images_list = [ - array_ops.squeeze(img, axis=0) for img in gen_images_list - ] - output_images_list.extend(gen_images_list) - - # Display diffs. - if display_diffs: - diff_images = gen_images - images - diff_images_list = array_ops.split(diff_images, num_domains) - diff_images_list = [ - array_ops.squeeze(img, axis=0) for img in diff_images_list - ] - output_images_list.append(array_ops.zeros_like(image)) - output_images_list.extend(diff_images_list) - - # Create the final image. - final_image = eval_utils.image_reshaper( - output_images_list, num_cols=num_domains + 1) - - # Reduce the first rank. - return array_ops.squeeze(final_image, axis=0) - - summary.image( - 'stargan_image_generation', - map_fn.map_fn( - _build_image, - stargan_model.input_data[:num_images], - parallel_iterations=num_images, - back_prop=False, - swap_memory=True), - max_outputs=num_images) - - -def add_gan_model_summaries(gan_model): - """Adds typical GANModel summaries. - - Args: - gan_model: A GANModel tuple. - """ - if isinstance(gan_model, namedtuples.CycleGANModel): - with ops.name_scope('cyclegan_x2y_summaries'): - add_gan_model_summaries(gan_model.model_x2y) - with ops.name_scope('cyclegan_y2x_summaries'): - add_gan_model_summaries(gan_model.model_y2x) - return - - with ops.name_scope('generator_variables'): - for var in gan_model.generator_variables: - summary.histogram(var.name, var) - with ops.name_scope('discriminator_variables'): - for var in gan_model.discriminator_variables: - summary.histogram(var.name, var) - - -def add_regularization_loss_summaries(gan_model): - """Adds summaries for a regularization losses.. - - Args: - gan_model: A GANModel tuple. - """ - if isinstance(gan_model, namedtuples.CycleGANModel): - with ops.name_scope('cyclegan_x2y_regularization_loss_summaries'): - add_regularization_loss_summaries(gan_model.model_x2y) - with ops.name_scope('cyclegan_y2x_regularization_loss_summaries'): - add_regularization_loss_summaries(gan_model.model_y2x) - return - - if gan_model.generator_scope: - summary.scalar( - 'generator_regularization_loss', - loss_util.get_regularization_loss(gan_model.generator_scope.name)) - if gan_model.discriminator_scope: - summary.scalar( - 'discriminator_regularization_loss', - loss_util.get_regularization_loss(gan_model.discriminator_scope.name)) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py deleted file mode 100644 index 53fc7cb8ede..00000000000 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TF-GAN summaries.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.summary import summary - - -def generator_model(inputs): - return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs - - -def discriminator_model(inputs, _): - return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs - - -def stargan_generator_model(inputs, _): - return generator_model(inputs) - - -def get_gan_model(): - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - pass - with variable_scope.variable_scope('discriminator') as dis_scope: - pass - return namedtuples.GANModel( - generator_inputs=array_ops.zeros([4, 32, 32, 3]), - generated_data=array_ops.zeros([4, 32, 32, 3]), - generator_variables=[variables.Variable(0), variables.Variable(1)], - generator_scope=gen_scope, - generator_fn=generator_model, - real_data=array_ops.ones([4, 32, 32, 3]), - discriminator_real_outputs=array_ops.ones([1, 2, 3]), - discriminator_gen_outputs=array_ops.ones([1, 2, 3]), - discriminator_variables=[variables.Variable(0)], - discriminator_scope=dis_scope, - discriminator_fn=discriminator_model) - - -def get_stargan_model(): - """Similar to get_gan_model().""" - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('discriminator') as dis_scope: - pass - with variable_scope.variable_scope('generator') as gen_scope: - return namedtuples.StarGANModel( - input_data=array_ops.ones([1, 2, 2, 3]), - input_data_domain_label=array_ops.ones([1, 2]), - generated_data=stargan_generator_model( - array_ops.ones([1, 2, 2, 3]), None), - generated_data_domain_target=array_ops.ones([1, 2]), - reconstructed_data=array_ops.ones([1, 2, 2, 3]), - discriminator_input_data_source_predication=array_ops.ones([1]), - discriminator_generated_data_source_predication=array_ops.ones([1]), - discriminator_input_data_domain_predication=array_ops.ones([1, 2]), - discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), - generator_variables=None, - generator_scope=gen_scope, - generator_fn=stargan_generator_model, - discriminator_variables=None, - discriminator_scope=dis_scope, - discriminator_fn=discriminator_model) - - -def get_cyclegan_model(): - with variable_scope.variable_scope('x2y'): - model_x2y = get_gan_model() - with variable_scope.variable_scope('y2x'): - model_y2x = get_gan_model() - return namedtuples.CycleGANModel( - model_x2y=model_x2y, - model_y2x=model_y2x, - reconstructed_x=array_ops.zeros([4, 32, 32, 3]), - reconstructed_y=array_ops.zeros([4, 32, 32, 3])) - - -class SummariesTest(test.TestCase): - - def _test_add_gan_model_image_summaries_impl( - self, get_model_fn, expected_num_summary_ops, model_summaries): - summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, - model_summaries=model_summaries) - - self.assertEquals(expected_num_summary_ops, - len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - with self.test_session(use_gpu=True): - variables.global_variables_initializer().run() - summary.merge_all().eval() - - def test_add_gan_model_image_summaries(self): - self._test_add_gan_model_image_summaries_impl(get_gan_model, 5, True) - - def test_add_gan_model_image_summaries_no_model(self): - self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) - - def test_cyclegan_image_summaries_dont_work(self): - with self.assertRaises(ValueError): - summaries.add_gan_model_image_summaries(get_cyclegan_model()) - - def _test_add_gan_model_summaries_impl(self, get_model_fn, - expected_num_summary_ops): - summaries.add_gan_model_summaries(get_model_fn()) - - self.assertEquals(expected_num_summary_ops, - len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - with self.test_session(use_gpu=True): - variables.global_variables_initializer().run() - summary.merge_all().eval() - - def test_add_gan_model_summaries(self): - self._test_add_gan_model_summaries_impl(get_gan_model, 3) - - def test_add_gan_model_summaries_for_cyclegan(self): - self._test_add_gan_model_summaries_impl(get_cyclegan_model, 6) - - def _test_add_regularization_loss_summaries_impl(self, get_model_fn, - expected_num_summary_ops): - summaries.add_regularization_loss_summaries(get_model_fn()) - - self.assertEquals(expected_num_summary_ops, - len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - with self.test_session(use_gpu=True): - summary.merge_all().eval() - - def test_add_regularization_loss_summaries(self): - self._test_add_regularization_loss_summaries_impl(get_gan_model, 2) - - def test_add_regularization_loss_summaries_for_cyclegan(self): - self._test_add_regularization_loss_summaries_impl(get_cyclegan_model, 4) - - # TODO(joelshor): Add correctness test. - def _test_add_image_comparison_summaries_impl(self, get_model_fn, - expected_num_summary_ops): - summaries.add_image_comparison_summaries(get_model_fn(), display_diffs=True) - - self.assertEquals(expected_num_summary_ops, - len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - with self.test_session(use_gpu=True): - summary.merge_all().eval() - - def test_add_image_comparison_summaries(self): - self._test_add_image_comparison_summaries_impl(get_gan_model, 1) - - def test_add_image_comparison_summaries_for_cyclegan(self): - summaries.add_cyclegan_image_summaries(get_cyclegan_model()) - - self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - with self.test_session(use_gpu=True): - summary.merge_all().eval() - - def test_add_image_comparison_summaries_for_stargan(self): - - summaries.add_stargan_image_summaries(get_stargan_model()) - - self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - - with self.test_session(use_gpu=True) as sess: - sess.run(variables.global_variables_initializer()) - summary.merge_all().eval() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py deleted file mode 100644 index 410c3a02052..00000000000 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2017 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TFGAN features module. - -This module includes support for virtual batch normalization, buffer replay, -conditioning, etc. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Collapse features into a single namespace. -# pylint: disable=unused-import,wildcard-import -from tensorflow.contrib.gan.python.features.python import clip_weights -from tensorflow.contrib.gan.python.features.python import conditioning_utils -from tensorflow.contrib.gan.python.features.python import random_tensor_pool -from tensorflow.contrib.gan.python.features.python import spectral_normalization -from tensorflow.contrib.gan.python.features.python import virtual_batchnorm - -from tensorflow.contrib.gan.python.features.python.clip_weights import * -from tensorflow.contrib.gan.python.features.python.conditioning_utils import * -from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * -from tensorflow.contrib.gan.python.features.python.spectral_normalization import * -from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * -# pylint: enable=unused-import,wildcard-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = clip_weights.__all__ -_allowed_symbols += conditioning_utils.__all__ -_allowed_symbols += random_tensor_pool.__all__ -_allowed_symbols += spectral_normalization.__all__ -_allowed_symbols += virtual_batchnorm.__all__ -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights.py b/tensorflow/contrib/gan/python/features/python/clip_weights.py deleted file mode 100644 index fa76fd7928f..00000000000 --- a/tensorflow/contrib/gan/python/features/python/clip_weights.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities to clip weights.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.features.python import clip_weights_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.clip_weights_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = clip_weights_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_impl.py b/tensorflow/contrib/gan/python/features/python/clip_weights_impl.py deleted file mode 100644 index 96fbb8186d7..00000000000 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_impl.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Utilities to clip weights. - -This is useful in the original formulation of the Wasserstein loss, which -requires that the discriminator be K-Lipschitz. See -https://arxiv.org/pdf/1701.07875 for more details. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.opt.python.training import variable_clipping_optimizer - - -__all__ = [ - 'clip_variables', - 'clip_discriminator_weights', -] - - -def clip_discriminator_weights(optimizer, model, weight_clip): - """Modifies an optimizer so it clips weights to a certain value. - - Args: - optimizer: An optimizer to perform variable weight clipping. - model: A GANModel namedtuple. - weight_clip: Positive python float to clip discriminator weights. Used to - enforce a K-lipschitz condition, which is useful for some GAN training - schemes (ex WGAN: https://arxiv.org/pdf/1701.07875). - - Returns: - An optimizer to perform weight clipping after updates. - - Raises: - ValueError: If `weight_clip` is less than 0. - """ - return clip_variables(optimizer, model.discriminator_variables, weight_clip) - - -def clip_variables(optimizer, variables, weight_clip): - """Modifies an optimizer so it clips weights to a certain value. - - Args: - optimizer: An optimizer to perform variable weight clipping. - variables: A list of TensorFlow variables. - weight_clip: Positive python float to clip discriminator weights. Used to - enforce a K-lipschitz condition, which is useful for some GAN training - schemes (ex WGAN: https://arxiv.org/pdf/1701.07875). - - Returns: - An optimizer to perform weight clipping after updates. - - Raises: - ValueError: If `weight_clip` is less than 0. - """ - if weight_clip < 0: - raise ValueError( - '`discriminator_weight_clip` must be positive. Instead, was %s', - weight_clip) - return variable_clipping_optimizer.VariableClippingOptimizer( - opt=optimizer, - # Do no reduction, so clipping happens per-value. - vars_to_clip_dims={var: [] for var in variables}, - max_norm=weight_clip, - use_locking=True, - colocate_clip_ops_with_vars=True) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py deleted file mode 100644 index e4fac1976d6..00000000000 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for features.clip_weights.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.contrib.gan.python.features.python import clip_weights_impl as clip_weights - -from tensorflow.python.ops import variables -from tensorflow.python.platform import test -from tensorflow.python.training import training - - -class ClipWeightsTest(test.TestCase): - """Tests for `discriminator_weight_clip`.""" - - def setUp(self): - super(ClipWeightsTest, self).setUp() - self.variables = [variables.Variable(2.0)] - self.tuple = collections.namedtuple( - 'VarTuple', ['discriminator_variables'])(self.variables) - - def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] - opt = training.GradientDescentOptimizer(1.0) - if use_tuple: - opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) - else: - opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) - - train_op1 = opt.minimize(loss, var_list=self.variables) - train_op2 = opt_clip.minimize(loss, var_list=self.variables) - - with self.cached_session(use_gpu=True) as sess: - sess.run(variables.global_variables_initializer()) - self.assertEqual(2.0, self.variables[0].eval()) - sess.run(train_op1) - self.assertLess(0.1, self.variables[0].eval()) - - with self.cached_session(use_gpu=True) as sess: - sess.run(variables.global_variables_initializer()) - self.assertEqual(2.0, self.variables[0].eval()) - sess.run(train_op2) - self.assertNear(0.1, self.variables[0].eval(), 1e-7) - - def test_weight_clipping_argsonly(self): - self._test_weight_clipping_helper(False) - - def test_weight_clipping_ganmodel(self): - self._test_weight_clipping_helper(True) - - def _test_incorrect_weight_clip_value_helper(self, use_tuple): - opt = training.GradientDescentOptimizer(1.0) - - if use_tuple: - with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) - else: - with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_variables(opt, self.variables, weight_clip=-1) - - def test_incorrect_weight_clip_value_argsonly(self): - self._test_incorrect_weight_clip_value_helper(False) - - def test_incorrect_weight_clip_value_tuple(self): - self._test_incorrect_weight_clip_value_helper(True) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py deleted file mode 100644 index a9b8faa7126..00000000000 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Miscellaneous utilities for TFGAN code and examples.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.conditioning_utils_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = conditioning_utils_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py deleted file mode 100644 index 364fa4eb461..00000000000 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Miscellaneous utilities for TFGAN code and examples. - -Includes: -1) Conditioning the value of a Tensor, based on techniques from - https://arxiv.org/abs/1609.03499. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.layers.python.layers import layers -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import embedding_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope - - -__all__ = [ - 'condition_tensor', - 'condition_tensor_from_onehot', -] - - -def _get_shape(tensor): - tensor_shape = array_ops.shape(tensor) - static_tensor_shape = tensor_util.constant_value(tensor_shape) - return (static_tensor_shape if static_tensor_shape is not None else - tensor_shape) - - -def condition_tensor(tensor, conditioning): - """Condition the value of a tensor. - - Conditioning scheme based on https://arxiv.org/abs/1609.03499. - - Args: - tensor: A minibatch tensor to be conditioned. - conditioning: A minibatch Tensor of to condition on. Must be 2D, with first - dimension the same as `tensor`. - - Returns: - `tensor` conditioned on `conditioning`. - - Raises: - ValueError: If the non-batch dimensions of `tensor` aren't fully defined. - ValueError: If `conditioning` isn't at least 2D. - ValueError: If the batch dimension for the input Tensors don't match. - """ - tensor.shape[1:].assert_is_fully_defined() - num_features = tensor.shape[1:].num_elements() - if conditioning.shape.ndims < 2: - raise ValueError('conditioning must be at least 2D, but saw shape: %s' - % conditioning.shape) - - mapped_conditioning = layers.linear( - layers.flatten(conditioning), num_features) - if not mapped_conditioning.shape.is_compatible_with(tensor.shape): - mapped_conditioning = array_ops.reshape( - mapped_conditioning, _get_shape(tensor)) - return tensor + mapped_conditioning - - -def _one_hot_to_embedding(one_hot, embedding_size): - """Get a dense embedding vector from a one-hot encoding.""" - num_tokens = one_hot.shape[1] - label_id = math_ops.argmax(one_hot, axis=1) - embedding = variable_scope.get_variable( - 'embedding', [num_tokens, embedding_size]) - return embedding_ops.embedding_lookup( - embedding, label_id, name='token_to_embedding') - - -def _validate_onehot(one_hot_labels): - one_hot_labels.shape.assert_has_rank(2) - one_hot_labels.shape[1:].assert_is_fully_defined() - - -def condition_tensor_from_onehot(tensor, one_hot_labels, embedding_size=256): - """Condition a tensor based on a one-hot tensor. - - Conditioning scheme based on https://arxiv.org/abs/1609.03499. - - Args: - tensor: Tensor to be conditioned. - one_hot_labels: A Tensor of one-hot labels. Shape is - [batch_size, num_classes]. - embedding_size: The size of the class embedding. - - Returns: - `tensor` conditioned on `one_hot_labels`. - - Raises: - ValueError: `one_hot_labels` isn't 2D, if non-batch dimensions aren't - fully defined, or if batch sizes don't match. - """ - _validate_onehot(one_hot_labels) - - conditioning = _one_hot_to_embedding(one_hot_labels, embedding_size) - return condition_tensor(tensor, conditioning) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py deleted file mode 100644 index f5c7d53cf2c..00000000000 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tfgan.python.features.conditioning_utils.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl as conditioning_utils - -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class ConditioningUtilsTest(test.TestCase): - - def test_condition_tensor_multiple_shapes(self): - for tensor_shape in [(4, 1), (4, 2), (4, 2, 6), (None, 5, 3)]: - for conditioning_shape in [(4, 1), (4, 8), (4, 5, 3)]: - conditioning_utils.condition_tensor( - array_ops.placeholder(dtypes.float32, tensor_shape), - array_ops.placeholder(dtypes.float32, conditioning_shape)) - - def test_condition_tensor_asserts(self): - with self.assertRaisesRegexp(ValueError, 'Cannot reshape'): - conditioning_utils.condition_tensor( - array_ops.placeholder(dtypes.float32, (4, 1)), - array_ops.placeholder(dtypes.float32, (5, 1))) - - with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'): - conditioning_utils.condition_tensor( - array_ops.placeholder(dtypes.float32, (5, None)), - array_ops.placeholder(dtypes.float32, (5, 1))) - - with self.assertRaisesRegexp(ValueError, 'at least 2D'): - conditioning_utils.condition_tensor( - array_ops.placeholder(dtypes.float32, (5, 2)), - array_ops.placeholder(dtypes.float32, (5))) - - def test_condition_tensor_from_onehot(self): - conditioning_utils.condition_tensor_from_onehot( - array_ops.placeholder(dtypes.float32, (5, 4, 1)), - array_ops.placeholder(dtypes.float32, (5, 10))) - - def test_condition_tensor_from_onehot_asserts(self): - with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 2'): - conditioning_utils.condition_tensor_from_onehot( - array_ops.placeholder(dtypes.float32, (5, 1)), - array_ops.placeholder(dtypes.float32, (5))) - - with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'): - conditioning_utils.condition_tensor_from_onehot( - array_ops.placeholder(dtypes.float32, (5, 1)), - array_ops.placeholder(dtypes.float32, (5, None))) - - with self.assertRaisesRegexp(ValueError, 'Cannot reshape a tensor'): - conditioning_utils.condition_tensor_from_onehot( - array_ops.placeholder(dtypes.float32, (5, 1)), - array_ops.placeholder(dtypes.float32, (4, 6))) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py deleted file mode 100644 index ca904971fa8..00000000000 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A tensor pool stores values from an input tensor and returns a stored one. - -See the following papers for more details. -1) `Learning from simulated and unsupervised images through adversarial - training` (https://arxiv.org/abs/1612.07828). -2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial - Networks` (https://arxiv.org/abs/1703.10593). -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.features.python import random_tensor_pool_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = random_tensor_pool_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py deleted file mode 100644 index ca2d724b49d..00000000000 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A tensor pool stores values from an input tensor and returns a stored one. - -We use this to keep a history of values created by a generator, such that -a discriminator can randomly be trained on some older samples, not just the -current one. This can help to not let the discriminator get too far ahead of the -generator and also to keep the system from oscillating, if the discriminator -forgets too fast what past samples from the generator looked like. - -See the following papers for more details. -1) `Learning from simulated and unsupervised images through adversarial - training` (https://arxiv.org/abs/1612.07828). -2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial - Networks` (https://arxiv.org/abs/1703.10593). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.util import nest - -__all__ = [ - 'tensor_pool', -] - - -def _to_list(x): - return [x] if isinstance(x, ops.Tensor) else list(x) - - -def tensor_pool(input_values, - pool_size=50, - pooling_probability=0.5, - name='tensor_pool'): - """Queue storing input values and returning random previously stored ones. - - Every time the returned `output_value` is evaluated, `input_value` is - evaluated and its value either directly returned (with - `1-pooling_probability`) or stored in the pool and a random one of the samples - currently in the pool is popped and returned. As long as the pool in not fully - filled, the input_value is always directly returned, as well as stored in the - pool. Note during inference / testing, it may be appropriate to set - `pool_size` = 0 or `pooling_probability` = 0. - - Args: - input_values: An arbitrarily nested structure of `tf.Tensors`, from which to - read values to be pooled. - pool_size: An integer specifying the maximum size of the pool. Defaults to - 50. - pooling_probability: A float `Tensor` specifying the probability of getting - a value from the pool, as opposed to just the current input. - name: A string prefix for the name scope for all tensorflow ops. - - Returns: - A nested structure of `Tensor` objects with the same structure as - `input_values`. With the given probability, the Tensor values are either the - same as in `input_values` or a randomly chosen sample that was previously - inserted in the pool. - - Raises: - ValueError: If `pool_size` is negative. - """ - pool_size = int(pool_size) - if pool_size < 0: - raise ValueError('`pool_size` is negative.') - elif pool_size == 0: - return input_values - - original_input_values = input_values - input_values = nest.flatten(input_values) - - with ops.name_scope('{}_pool_queue'.format(name), - values=input_values + [pooling_probability]): - pool_queue = data_flow_ops.RandomShuffleQueue( - capacity=pool_size, - min_after_dequeue=0, - dtypes=[v.dtype for v in input_values], - shapes=None) - - # In pseudo code this code does the following: - # if not pool_full: - # enqueue(input_values) - # return input_values - # else - # dequeue_values = dequeue_random_sample() - # enqueue(input_values) - # if rand() < pooling_probability: - # return dequeue_values - # else - # return input_values - - def _get_input_value_pooled(): - enqueue_op = pool_queue.enqueue(input_values) - with ops.control_dependencies([enqueue_op]): - return [array_ops.identity(v) for v in input_values] - - def _get_random_pool_value_and_enqueue_input(): - dequeue_values = _to_list(pool_queue.dequeue()) - with ops.control_dependencies(dequeue_values): - enqueue_op = pool_queue.enqueue(input_values) - with ops.control_dependencies([enqueue_op]): - prob = random_ops.random_uniform( - (), dtype=dtypes.float32) < pooling_probability - return control_flow_ops.cond(prob, lambda: dequeue_values, - lambda: input_values) - - output_values = _to_list(control_flow_ops.cond( - pool_queue.size() < pool_size, _get_input_value_pooled, - _get_random_pool_value_and_enqueue_input)) - - # Make sure that the shape of `output_value` is set. - for input_value, output_value in zip(input_values, output_values): - output_value.set_shape(input_value.shape) - - return nest.pack_sequence_as(original_input_values, output_values) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py deleted file mode 100644 index 3c9dfd6de02..00000000000 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.contrib.gan.python.features.random_tensor_pool.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class TensorPoolTest(test.TestCase): - - def test_pool_unknown_input_shape(self): - """Checks that `input_value` can have unknown shape.""" - input_value = array_ops.placeholder( - dtype=dtypes.int32, shape=[None, None, 3]) - output_value = tensor_pool(input_value, pool_size=10) - self.assertEqual(output_value.shape.as_list(), [None, None, 3]) - - with self.session(use_gpu=True) as session: - for i in range(10): - session.run(output_value, {input_value: [[[i] * 3]]}) - session.run(output_value, {input_value: [[[i] * 3] * 2]}) - session.run(output_value, {input_value: [[[i] * 3] * 5] * 2}) - - def test_pool_sequence(self): - """Checks that values are pooled and returned maximally twice.""" - input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - output_value = tensor_pool(input_value, pool_size=10) - self.assertEqual(output_value.shape.as_list(), []) - - with self.session(use_gpu=True) as session: - outs = [] - for i in range(50): - out = session.run(output_value, {input_value: i}) - outs.append(out) - self.assertLessEqual(out, i) - - _, counts = np.unique(outs, return_counts=True) - # Check that each value is returned maximally twice. - self.assertTrue((counts <= 2).all()) - - def test_never_pool(self): - """Checks that setting `pooling_probability` to zero works.""" - input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - output_value = tensor_pool( - input_value, pool_size=10, pooling_probability=0.0) - self.assertEqual(output_value.shape.as_list(), []) - - with self.session(use_gpu=True) as session: - for i in range(50): - out = session.run(output_value, {input_value: i}) - self.assertEqual(out, i) - - def test_pooling_probability(self): - """Checks that `pooling_probability` works.""" - input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - pool_size = 10 - pooling_probability = 0.2 - output_value = tensor_pool( - input_value, - pool_size=pool_size, - pooling_probability=pooling_probability) - self.assertEqual(output_value.shape.as_list(), []) - - with self.session(use_gpu=True) as session: - not_pooled = 0 - total = 1000 - for i in range(total): - out = session.run(output_value, {input_value: i}) - if out == i: - not_pooled += 1 - self.assertAllClose( - (not_pooled - pool_size) / (total - pool_size), - 1 - pooling_probability, - atol=0.03) - - def test_input_values_tuple(self): - """Checks that `input_values` can be a tuple.""" - input_values = (array_ops.placeholder(dtype=dtypes.int32, shape=[]), - array_ops.placeholder(dtype=dtypes.int32, shape=[])) - output_values = tensor_pool(input_values, pool_size=3) - self.assertEqual(len(output_values), len(input_values)) - for output_value in output_values: - self.assertEqual(output_value.shape.as_list(), []) - - with self.session(use_gpu=True) as session: - for i in range(10): - outs = session.run(output_values, { - input_values[0]: i, - input_values[1]: i + 1 - }) - self.assertEqual(len(outs), len(input_values)) - self.assertEqual(outs[1] - outs[0], 1) - - def test_pool_preserves_shape(self): - t = constant_op.constant(1) - input_values = [[t, t, t], (t, t), t] - output_values = tensor_pool(input_values, pool_size=5) - print('stuff: ', output_values) - # Overall shape. - self.assertIsInstance(output_values, list) - self.assertEqual(3, len(output_values)) - # Shape of first element. - self.assertIsInstance(output_values[0], list) - self.assertEqual(3, len(output_values[0])) - # Shape of second element. - self.assertIsInstance(output_values[1], tuple) - self.assertEqual(2, len(output_values[1])) - # Shape of third element. - self.assertIsInstance(output_values[2], ops.Tensor) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization.py deleted file mode 100644 index 54d3d0a218d..00000000000 --- a/tensorflow/contrib/gan/python/features/python/spectral_normalization.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras-like layers and utilities that implement Spectral Normalization. - -Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, -et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.spectral_normalization_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = spectral_normalization_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py deleted file mode 100644 index 9004be6229f..00000000000 --- a/tensorflow/contrib/gan/python/features/python/spectral_normalization_impl.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Keras-like layers and utilities that implement Spectral Normalization. - -Based on "Spectral Normalization for Generative Adversarial Networks" by Miyato, -et al in ICLR 2018. https://openreview.net/pdf?id=B1QRgziT- -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import contextlib -import numbers -import re - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.keras.engine import base_layer_utils as keras_base_layer_utils -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging - -__all__ = [ - 'compute_spectral_norm', 'spectral_normalize', 'spectral_norm_regularizer', - 'spectral_normalization_custom_getter', 'keras_spectral_normalization' -] - -# tf.bfloat16 should work, but tf.matmul converts those to tf.float32 which then -# can't directly be assigned back to the tf.bfloat16 variable. -_OK_DTYPES_FOR_SPECTRAL_NORM = (dtypes.float16, dtypes.float32, dtypes.float64) -_PERSISTED_U_VARIABLE_SUFFIX = 'spectral_norm_u' - - -def compute_spectral_norm(w_tensor, power_iteration_rounds=1, name=None): - """Estimates the largest singular value in the weight tensor. - - Args: - w_tensor: The weight matrix whose spectral norm should be computed. - power_iteration_rounds: The number of iterations of the power method to - perform. A higher number yields a better approximation. - name: An optional scope name. - - Returns: - The largest singular value (the spectral norm) of w. - """ - with variable_scope.variable_scope(name, 'spectral_norm'): - # The paper says to flatten convnet kernel weights from - # (C_out, C_in, KH, KW) to (C_out, C_in * KH * KW). But TensorFlow's Conv2D - # kernel weight shape is (KH, KW, C_in, C_out), so it should be reshaped to - # (KH * KW * C_in, C_out), and similarly for other layers that put output - # channels as last dimension. - # n.b. this means that w here is equivalent to w.T in the paper. - w = array_ops.reshape(w_tensor, (-1, w_tensor.get_shape()[-1])) - - # Persisted approximation of first left singular vector of matrix `w`. - u_var = variable_scope.get_variable( - _PERSISTED_U_VARIABLE_SUFFIX, - shape=(w.shape[0], 1), - dtype=w.dtype, - initializer=init_ops.random_normal_initializer(), - trainable=False) - u = u_var - - # Use power iteration method to approximate spectral norm. - for _ in range(power_iteration_rounds): - # `v` approximates the first right singular vector of matrix `w`. - v = nn.l2_normalize(math_ops.matmul(array_ops.transpose(w), u)) - u = nn.l2_normalize(math_ops.matmul(w, v)) - - # Update persisted approximation. - with ops.control_dependencies([u_var.assign(u, name='update_u')]): - u = array_ops.identity(u) - - u = array_ops.stop_gradient(u) - v = array_ops.stop_gradient(v) - - # Largest singular value of `w`. - spectral_norm = math_ops.matmul( - math_ops.matmul(array_ops.transpose(u), w), v) - spectral_norm.shape.assert_is_fully_defined() - spectral_norm.shape.assert_is_compatible_with([1, 1]) - - return spectral_norm[0][0] - - -def spectral_normalize(w, power_iteration_rounds=1, name=None): - """Normalizes a weight matrix by its spectral norm. - - Args: - w: The weight matrix to be normalized. - power_iteration_rounds: The number of iterations of the power method to - perform. A higher number yields a better approximation. - name: An optional scope name. - - Returns: - A normalized weight matrix tensor. - """ - with variable_scope.variable_scope(name, 'spectral_normalize'): - w_normalized = w / compute_spectral_norm( - w, power_iteration_rounds=power_iteration_rounds) - return array_ops.reshape(w_normalized, w.get_shape()) - - -def spectral_norm_regularizer(scale, power_iteration_rounds=1, scope=None): - """Returns a functions that can be used to apply spectral norm regularization. - - Small spectral norms enforce a small Lipschitz constant, which is necessary - for Wasserstein GANs. - - Args: - scale: A scalar multiplier. 0.0 disables the regularizer. - power_iteration_rounds: The number of iterations of the power method to - perform. A higher number yields a better approximation. - scope: An optional scope name. - - Returns: - A function with the signature `sn(weights)` that applies spectral norm - regularization. - - Raises: - ValueError: If scale is negative or if scale is not a float. - """ - if isinstance(scale, numbers.Integral): - raise ValueError('scale cannot be an integer: %s' % scale) - if isinstance(scale, numbers.Real): - if scale < 0.0: - raise ValueError( - 'Setting a scale less than 0 on a regularizer: %g' % scale) - if scale == 0.0: - logging.info('Scale of 0 disables regularizer.') - return lambda _: None - - def sn(weights, name=None): - """Applies spectral norm regularization to weights.""" - with ops.name_scope(scope, 'SpectralNormRegularizer', [weights]) as name: - scale_t = ops.convert_to_tensor( - scale, dtype=weights.dtype.base_dtype, name='scale') - return math_ops.multiply( - scale_t, - compute_spectral_norm( - weights, power_iteration_rounds=power_iteration_rounds), - name=name) - - return sn - - -def _default_name_filter(name): - """A filter function to identify common names of weight variables. - - Args: - name: The variable name. - - Returns: - Whether `name` is a standard name for a weight/kernel variables used in the - Keras, tf.layers, tf.contrib.layers or tf.contrib.slim libraries. - """ - match = re.match(r'(.*\/)?(depthwise_|pointwise_)?(weights|kernel)$', name) - return match is not None - - -def spectral_normalization_custom_getter(name_filter=_default_name_filter, - power_iteration_rounds=1): - """Custom getter that performs Spectral Normalization on a weight tensor. - - Specifically it divides the weight tensor by its largest singular value. This - is intended to stabilize GAN training, by making the discriminator satisfy a - local 1-Lipschitz constraint. - - Based on [Spectral Normalization for Generative Adversarial Networks][sn-gan]. - - [sn-gan]: https://openreview.net/forum?id=B1QRgziT- - - To reproduce an SN-GAN, apply this custom_getter to every weight tensor of - your discriminator. The last dimension of the weight tensor must be the number - of output channels. - - Apply this to layers by supplying this as the `custom_getter` of a - `tf.compat.v1.variable_scope`. For example: - - with tf.compat.v1.variable_scope('discriminator', - custom_getter=spectral_norm_getter()): - net = discriminator_fn(net) - - IMPORTANT: Keras does not respect the custom_getter supplied by the - VariableScope, so Keras users should use `keras_spectral_normalization` - instead of (or in addition to) this approach. - - It is important to carefully select to which weights you want to apply - Spectral Normalization. In general you want to normalize the kernels of - convolution and dense layers, but you do not want to normalize biases. You - also want to avoid normalizing batch normalization (and similar) variables, - but in general such layers play poorly with Spectral Normalization, since the - gamma can cancel out the normalization in other layers. By default we supply a - filter that matches the kernel variable names of the dense and convolution - layers of the tf.layers, tf.contrib.layers, tf.keras and tf.contrib.slim - libraries. If you are using anything else you'll need a custom `name_filter`. - - This custom getter internally creates a variable used to compute the spectral - norm by power iteration. It will update every time the variable is accessed, - which means the normalized discriminator weights may change slightly whilst - training the generator. Whilst unusual, this matches how the paper's authors - implement it, and in general additional rounds of power iteration can't hurt. - - Args: - name_filter: Optionally, a method that takes a Variable name as input and - returns whether this Variable should be normalized. - power_iteration_rounds: The number of iterations of the power method to - perform per step. A higher number yields a better approximation of the - true spectral norm. - - Returns: - A custom getter function that applies Spectral Normalization to all - Variables whose names match `name_filter`. - - Raises: - ValueError: If name_filter is not callable. - """ - if not callable(name_filter): - raise ValueError('name_filter must be callable') - - def _internal_getter(getter, name, *args, **kwargs): - """A custom getter function that applies Spectral Normalization. - - Args: - getter: The true getter to call. - name: Name of new/existing variable, in the same format as - tf.get_variable. - *args: Other positional arguments, in the same format as tf.get_variable. - **kwargs: Keyword arguments, in the same format as tf.get_variable. - - Returns: - The return value of `getter(name, *args, **kwargs)`, spectrally - normalized. - - Raises: - ValueError: If used incorrectly, or if `dtype` is not supported. - """ - if not name_filter(name): - return getter(name, *args, **kwargs) - - if name.endswith(_PERSISTED_U_VARIABLE_SUFFIX): - raise ValueError( - 'Cannot apply Spectral Normalization to internal variables created ' - 'for Spectral Normalization. Tried to normalized variable [%s]' % - name) - - if kwargs['dtype'] not in _OK_DTYPES_FOR_SPECTRAL_NORM: - raise ValueError('Disallowed data type {}'.format(kwargs['dtype'])) - - # This layer's weight Variable/PartitionedVariable. - w_tensor = getter(name, *args, **kwargs) - - if len(w_tensor.get_shape()) < 2: - raise ValueError( - 'Spectral norm can only be applied to multi-dimensional tensors') - - return spectral_normalize( - w_tensor, - power_iteration_rounds=power_iteration_rounds, - name=(name + '/spectral_normalize')) - - return _internal_getter - - -@contextlib.contextmanager -def keras_spectral_normalization(name_filter=_default_name_filter, - power_iteration_rounds=1): - """A context manager that enables Spectral Normalization for Keras. - - Keras doesn't respect the `custom_getter` in the VariableScope, so this is a - bit of a hack to make things work. - - Usage: - with keras_spectral_normalization(): - net = discriminator_fn(net) - - Args: - name_filter: Optionally, a method that takes a Variable name as input and - returns whether this Variable should be normalized. - power_iteration_rounds: The number of iterations of the power method to - perform per step. A higher number yields a better approximation of the - true spectral norm. - - Yields: - A context manager that wraps the standard Keras variable creation method - with the `spectral_normalization_custom_getter`. - """ - original_make_variable = keras_base_layer_utils.make_variable - sn_getter = spectral_normalization_custom_getter( - name_filter=name_filter, power_iteration_rounds=power_iteration_rounds) - - def make_variable_wrapper(name, *args, **kwargs): - return sn_getter(original_make_variable, name, *args, **kwargs) - - keras_base_layer_utils.make_variable = make_variable_wrapper - - yield - - keras_base_layer_utils.make_variable = original_make_variable diff --git a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py b/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py deleted file mode 100644 index 4ea21f70ec0..00000000000 --- a/tensorflow/contrib/gan/python/features/python/spectral_normalization_test.py +++ /dev/null @@ -1,354 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for features.spectral_normalization.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import slim -from tensorflow.contrib.gan.python.features.python import spectral_normalization_impl as spectral_normalization -from tensorflow.contrib.layers.python.layers import layers as contrib_layers -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.keras.layers import convolutional as keras_convolutional -from tensorflow.python.keras.layers import core as keras_core -from tensorflow.python.layers import convolutional as layers_convolutional -from tensorflow.python.layers import core as layers_core -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class SpectralNormalizationTest(test.TestCase): - - def testComputeSpectralNorm(self): - weights = variable_scope.get_variable( - 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) - weights = math_ops.multiply(weights, 10.0) - s = linalg_ops.svd( - array_ops.reshape(weights, [-1, weights.shape[-1]]), compute_uv=False) - true_sn = s[..., 0] - estimated_sn = spectral_normalization.compute_spectral_norm(weights) - - with self.cached_session() as sess: - sess.run(variables.global_variables_initializer()) - np_true_sn = sess.run(true_sn) - for i in range(50): - est = sess.run(estimated_sn) - if i < 1: - np_est_1 = est - if i < 4: - np_est_5 = est - if i < 9: - np_est_10 = est - np_est_50 = est - - # Check that the estimate improves with more iterations. - self.assertAlmostEqual(np_true_sn, np_est_50, 0) - self.assertGreater( - abs(np_true_sn - np_est_10), abs(np_true_sn - np_est_50)) - self.assertGreater( - abs(np_true_sn - np_est_5), abs(np_true_sn - np_est_10)) - self.assertGreater(abs(np_true_sn - np_est_1), abs(np_true_sn - np_est_5)) - - def testSpectralNormalize(self): - weights = variable_scope.get_variable( - 'w', dtype=dtypes.float32, shape=[2, 3, 50, 100]) - weights = math_ops.multiply(weights, 10.0) - normalized_weights = spectral_normalization.spectral_normalize( - weights, power_iteration_rounds=1) - - unnormalized_sigma = linalg_ops.svd( - array_ops.reshape(weights, [-1, weights.shape[-1]]), - compute_uv=False)[..., 0] - normalized_sigma = linalg_ops.svd( - array_ops.reshape(normalized_weights, [-1, weights.shape[-1]]), - compute_uv=False)[..., 0] - - with self.cached_session() as sess: - sess.run(variables.global_variables_initializer()) - s0 = sess.run(unnormalized_sigma) - - for i in range(50): - sigma = sess.run(normalized_sigma) - if i < 1: - s1 = sigma - if i < 5: - s5 = sigma - if i < 10: - s10 = sigma - s50 = sigma - - self.assertAlmostEqual(1., s50, 0) - self.assertGreater(abs(s10 - 1.), abs(s50 - 1.)) - self.assertGreater(abs(s5 - 1.), abs(s10 - 1.)) - self.assertGreater(abs(s1 - 1.), abs(s5 - 1.)) - self.assertGreater(abs(s0 - 1.), abs(s1 - 1.)) - - def _testLayerHelper(self, build_layer_fn, w_shape, b_shape, is_keras=False): - x = array_ops.placeholder(dtypes.float32, shape=[2, 10, 10, 3]) - - w_initial = np.random.randn(*w_shape) * 10 - w_initializer = init_ops.constant_initializer(w_initial) - b_initial = np.random.randn(*b_shape) - b_initializer = init_ops.constant_initializer(b_initial) - - if is_keras: - context_manager = spectral_normalization.keras_spectral_normalization() - else: - getter = spectral_normalization.spectral_normalization_custom_getter() - context_manager = variable_scope.variable_scope('', custom_getter=getter) - - with context_manager: - (net, - expected_normalized_vars, expected_not_normalized_vars) = build_layer_fn( - x, w_initializer, b_initializer) - - x_data = np.random.rand(*x.shape) - - with self.cached_session() as sess: - sess.run(variables.global_variables_initializer()) - - # Before running a forward pass we still expect the variables values to - # differ from the initial value because of the normalizer. - w_befores = [] - for name, var in expected_normalized_vars.items(): - w_before = sess.run(var) - w_befores.append(w_before) - self.assertFalse( - np.allclose(w_initial, w_before), - msg=('%s appears not to be normalized. Before: %s After: %s' % - (name, w_initial, w_before))) - - # Not true for the unnormalized variables. - for name, var in expected_not_normalized_vars.items(): - b_before = sess.run(var) - self.assertTrue( - np.allclose(b_initial, b_before), - msg=('%s appears to be unexpectedly normalized. ' - 'Before: %s After: %s' % (name, b_initial, b_before))) - - # Run a bunch of forward passes. - for _ in range(1000): - _ = sess.run(net, feed_dict={x: x_data}) - - # We expect this to have improved the estimate of the spectral norm, - # which should have changed the variable values and brought them close - # to the true Spectral Normalized values. - _, s, _ = np.linalg.svd(w_initial.reshape([-1, 3])) - exactly_normalized = w_initial / s[0] - for w_before, (name, var) in zip(w_befores, - expected_normalized_vars.items()): - w_after = sess.run(var) - self.assertFalse( - np.allclose(w_before, w_after, rtol=1e-8, atol=1e-8), - msg=('%s did not improve over many iterations. ' - 'Before: %s After: %s' % (name, w_before, w_after))) - self.assertAllClose( - exactly_normalized, - w_after, - rtol=1e-4, - atol=1e-4, - msg=('Estimate of spectral norm for %s was innacurate. ' - 'Normalized matrices do not match.' - 'Estimate: %s Actual: %s' % (name, w_after, - exactly_normalized))) - - def testConv2D_Layers(self): - - def build_layer_fn(x, w_initializer, b_initializer): - layer = layers_convolutional.Conv2D( - filters=3, - kernel_size=3, - padding='same', - kernel_initializer=w_initializer, - bias_initializer=b_initializer) - net = layer.apply(x) - expected_normalized_vars = {'tf.layers.Conv2d.kernel': layer.kernel} - expected_not_normalized_vars = {'tf.layers.Conv2d.bias': layer.bias} - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) - - def testConv2D_ContribLayers(self): - - def build_layer_fn(x, w_initializer, b_initializer): - var_collection = { - 'weights': ['CONTRIB_LAYERS_CONV2D_WEIGHTS'], - 'biases': ['CONTRIB_LAYERS_CONV2D_BIASES'] - } - net = contrib_layers.conv2d( - x, - 3, - 3, - weights_initializer=w_initializer, - biases_initializer=b_initializer, - variables_collections=var_collection) - weight_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_WEIGHTS') - self.assertEquals(1, len(weight_vars)) - bias_vars = ops.get_collection('CONTRIB_LAYERS_CONV2D_BIASES') - self.assertEquals(1, len(bias_vars)) - expected_normalized_vars = { - 'contrib.layers.conv2d.weights': weight_vars[0] - } - expected_not_normalized_vars = { - 'contrib.layers.conv2d.bias': bias_vars[0] - } - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) - - def testConv2D_Slim(self): - - def build_layer_fn(x, w_initializer, b_initializer): - var_collection = { - 'weights': ['SLIM_CONV2D_WEIGHTS'], - 'biases': ['SLIM_CONV2D_BIASES'] - } - net = slim.conv2d( - x, - 3, - 3, - weights_initializer=w_initializer, - biases_initializer=b_initializer, - variables_collections=var_collection) - weight_vars = ops.get_collection('SLIM_CONV2D_WEIGHTS') - self.assertEquals(1, len(weight_vars)) - bias_vars = ops.get_collection('SLIM_CONV2D_BIASES') - self.assertEquals(1, len(bias_vars)) - expected_normalized_vars = {'slim.conv2d.weights': weight_vars[0]} - expected_not_normalized_vars = {'slim.conv2d.bias': bias_vars[0]} - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,)) - - def testConv2D_Keras(self): - - def build_layer_fn(x, w_initializer, b_initializer): - layer = keras_convolutional.Conv2D( - filters=3, - kernel_size=3, - padding='same', - kernel_initializer=w_initializer, - bias_initializer=b_initializer) - net = layer.apply(x) - expected_normalized_vars = {'keras.layers.Conv2d.kernel': layer.kernel} - expected_not_normalized_vars = {'keras.layers.Conv2d.bias': layer.bias} - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (3, 3, 3, 3), (3,), is_keras=True) - - def testFC_Layers(self): - - def build_layer_fn(x, w_initializer, b_initializer): - x = layers_core.Flatten()(x) - layer = layers_core.Dense( - units=3, - kernel_initializer=w_initializer, - bias_initializer=b_initializer) - net = layer.apply(x) - expected_normalized_vars = {'tf.layers.Dense.kernel': layer.kernel} - expected_not_normalized_vars = {'tf.layers.Dense.bias': layer.bias} - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (300, 3), (3,)) - - def testFC_ContribLayers(self): - - def build_layer_fn(x, w_initializer, b_initializer): - var_collection = { - 'weights': ['CONTRIB_LAYERS_FC_WEIGHTS'], - 'biases': ['CONTRIB_LAYERS_FC_BIASES'] - } - x = contrib_layers.flatten(x) - net = contrib_layers.fully_connected( - x, - 3, - weights_initializer=w_initializer, - biases_initializer=b_initializer, - variables_collections=var_collection) - weight_vars = ops.get_collection('CONTRIB_LAYERS_FC_WEIGHTS') - self.assertEquals(1, len(weight_vars)) - bias_vars = ops.get_collection('CONTRIB_LAYERS_FC_BIASES') - self.assertEquals(1, len(bias_vars)) - expected_normalized_vars = { - 'contrib.layers.fully_connected.weights': weight_vars[0] - } - expected_not_normalized_vars = { - 'contrib.layers.fully_connected.bias': bias_vars[0] - } - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (300, 3), (3,)) - - def testFC_Slim(self): - - def build_layer_fn(x, w_initializer, b_initializer): - var_collection = { - 'weights': ['SLIM_FC_WEIGHTS'], - 'biases': ['SLIM_FC_BIASES'] - } - x = slim.flatten(x) - net = slim.fully_connected( - x, - 3, - weights_initializer=w_initializer, - biases_initializer=b_initializer, - variables_collections=var_collection) - weight_vars = ops.get_collection('SLIM_FC_WEIGHTS') - self.assertEquals(1, len(weight_vars)) - bias_vars = ops.get_collection('SLIM_FC_BIASES') - self.assertEquals(1, len(bias_vars)) - expected_normalized_vars = { - 'slim.fully_connected.weights': weight_vars[0] - } - expected_not_normalized_vars = {'slim.fully_connected.bias': bias_vars[0]} - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (300, 3), (3,)) - - def testFC_Keras(self): - - def build_layer_fn(x, w_initializer, b_initializer): - x = keras_core.Flatten()(x) - layer = keras_core.Dense( - units=3, - kernel_initializer=w_initializer, - bias_initializer=b_initializer) - net = layer.apply(x) - expected_normalized_vars = {'keras.layers.Dense.kernel': layer.kernel} - expected_not_normalized_vars = {'keras.layers.Dense.bias': layer.bias} - - return net, expected_normalized_vars, expected_not_normalized_vars - - self._testLayerHelper(build_layer_fn, (300, 3), (3,), is_keras=True) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm.py deleted file mode 100644 index ea54ac01cee..00000000000 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Virtual batch normalization.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.virtual_batchnorm_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = virtual_batchnorm_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py deleted file mode 100644 index 030ce942607..00000000000 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Virtual batch normalization. - -This technique was first introduced in `Improved Techniques for Training GANs` -(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch -normalization on a minibatch, it fixes a reference subset of the data to use for -calculating normalization statistics. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import variable_scope - -__all__ = [ - 'VBN', -] - - -def _static_or_dynamic_batch_size(tensor, batch_axis): - """Returns the static or dynamic batch size.""" - batch_size = array_ops.shape(tensor)[batch_axis] - static_batch_size = tensor_util.constant_value(batch_size) - return static_batch_size or batch_size - - -def _statistics(x, axes): - """Calculate the mean and mean square of `x`. - - Modified from the implementation of `tf.nn.moments`. - - Args: - x: A `Tensor`. - axes: Array of ints. Axes along which to compute mean and variance. - - Returns: - Two `Tensor` objects: `mean` and `square mean`. - """ - # The dynamic range of fp16 is too limited to support the collection of - # sufficient statistics. As a workaround we simply perform the operations - # on 32-bit floats before converting the mean and variance back to fp16 - y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x - - # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - - shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) - mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) - - mean = array_ops.squeeze(mean, axes) - mean_squared = array_ops.squeeze(mean_squared, axes) - if x.dtype == dtypes.float16: - return (math_ops.cast(mean, dtypes.float16), - math_ops.cast(mean_squared, dtypes.float16)) - else: - return (mean, mean_squared) - - -def _validate_init_input_and_get_axis(reference_batch, axis): - """Validate input and return the used axis value.""" - if reference_batch.shape.ndims is None: - raise ValueError('`reference_batch` has unknown dimensions.') - - ndims = reference_batch.shape.ndims - if axis < 0: - used_axis = ndims + axis - else: - used_axis = axis - if used_axis < 0 or used_axis >= ndims: - raise ValueError('Value of `axis` argument ' + str(used_axis) + - ' is out of range for input with rank ' + str(ndims)) - return used_axis - - -def _validate_call_input(tensor_list, batch_dim): - """Verifies that tensor shapes are compatible, except for `batch_dim`.""" - - def _get_shape(tensor): - shape = tensor.shape.as_list() - del shape[batch_dim] - return shape - - base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0])) - for tensor in tensor_list: - base_shape.assert_is_compatible_with(_get_shape(tensor)) - - -class VBN(object): - """A class to perform virtual batch normalization. - - This technique was first introduced in `Improved Techniques for Training GANs` - (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch - normalization on a minibatch, it fixes a reference subset of the data to use - for calculating normalization statistics. - - To do this, we calculate the reference batch mean and mean square, and modify - those statistics for each example. We use mean square instead of variance, - since it is linear. - - Note that if `center` or `scale` variables are created, they are shared - between all calls to this object. - - The `__init__` API is intended to mimic - `tf.compat.v1.layers.batch_normalization` as - closely as possible. - """ - - def __init__(self, - reference_batch, - axis=-1, - epsilon=1e-3, - center=True, - scale=True, - beta_initializer=init_ops.zeros_initializer(), - gamma_initializer=init_ops.ones_initializer(), - beta_regularizer=None, - gamma_regularizer=None, - trainable=True, - name=None, - batch_axis=0): - """Initialize virtual batch normalization object. - - We precompute the 'mean' and 'mean squared' of the reference batch, so that - `__call__` is efficient. This means that the axis must be supplied when the - object is created, not when it is called. - - We precompute 'square mean' instead of 'variance', because the square mean - can be easily adjusted on a per-example basis. - - Args: - reference_batch: A minibatch tensors. This will form the reference data - from which the normalization statistics are calculated. See - https://arxiv.org/abs/1606.03498 for more details. - axis: Integer, the axis that should be normalized (typically the features - axis). For instance, after a `Convolution2D` layer with - `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. - epsilon: Small float added to variance to avoid dividing by zero. - center: If True, add offset of `beta` to normalized tensor. If False, - `beta` is ignored. - scale: If True, multiply by `gamma`. If False, `gamma` is not used. When - the next layer is linear (also e.g. `nn.relu`), this can be disabled - since the scaling can be done by the next layer. - beta_initializer: Initializer for the beta weight. - gamma_initializer: Initializer for the gamma weight. - beta_regularizer: Optional regularizer for the beta weight. - gamma_regularizer: Optional regularizer for the gamma weight. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). - name: String, the name of the ops. - batch_axis: The axis of the batch dimension. This dimension is treated - differently in `virtual batch normalization` vs `batch normalization`. - - Raises: - ValueError: If `reference_batch` has unknown dimensions at graph - construction. - ValueError: If `batch_axis` is the same as `axis`. - """ - axis = _validate_init_input_and_get_axis(reference_batch, axis) - self._epsilon = epsilon - self._beta = 0 - self._gamma = 1 - self._batch_axis = _validate_init_input_and_get_axis( - reference_batch, batch_axis) - - if axis == self._batch_axis: - raise ValueError('`axis` and `batch_axis` cannot be the same.') - - with variable_scope.variable_scope( - name, 'VBN', values=[reference_batch]) as self._vs: - self._reference_batch = reference_batch - - # Calculate important shapes: - # 1) Reduction axes for the reference batch - # 2) Broadcast shape, if necessary - # 3) Reduction axes for the virtual batchnormed batch - # 4) Shape for optional parameters - input_shape = self._reference_batch.shape - ndims = input_shape.ndims - reduction_axes = list(range(ndims)) - del reduction_axes[axis] - - self._broadcast_shape = [1] * len(input_shape) - self._broadcast_shape[axis] = input_shape.dims[axis] - - self._example_reduction_axes = list(range(ndims)) - del self._example_reduction_axes[max(axis, self._batch_axis)] - del self._example_reduction_axes[min(axis, self._batch_axis)] - - params_shape = self._reference_batch.shape[axis] - - # Determines whether broadcasting is needed. This is slightly different - # than in the `nn.batch_normalization` case, due to `batch_dim`. - self._needs_broadcasting = ( - sorted(self._example_reduction_axes) != list(range(ndims))[:-2]) - - # Calculate the sufficient statistics for the reference batch in a way - # that can be easily modified by additional examples. - self._ref_mean, self._ref_mean_squares = _statistics( - self._reference_batch, reduction_axes) - self._ref_variance = ( - self._ref_mean_squares - math_ops.square(self._ref_mean)) - - # Virtual batch normalization uses a weighted average between example - # statistics and the reference batch statistics. - ref_batch_size = _static_or_dynamic_batch_size(self._reference_batch, - self._batch_axis) - self._example_weight = 1. / ( - math_ops.cast(ref_batch_size, dtypes.float32) + 1.) - self._ref_weight = 1. - self._example_weight - - # Make the variables, if necessary. - if center: - self._beta = variable_scope.get_variable( - name='beta', - shape=(params_shape,), - initializer=beta_initializer, - regularizer=beta_regularizer, - trainable=trainable) - if scale: - self._gamma = variable_scope.get_variable( - name='gamma', - shape=(params_shape,), - initializer=gamma_initializer, - regularizer=gamma_regularizer, - trainable=trainable) - - def _virtual_statistics(self, inputs, reduction_axes): - """Compute the statistics needed for virtual batch normalization.""" - cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes) - vb_mean = ( - self._example_weight * cur_mean + self._ref_weight * self._ref_mean) - vb_mean_sq = ( - self._example_weight * cur_mean_sq + - self._ref_weight * self._ref_mean_squares) - return (vb_mean, vb_mean_sq) - - def _broadcast(self, v, broadcast_shape=None): - # The exact broadcast shape depends on the current batch, not the reference - # batch, unless we're calculating the batch normalization of the reference - # batch. - b_shape = broadcast_shape or self._broadcast_shape - if self._needs_broadcasting and v is not None: - return array_ops.reshape(v, b_shape) - return v - - def reference_batch_normalization(self): - """Return the reference batch, but batch normalized.""" - with ops.name_scope(self._vs.name): - return nn.batch_normalization(self._reference_batch, - self._broadcast(self._ref_mean), - self._broadcast(self._ref_variance), - self._broadcast(self._beta), - self._broadcast(self._gamma), self._epsilon) - - def __call__(self, inputs): - """Run virtual batch normalization on inputs. - - Args: - inputs: Tensor input. - - Returns: - A virtual batch normalized version of `inputs`. - - Raises: - ValueError: If `inputs` shape isn't compatible with the reference batch. - """ - _validate_call_input([inputs, self._reference_batch], self._batch_axis) - - with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]): - # Calculate the statistics on the current input on a per-example basis. - vb_mean, vb_mean_sq = self._virtual_statistics( - inputs, self._example_reduction_axes) - vb_variance = vb_mean_sq - math_ops.square(vb_mean) - - # The exact broadcast shape of the input statistic Tensors depends on the - # current batch, not the reference batch. The parameter broadcast shape - # is independent of the shape of the input statistic Tensor dimensions. - b_shape = self._broadcast_shape[:] # deep copy - b_shape[self._batch_axis] = _static_or_dynamic_batch_size( - inputs, self._batch_axis) - return nn.batch_normalization( - inputs, self._broadcast(vb_mean, b_shape), - self._broadcast(vb_variance, b_shape), - self._broadcast(self._beta, self._broadcast_shape), - self._broadcast(self._gamma, self._broadcast_shape), self._epsilon) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py deleted file mode 100644 index 9848f654bad..00000000000 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tfgan.python.features.virtual_batchnorm.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib -from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl as virtual_batchnorm -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import random_seed -from tensorflow.python.layers import normalization -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables as variables_lib -from tensorflow.python.platform import test - - -class VirtualBatchnormTest(test.TestCase): - - def test_syntax(self): - reference_batch = array_ops.zeros([5, 3, 16, 9, 15]) - vbn = virtual_batchnorm.VBN(reference_batch, batch_axis=1) - vbn(array_ops.ones([5, 7, 16, 9, 15])) - - def test_no_broadcast_needed(self): - """When `axis` and `batch_axis` are at the end, no broadcast is needed.""" - reference_batch = array_ops.zeros([5, 3, 16, 9, 15]) - minibatch = array_ops.zeros([5, 3, 16, 3, 15]) - vbn = virtual_batchnorm.VBN(reference_batch, axis=-1, batch_axis=-2) - vbn(minibatch) - - def test_statistics(self): - """Check that `_statistics` gives the same result as `nn.moments`.""" - random_seed.set_random_seed(1234) - - tensors = random_ops.random_normal([4, 5, 7, 3]) - for axes in [(3), (0, 2), (1, 2, 3)]: - vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes) - mom_mean, mom_var = nn.moments(tensors, axes) - vb_var = mean_sq - math_ops.square(vb_mean) - - with self.cached_session(use_gpu=True) as sess: - vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([ - vb_mean, vb_var, mom_mean, mom_var]) - - self.assertAllClose(mom_mean_np, vb_mean_np) - self.assertAllClose(mom_var_np, vb_var_np) - - def test_virtual_statistics(self): - """Check that `_virtual_statistics` gives same result as `nn.moments`.""" - random_seed.set_random_seed(1234) - - batch_axis = 0 - partial_batch = random_ops.random_normal([4, 5, 7, 3]) - single_example = random_ops.random_normal([1, 5, 7, 3]) - full_batch = array_ops.concat([partial_batch, single_example], axis=0) - - for reduction_axis in range(1, 4): - # Get `nn.moments` on the full batch. - reduction_axes = list(range(4)) - del reduction_axes[reduction_axis] - mom_mean, mom_variance = nn.moments(full_batch, reduction_axes) - - # Get virtual batch statistics. - vb_reduction_axes = list(range(4)) - del vb_reduction_axes[reduction_axis] - del vb_reduction_axes[batch_axis] - vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis) - vb_mean, mean_sq = vbn._virtual_statistics( - single_example, vb_reduction_axes) - vb_variance = mean_sq - math_ops.square(vb_mean) - # Remove singleton batch dim for easy comparisons. - vb_mean = array_ops.squeeze(vb_mean, batch_axis) - vb_variance = array_ops.squeeze(vb_variance, batch_axis) - - with self.cached_session(use_gpu=True) as sess: - vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([ - vb_mean, vb_variance, mom_mean, mom_variance]) - - self.assertAllClose(mom_mean_np, vb_mean_np) - self.assertAllClose(mom_var_np, vb_var_np) - - def test_reference_batch_normalization(self): - """Check that batch norm from VBN agrees with opensource implementation.""" - random_seed.set_random_seed(1234) - - batch = random_ops.random_normal([6, 5, 7, 3, 3]) - - for axis in range(5): - # Get `layers` batchnorm result. - bn_normalized = normalization.batch_normalization( - batch, axis, training=True) - - # Get VBN's batch normalization on reference batch. - batch_axis = 0 if axis != 0 else 1 # axis and batch_axis can't same - vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis) - vbn_normalized = vbn.reference_batch_normalization() - - with self.cached_session(use_gpu=True) as sess: - variables_lib.global_variables_initializer().run() - - bn_normalized_np, vbn_normalized_np = sess.run( - [bn_normalized, vbn_normalized]) - self.assertAllClose(bn_normalized_np, vbn_normalized_np) - - def test_same_as_batchnorm(self): - """Check that batch norm on set X is the same as ref of X / y on `y`.""" - random_seed.set_random_seed(1234) - - num_examples = 4 - examples = [random_ops.random_normal([5, 7, 3]) for _ in - range(num_examples)] - - # Get the result of the opensource batch normalization. - batch_normalized = normalization.batch_normalization( - array_ops.stack(examples), training=True) - - for i in range(num_examples): - examples_except_i = array_ops.stack(examples[:i] + examples[i+1:]) - # Get the result of VBN's batch normalization. - vbn = virtual_batchnorm.VBN(examples_except_i) - vb_normed = array_ops.squeeze( - vbn(array_ops.expand_dims(examples[i], [0])), [0]) - - with self.cached_session(use_gpu=True) as sess: - variables_lib.global_variables_initializer().run() - bn_np, vb_np = sess.run([batch_normalized, vb_normed]) - self.assertAllClose(bn_np[i, ...], vb_np) - - def test_minibatch_independent(self): - """Test that virtual batch normalized examples are independent. - - Unlike batch normalization, virtual batch normalization has the property - that the virtual batch normalized value of an example is independent of the - other examples in the minibatch. In this test, we verify this property. - """ - random_seed.set_random_seed(1234) - - # These can be random, but must be the same for all session calls. - reference_batch = constant_op.constant( - np.random.normal(size=[4, 7, 3]), dtype=dtypes.float32) - fixed_example = constant_op.constant(np.random.normal(size=[7, 3]), - dtype=dtypes.float32) - - # Get the VBN object and the virtual batch normalized value for - # `fixed_example`. - vbn = virtual_batchnorm.VBN(reference_batch) - vbn_fixed_example = array_ops.squeeze( - vbn(array_ops.expand_dims(fixed_example, 0)), 0) - with self.session(use_gpu=True): - variables_lib.global_variables_initializer().run() - vbn_fixed_example_np = vbn_fixed_example.eval() - - # Check that the value is the same for different minibatches, and different - # sized minibatches. - for minibatch_size in range(1, 6): - examples = [random_ops.random_normal([7, 3]) for _ in - range(minibatch_size)] - - minibatch = array_ops.stack([fixed_example] + examples) - vbn_minibatch = vbn(minibatch) - cur_vbn_fixed_example = vbn_minibatch[0, ...] - with self.cached_session(use_gpu=True): - variables_lib.global_variables_initializer().run() - cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval() - self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np) - - def test_variable_reuse(self): - """Test that variable scopes work and inference on a real-ish case.""" - tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3]) - tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3]) - tensor2_ref = array_ops.zeros([4, 2, 3]) - tensor2_examples = array_ops.zeros([2, 2, 3]) - - with variable_scope.variable_scope('dummy_scope', reuse=True): - with self.assertRaisesRegexp( - ValueError, 'does not exist, or was not created with ' - 'tf.get_variable()'): - virtual_batchnorm.VBN(tensor1_ref) - - vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1') - vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2') - - # Fetch reference and examples after virtual batch normalization. Also - # fetch in variable reuse case. - to_fetch = [] - - to_fetch.append(vbn1.reference_batch_normalization()) - to_fetch.append(vbn2.reference_batch_normalization()) - to_fetch.append(vbn1(tensor1_examples)) - to_fetch.append(vbn2(tensor2_examples)) - - variable_scope.get_variable_scope().reuse_variables() - - to_fetch.append(vbn1.reference_batch_normalization()) - to_fetch.append(vbn2.reference_batch_normalization()) - to_fetch.append(vbn1(tensor1_examples)) - to_fetch.append(vbn2(tensor2_examples)) - - self.assertEqual(4, len(contrib_variables_lib.get_variables())) - - with self.session(use_gpu=True) as sess: - variables_lib.global_variables_initializer().run() - sess.run(to_fetch) - - def test_invalid_input(self): - # Reference batch has unknown dimensions. - with self.assertRaisesRegexp( - ValueError, '`reference_batch` has unknown dimensions.'): - virtual_batchnorm.VBN(array_ops.placeholder(dtypes.float32), name='vbn1') - - # Axis too negative. - with self.assertRaisesRegexp( - ValueError, 'Value of `axis` argument .* is out of range'): - virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=-3, name='vbn2') - - # Axis too large. - with self.assertRaisesRegexp( - ValueError, 'Value of `axis` argument .* is out of range'): - virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=2, name='vbn3') - - # Batch axis too negative. - with self.assertRaisesRegexp( - ValueError, 'Value of `axis` argument .* is out of range'): - virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn4', batch_axis=-3) - - # Batch axis too large. - with self.assertRaisesRegexp( - ValueError, 'Value of `axis` argument .* is out of range'): - virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn5', batch_axis=2) - - # Axis and batch axis are the same. - with self.assertRaisesRegexp( - ValueError, '`axis` and `batch_axis` cannot be the same.'): - virtual_batchnorm.VBN(array_ops.zeros( - [1, 2]), axis=1, name='vbn6', batch_axis=1) - - # Reference Tensor and example Tensor have incompatible shapes. - tensor_ref = array_ops.zeros([5, 2, 3]) - tensor_examples = array_ops.zeros([3, 2, 3]) - vbn = virtual_batchnorm.VBN(tensor_ref, name='vbn7', batch_axis=1) - with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): - vbn(tensor_examples) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py deleted file mode 100644 index d9bf8ebfdf6..00000000000 --- a/tensorflow/contrib/gan/python/losses/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2017 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TFGAN losses and penalties. - -Losses can be used with individual arguments or with GANModel tuples. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Collapse losses into a single namespace. -from tensorflow.contrib.gan.python.losses.python import losses_wargs as wargs -from tensorflow.contrib.gan.python.losses.python import tuple_losses - -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.losses.python.tuple_losses import * -# pylint: enable=wildcard-import - -from tensorflow.python.util.all_util import remove_undocumented - -_allowed_symbols = ['wargs'] + tuple_losses.__all__ -remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py deleted file mode 100644 index 99bdf5b20d3..00000000000 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ /dev/null @@ -1,1030 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Losses that are useful for training GANs. - -The losses belong to two main groups, but there are others that do not: -1) xxxxx_generator_loss -2) xxxxx_discriminator_loss - -Example: -1) wasserstein_generator_loss -2) wasserstein_discriminator_loss - -Other example: -wasserstein_gradient_penalty - -All losses must be able to accept 1D or 2D Tensors, so as to be compatible with -patchGAN style losses (https://arxiv.org/abs/1611.07004). - -To make these losses usable in the TF-GAN framework, please create a tuple -version of the losses with `losses_utils.py`. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.losses import losses -from tensorflow.python.ops.losses import util -from tensorflow.python.summary import summary - -__all__ = [ - 'acgan_discriminator_loss', - 'acgan_generator_loss', - 'least_squares_discriminator_loss', - 'least_squares_generator_loss', - 'modified_discriminator_loss', - 'modified_generator_loss', - 'minimax_discriminator_loss', - 'minimax_generator_loss', - 'wasserstein_discriminator_loss', - 'wasserstein_generator_loss', - 'wasserstein_gradient_penalty', - 'mutual_information_penalty', - 'combine_adversarial_loss', - 'cycle_consistency_loss', -] - - -def _to_float(tensor): - return math_ops.cast(tensor, dtypes.float32) - - -# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). -def wasserstein_generator_loss( - discriminator_gen_outputs, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Wasserstein generator loss for GANs. - - See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details. - - Args: - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_gen_outputs`, and must be broadcastable to - `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add detailed summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope(scope, 'generator_wasserstein_loss', - (discriminator_gen_outputs, weights)) as scope: - discriminator_gen_outputs = _to_float(discriminator_gen_outputs) - - loss = -discriminator_gen_outputs - loss = losses.compute_weighted_loss(loss, weights, scope, loss_collection, - reduction) - - if add_summaries: - summary.scalar('generator_wass_loss', loss) - - return loss - - -def wasserstein_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Wasserstein discriminator loss for GANs. - - See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details. - - Args: - discriminator_real_outputs: Discriminator output on real data. - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - real_weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_real_outputs`, and must be broadcastable to - `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - generated_weights: Same as `real_weights`, but for - `discriminator_gen_outputs`. - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope(scope, 'discriminator_wasserstein_loss', - (discriminator_real_outputs, discriminator_gen_outputs, - real_weights, generated_weights)) as scope: - discriminator_real_outputs = _to_float(discriminator_real_outputs) - discriminator_gen_outputs = _to_float(discriminator_gen_outputs) - discriminator_real_outputs.shape.assert_is_compatible_with( - discriminator_gen_outputs.shape) - - loss_on_generated = losses.compute_weighted_loss( - discriminator_gen_outputs, - generated_weights, - scope, - loss_collection=None, - reduction=reduction) - loss_on_real = losses.compute_weighted_loss( - discriminator_real_outputs, - real_weights, - scope, - loss_collection=None, - reduction=reduction) - loss = loss_on_generated - loss_on_real - util.add_loss(loss, loss_collection) - - if add_summaries: - summary.scalar('discriminator_gen_wass_loss', loss_on_generated) - summary.scalar('discriminator_real_wass_loss', loss_on_real) - summary.scalar('discriminator_wass_loss', loss) - - return loss - - -# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` -# (https://arxiv.org/abs/1610.09585). -def acgan_discriminator_loss(discriminator_real_classification_logits, - discriminator_gen_classification_logits, - one_hot_labels, - label_smoothing=0.0, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """ACGAN loss for the discriminator. - - The ACGAN loss adds a classification loss to the conditional discriminator. - Therefore, the discriminator must output a tuple consisting of - (1) the real/fake prediction and - (2) the logits for the classification (usually the last conv layer, - flattened). - - For more details: - ACGAN: https://arxiv.org/abs/1610.09585 - - Args: - discriminator_real_classification_logits: Classification logits for real - data. - discriminator_gen_classification_logits: Classification logits for generated - data. - one_hot_labels: A Tensor holding one-hot labels for the batch. - label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for - "discriminator on real data" as suggested in - https://arxiv.org/pdf/1701.00160 - real_weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_real_outputs`, and must be broadcastable to - `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - generated_weights: Same as `real_weights`, but for - `discriminator_gen_classification_logits`. - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. Shape depends on `reduction`. - - Raises: - TypeError: If the discriminator does not output a tuple. - """ - with ops.name_scope( - scope, 'acgan_discriminator_loss', - (discriminator_real_classification_logits, - discriminator_gen_classification_logits, one_hot_labels)) as scope: - loss_on_generated = losses.softmax_cross_entropy( - one_hot_labels, - discriminator_gen_classification_logits, - weights=generated_weights, - scope=scope, - loss_collection=None, - reduction=reduction) - loss_on_real = losses.softmax_cross_entropy( - one_hot_labels, - discriminator_real_classification_logits, - weights=real_weights, - label_smoothing=label_smoothing, - scope=scope, - loss_collection=None, - reduction=reduction) - loss = loss_on_generated + loss_on_real - util.add_loss(loss, loss_collection) - - if add_summaries: - summary.scalar('discriminator_gen_ac_loss', loss_on_generated) - summary.scalar('discriminator_real_ac_loss', loss_on_real) - summary.scalar('discriminator_ac_loss', loss) - - return loss - - -def acgan_generator_loss(discriminator_gen_classification_logits, - one_hot_labels, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """ACGAN loss for the generator. - - The ACGAN loss adds a classification loss to the conditional discriminator. - Therefore, the discriminator must output a tuple consisting of - (1) the real/fake prediction and - (2) the logits for the classification (usually the last conv layer, - flattened). - - For more details: - ACGAN: https://arxiv.org/abs/1610.09585 - - Args: - discriminator_gen_classification_logits: Classification logits for generated - data. - one_hot_labels: A Tensor holding one-hot labels for the batch. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_gen_classification_logits`, and must be broadcastable to - `discriminator_gen_classification_logits` (i.e., all dimensions must be - either `1`, or the same as the corresponding dimension). - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. Shape depends on `reduction`. - - Raises: - ValueError: if arg module not either `generator` or `discriminator` - TypeError: if the discriminator does not output a tuple. - """ - with ops.name_scope( - scope, 'acgan_generator_loss', - (discriminator_gen_classification_logits, one_hot_labels)) as scope: - loss = losses.softmax_cross_entropy( - one_hot_labels, - discriminator_gen_classification_logits, - weights=weights, - scope=scope, - loss_collection=loss_collection, - reduction=reduction) - - if add_summaries: - summary.scalar('generator_ac_loss', loss) - - return loss - - -# Wasserstein Gradient Penalty losses from `Improved Training of Wasserstein -# GANs` (https://arxiv.org/abs/1704.00028). - - -def wasserstein_gradient_penalty( - real_data, - generated_data, - generator_inputs, - discriminator_fn, - discriminator_scope, - epsilon=1e-10, - target=1.0, - one_sided=False, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """The gradient penalty for the Wasserstein discriminator loss. - - See `Improved Training of Wasserstein GANs` - (https://arxiv.org/abs/1704.00028) for more details. - - Args: - real_data: Real data. - generated_data: Output of the generator. - generator_inputs: Exact argument to pass to the generator, which is used as - optional conditioning to the discriminator. - discriminator_fn: A discriminator function that conforms to TF-GAN API. - discriminator_scope: If not `None`, reuse discriminators from this scope. - epsilon: A small positive number added for numerical stability when - computing the gradient norm. - target: Optional Python number or `Tensor` indicating the target value of - gradient norm. Defaults to 1.0. - one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894 - is used. Defaults to `False`. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `real_data` and `generated_data`, and must be broadcastable to them (i.e., - all dimensions must be either `1`, or the same as the corresponding - dimension). - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - - Raises: - ValueError: If the rank of data Tensors is unknown. - """ - with ops.name_scope(scope, 'wasserstein_gradient_penalty', - (real_data, generated_data)) as scope: - real_data = ops.convert_to_tensor(real_data) - generated_data = ops.convert_to_tensor(generated_data) - if real_data.shape.ndims is None: - raise ValueError('`real_data` can\'t have unknown rank.') - if generated_data.shape.ndims is None: - raise ValueError('`generated_data` can\'t have unknown rank.') - - differences = generated_data - real_data - batch_size = differences.shape.dims[0].value or array_ops.shape( - differences)[0] - alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) - alpha = random_ops.random_uniform(shape=alpha_shape) - interpolates = real_data + (alpha * differences) - - with ops.name_scope(None): # Clear scope so update ops are added properly. - # Reuse variables if variables already exists. - with variable_scope.variable_scope( - discriminator_scope, - 'gpenalty_dscope', - reuse=variable_scope.AUTO_REUSE): - disc_interpolates = discriminator_fn(interpolates, generator_inputs) - - if isinstance(disc_interpolates, tuple): - # ACGAN case: disc outputs more than one tensor - disc_interpolates = disc_interpolates[0] - - gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] - gradient_squares = math_ops.reduce_sum( - math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) - # Propagate shape information, if possible. - if isinstance(batch_size, int): - gradient_squares.set_shape([batch_size] + - gradient_squares.shape.as_list()[1:]) - # For numerical stability, add epsilon to the sum before taking the square - # root. Note tf.norm does not add epsilon. - slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = slopes / target - 1.0 - if one_sided: - penalties = math_ops.maximum(0., penalties) - penalties_squared = math_ops.square(penalties) - penalty = losses.compute_weighted_loss( - penalties_squared, - weights, - scope=scope, - loss_collection=loss_collection, - reduction=reduction) - - if add_summaries: - summary.scalar('gradient_penalty_loss', penalty) - - return penalty - - -# Original losses from `Generative Adversarial Nets` -# (https://arxiv.org/abs/1406.2661). - - -def minimax_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs, - label_smoothing=0.25, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Original minimax discriminator loss for GANs, with label smoothing. - - Note that the authors don't recommend using this loss. A more practically - useful loss is `modified_discriminator_loss`. - - L = - real_weights * log(sigmoid(D(x))) - - generated_weights * log(1 - sigmoid(D(G(z)))) - - See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more - details. - - Args: - discriminator_real_outputs: Discriminator output on real data. - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - label_smoothing: The amount of smoothing for positive labels. This technique - is taken from `Improved Techniques for Training GANs` - (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - real_weights: Optional `Tensor` whose rank is either 0, or the same rank as - `real_data`, and must be broadcastable to `real_data` (i.e., all - dimensions must be either `1`, or the same as the corresponding - dimension). - generated_weights: Same as `real_weights`, but for `generated_data`. - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope( - scope, 'discriminator_minimax_loss', - (discriminator_real_outputs, discriminator_gen_outputs, real_weights, - generated_weights, label_smoothing)) as scope: - - # -log((1 - label_smoothing) - sigmoid(D(x))) - loss_on_real = losses.sigmoid_cross_entropy( - array_ops.ones_like(discriminator_real_outputs), - discriminator_real_outputs, - real_weights, - label_smoothing, - scope, - loss_collection=None, - reduction=reduction) - # -log(- sigmoid(D(G(x)))) - loss_on_generated = losses.sigmoid_cross_entropy( - array_ops.zeros_like(discriminator_gen_outputs), - discriminator_gen_outputs, - generated_weights, - scope=scope, - loss_collection=None, - reduction=reduction) - - loss = loss_on_real + loss_on_generated - util.add_loss(loss, loss_collection) - - if add_summaries: - summary.scalar('discriminator_gen_minimax_loss', loss_on_generated) - summary.scalar('discriminator_real_minimax_loss', loss_on_real) - summary.scalar('discriminator_minimax_loss', loss) - - return loss - - -def minimax_generator_loss(discriminator_gen_outputs, - label_smoothing=0.0, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Original minimax generator loss for GANs. - - Note that the authors don't recommend using this loss. A more practically - useful loss is `modified_generator_loss`. - - L = log(sigmoid(D(x))) + log(1 - sigmoid(D(G(z)))) - - See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more - details. - - Args: - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - label_smoothing: The amount of smoothing for positive labels. This technique - is taken from `Improved Techniques for Training GANs` - (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_gen_outputs`, and must be broadcastable to - `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope(scope, 'generator_minimax_loss') as scope: - loss = -minimax_discriminator_loss( - array_ops.ones_like(discriminator_gen_outputs), - discriminator_gen_outputs, - label_smoothing, - weights, - weights, - scope, - loss_collection, - reduction, - add_summaries=False) - - if add_summaries: - summary.scalar('generator_minimax_loss', loss) - - return loss - - -def modified_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs, - label_smoothing=0.25, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Same as minimax discriminator loss. - - See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more - details. - - Args: - discriminator_real_outputs: Discriminator output on real data. - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - label_smoothing: The amount of smoothing for positive labels. This technique - is taken from `Improved Techniques for Training GANs` - (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - real_weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_gen_outputs`, and must be broadcastable to - `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - generated_weights: Same as `real_weights`, but for - `discriminator_gen_outputs`. - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - return minimax_discriminator_loss(discriminator_real_outputs, - discriminator_gen_outputs, label_smoothing, - real_weights, generated_weights, scope or - 'discriminator_modified_loss', - loss_collection, reduction, add_summaries) - - -def modified_generator_loss(discriminator_gen_outputs, - label_smoothing=0.0, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Modified generator loss for GANs. - - L = -log(sigmoid(D(G(z)))) - - This is the trick used in the original paper to avoid vanishing gradients - early in training. See `Generative Adversarial Nets` - (https://arxiv.org/abs/1406.2661) for more details. - - Args: - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - label_smoothing: The amount of smoothing for positive labels. This technique - is taken from `Improved Techniques for Training GANs` - (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_gen_outputs`, and must be broadcastable to `labels` (i.e., - all dimensions must be either `1`, or the same as the corresponding - dimension). - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope(scope, 'generator_modified_loss', - [discriminator_gen_outputs]) as scope: - loss = losses.sigmoid_cross_entropy( - array_ops.ones_like(discriminator_gen_outputs), - discriminator_gen_outputs, weights, label_smoothing, scope, - loss_collection, reduction) - - if add_summaries: - summary.scalar('generator_modified_loss', loss) - - return loss - - -# Least Squares loss from `Least Squares Generative Adversarial Networks` -# (https://arxiv.org/abs/1611.04076). - - -def least_squares_generator_loss( - discriminator_gen_outputs, - real_label=1, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Least squares generator loss. - - This loss comes from `Least Squares Generative Adversarial Networks` - (https://arxiv.org/abs/1611.04076). - - L = 1/2 * (D(G(z)) - `real_label`) ** 2 - - where D(y) are discriminator logits. - - Args: - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - real_label: The value that the generator is trying to get the discriminator - to output on generated data. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_gen_outputs`, and must be broadcastable to - `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope(scope, 'lsq_generator_loss', - (discriminator_gen_outputs, real_label)) as scope: - discriminator_gen_outputs = _to_float(discriminator_gen_outputs) - loss = math_ops.squared_difference(discriminator_gen_outputs, - real_label) / 2.0 - loss = losses.compute_weighted_loss(loss, weights, scope, loss_collection, - reduction) - - if add_summaries: - summary.scalar('generator_lsq_loss', loss) - - return loss - - -def least_squares_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs, - real_label=1, - fake_label=0, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Least squares discriminator loss. - - This loss comes from `Least Squares Generative Adversarial Networks` - (https://arxiv.org/abs/1611.04076). - - L = 1/2 * (D(x) - `real`) ** 2 + - 1/2 * (D(G(z)) - `fake_label`) ** 2 - - where D(y) are discriminator logits. - - Args: - discriminator_real_outputs: Discriminator output on real data. - discriminator_gen_outputs: Discriminator output on generated data. Expected - to be in the range of (-inf, inf). - real_label: The value that the discriminator tries to output for real data. - fake_label: The value that the discriminator tries to output for fake data. - real_weights: Optional `Tensor` whose rank is either 0, or the same rank as - `discriminator_real_outputs`, and must be broadcastable to - `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or - the same as the corresponding dimension). - generated_weights: Same as `real_weights`, but for - `discriminator_gen_outputs`. - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A loss Tensor. The shape depends on `reduction`. - """ - with ops.name_scope(scope, 'lsq_discriminator_loss', - (discriminator_gen_outputs, real_label)) as scope: - discriminator_real_outputs = _to_float(discriminator_real_outputs) - discriminator_gen_outputs = _to_float(discriminator_gen_outputs) - discriminator_real_outputs.shape.assert_is_compatible_with( - discriminator_gen_outputs.shape) - - real_losses = math_ops.squared_difference(discriminator_real_outputs, - real_label) / 2.0 - fake_losses = math_ops.squared_difference(discriminator_gen_outputs, - fake_label) / 2.0 - - loss_on_real = losses.compute_weighted_loss( - real_losses, - real_weights, - scope, - loss_collection=None, - reduction=reduction) - loss_on_generated = losses.compute_weighted_loss( - fake_losses, - generated_weights, - scope, - loss_collection=None, - reduction=reduction) - - loss = loss_on_real + loss_on_generated - util.add_loss(loss, loss_collection) - - if add_summaries: - summary.scalar('discriminator_gen_lsq_loss', loss_on_generated) - summary.scalar('discriminator_real_lsq_loss', loss_on_real) - summary.scalar('discriminator_lsq_loss', loss) - - return loss - - -# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by -# `Information Maximizing Generative Adversarial Nets` -# https://arxiv.org/abs/1606.03657 - - -def _validate_distributions(distributions): - if not isinstance(distributions, (list, tuple)): - raise ValueError('`distributions` must be a list or tuple. Instead, ' - 'found %s.' % type(distributions)) - for x in distributions: - # We used to check with `isinstance(x, tf.compat.v1.distributions.Distribution)`. - # However, distributions have migrated to `tfp.distributions.Distribution`, - # which is a new code repo, so we can't check this way anymore until - # TF-GAN is migrated to a new repo as well. - # This new check is not sufficient, but is a useful heuristic for now. - if not callable(getattr(x, 'log_prob', None)): - raise ValueError('`distributions` must be a list of `Distributions`. ' - 'Instead, found %s.' % type(x)) - - -def _validate_information_penalty_inputs(structured_generator_inputs, - predicted_distributions): - """Validate input to `mutual_information_penalty`.""" - _validate_distributions(predicted_distributions) - if len(structured_generator_inputs) != len(predicted_distributions): - raise ValueError( - '`structured_generator_inputs` length %i must be the same ' - 'as `predicted_distributions` length %i.' % - (len(structured_generator_inputs), len(predicted_distributions))) - - -def mutual_information_penalty( - structured_generator_inputs, - predicted_distributions, - weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False): - """Returns a penalty on the mutual information in an InfoGAN model. - - This loss comes from an InfoGAN paper https://arxiv.org/abs/1606.03657. - - Args: - structured_generator_inputs: A list of Tensors representing the random noise - that must have high mutual information with the generator output. List - length should match `predicted_distributions`. - predicted_distributions: A list of `tfp.distributions.Distribution`s. - Predicted by the recognizer, and used to evaluate the likelihood of the - structured noise. List length should match `structured_generator_inputs`. - weights: Optional `Tensor` whose rank is either 0, or the same dimensions as - `structured_generator_inputs`. - scope: The scope for the operations performed in computing the loss. - loss_collection: collection to which this loss will be added. - reduction: A `tf.compat.v1.losses.Reduction` to apply to loss. - add_summaries: Whether or not to add summaries for the loss. - - Returns: - A scalar Tensor representing the mutual information loss. - """ - _validate_information_penalty_inputs(structured_generator_inputs, - predicted_distributions) - - with ops.name_scope(scope, 'mutual_information_loss') as scope: - # Calculate the negative log-likelihood of the reconstructed noise. - log_probs = [ - math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in zip( - predicted_distributions, structured_generator_inputs) - ] - loss = -1 * losses.compute_weighted_loss( - log_probs, - weights, - scope, - loss_collection=loss_collection, - reduction=reduction) - - if add_summaries: - summary.scalar('mutual_information_penalty', loss) - - return loss - - -def _numerically_stable_global_norm(tensor_list): - """Compute the global norm of a list of Tensors, with improved stability. - - The global norm computation sometimes overflows due to the intermediate L2 - step. To avoid this, we divide by a cheap-to-compute max over the - matrix elements. - - Args: - tensor_list: A list of tensors, or `None`. - - Returns: - A scalar tensor with the global norm. - """ - if all(x is None for x in tensor_list): - return 0.0 - - list_max = math_ops.reduce_max([ - math_ops.reduce_max(math_ops.abs(x)) for x in tensor_list if x is not None - ]) - return list_max * clip_ops.global_norm( - [x / list_max for x in tensor_list if x is not None]) - - -def _used_weight(weights_list): - for weight in weights_list: - if weight is not None: - return tensor_util.constant_value(ops.convert_to_tensor(weight)) - - -def _validate_args(losses_list, weight_factor, gradient_ratio): - for loss in losses_list: - loss.shape.assert_is_compatible_with([]) - if weight_factor is None and gradient_ratio is None: - raise ValueError( - '`weight_factor` and `gradient_ratio` cannot both be `None.`') - if weight_factor is not None and gradient_ratio is not None: - raise ValueError( - '`weight_factor` and `gradient_ratio` cannot both be specified.') - - -# TODO(joelshor): Add ability to pass in gradients, to avoid recomputing. -def combine_adversarial_loss(main_loss, - adversarial_loss, - weight_factor=None, - gradient_ratio=None, - gradient_ratio_epsilon=1e-6, - variables=None, - scalar_summaries=True, - gradient_summaries=True, - scope=None): - """Utility to combine main and adversarial losses. - - This utility combines the main and adversarial losses in one of two ways. - 1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case. - 2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often - used to make sure both losses affect weights roughly equally, as in - https://arxiv.org/pdf/1705.05823. - - One can optionally also visualize the scalar and gradient behavior of the - losses. - - Args: - main_loss: A floating scalar Tensor indicating the main loss. - adversarial_loss: A floating scalar Tensor indication the adversarial loss. - weight_factor: If not `None`, the coefficient by which to multiply the - adversarial loss. Exactly one of this and `gradient_ratio` must be - non-None. - gradient_ratio: If not `None`, the ratio of the magnitude of the gradients. - Specifically, gradient_ratio = grad_mag(main_loss) / - grad_mag(adversarial_loss) Exactly one of this and `weight_factor` must be - non-None. - gradient_ratio_epsilon: An epsilon to add to the adversarial loss - coefficient denominator, to avoid division-by-zero. - variables: List of variables to calculate gradients with respect to. If not - present, defaults to all trainable variables. - scalar_summaries: Create scalar summaries of losses. - gradient_summaries: Create gradient summaries of losses. - scope: Optional name scope. - - Returns: - A floating scalar Tensor indicating the desired combined loss. - - Raises: - ValueError: Malformed input. - """ - _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio) - if variables is None: - variables = contrib_variables_lib.get_trainable_variables() - - with ops.name_scope( - scope, 'adversarial_loss', values=[main_loss, adversarial_loss]): - # Compute gradients if we will need them. - if gradient_summaries or gradient_ratio is not None: - main_loss_grad_mag = _numerically_stable_global_norm( - gradients_impl.gradients(main_loss, variables)) - adv_loss_grad_mag = _numerically_stable_global_norm( - gradients_impl.gradients(adversarial_loss, variables)) - - # Add summaries, if applicable. - if scalar_summaries: - summary.scalar('main_loss', main_loss) - summary.scalar('adversarial_loss', adversarial_loss) - if gradient_summaries: - summary.scalar('main_loss_gradients', main_loss_grad_mag) - summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag) - - # Combine losses in the appropriate way. - # If `weight_factor` is always `0`, avoid computing the adversarial loss - # tensor entirely. - if _used_weight((weight_factor, gradient_ratio)) == 0: - final_loss = main_loss - elif weight_factor is not None: - final_loss = ( - main_loss + array_ops.stop_gradient(weight_factor) * adversarial_loss) - elif gradient_ratio is not None: - grad_mag_ratio = main_loss_grad_mag / ( - adv_loss_grad_mag + gradient_ratio_epsilon) - adv_coeff = grad_mag_ratio / gradient_ratio - summary.scalar('adversarial_coefficient', adv_coeff) - final_loss = ( - main_loss + array_ops.stop_gradient(adv_coeff) * adversarial_loss) - - return final_loss - - -def cycle_consistency_loss(data_x, - reconstructed_data_x, - data_y, - reconstructed_data_y, - scope=None, - add_summaries=False): - """Defines the cycle consistency loss. - - The cyclegan model has two partial models where `model_x2y` generator F maps - data set X to Y, `model_y2x` generator G maps data set Y to X. For a `data_x` - in data set X, we could reconstruct it by - * reconstructed_data_x = G(F(data_x)) - Similarly - * reconstructed_data_y = F(G(data_y)) - - The cycle consistency loss is about the difference between data and - reconstructed data, namely - * loss_x2x = |data_x - G(F(data_x))| (L1-norm) - * loss_y2y = |data_y - F(G(data_y))| (L1-norm) - * loss = (loss_x2x + loss_y2y) / 2 - where `loss` is the final result. - - For the L1-norm, we follow the original implementation: - https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua - we use L1-norm of pixel-wise error normalized by data size such that - `cycle_loss_weight` can be specified independent of image size. - - See https://arxiv.org/abs/1703.10593 for more details. - - Args: - data_x: A `Tensor` of data X. - reconstructed_data_x: A `Tensor` of reconstructed data X. - data_y: A `Tensor` of data Y. - reconstructed_data_y: A `Tensor` of reconstructed data Y. - scope: The scope for the operations performed in computing the loss. - Defaults to None. - add_summaries: Whether or not to add detailed summaries for the loss. - Defaults to False. - - Returns: - A scalar `Tensor` of cycle consistency loss. - """ - - with ops.name_scope( - scope, - 'cycle_consistency_loss', - values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]): - loss_x2x = losses.absolute_difference(data_x, reconstructed_data_x) - loss_y2y = losses.absolute_difference(data_y, reconstructed_data_y) - loss = (loss_x2x + loss_y2y) / 2.0 - if add_summaries: - summary.scalar('cycle_consistency_loss_x2x', loss_x2x) - summary.scalar('cycle_consistency_loss_y2y', loss_y2y) - summary.scalar('cycle_consistency_loss', loss) - - return loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py deleted file mode 100644 index 44ee0f52696..00000000000 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ /dev/null @@ -1,701 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TFGAN losses.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import categorical -from tensorflow.python.ops.distributions import normal -from tensorflow.python.ops.losses import losses as tf_losses -from tensorflow.python.platform import test - - -# TODO(joelshor): Use `parameterized` tests when opensourced. -class _LossesTest(object): - - def init_constants(self): - self._discriminator_real_outputs_np = [-5.0, 1.4, 12.5, 2.7] - self._discriminator_gen_outputs_np = [10.0, 4.4, -5.5, 3.6] - self._weights = 2.3 - self._discriminator_real_outputs = constant_op.constant( - self._discriminator_real_outputs_np, dtype=dtypes.float32) - self._discriminator_gen_outputs = constant_op.constant( - self._discriminator_gen_outputs_np, dtype=dtypes.float32) - - def test_generator_all_correct(self): - loss = self._g_loss_fn(self._discriminator_gen_outputs) - self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - self.assertEqual(self._generator_loss_name, loss.op.name) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) - - def test_discriminator_all_correct(self): - loss = self._d_loss_fn( - self._discriminator_real_outputs, self._discriminator_gen_outputs) - self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - self.assertEqual(self._discriminator_loss_name, loss.op.name) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) - - def test_generator_loss_collection(self): - self.assertEqual(0, len(ops.get_collection('collection'))) - self._g_loss_fn( - self._discriminator_gen_outputs, loss_collection='collection') - self.assertEqual(1, len(ops.get_collection('collection'))) - - def test_discriminator_loss_collection(self): - self.assertEqual(0, len(ops.get_collection('collection'))) - self._d_loss_fn( - self._discriminator_real_outputs, self._discriminator_gen_outputs, - loss_collection='collection') - self.assertEqual(1, len(ops.get_collection('collection'))) - - def test_generator_no_reduction(self): - loss = self._g_loss_fn( - self._discriminator_gen_outputs, reduction=tf_losses.Reduction.NONE) - self.assertAllEqual([4], loss.shape) - - def test_discriminator_no_reduction(self): - loss = self._d_loss_fn( - self._discriminator_real_outputs, self._discriminator_gen_outputs, - reduction=tf_losses.Reduction.NONE) - self.assertAllEqual([4], loss.shape) - - def test_generator_patch(self): - loss = self._g_loss_fn( - array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) - self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) - - def test_discriminator_patch(self): - loss = self._d_loss_fn( - array_ops.reshape(self._discriminator_real_outputs, [2, 2]), - array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) - self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) - - def test_generator_loss_with_placeholder_for_logits(self): - logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) - weights = array_ops.ones_like(logits, dtype=dtypes.float32) - - loss = self._g_loss_fn(logits, weights=weights) - self.assertEqual(logits.dtype, loss.dtype) - - with self.cached_session() as sess: - loss = sess.run(loss, - feed_dict={ - logits: [[10.0, 4.4, -5.5, 3.6]], - }) - self.assertAlmostEqual(self._expected_g_loss, loss, 5) - - def test_discriminator_loss_with_placeholder_for_logits(self): - logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) - logits2 = array_ops.placeholder(dtypes.float32, shape=(None, 4)) - real_weights = array_ops.ones_like(logits, dtype=dtypes.float32) - generated_weights = array_ops.ones_like(logits, dtype=dtypes.float32) - - loss = self._d_loss_fn( - logits, logits2, real_weights=real_weights, - generated_weights=generated_weights) - - with self.cached_session() as sess: - loss = sess.run(loss, - feed_dict={ - logits: [self._discriminator_real_outputs_np], - logits2: [self._discriminator_gen_outputs_np], - }) - self.assertAlmostEqual(self._expected_d_loss, loss, 5) - - def test_generator_with_python_scalar_weight(self): - loss = self._g_loss_fn( - self._discriminator_gen_outputs, weights=self._weights) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss * self._weights, - loss.eval(), 4) - - def test_discriminator_with_python_scalar_weight(self): - loss = self._d_loss_fn( - self._discriminator_real_outputs, self._discriminator_gen_outputs, - real_weights=self._weights, generated_weights=self._weights) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss * self._weights, - loss.eval(), 4) - - def test_generator_with_scalar_tensor_weight(self): - loss = self._g_loss_fn(self._discriminator_gen_outputs, - weights=constant_op.constant(self._weights)) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss * self._weights, - loss.eval(), 4) - - def test_discriminator_with_scalar_tensor_weight(self): - weights = constant_op.constant(self._weights) - loss = self._d_loss_fn( - self._discriminator_real_outputs, self._discriminator_gen_outputs, - real_weights=weights, generated_weights=weights) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss * self._weights, - loss.eval(), 4) - - def test_generator_add_summaries(self): - self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - self._g_loss_fn(self._discriminator_gen_outputs, add_summaries=True) - self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - - def test_discriminator_add_summaries(self): - self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - self._d_loss_fn( - self._discriminator_real_outputs, self._discriminator_gen_outputs, - add_summaries=True) - self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - - -class LeastSquaresLossTest(test.TestCase, _LossesTest): - """Tests for least_squares_xxx_loss.""" - - def setUp(self): - super(LeastSquaresLossTest, self).setUp() - self.init_constants() - self._expected_g_loss = 17.69625 - self._expected_d_loss = 41.73375 - self._generator_loss_name = 'lsq_generator_loss/value' - self._discriminator_loss_name = 'lsq_discriminator_loss/add' - self._g_loss_fn = tfgan_losses.least_squares_generator_loss - self._d_loss_fn = tfgan_losses.least_squares_discriminator_loss - - -class ModifiedLossTest(test.TestCase, _LossesTest): - """Tests for modified_xxx_loss.""" - - def setUp(self): - super(ModifiedLossTest, self).setUp() - self.init_constants() - self._expected_g_loss = 1.38582 - self._expected_d_loss = 6.19637 - self._generator_loss_name = 'generator_modified_loss/value' - self._discriminator_loss_name = 'discriminator_modified_loss/add_1' - self._g_loss_fn = tfgan_losses.modified_generator_loss - self._d_loss_fn = tfgan_losses.modified_discriminator_loss - - -class MinimaxLossTest(test.TestCase, _LossesTest): - """Tests for minimax_xxx_loss.""" - - def setUp(self): - super(MinimaxLossTest, self).setUp() - self.init_constants() - self._expected_g_loss = -4.82408 - self._expected_d_loss = 6.19637 - self._generator_loss_name = 'generator_minimax_loss/Neg' - self._discriminator_loss_name = 'discriminator_minimax_loss/add_1' - self._g_loss_fn = tfgan_losses.minimax_generator_loss - self._d_loss_fn = tfgan_losses.minimax_discriminator_loss - - -class WassersteinLossTest(test.TestCase, _LossesTest): - """Tests for wasserstein_xxx_loss.""" - - def setUp(self): - super(WassersteinLossTest, self).setUp() - self.init_constants() - self._expected_g_loss = -3.12500 - self._expected_d_loss = 0.22500 - self._generator_loss_name = 'generator_wasserstein_loss/value' - self._discriminator_loss_name = 'discriminator_wasserstein_loss/sub' - self._g_loss_fn = tfgan_losses.wasserstein_generator_loss - self._d_loss_fn = tfgan_losses.wasserstein_discriminator_loss - - -# TODO(joelshor): Use `parameterized` tests when opensourced. -# TODO(joelshor): Refactor this test to use the same code as the other losses. -class ACGANLossTest(test.TestCase): - """Tests for wasserstein_xxx_loss.""" - - def setUp(self): - super(ACGANLossTest, self).setUp() - self._g_loss_fn = tfgan_losses.acgan_generator_loss - self._d_loss_fn = tfgan_losses.acgan_discriminator_loss - self._discriminator_gen_classification_logits_np = [[10.0, 4.4, -5.5, 3.6], - [-4.0, 4.4, 5.2, 4.6], - [1.1, 2.4, -3.5, 5.6], - [1.1, 2.4, -3.5, 5.6]] - self._discriminator_real_classification_logits_np = [[-2.0, 0.4, 12.5, 2.7], - [-1.2, 1.9, 12.3, 2.6], - [-2.4, -1.7, 2.5, 2.7], - [1.1, 2.4, -3.5, 5.6]] - self._one_hot_labels_np = [[0, 1, 0, 0], - [0, 0, 1, 0], - [1, 0, 0, 0], - [1, 0, 0, 0]] - self._weights = 2.3 - - self._discriminator_gen_classification_logits = constant_op.constant( - self._discriminator_gen_classification_logits_np, dtype=dtypes.float32) - self._discriminator_real_classification_logits = constant_op.constant( - self._discriminator_real_classification_logits_np, dtype=dtypes.float32) - self._one_hot_labels = constant_op.constant( - self._one_hot_labels_np, dtype=dtypes.float32) - self._generator_kwargs = { - 'discriminator_gen_classification_logits': - self._discriminator_gen_classification_logits, - 'one_hot_labels': self._one_hot_labels, - } - self._discriminator_kwargs = { - 'discriminator_gen_classification_logits': - self._discriminator_gen_classification_logits, - 'discriminator_real_classification_logits': - self._discriminator_real_classification_logits, - 'one_hot_labels': self._one_hot_labels, - } - self._generator_loss_name = 'acgan_generator_loss/value' - self._discriminator_loss_name = 'acgan_discriminator_loss/add' - self._expected_g_loss = 3.84974 - self._expected_d_loss = 9.43950 - - def test_generator_all_correct(self): - loss = self._g_loss_fn(**self._generator_kwargs) - self.assertEqual( - self._discriminator_gen_classification_logits.dtype, loss.dtype) - self.assertEqual(self._generator_loss_name, loss.op.name) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) - - def test_discriminator_all_correct(self): - loss = self._d_loss_fn(**self._discriminator_kwargs) - self.assertEqual( - self._discriminator_gen_classification_logits.dtype, loss.dtype) - self.assertEqual(self._discriminator_loss_name, loss.op.name) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) - - def test_generator_loss_collection(self): - self.assertEqual(0, len(ops.get_collection('collection'))) - self._g_loss_fn(loss_collection='collection', **self._generator_kwargs) - self.assertEqual(1, len(ops.get_collection('collection'))) - - def test_discriminator_loss_collection(self): - self.assertEqual(0, len(ops.get_collection('collection'))) - self._d_loss_fn(loss_collection='collection', **self._discriminator_kwargs) - self.assertEqual(1, len(ops.get_collection('collection'))) - - def test_generator_no_reduction(self): - loss = self._g_loss_fn( - reduction=tf_losses.Reduction.NONE, **self._generator_kwargs) - self.assertAllEqual([4], loss.shape) - - def test_discriminator_no_reduction(self): - loss = self._d_loss_fn( - reduction=tf_losses.Reduction.NONE, **self._discriminator_kwargs) - self.assertAllEqual([4], loss.shape) - - def test_generator_patch(self): - patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in - self._generator_kwargs.items()} - loss = self._g_loss_fn(**patch_args) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) - - def test_discriminator_patch(self): - patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in - self._discriminator_kwargs.items()} - loss = self._d_loss_fn(**patch_args) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) - - def test_generator_loss_with_placeholder_for_logits(self): - gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) - one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4)) - - loss = self._g_loss_fn(gen_logits, one_hot_labels) - with self.cached_session() as sess: - loss = sess.run( - loss, feed_dict={ - gen_logits: self._discriminator_gen_classification_logits_np, - one_hot_labels: self._one_hot_labels_np, - }) - self.assertAlmostEqual(self._expected_g_loss, loss, 5) - - def test_discriminator_loss_with_placeholder_for_logits_and_weights(self): - gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) - real_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) - one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4)) - - loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels) - - with self.cached_session() as sess: - loss = sess.run( - loss, feed_dict={ - gen_logits: self._discriminator_gen_classification_logits_np, - real_logits: self._discriminator_real_classification_logits_np, - one_hot_labels: self._one_hot_labels_np, - }) - self.assertAlmostEqual(self._expected_d_loss, loss, 5) - - def test_generator_with_python_scalar_weight(self): - loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss * self._weights, - loss.eval(), 4) - - def test_discriminator_with_python_scalar_weight(self): - loss = self._d_loss_fn( - real_weights=self._weights, generated_weights=self._weights, - **self._discriminator_kwargs) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss * self._weights, - loss.eval(), 4) - - def test_generator_with_scalar_tensor_weight(self): - loss = self._g_loss_fn( - weights=constant_op.constant(self._weights), **self._generator_kwargs) - with self.cached_session(): - self.assertAlmostEqual(self._expected_g_loss * self._weights, - loss.eval(), 4) - - def test_discriminator_with_scalar_tensor_weight(self): - weights = constant_op.constant(self._weights) - loss = self._d_loss_fn(real_weights=weights, generated_weights=weights, - **self._discriminator_kwargs) - with self.cached_session(): - self.assertAlmostEqual(self._expected_d_loss * self._weights, - loss.eval(), 4) - - def test_generator_add_summaries(self): - self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - self._g_loss_fn(add_summaries=True, **self._generator_kwargs) - self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - - def test_discriminator_add_summaries(self): - self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - self._d_loss_fn(add_summaries=True, **self._discriminator_kwargs) - self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) - - -class _PenaltyTest(object): - - def test_all_correct(self): - loss = self._penalty_fn(**self._kwargs) - self.assertEqual(self._expected_dtype, loss.dtype) - # NOTE: Op names will change, it is inappropriate to include them in tests. - # See go/tf-breaking-change. - # self.assertEqual(self._expected_op_name, loss.op.name) - with self.cached_session(): - variables.global_variables_initializer().run() - self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) - - def test_loss_collection(self): - self.assertEqual(0, len(ops.get_collection('collection'))) - self._penalty_fn(loss_collection='collection', **self._kwargs) - self.assertEqual(1, len(ops.get_collection('collection'))) - - def test_no_reduction(self): - loss = self._penalty_fn(reduction=tf_losses.Reduction.NONE, **self._kwargs) - self.assertAllEqual([self._batch_size], loss.shape) - - def test_python_scalar_weight(self): - loss = self._penalty_fn(weights=2.3, **self._kwargs) - with self.cached_session(): - variables.global_variables_initializer().run() - self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) - - def test_scalar_tensor_weight(self): - loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs) - with self.cached_session(): - variables.global_variables_initializer().run() - self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) - - -class GradientPenaltyTest(test.TestCase, _PenaltyTest): - """Tests for wasserstein_gradient_penalty.""" - - def setUp(self): - super(GradientPenaltyTest, self).setUp() - self._penalty_fn = tfgan_losses.wasserstein_gradient_penalty - self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]] - self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]] - self._expected_dtype = dtypes.float32 - - with variable_scope.variable_scope('fake_scope') as self._scope: - self._discriminator_fn(0.0, 0.0) - - self._kwargs = { - 'generated_data': constant_op.constant( - self._generated_data_np, dtype=self._expected_dtype), - 'real_data': constant_op.constant( - self._real_data_np, dtype=self._expected_dtype), - 'generator_inputs': None, - 'discriminator_fn': self._discriminator_fn, - 'discriminator_scope': self._scope, - } - self._expected_loss = 9.00000 - self._expected_op_name = 'wasserstein_gradient_penalty/value' - self._batch_size = 1 - - def _discriminator_fn(self, inputs, _): - ops.add_to_collection('fake_update_ops', constant_op.constant(1.0)) - return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs - - def test_loss_with_placeholder(self): - generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) - real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) - - loss = tfgan_losses.wasserstein_gradient_penalty( - generated_data, - real_data, - self._kwargs['generator_inputs'], - self._kwargs['discriminator_fn'], - self._kwargs['discriminator_scope']) - self.assertEqual(generated_data.dtype, loss.dtype) - - with self.cached_session() as sess: - variables.global_variables_initializer().run() - loss = sess.run(loss, - feed_dict={ - generated_data: self._generated_data_np, - real_data: self._real_data_np, - }) - self.assertAlmostEqual(self._expected_loss, loss, 5) - - def test_loss_using_one_sided_mode(self): - generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) - real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) - - loss = tfgan_losses.wasserstein_gradient_penalty( - generated_data, - real_data, - self._kwargs['generator_inputs'], - self._kwargs['discriminator_fn'], - self._kwargs['discriminator_scope'], - one_sided=True) - self.assertEqual(generated_data.dtype, loss.dtype) - - with self.cached_session() as sess: - variables.global_variables_initializer().run() - loss = sess.run(loss, - feed_dict={ - generated_data: self._generated_data_np, - real_data: self._real_data_np, - }) - self.assertAlmostEqual(self._expected_loss, loss, 5) - - def test_loss_with_gradient_norm_target(self): - """Test loss value with non default gradient norm target.""" - generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) - real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) - - loss = tfgan_losses.wasserstein_gradient_penalty( - generated_data, - real_data, - self._kwargs['generator_inputs'], - self._kwargs['discriminator_fn'], - self._kwargs['discriminator_scope'], - target=2.0) - - with self.cached_session() as sess: - variables.global_variables_initializer().run() - loss = sess.run( - loss, - feed_dict={ - generated_data: self._generated_data_np, - real_data: self._real_data_np, - }) - self.assertAlmostEqual(1.0, loss, 5) - - def test_reuses_scope(self): - """Test that gradient penalty reuses discriminator scope.""" - num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) - tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) - self.assertEqual( - num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) - - def test_works_with_get_collection(self): - """Tests that gradient penalty works inside other scopes.""" - # We ran the discriminator once in the setup, so there should be an op - # already in the collection. - self.assertEqual(1, len(ops.get_collection( - 'fake_update_ops', self._kwargs['discriminator_scope'].name))) - - # Make sure the op is added to the collection even if it's in a name scope. - with ops.name_scope('loss'): - tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) - self.assertEqual(2, len(ops.get_collection( - 'fake_update_ops', self._kwargs['discriminator_scope'].name))) - - # Make sure the op is added to the collection even if it's in a variable - # scope. - with variable_scope.variable_scope('loss_vscope'): - tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) - self.assertEqual(3, len(ops.get_collection( - 'fake_update_ops', self._kwargs['discriminator_scope'].name))) - - -class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): - """Tests for mutual_information_penalty.""" - - def setUp(self): - super(MutualInformationPenaltyTest, self).setUp() - self._penalty_fn = tfgan_losses.mutual_information_penalty - self._structured_generator_inputs = [1.0, 2.0] - self._predicted_distributions = [categorical.Categorical(logits=[1.0, 2.0]), - normal.Normal([0.0], [1.0])] - self._expected_dtype = dtypes.float32 - - self._kwargs = { - 'structured_generator_inputs': self._structured_generator_inputs, - 'predicted_distributions': self._predicted_distributions, - } - self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul_1' - self._batch_size = 2 - - -class CombineAdversarialLossTest(test.TestCase): - """Tests for combine_adversarial_loss.""" - - def setUp(self): - super(CombineAdversarialLossTest, self).setUp() - self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]] - self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]] - self._generated_data = constant_op.constant( - self._generated_data_np, dtype=dtypes.float32) - self._real_data = constant_op.constant( - self._real_data_np, dtype=dtypes.float32) - self._generated_inputs = None - self._expected_loss = 9.00000 - - def _test_correct_helper(self, use_weight_factor): - variable_list = [variables.Variable(1.0)] - main_loss = variable_list[0] * 2 - adversarial_loss = variable_list[0] * 3 - gradient_ratio_epsilon = 1e-6 - if use_weight_factor: - weight_factor = constant_op.constant(2.0) - gradient_ratio = None - adv_coeff = 2.0 - expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3 - else: - weight_factor = None - gradient_ratio = constant_op.constant(0.5) - adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon) - expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3 - combined_loss = tfgan_losses.combine_adversarial_loss( - main_loss, - adversarial_loss, - weight_factor=weight_factor, - gradient_ratio=gradient_ratio, - gradient_ratio_epsilon=gradient_ratio_epsilon, - variables=variable_list) - - with self.test_session(use_gpu=True): - variables.global_variables_initializer().run() - self.assertNear(expected_loss, combined_loss.eval(), 1e-5) - - def test_correct_useweightfactor(self): - self._test_correct_helper(True) - - def test_correct_nouseweightfactor(self): - self._test_correct_helper(False) - - def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor): - """Test the 0 adversarial weight or grad ratio skips adversarial loss.""" - main_loss = constant_op.constant(1.0) - adversarial_loss = constant_op.constant(1.0) - - weight_factor = 0.0 if use_weight_factor else None - gradient_ratio = None if use_weight_factor else 0.0 - - combined_loss = tfgan_losses.combine_adversarial_loss( - main_loss, - adversarial_loss, - weight_factor=weight_factor, - gradient_ratio=gradient_ratio, - gradient_summaries=False) - - with self.test_session(use_gpu=True): - self.assertEqual(1.0, combined_loss.eval()) - - def test_no_weight_skips_adversarial_loss_useweightfactor(self): - self._test_no_weight_skips_adversarial_loss_helper(True) - - def test_no_weight_skips_adversarial_loss_nouseweightfactor(self): - self._test_no_weight_skips_adversarial_loss_helper(False) - - def test_stable_global_norm_avoids_overflow(self): - tensors = [array_ops.ones([4]), array_ops.ones([4, 4]) * 1e19, None] - gnorm_is_inf = math_ops.is_inf(clip_ops.global_norm(tensors)) - stable_gnorm_is_inf = math_ops.is_inf( - tfgan_losses._numerically_stable_global_norm(tensors)) - - with self.test_session(use_gpu=True): - self.assertTrue(gnorm_is_inf.eval()) - self.assertFalse(stable_gnorm_is_inf.eval()) - - def test_stable_global_norm_unchanged(self): - """Test that preconditioning doesn't change global norm value.""" - random_seed.set_random_seed(1234) - tensors = [random_ops.random_uniform([3]*i, -10.0, 10.0) for i in range(6)] - gnorm = clip_ops.global_norm(tensors) - precond_gnorm = tfgan_losses._numerically_stable_global_norm(tensors) - - with self.test_session(use_gpu=True) as sess: - for _ in range(10): # spot check closeness on more than one sample. - gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm]) - self.assertNear(gnorm_np, precond_gnorm_np, 1e-4) - - -class CycleConsistencyLossTest(test.TestCase): - """Tests for cycle_consistency_loss.""" - - def setUp(self): - super(CycleConsistencyLossTest, self).setUp() - - self._data_x_np = [[1.0, 2, 3], [4, 5, 6]] - self._reconstructed_data_x_np = [[7.0, 8, 9], [10, 11, 12]] - self._data_y_np = [1.0, 9] - self._reconstructed_data_y_np = [-2.0, 3] - - self._data_x = constant_op.constant(self._data_x_np, dtype=dtypes.float32) - self._reconstructed_data_x = constant_op.constant( - self._reconstructed_data_x_np, dtype=dtypes.float32) - self._data_y = constant_op.constant(self._data_y_np, dtype=dtypes.float32) - self._reconstructed_data_y = constant_op.constant( - self._reconstructed_data_y_np, dtype=dtypes.float32) - - def test_correct_loss(self): - loss = tfgan_losses.cycle_consistency_loss( - self._data_x, self._reconstructed_data_x, self._data_y, - self._reconstructed_data_y) - with self.test_session(use_gpu=True): - variables.global_variables_initializer().run() - self.assertNear(5.25, loss.eval(), 1e-5) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_wargs.py b/tensorflow/contrib/gan/python/losses/python/losses_wargs.py deleted file mode 100644 index f212bdcf30b..00000000000 --- a/tensorflow/contrib/gan/python/losses/python/losses_wargs.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.losses.python import losses_impl -from tensorflow.contrib.gan.python.losses.python.losses_impl import * -# pylint: enable=wildcard-import - -from tensorflow.python.util.all_util import remove_undocumented - -remove_undocumented(__name__, losses_impl.__all__) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses.py deleted file mode 100644 index 1a50b3f5880..00000000000 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TFGAN utilities for loss functions that accept GANModel namedtuples.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl -from tensorflow.contrib.gan.python.losses.python.tuple_losses_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = tuple_losses_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py deleted file mode 100644 index 76e57df7f64..00000000000 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TF-GAN utilities for loss functions that accept GANModel namedtuples. - -The losses and penalties in this file all correspond to losses in -`losses_impl.py`. Losses in that file take individual arguments, whereas in this -file they take a `GANModel` tuple. For example: - -losses_impl.py: - ```python - def wasserstein_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False) - ``` - -tuple_losses_impl.py: - ```python - def wasserstein_discriminator_loss( - gan_model, - real_weights=1.0, - generated_weights=1.0, - scope=None, - loss_collection=ops.GraphKeys.LOSSES, - reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, - add_summaries=False) - ``` - - - -Example usage: - ```python - # `tfgan.losses.wargs` losses take individual arguments. - w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss( - discriminator_real_outputs, - discriminator_gen_outputs) - - # `tfgan.losses` losses take GANModel namedtuples. - w_loss2 = tfgan.losses.wasserstein_discriminator_loss(gan_model) - ``` -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python.losses.python import losses_impl -from tensorflow.python.util import tf_inspect - - -__all__ = [ - 'acgan_discriminator_loss', - 'acgan_generator_loss', - 'least_squares_discriminator_loss', - 'least_squares_generator_loss', - 'modified_discriminator_loss', - 'modified_generator_loss', - 'minimax_discriminator_loss', - 'minimax_generator_loss', - 'wasserstein_discriminator_loss', - 'wasserstein_generator_loss', - 'wasserstein_gradient_penalty', - 'mutual_information_penalty', - 'combine_adversarial_loss', - 'cycle_consistency_loss', - 'stargan_generator_loss_wrapper', - 'stargan_discriminator_loss_wrapper', - 'stargan_gradient_penalty_wrapper' -] - - -def _args_to_gan_model(loss_fn): - """Converts a loss taking individual args to one taking a GANModel namedtuple. - - The new function has the same name as the original one. - - Args: - loss_fn: A python function taking a `GANModel` object and returning a loss - Tensor calculated from that object. The shape of the loss depends on - `reduction`. - - Returns: - A new function that takes a GANModel namedtuples and returns the same loss. - """ - # Match arguments in `loss_fn` to elements of `namedtuple`. - # TODO(joelshor): Properly handle `varargs` and `keywords`. - argspec = tf_inspect.getargspec(loss_fn) - defaults = argspec.defaults or [] - - required_args = set(argspec.args[:-len(defaults)]) - args_with_defaults = argspec.args[-len(defaults):] - default_args_dict = dict(zip(args_with_defaults, defaults)) - - def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring - def _asdict(namedtuple): - """Returns a namedtuple as a dictionary. - - This is required because `_asdict()` in Python 3.x.x is broken in classes - that inherit from `collections.namedtuple`. See - https://bugs.python.org/issue24931 for more details. - - Args: - namedtuple: An object that inherits from `collections.namedtuple`. - - Returns: - A dictionary version of the tuple. - """ - return {k: getattr(namedtuple, k) for k in namedtuple._fields} - gan_model_dict = _asdict(gan_model) - - # Make sure non-tuple required args are supplied. - args_from_tuple = set(argspec.args).intersection(set(gan_model._fields)) - required_args_not_from_tuple = required_args - args_from_tuple - for arg in required_args_not_from_tuple: - if arg not in kwargs: - raise ValueError('`%s` must be supplied to %s loss function.' % ( - arg, loss_fn.__name__)) - - # Make sure tuple args aren't also supplied as keyword args. - ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys())) - if ambiguous_args: - raise ValueError( - 'The following args are present in both the tuple and keyword args ' - 'for %s: %s' % (loss_fn.__name__, ambiguous_args)) - - # Add required args to arg dictionary. - required_args_from_tuple = required_args.intersection(args_from_tuple) - for arg in required_args_from_tuple: - assert arg not in kwargs - kwargs[arg] = gan_model_dict[arg] - - # Add arguments that have defaults. - for arg in default_args_dict: - val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None - val_from_kwargs = kwargs[arg] if arg in kwargs else None - assert not (val_from_tuple is not None and val_from_kwargs is not None) - kwargs[arg] = (val_from_tuple if val_from_tuple is not None else - val_from_kwargs if val_from_kwargs is not None else - default_args_dict[arg]) - - return loss_fn(**kwargs) - - new_docstring = """The gan_model version of %s.""" % loss_fn.__name__ - new_loss_fn.__docstring__ = new_docstring - new_loss_fn.__name__ = loss_fn.__name__ - new_loss_fn.__module__ = loss_fn.__module__ - return new_loss_fn - - -# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). -wasserstein_generator_loss = _args_to_gan_model( - losses_impl.wasserstein_generator_loss) -wasserstein_discriminator_loss = _args_to_gan_model( - losses_impl.wasserstein_discriminator_loss) -wasserstein_gradient_penalty = _args_to_gan_model( - losses_impl.wasserstein_gradient_penalty) - -# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` -# (https://arxiv.org/abs/1610.09585). -acgan_discriminator_loss = _args_to_gan_model( - losses_impl.acgan_discriminator_loss) -acgan_generator_loss = _args_to_gan_model( - losses_impl.acgan_generator_loss) - - -# Original losses from `Generative Adversarial Nets` -# (https://arxiv.org/abs/1406.2661). -minimax_discriminator_loss = _args_to_gan_model( - losses_impl.minimax_discriminator_loss) -minimax_generator_loss = _args_to_gan_model( - losses_impl.minimax_generator_loss) -modified_discriminator_loss = _args_to_gan_model( - losses_impl.modified_discriminator_loss) -modified_generator_loss = _args_to_gan_model( - losses_impl.modified_generator_loss) - - -# Least Squares loss from `Least Squares Generative Adversarial Networks` -# (https://arxiv.org/abs/1611.04076). -least_squares_generator_loss = _args_to_gan_model( - losses_impl.least_squares_generator_loss) -least_squares_discriminator_loss = _args_to_gan_model( - losses_impl.least_squares_discriminator_loss) - - -# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by -# `Information Maximizing Generative Adversarial Nets` -# https://arxiv.org/abs/1606.03657 -mutual_information_penalty = _args_to_gan_model( - losses_impl.mutual_information_penalty) - - -def combine_adversarial_loss(gan_loss, - gan_model, - non_adversarial_loss, - weight_factor=None, - gradient_ratio=None, - gradient_ratio_epsilon=1e-6, - scalar_summaries=True, - gradient_summaries=True): - """Combine adversarial loss and main loss. - - Uses `combine_adversarial_loss` to combine the losses, and returns - a modified GANLoss namedtuple. - - Args: - gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the - adversarial loss. - gan_model: A GANModel namedtuple. Used to access the generator's variables. - non_adversarial_loss: Same as `main_loss` from - `combine_adversarial_loss`. - weight_factor: Same as `weight_factor` from - `combine_adversarial_loss`. - gradient_ratio: Same as `gradient_ratio` from - `combine_adversarial_loss`. - gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from - `combine_adversarial_loss`. - scalar_summaries: Same as `scalar_summaries` from - `combine_adversarial_loss`. - gradient_summaries: Same as `gradient_summaries` from - `combine_adversarial_loss`. - - Returns: - A modified GANLoss namedtuple, with `non_adversarial_loss` included - appropriately. - """ - combined_loss = losses_impl.combine_adversarial_loss( - non_adversarial_loss, - gan_loss.generator_loss, - weight_factor, - gradient_ratio, - gradient_ratio_epsilon, - gan_model.generator_variables, - scalar_summaries, - gradient_summaries) - return gan_loss._replace(generator_loss=combined_loss) - - -def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): - """Defines the cycle consistency loss. - - Uses `cycle_consistency_loss` to compute the cycle consistency loss for a - `cyclegan_model`. - - Args: - cyclegan_model: A `CycleGANModel` namedtuple. - scope: The scope for the operations performed in computing the loss. - Defaults to None. - add_summaries: Whether or not to add detailed summaries for the loss. - Defaults to False. - - Returns: - A scalar `Tensor` of cycle consistency loss. - - Raises: - ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple. - """ - if not isinstance(cyclegan_model, namedtuples.CycleGANModel): - raise ValueError( - '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' % - type(cyclegan_model)) - return losses_impl.cycle_consistency_loss( - cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, - cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, - scope, add_summaries) - - -def stargan_generator_loss_wrapper(loss_fn): - """Convert a generator loss function to take a StarGANModel. - - The new function has the same name as the original one. - - Args: - loss_fn: A python function taking Discriminator's real/fake prediction for - generated data. - - Returns: - A new function that takes a StarGANModel namedtuple and returns the same - loss. - """ - - def new_loss_fn(stargan_model, **kwargs): - return loss_fn( - stargan_model.discriminator_generated_data_source_predication, **kwargs) - - new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ - new_loss_fn.__docstring__ = new_docstring - new_loss_fn.__name__ = loss_fn.__name__ - new_loss_fn.__module__ = loss_fn.__module__ - return new_loss_fn - - -def stargan_discriminator_loss_wrapper(loss_fn): - """Convert a discriminator loss function to take a StarGANModel. - - The new function has the same name as the original one. - - Args: - loss_fn: A python function taking Discriminator's real/fake prediction for - real data and generated data. - - Returns: - A new function that takes a StarGANModel namedtuple and returns the same - loss. - """ - - def new_loss_fn(stargan_model, **kwargs): - return loss_fn( - stargan_model.discriminator_input_data_source_predication, - stargan_model.discriminator_generated_data_source_predication, **kwargs) - - new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ - new_loss_fn.__docstring__ = new_docstring - new_loss_fn.__name__ = loss_fn.__name__ - new_loss_fn.__module__ = loss_fn.__module__ - return new_loss_fn - - -def stargan_gradient_penalty_wrapper(loss_fn): - """Convert a gradient penalty function to take a StarGANModel. - - The new function has the same name as the original one. - - Args: - loss_fn: A python function taking real_data, generated_data, - generator_inputs for Discriminator's condition (i.e. number of domains), - discriminator_fn, and discriminator_scope. - - Returns: - A new function that takes a StarGANModel namedtuple and returns the same - loss. - """ - - def new_loss_fn(stargan_model, **kwargs): - num_domains = stargan_model.input_data_domain_label.shape.as_list()[-1] - return loss_fn( - real_data=stargan_model.input_data, - generated_data=stargan_model.generated_data, - generator_inputs=num_domains, - discriminator_fn=stargan_model.discriminator_fn, - discriminator_scope=stargan_model.discriminator_scope, - **kwargs) - - new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__ - new_loss_fn.__docstring__ = new_docstring - new_loss_fn.__name__ = loss_fn.__name__ - new_loss_fn.__module__ = loss_fn.__module__ - return new_loss_fn diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py deleted file mode 100644 index 25d74a8c23d..00000000000 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for contrib.gan.python.losses.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl -from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class ArgsToGanModelTest(test.TestCase): - - def test_args_to_gan_model(self): - """Test `_args_to_gan_model`.""" - tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3']) - - def args_loss(arg1, arg2, arg3=3, arg4=4): - return arg1 + arg2 + arg3 + arg4 - - gan_model_loss = tfgan_losses._args_to_gan_model(args_loss) - - # Value is correct. - self.assertEqual(1 + 2 + 5 + 6, - gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6)) - - # Uses tuple argument with defaults. - self.assertEqual(1 + 5 + 3 + 7, - gan_model_loss(tuple_type(1, None), arg2=5, arg4=7)) - - # Uses non-tuple argument with defaults. - self.assertEqual(1 + 5 + 2 + 4, - gan_model_loss(tuple_type(1, 2), arg2=5)) - - # Requires non-tuple, non-default arguments. - with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'): - gan_model_loss(tuple_type(1, 2)) - - # Can't pass tuple argument outside tuple. - with self.assertRaisesRegexp( - ValueError, 'present in both the tuple and keyword args'): - gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5) - - def test_args_to_gan_model_name(self): - """Test that `_args_to_gan_model` produces correctly named functions.""" - def loss_fn(x): - return x - new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn) - self.assertEqual('loss_fn', new_loss_fn.__name__) - self.assertTrue('The gan_model version of' in new_loss_fn.__docstring__) - - def test_tuple_respects_optional_args(self): - """Test that optional args can be changed with tuple losses.""" - tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) - def args_loss(arg1, arg2, arg3=3): - return arg1 + 2 * arg2 + 3 * arg3 - - loss_fn = tfgan_losses._args_to_gan_model(args_loss) - loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4) - - # If `arg3` were not set properly, this value would be different. - self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) - - def test_works_with_child_classes(self): - """`args_to_gan_model` should work with classes derived from namedtuple.""" - tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) - - class InheritedType(tuple_type): - pass - def args_loss(arg1, arg2, arg3=3): - return arg1 + 2 * arg2 + 3 * arg3 - - loss_fn = tfgan_losses._args_to_gan_model(args_loss) - loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4) - - # If `arg3` were not set properly, this value would be different. - self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) - - -class ConsistentLossesTest(test.TestCase): - - pass - - -def _tuple_from_dict(args_dict): - return collections.namedtuple('Tuple', args_dict.keys())(**args_dict) - - -def add_loss_consistency_test(test_class, loss_name_str, loss_args): - tuple_loss = getattr(tfgan_losses, loss_name_str) - arg_loss = getattr(tfgan_losses.losses_impl, loss_name_str) - - def consistency_test(self): - self.assertEqual(arg_loss.__name__, tuple_loss.__name__) - with self.cached_session(): - self.assertEqual(arg_loss(**loss_args).eval(), - tuple_loss(_tuple_from_dict(loss_args)).eval()) - - test_name = 'test_loss_consistency_%s' % loss_name_str - setattr(test_class, test_name, consistency_test) - - -# A list of consistency tests which need to be manually written. -manual_tests = [ - 'acgan_discriminator_loss', - 'acgan_generator_loss', - 'combine_adversarial_loss', - 'mutual_information_penalty', - 'wasserstein_gradient_penalty', - 'cycle_consistency_loss', - 'stargan_generator_loss_wrapper', - 'stargan_discriminator_loss_wrapper', - 'stargan_gradient_penalty_wrapper' -] - -discriminator_keyword_args = { - 'discriminator_real_outputs': np.array([[3.4, 2.3, -2.3], - [6.3, -2.1, 0.2]]), - 'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3], - [-2.9, -5.1, 0.1]]), -} -generator_keyword_args = { - 'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3], - [-2.9, -5.1, 0.1]]), -} - - -class CycleConsistencyLossTest(test.TestCase): - - def setUp(self): - super(CycleConsistencyLossTest, self).setUp() - - def _partial_model(generator_inputs_np): - model = namedtuples.GANModel(*[None] * 11) - return model._replace( - generator_inputs=constant_op.constant( - generator_inputs_np, dtype=dtypes.float32)) - - self._model_x2y = _partial_model([1, 2]) - self._model_y2x = _partial_model([5, 6]) - - def test_model_type(self): - """Test the input model type for `cycle_consistency_loss`.""" - with self.assertRaises(ValueError): - tfgan_losses.cycle_consistency_loss(self._model_x2y) - - def test_correct_loss(self): - """Test the output of `cycle_consistency_loss`.""" - loss = tfgan_losses.cycle_consistency_loss( - namedtuples.CycleGANModel( - model_x2y=self._model_x2y, - model_y2x=self._model_y2x, - reconstructed_x=constant_op.constant([9, 8], dtype=dtypes.float32), - reconstructed_y=constant_op.constant([7, 2], dtype=dtypes.float32))) - with self.test_session(use_gpu=True): - variables.global_variables_initializer().run() - self.assertNear(5.0, loss.eval(), 1e-5) - - -class StarGANLossWrapperTest(test.TestCase): - - def setUp(self): - - super(StarGANLossWrapperTest, self).setUp() - - self.input_data = array_ops.ones([1, 2, 2, 3]) - self.input_data_domain_label = constant_op.constant([[0, 1]]) - self.generated_data = array_ops.ones([1, 2, 2, 3]) - self.discriminator_input_data_source_predication = array_ops.ones([1]) - self.discriminator_generated_data_source_predication = array_ops.ones([1]) - - def _discriminator_fn(inputs, num_domains): - """Differentiable dummy discriminator for StarGAN.""" - hidden = layers.flatten(inputs) - output_src = math_ops.reduce_mean(hidden, axis=1) - output_cls = layers.fully_connected( - inputs=hidden, - num_outputs=num_domains, - activation_fn=None, - normalizer_fn=None, - biases_initializer=None) - return output_src, output_cls - - with variable_scope.variable_scope('discriminator') as dis_scope: - pass - - self.model = namedtuples.StarGANModel( - input_data=self.input_data, - input_data_domain_label=self.input_data_domain_label, - generated_data=self.generated_data, - generated_data_domain_target=None, - reconstructed_data=None, - discriminator_input_data_source_predication=self. - discriminator_input_data_source_predication, - discriminator_generated_data_source_predication=self. - discriminator_generated_data_source_predication, - discriminator_input_data_domain_predication=None, - discriminator_generated_data_domain_predication=None, - generator_variables=None, - generator_scope=None, - generator_fn=None, - discriminator_variables=None, - discriminator_scope=dis_scope, - discriminator_fn=_discriminator_fn) - - self.discriminator_fn = _discriminator_fn - self.discriminator_scope = dis_scope - - def test_stargan_generator_loss_wrapper(self): - """Test StarGAN generator loss wrapper.""" - loss_fn = tfgan_losses_impl.wasserstein_generator_loss - wrapped_loss_fn = tfgan_losses.stargan_generator_loss_wrapper(loss_fn) - - loss_result_tensor = loss_fn( - self.discriminator_generated_data_source_predication) - wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - - with self.cached_session() as sess: - sess.run(variables.global_variables_initializer()) - loss_result, wrapped_loss_result = sess.run( - [loss_result_tensor, wrapped_loss_result_tensor]) - self.assertAlmostEqual(loss_result, wrapped_loss_result) - - def test_stargan_discriminator_loss_wrapper(self): - """Test StarGAN discriminator loss wrapper.""" - loss_fn = tfgan_losses_impl.wasserstein_discriminator_loss - wrapped_loss_fn = tfgan_losses.stargan_discriminator_loss_wrapper(loss_fn) - - loss_result_tensor = loss_fn( - self.discriminator_generated_data_source_predication, - self.discriminator_generated_data_source_predication) - wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - - with self.cached_session() as sess: - sess.run(variables.global_variables_initializer()) - loss_result, wrapped_loss_result = sess.run( - [loss_result_tensor, wrapped_loss_result_tensor]) - self.assertAlmostEqual(loss_result, wrapped_loss_result) - - def test_stargan_gradient_penalty_wrapper(self): - """Test StaGAN gradient penalty wrapper. - - Notes: - The random interpolates are handled by given setting the reconstruction to - be the same as the input. - - """ - loss_fn = tfgan_losses_impl.wasserstein_gradient_penalty - wrapped_loss_fn = tfgan_losses.stargan_gradient_penalty_wrapper(loss_fn) - - loss_result_tensor = loss_fn( - real_data=self.input_data, - generated_data=self.generated_data, - generator_inputs=self.input_data_domain_label.shape.as_list()[-1], - discriminator_fn=self.discriminator_fn, - discriminator_scope=self.discriminator_scope) - wrapped_loss_result_tensor = wrapped_loss_fn(self.model) - - with self.cached_session() as sess: - sess.run(variables.global_variables_initializer()) - loss_result, wrapped_loss_result = sess.run( - [loss_result_tensor, wrapped_loss_result_tensor]) - self.assertAlmostEqual(loss_result, wrapped_loss_result) - - -if __name__ == '__main__': - for loss_name in tfgan_losses.__all__: - if loss_name in manual_tests: continue - keyword_args = (generator_keyword_args if 'generator' in loss_name else - discriminator_keyword_args) - add_loss_consistency_test(ConsistentLossesTest, loss_name, keyword_args) - - test.main() diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py deleted file mode 100644 index 73dfee4fdee..00000000000 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Named tuples for TF-GAN. - -TF-GAN training occurs in four steps, and each step communicates with the next -step via one of these named tuples. At each step, you can either use a TF-GAN -helper function in `train.py`, or you can manually construct a tuple. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -__all__ = [ - 'GANModel', - 'InfoGANModel', - 'ACGANModel', - 'CycleGANModel', - 'StarGANModel', - 'GANLoss', - 'CycleGANLoss', - 'GANTrainOps', - 'GANTrainSteps', -] - - -class GANModel( - collections.namedtuple('GANModel', ( - 'generator_inputs', - 'generated_data', - 'generator_variables', - 'generator_scope', - 'generator_fn', - 'real_data', - 'discriminator_real_outputs', - 'discriminator_gen_outputs', - 'discriminator_variables', - 'discriminator_scope', - 'discriminator_fn', - ))): - """A GANModel contains all the pieces needed for GAN training. - - Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt - to create an implicit generative model of data by solving a two agent game. - The generator generates candidate examples that are supposed to match the - data distribution, and the discriminator aims to tell the real examples - apart from the generated samples. - - Args: - generator_inputs: The random noise source that acts as input to the - generator. - generated_data: The generated output data of the GAN. - generator_variables: A list of all generator variables. - generator_scope: Variable scope all generator variables live in. - generator_fn: The generator function. - real_data: A tensor or real data. - discriminator_real_outputs: The discriminator's output on real data. - discriminator_gen_outputs: The discriminator's output on generated data. - discriminator_variables: A list of all discriminator variables. - discriminator_scope: Variable scope all discriminator variables live in. - discriminator_fn: The discriminator function. - """ - - -# TODO(joelshor): Have this class inherit from `GANModel`. -class InfoGANModel( - collections.namedtuple('InfoGANModel', GANModel._fields + ( - 'structured_generator_inputs', - 'predicted_distributions', - 'discriminator_and_aux_fn', - ))): - """An InfoGANModel contains all the pieces needed for InfoGAN training. - - See https://arxiv.org/abs/1606.03657 for more details. - - Args: - structured_generator_inputs: A list of Tensors representing the random noise - that must have high mutual information with the generator output. List - length should match `predicted_distributions`. - predicted_distributions: A list of `tfp.distributions.Distribution`s. - Predicted by the recognizer, and used to evaluate the likelihood of the - structured noise. List length should match `structured_generator_inputs`. - discriminator_and_aux_fn: The original discriminator function that returns - a tuple of (logits, `predicted_distributions`). - """ - - -class ACGANModel( - collections.namedtuple('ACGANModel', GANModel._fields + - ('one_hot_labels', - 'discriminator_real_classification_logits', - 'discriminator_gen_classification_logits',))): - """An ACGANModel contains all the pieces needed for ACGAN training. - - See https://arxiv.org/abs/1610.09585 for more details. - - Args: - one_hot_labels: A Tensor holding one-hot-labels for the batch. - discriminator_real_classification_logits: Classification logits for real - data. - discriminator_gen_classification_logits: Classification logits for generated - data. - """ - - -class CycleGANModel( - collections.namedtuple( - 'CycleGANModel', - ('model_x2y', 'model_y2x', 'reconstructed_x', 'reconstructed_y'))): - """An CycleGANModel contains all the pieces needed for CycleGAN training. - - The model `model_x2y` generator F maps data set X to Y, while the model - `model_y2x` generator G maps data set Y to X. - - See https://arxiv.org/abs/1703.10593 for more details. - - Args: - model_x2y: A `GANModel` namedtuple whose generator maps data set X to Y. - model_y2x: A `GANModel` namedtuple whose generator maps data set Y to X. - reconstructed_x: A `Tensor` of reconstructed data X which is G(F(X)). - reconstructed_y: A `Tensor` of reconstructed data Y which is F(G(Y)). - """ - - -class StarGANModel( - collections.namedtuple('StarGANModel', ( - 'input_data', - 'input_data_domain_label', - 'generated_data', - 'generated_data_domain_target', - 'reconstructed_data', - 'discriminator_input_data_source_predication', - 'discriminator_generated_data_source_predication', - 'discriminator_input_data_domain_predication', - 'discriminator_generated_data_domain_predication', - 'generator_variables', - 'generator_scope', - 'generator_fn', - 'discriminator_variables', - 'discriminator_scope', - 'discriminator_fn', - ))): - """A StarGANModel contains all the pieces needed for StarGAN training. - - Args: - input_data: The real images that need to be transferred by the generator. - input_data_domain_label: The real domain labels associated with the real - images. - generated_data: The generated images produced by the generator. It has the - same shape as the input_data. - generated_data_domain_target: The target domain that the generated images - belong to. It has the same shape as the input_data_domain_label. - reconstructed_data: The reconstructed images produced by the G(enerator). - reconstructed_data = G(G(input_data, generated_data_domain_target), - input_data_domain_label). - discriminator_input_data_source: The discriminator's output for predicting - the source (real/generated) of input_data. - discriminator_generated_data_source: The discriminator's output for - predicting the source (real/generated) of generated_data. - discriminator_input_data_domain_predication: The discriminator's output for - predicting the domain_label for the input_data. - discriminator_generated_data_domain_predication: The discriminatorr's output - for predicting the domain_target for the generated_data. - generator_variables: A list of all generator variables. - generator_scope: Variable scope all generator variables live in. - generator_fn: The generator function. - discriminator_variables: A list of all discriminator variables. - discriminator_scope: Variable scope all discriminator variables live in. - discriminator_fn: The discriminator function. - """ - - -class GANLoss( - collections.namedtuple('GANLoss', ( - 'generator_loss', - 'discriminator_loss' - ))): - """GANLoss contains the generator and discriminator losses. - - Args: - generator_loss: A tensor for the generator loss. - discriminator_loss: A tensor for the discriminator loss. - """ - - -class CycleGANLoss( - collections.namedtuple('CycleGANLoss', ('loss_x2y', 'loss_y2x'))): - """CycleGANLoss contains the losses for `CycleGANModel`. - - See https://arxiv.org/abs/1703.10593 for more details. - - Args: - loss_x2y: A `GANLoss` namedtuple representing the loss of `model_x2y`. - loss_y2x: A `GANLoss` namedtuple representing the loss of `model_y2x`. - """ - - -class GANTrainOps( - collections.namedtuple('GANTrainOps', ( - 'generator_train_op', - 'discriminator_train_op', - 'global_step_inc_op', - 'train_hooks' - ))): - """GANTrainOps contains the training ops. - - Args: - generator_train_op: Op that performs a generator update step. - discriminator_train_op: Op that performs a discriminator update step. - global_step_inc_op: Op that increments the shared global step. - train_hooks: a list or tuple containing hooks related to training that need - to be populated when training ops are instantiated. Used primarily for - sync hooks. - """ - - def __new__(cls, generator_train_op, discriminator_train_op, - global_step_inc_op, train_hooks=()): - return super(GANTrainOps, cls).__new__(cls, generator_train_op, - discriminator_train_op, - global_step_inc_op, train_hooks) - - -class GANTrainSteps( - collections.namedtuple('GANTrainSteps', ( - 'generator_train_steps', - 'discriminator_train_steps' - ))): - """Contains configuration for the GAN Training. - - Args: - generator_train_steps: Number of generator steps to take in each GAN step. - discriminator_train_steps: Number of discriminator steps to take in each GAN - step. - """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py deleted file mode 100644 index 422e16f0bfe..00000000000 --- a/tensorflow/contrib/gan/python/train.py +++ /dev/null @@ -1,1318 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""The TF-GAN project provides a lightweight GAN training/testing framework. - -This file contains the core helper functions to create and train a GAN model. -See the README or examples in `tensorflow_models` for details on how to use. - -TF-GAN training occurs in four steps: -1) Create a model -2) Add a loss -3) Create train ops -4) Run the train ops - -The functions in this file are organized around these four steps. Each function -corresponds to one of the steps. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.framework.python.ops import variables as variables_lib -from tensorflow.contrib.gan.python import losses as tfgan_losses -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl -from tensorflow.contrib.slim.python.slim import learning as slim_learning -from tensorflow.contrib.training.python.training import training -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops.losses import losses -from tensorflow.python.summary import summary -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import sync_replicas_optimizer -from tensorflow.python.training import training_util - -__all__ = [ - 'gan_model', - 'infogan_model', - 'acgan_model', - 'cyclegan_model', - 'stargan_model', - 'gan_loss', - 'cyclegan_loss', - 'stargan_loss', - 'gan_train_ops', - 'gan_train', - 'get_sequential_train_hooks', - 'get_joint_train_hooks', - 'get_sequential_train_steps', - 'RunTrainOpsHook', -] - - -def gan_model( - # Lambdas defining models. - generator_fn, - discriminator_fn, - # Real data and conditioning. - real_data, - generator_inputs, - # Optional scopes. - generator_scope='Generator', - discriminator_scope='Discriminator', - # Options. - check_shapes=True): - """Returns GAN model outputs and variables. - - Args: - generator_fn: A python lambda that takes `generator_inputs` as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. - real_data: A Tensor representing the real data. - generator_inputs: A Tensor or list of Tensors to the generator. In the - vanilla GAN case, this might be a single noise Tensor. In the conditional - GAN case, this might be the generator's conditioning. - generator_scope: Optional generator variable scope. Useful if you want to - reuse a subgraph that has already been created. - discriminator_scope: Optional discriminator variable scope. Useful if you - want to reuse a subgraph that has already been created. - check_shapes: If `True`, check that generator produces Tensors that are the - same shape as real data. Otherwise, skip this check. - - Returns: - A GANModel namedtuple. - - Raises: - ValueError: If the generator outputs a Tensor that isn't the same shape as - `real_data`. - """ - # Create models - with variable_scope.variable_scope(generator_scope) as gen_scope: - generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) - generated_data = generator_fn(generator_inputs) - with variable_scope.variable_scope(discriminator_scope) as dis_scope: - discriminator_gen_outputs = discriminator_fn(generated_data, - generator_inputs) - with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = _convert_tensor_or_l_or_d(real_data) - discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) - - if check_shapes: - if not generated_data.shape.is_compatible_with(real_data.shape): - raise ValueError( - 'Generator output shape (%s) must be the same shape as real data ' - '(%s).' % (generated_data.shape, real_data.shape)) - - # Get model-specific variables. - generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables(dis_scope) - - return namedtuples.GANModel(generator_inputs, generated_data, - generator_variables, gen_scope, generator_fn, - real_data, discriminator_real_outputs, - discriminator_gen_outputs, - discriminator_variables, dis_scope, - discriminator_fn) - - -def infogan_model( - # Lambdas defining models. - generator_fn, - discriminator_fn, - # Real data and conditioning. - real_data, - unstructured_generator_inputs, - structured_generator_inputs, - # Optional scopes. - generator_scope='Generator', - discriminator_scope='Discriminator'): - """Returns an InfoGAN model outputs and variables. - - See https://arxiv.org/abs/1606.03657 for more details. - - Args: - generator_fn: A python lambda that takes a list of Tensors as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list). - `logits` are in the range [-inf, inf], and `distribution_list` is a list - of Tensorflow distributions representing the predicted noise distribution - of the ith structure noise. - real_data: A Tensor representing the real data. - unstructured_generator_inputs: A list of Tensors to the generator. These - tensors represent the unstructured noise or conditioning. - structured_generator_inputs: A list of Tensors to the generator. These - tensors must have high mutual information with the recognizer. - generator_scope: Optional generator variable scope. Useful if you want to - reuse a subgraph that has already been created. - discriminator_scope: Optional discriminator variable scope. Useful if you - want to reuse a subgraph that has already been created. - - Returns: - An InfoGANModel namedtuple. - - Raises: - ValueError: If the generator outputs a Tensor that isn't the same shape as - `real_data`. - ValueError: If the discriminator output is malformed. - """ - # Create models - with variable_scope.variable_scope(generator_scope) as gen_scope: - unstructured_generator_inputs = _convert_tensor_or_l_or_d( - unstructured_generator_inputs) - structured_generator_inputs = _convert_tensor_or_l_or_d( - structured_generator_inputs) - generator_inputs = ( - unstructured_generator_inputs + structured_generator_inputs) - generated_data = generator_fn(generator_inputs) - with variable_scope.variable_scope(discriminator_scope) as disc_scope: - dis_gen_outputs, predicted_distributions = discriminator_fn( - generated_data, generator_inputs) - _validate_distributions(predicted_distributions, structured_generator_inputs) - with variable_scope.variable_scope(disc_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) - dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs) - - if not generated_data.get_shape().is_compatible_with(real_data.get_shape()): - raise ValueError( - 'Generator output shape (%s) must be the same shape as real data ' - '(%s).' % (generated_data.get_shape(), real_data.get_shape())) - - # Get model-specific variables. - generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables(disc_scope) - - return namedtuples.InfoGANModel( - generator_inputs, - generated_data, - generator_variables, - gen_scope, - generator_fn, - real_data, - dis_real_outputs, - dis_gen_outputs, - discriminator_variables, - disc_scope, - lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API - structured_generator_inputs, - predicted_distributions, - discriminator_fn) - - -def acgan_model( - # Lambdas defining models. - generator_fn, - discriminator_fn, - # Real data and conditioning. - real_data, - generator_inputs, - one_hot_labels, - # Optional scopes. - generator_scope='Generator', - discriminator_scope='Discriminator', - # Options. - check_shapes=True): - """Returns an ACGANModel contains all the pieces needed for ACGAN training. - - The `acgan_model` is the same as the `gan_model` with the only difference - being that the discriminator additionally outputs logits to classify the input - (real or generated). - Therefore, an explicit field holding one_hot_labels is necessary, as well as a - discriminator_fn that outputs a 2-tuple holding the logits for real/fake and - classification. - - See https://arxiv.org/abs/1610.09585 for more details. - - Args: - generator_fn: A python lambda that takes `generator_inputs` as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a tuple consisting of two Tensors: (1) - real/fake logits in the range [-inf, inf] (2) classification logits in - the range [-inf, inf] - real_data: A Tensor representing the real data. - generator_inputs: A Tensor or list of Tensors to the generator. In the - vanilla GAN case, this might be a single noise Tensor. In the conditional - GAN case, this might be the generator's conditioning. - one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by - acgan_loss. - generator_scope: Optional generator variable scope. Useful if you want to - reuse a subgraph that has already been created. - discriminator_scope: Optional discriminator variable scope. Useful if you - want to reuse a subgraph that has already been created. - check_shapes: If `True`, check that generator produces Tensors that are the - same shape as real data. Otherwise, skip this check. - - Returns: - A ACGANModel namedtuple. - - Raises: - ValueError: If the generator outputs a Tensor that isn't the same shape as - `real_data`. - TypeError: If the discriminator does not output a tuple consisting of - (discrimination logits, classification logits). - """ - # Create models - with variable_scope.variable_scope(generator_scope) as gen_scope: - generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) - generated_data = generator_fn(generator_inputs) - with variable_scope.variable_scope(discriminator_scope) as dis_scope: - with ops.name_scope(dis_scope.name + '/generated/'): - (discriminator_gen_outputs, discriminator_gen_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(generated_data, generator_inputs)) - with variable_scope.variable_scope(dis_scope, reuse=True): - with ops.name_scope(dis_scope.name + '/real/'): - real_data = ops.convert_to_tensor(real_data) - (discriminator_real_outputs, discriminator_real_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(real_data, generator_inputs)) - if check_shapes: - if not generated_data.shape.is_compatible_with(real_data.shape): - raise ValueError( - 'Generator output shape (%s) must be the same shape as real data ' - '(%s).' % (generated_data.shape, real_data.shape)) - - # Get model-specific variables. - generator_variables = variables_lib.get_trainable_variables(gen_scope) - discriminator_variables = variables_lib.get_trainable_variables(dis_scope) - - return namedtuples.ACGANModel(generator_inputs, generated_data, - generator_variables, gen_scope, generator_fn, - real_data, discriminator_real_outputs, - discriminator_gen_outputs, - discriminator_variables, dis_scope, - discriminator_fn, one_hot_labels, - discriminator_real_classification_logits, - discriminator_gen_classification_logits) - - -def cyclegan_model( - # Lambdas defining models. - generator_fn, - discriminator_fn, - # data X and Y. - data_x, - data_y, - # Optional scopes. - generator_scope='Generator', - discriminator_scope='Discriminator', - model_x2y_scope='ModelX2Y', - model_y2x_scope='ModelY2X', - # Options. - check_shapes=True): - """Returns a CycleGAN model outputs and variables. - - See https://arxiv.org/abs/1703.10593 for more details. - - Args: - generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. - data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`. - data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`. - generator_scope: Optional generator variable scope. Useful if you want to - reuse a subgraph that has already been created. Defaults to 'Generator'. - discriminator_scope: Optional discriminator variable scope. Useful if you - want to reuse a subgraph that has already been created. Defaults to - 'Discriminator'. - model_x2y_scope: Optional variable scope for model x2y variables. Defaults - to 'ModelX2Y'. - model_y2x_scope: Optional variable scope for model y2x variables. Defaults - to 'ModelY2X'. - check_shapes: If `True`, check that generator produces Tensors that are the - same shape as `data_x` (`data_y`). Otherwise, skip this check. - - Returns: - A `CycleGANModel` namedtuple. - - Raises: - ValueError: If `check_shapes` is True and `data_x` or the generator output - does not have the same shape as `data_y`. - """ - - # Create models. - def _define_partial_model(input_data, output_data): - return gan_model( - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - real_data=output_data, - generator_inputs=input_data, - generator_scope=generator_scope, - discriminator_scope=discriminator_scope, - check_shapes=check_shapes) - - with variable_scope.variable_scope(model_x2y_scope): - model_x2y = _define_partial_model(data_x, data_y) - with variable_scope.variable_scope(model_y2x_scope): - model_y2x = _define_partial_model(data_y, data_x) - - with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True): - reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data) - with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True): - reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data) - - return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x, - reconstructed_y) - - -def stargan_model(generator_fn, - discriminator_fn, - input_data, - input_data_domain_label, - generator_scope='Generator', - discriminator_scope='Discriminator'): - """Returns a StarGAN model outputs and variables. - - See https://arxiv.org/abs/1711.09020 for more details. - - Args: - generator_fn: A python lambda that takes `inputs` and `targets` as inputs - and returns 'generated_data' as the transformed version of `input` based - on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, - num_domains), and `generated_data` has the same shape as `input`. - discriminator_fn: A python lambda that takes `inputs` and `num_domains` as - inputs and returns a tuple (`source_prediction`, `domain_prediction`). - `source_prediction` represents the source(real/generated) prediction by - the discriminator, and `domain_prediction` represents the domain - prediction/classification by the discriminator. `source_prediction` has - shape (n) and `domain_prediction` has shape (n, num_domains). - input_data: Tensor or a list of tensor of shape (n, h, w, c) representing - the real input images. - input_data_domain_label: Tensor or a list of tensor of shape (batch_size, - num_domains) representing the domain label associated with the real - images. - generator_scope: Optional generator variable scope. Useful if you want to - reuse a subgraph that has already been created. - discriminator_scope: Optional discriminator variable scope. Useful if you - want to reuse a subgraph that has already been created. - - Returns: - StarGANModel nametuple return the tensor that are needed to compute the - loss. - - Raises: - ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully - defined in every dimensions. - """ - - # Convert to tensor. - input_data = _convert_tensor_or_l_or_d(input_data) - input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label) - - # Convert list of tensor to a single tensor if applicable. - if isinstance(input_data, (list, tuple)): - input_data = array_ops.concat( - [ops.convert_to_tensor(x) for x in input_data], 0) - if isinstance(input_data_domain_label, (list, tuple)): - input_data_domain_label = array_ops.concat( - [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) - - # Get batch_size, num_domains from the labels. - input_data_domain_label.shape.assert_has_rank(2) - input_data_domain_label.shape.assert_is_fully_defined() - batch_size, num_domains = input_data_domain_label.shape.as_list() - - # Transform input_data to random target domains. - with variable_scope.variable_scope(generator_scope) as generator_scope: - generated_data_domain_target = _generate_stargan_random_domain_target( - batch_size, num_domains) - generated_data = generator_fn(input_data, generated_data_domain_target) - - # Transform generated_data back to the original input_data domain. - with variable_scope.variable_scope(generator_scope, reuse=True): - reconstructed_data = generator_fn(generated_data, input_data_domain_label) - - # Predict source and domain for the generated_data using the discriminator. - with variable_scope.variable_scope( - discriminator_scope) as discriminator_scope: - disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( - generated_data, num_domains) - - # Predict source and domain for the input_data using the discriminator. - with variable_scope.variable_scope(discriminator_scope, reuse=True): - disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( - input_data, num_domains) - - # Collect trainable variables from the neural networks. - generator_variables = variables_lib.get_trainable_variables(generator_scope) - discriminator_variables = variables_lib.get_trainable_variables( - discriminator_scope) - - # Create the StarGANModel namedtuple. - return namedtuples.StarGANModel( - input_data=input_data, - input_data_domain_label=input_data_domain_label, - generated_data=generated_data, - generated_data_domain_target=generated_data_domain_target, - reconstructed_data=reconstructed_data, - discriminator_input_data_source_predication=disc_input_data_source_pred, - discriminator_generated_data_source_predication=disc_gen_data_source_pred, - discriminator_input_data_domain_predication=disc_input_data_domain_pred, - discriminator_generated_data_domain_predication=disc_gen_data_domain_pred, - generator_variables=generator_variables, - generator_scope=generator_scope, - generator_fn=generator_fn, - discriminator_variables=discriminator_variables, - discriminator_scope=discriminator_scope, - discriminator_fn=discriminator_fn) - - -def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): - if isinstance(aux_loss_weight, ops.Tensor): - aux_loss_weight.shape.assert_is_compatible_with([]) - with ops.control_dependencies( - [check_ops.assert_greater_equal(aux_loss_weight, 0.0)]): - aux_loss_weight = array_ops.identity(aux_loss_weight) - elif aux_loss_weight is not None and aux_loss_weight < 0: - raise ValueError('`%s` must be greater than 0. Instead, was %s' % - (name, aux_loss_weight)) - return aux_loss_weight - - -def _use_aux_loss(aux_loss_weight): - if aux_loss_weight is not None: - if not isinstance(aux_loss_weight, ops.Tensor): - return aux_loss_weight > 0 - else: - return True - else: - return False - - -def _tensor_pool_adjusted_model(model, tensor_pool_fn): - """Adjusts model using `tensor_pool_fn`. - - Args: - model: A GANModel tuple. - tensor_pool_fn: A function that takes (generated_data, generator_inputs), - stores them in an internal pool and returns a previously stored - (generated_data, generator_inputs) with some probability. For example - tfgan.features.tensor_pool. - - Returns: - A new GANModel tuple where discriminator outputs are adjusted by taking - pooled generator outputs as inputs. Returns the original model if - `tensor_pool_fn` is None. - - Raises: - ValueError: If tensor pool does not support the `model`. - """ - if isinstance(model, namedtuples.GANModel): - pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( - (model.generator_inputs, model.generated_data)) - with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - dis_gen_outputs = model.discriminator_fn(pooled_generated_data, - pooled_generator_inputs) - return model._replace( - generator_inputs=pooled_generator_inputs, - generated_data=pooled_generated_data, - discriminator_gen_outputs=dis_gen_outputs) - elif isinstance(model, namedtuples.ACGANModel): - pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( - (model.generator_inputs, model.generated_data)) - with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (pooled_discriminator_gen_outputs, - pooled_discriminator_gen_classification_logits) = model.discriminator_fn( - pooled_generated_data, pooled_generator_inputs) - return model._replace( - generator_inputs=pooled_generator_inputs, - generated_data=pooled_generated_data, - discriminator_gen_outputs=pooled_discriminator_gen_outputs, - discriminator_gen_classification_logits=pooled_discriminator_gen_classification_logits # pylint: disable=line-too-long - ) - elif isinstance(model, namedtuples.InfoGANModel): - pooled_generator_inputs, pooled_generated_data, pooled_structured_input = ( - tensor_pool_fn((model.generator_inputs, model.generated_data, - model.structured_generator_inputs))) - with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (pooled_discriminator_gen_outputs, - pooled_predicted_distributions) = model.discriminator_and_aux_fn( - pooled_generated_data, pooled_generator_inputs) - return model._replace( - generator_inputs=pooled_generator_inputs, - generated_data=pooled_generated_data, - structured_generator_inputs=pooled_structured_input, - discriminator_gen_outputs=pooled_discriminator_gen_outputs, - predicted_distributions=pooled_predicted_distributions) - else: - raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) - - -def gan_loss( - # GANModel. - model, - # Loss functions. - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, - # Auxiliary losses. - gradient_penalty_weight=None, - gradient_penalty_epsilon=1e-10, - gradient_penalty_target=1.0, - gradient_penalty_one_sided=False, - mutual_information_penalty_weight=None, - aux_cond_generator_weight=None, - aux_cond_discriminator_weight=None, - tensor_pool_fn=None, - # Options. - add_summaries=True): - """Returns losses necessary to train generator and discriminator. - - Args: - model: A GANModel tuple. - generator_loss_fn: The loss function on the generator. Takes a GANModel - tuple. - discriminator_loss_fn: The loss function on the discriminator. Takes a - GANModel tuple. - gradient_penalty_weight: If not `None`, must be a non-negative Python number - or Tensor indicating how much to weight the gradient penalty. See - https://arxiv.org/pdf/1704.00028.pdf for more details. - gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the - small positive value used by the gradient penalty function for numerical - stability. Note some applications will need to increase this value to - avoid NaNs. - gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python - number or `Tensor` indicating the target value of gradient norm. See the - CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. - gradient_penalty_one_sided: If `True`, penalty proposed in - https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. - mutual_information_penalty_weight: If not `None`, must be a non-negative - Python number or Tensor indicating how much to weight the mutual - information penalty. See https://arxiv.org/abs/1606.03657 for more - details. - aux_cond_generator_weight: If not None: add a classification loss as in - https://arxiv.org/abs/1610.09585 - aux_cond_discriminator_weight: If not None: add a classification loss as in - https://arxiv.org/abs/1610.09585 - tensor_pool_fn: A function that takes (generated_data, generator_inputs), - stores them in an internal pool and returns previous stored - (generated_data, generator_inputs). For example - `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). - add_summaries: Whether or not to add summaries for the losses. - - Returns: - A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes - regularization losses. - - Raises: - ValueError: If any of the auxiliary loss weights is provided and negative. - ValueError: If `mutual_information_penalty_weight` is provided, but the - `model` isn't an `InfoGANModel`. - """ - # Validate arguments. - gradient_penalty_weight = _validate_aux_loss_weight( - gradient_penalty_weight, 'gradient_penalty_weight') - mutual_information_penalty_weight = _validate_aux_loss_weight( - mutual_information_penalty_weight, 'infogan_weight') - aux_cond_generator_weight = _validate_aux_loss_weight( - aux_cond_generator_weight, 'aux_cond_generator_weight') - aux_cond_discriminator_weight = _validate_aux_loss_weight( - aux_cond_discriminator_weight, 'aux_cond_discriminator_weight') - - # Verify configuration for mutual information penalty - if (_use_aux_loss(mutual_information_penalty_weight) and - not isinstance(model, namedtuples.InfoGANModel)): - raise ValueError( - 'When `mutual_information_penalty_weight` is provided, `model` must be ' - 'an `InfoGANModel`. Instead, was %s.' % type(model)) - - # Verify configuration for mutual auxiliary condition loss (ACGAN). - if ((_use_aux_loss(aux_cond_generator_weight) or - _use_aux_loss(aux_cond_discriminator_weight)) and - not isinstance(model, namedtuples.ACGANModel)): - raise ValueError( - 'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` ' - 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % - type(model)) - - # Optionally create pooled model. - if tensor_pool_fn: - pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) - else: - pooled_model = model - - # Create standard losses. - gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries) - - # Add optional extra losses. - if _use_aux_loss(gradient_penalty_weight): - gp_loss = tfgan_losses.wasserstein_gradient_penalty( - pooled_model, - epsilon=gradient_penalty_epsilon, - target=gradient_penalty_target, - one_sided=gradient_penalty_one_sided, - add_summaries=add_summaries) - dis_loss += gradient_penalty_weight * gp_loss - if _use_aux_loss(mutual_information_penalty_weight): - gen_info_loss = tfgan_losses.mutual_information_penalty( - model, add_summaries=add_summaries) - if tensor_pool_fn is None: - dis_info_loss = gen_info_loss - else: - dis_info_loss = tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries) - gen_loss += mutual_information_penalty_weight * gen_info_loss - dis_loss += mutual_information_penalty_weight * dis_info_loss - if _use_aux_loss(aux_cond_generator_weight): - ac_gen_loss = tfgan_losses.acgan_generator_loss( - model, add_summaries=add_summaries) - gen_loss += aux_cond_generator_weight * ac_gen_loss - if _use_aux_loss(aux_cond_discriminator_weight): - ac_disc_loss = tfgan_losses.acgan_discriminator_loss( - pooled_model, add_summaries=add_summaries) - dis_loss += aux_cond_discriminator_weight * ac_disc_loss - # Gathers auxiliary losses. - if model.generator_scope: - gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) - else: - gen_reg_loss = 0 - if model.discriminator_scope: - dis_reg_loss = losses.get_regularization_loss( - model.discriminator_scope.name) - else: - dis_reg_loss = 0 - - return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss) - - -def cyclegan_loss( - model, - # Loss functions. - generator_loss_fn=tfgan_losses.least_squares_generator_loss, - discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss, - # Auxiliary losses. - cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss, - cycle_consistency_loss_weight=10.0, - # Options - **kwargs): - """Returns the losses for a `CycleGANModel`. - - See https://arxiv.org/abs/1703.10593 for more details. - - Args: - model: A `CycleGANModel` namedtuple. - generator_loss_fn: The loss function on the generator. Takes a `GANModel` - named tuple. - discriminator_loss_fn: The loss function on the discriminator. Takes a - `GANModel` namedtuple. - cycle_consistency_loss_fn: The cycle consistency loss function. Takes a - `CycleGANModel` namedtuple. - cycle_consistency_loss_weight: A non-negative Python number or a scalar - `Tensor` indicating how much to weigh the cycle consistency loss. - **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss - for each partial model of `model`. - - Returns: - A `CycleGANLoss` namedtuple. - - Raises: - ValueError: If `model` is not a `CycleGANModel` namedtuple. - """ - # Sanity checks. - if not isinstance(model, namedtuples.CycleGANModel): - raise ValueError('`model` must be a `CycleGANModel`. Instead, was %s.' % - type(model)) - - # Defines cycle consistency loss. - cycle_consistency_loss = cycle_consistency_loss_fn( - model, add_summaries=kwargs.get('add_summaries', True)) - cycle_consistency_loss_weight = _validate_aux_loss_weight( - cycle_consistency_loss_weight, 'cycle_consistency_loss_weight') - aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss - - # Defines losses for each partial model. - def _partial_loss(partial_model): - partial_loss = gan_loss( - partial_model, - generator_loss_fn=generator_loss_fn, - discriminator_loss_fn=discriminator_loss_fn, - **kwargs) - return partial_loss._replace(generator_loss=partial_loss.generator_loss + - aux_loss) - - with ops.name_scope('cyclegan_loss_x2y'): - loss_x2y = _partial_loss(model.model_x2y) - with ops.name_scope('cyclegan_loss_y2x'): - loss_y2x = _partial_loss(model.model_y2x) - - return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) - - -# Begin google-internal -# The four major parts can be found here: http://screen/tMRMBAohDYG. -# End google-internal -def stargan_loss( - model, - generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( - tfgan_losses_impl.wasserstein_generator_loss), - discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper( - tfgan_losses_impl.wasserstein_discriminator_loss), - gradient_penalty_weight=10.0, - gradient_penalty_epsilon=1e-10, - gradient_penalty_target=1.0, - gradient_penalty_one_sided=False, - reconstruction_loss_fn=losses.absolute_difference, - reconstruction_loss_weight=10.0, - classification_loss_fn=losses.softmax_cross_entropy, - classification_loss_weight=1.0, - classification_one_hot=True, - add_summaries=True): - """StarGAN Loss. - - Args: - model: (StarGAN) Model output of the stargan_model() function call. - generator_loss_fn: The loss function on the generator. Takes a - `StarGANModel` named tuple. - discriminator_loss_fn: The loss function on the discriminator. Takes a - `StarGANModel` namedtuple. - gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per - the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to - turn off gradient penalty. - gradient_penalty_epsilon: (float) A small positive number added for - numerical stability when computing the gradient norm. - gradient_penalty_target: (float, or tf.float `Tensor`) The target value of - gradient norm. Defaults to 1.0. - gradient_penalty_one_sided: (bool) If `True`, penalty proposed in - https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. - reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm - and the function must conform to the `tf.losses` API. - reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0. - classification_loss_fn: The loss function on the discriminator's ability to - classify domain of the input. Default to one-hot softmax cross entropy - loss, and the function must conform to the `tf.losses` API. - classification_loss_weight: (float) Classification loss weight. Default to - 1.0. - classification_one_hot: (bool) If the label is one hot representation. - Default to True. If False, classification classification_loss_fn need to - be sigmoid cross entropy loss instead. - add_summaries: (bool) Add the loss to the summary - - Returns: - GANLoss namedtuple where we have generator loss and discriminator loss. - - Raises: - ValueError: If input StarGANModel.input_data_domain_label does not have rank - 2, or dimension 2 is not defined. - """ - - def _classification_loss_helper(true_labels, predict_logits, scope_name): - """Classification Loss Function Helper. - - Args: - true_labels: Tensor of shape [batch_size, num_domains] representing the - label where each row is an one-hot vector. - predict_logits: Tensor of shape [batch_size, num_domains] representing the - predicted label logit, which is UNSCALED output from the NN. - scope_name: (string) Name scope of the loss component. - - Returns: - Single scalar tensor representing the classification loss. - """ - - with ops.name_scope(scope_name, values=(true_labels, predict_logits)): - - loss = classification_loss_fn( - onehot_labels=true_labels, logits=predict_logits) - - if not classification_one_hot: - loss = math_ops.reduce_sum(loss, axis=1) - loss = math_ops.reduce_mean(loss) - - if add_summaries: - summary.scalar(scope_name, loss) - - return loss - - # Check input shape. - model.input_data_domain_label.shape.assert_has_rank(2) - model.input_data_domain_label.shape[1:].assert_is_fully_defined() - - # Adversarial Loss. - generator_loss = generator_loss_fn(model, add_summaries=add_summaries) - discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries) - - # Gradient Penalty. - if _use_aux_loss(gradient_penalty_weight): - gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper( - tfgan_losses_impl.wasserstein_gradient_penalty) - discriminator_loss += gradient_penalty_fn( - model, - epsilon=gradient_penalty_epsilon, - target=gradient_penalty_target, - one_sided=gradient_penalty_one_sided, - add_summaries=add_summaries) * gradient_penalty_weight - - # Reconstruction Loss. - reconstruction_loss = reconstruction_loss_fn(model.input_data, - model.reconstructed_data) - generator_loss += reconstruction_loss * reconstruction_loss_weight - if add_summaries: - summary.scalar('reconstruction_loss', reconstruction_loss) - - # Classification Loss. - generator_loss += _classification_loss_helper( - true_labels=model.generated_data_domain_target, - predict_logits=model.discriminator_generated_data_domain_predication, - scope_name='generator_classification_loss') * classification_loss_weight - discriminator_loss += _classification_loss_helper( - true_labels=model.input_data_domain_label, - predict_logits=model.discriminator_input_data_domain_predication, - scope_name='discriminator_classification_loss' - ) * classification_loss_weight - - return namedtuples.GANLoss(generator_loss, discriminator_loss) - - -def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): - """Gets generator and discriminator update ops. - - Args: - kwargs: A dictionary of kwargs to be passed to `create_train_op`. - `update_ops` is removed, if present. - gen_scope: A scope for the generator. - dis_scope: A scope for the discriminator. - check_for_unused_ops: A Python bool. If `True`, throw Exception if there are - unused update ops. - - Returns: - A 2-tuple of (generator update ops, discriminator train ops). - - Raises: - ValueError: If there are update ops outside of the generator or - discriminator scopes. - """ - if 'update_ops' in kwargs: - update_ops = set(kwargs['update_ops']) - del kwargs['update_ops'] - else: - update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) - - all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope)) - all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope)) - - if check_for_unused_ops: - unused_ops = update_ops - all_gen_ops - all_dis_ops - if unused_ops: - raise ValueError('There are unused update ops: %s' % unused_ops) - - gen_update_ops = list(all_gen_ops & update_ops) - dis_update_ops = list(all_dis_ops & update_ops) - - return gen_update_ops, dis_update_ops - - -def gan_train_ops( - model, - loss, - generator_optimizer, - discriminator_optimizer, - check_for_unused_update_ops=True, - is_chief=True, - # Optional args to pass directly to the `create_train_op`. - **kwargs): - """Returns GAN train ops. - - The highest-level call in TF-GAN. It is composed of functions that can also - be called, should a user require more control over some part of the GAN - training process. - - Args: - model: A GANModel. - loss: A GANLoss. - generator_optimizer: The optimizer for generator updates. - discriminator_optimizer: The optimizer for the discriminator updates. - check_for_unused_update_ops: If `True`, throws an exception if there are - update ops outside of the generator or discriminator scopes. - is_chief: Specifies whether or not the training is being run by the primary - replica during replica training. - **kwargs: Keyword args to pass directly to `training.create_train_op` for - both the generator and discriminator train op. - - Returns: - A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can - be used to train a generator/discriminator pair. - """ - if isinstance(model, namedtuples.CycleGANModel): - # Get and store all arguments other than model and loss from locals. - # Contents of locals should not be modified, may not affect values. So make - # a copy. https://docs.python.org/2/library/functions.html#locals. - saved_params = dict(locals()) - saved_params.pop('model', None) - saved_params.pop('loss', None) - kwargs = saved_params.pop('kwargs', {}) - saved_params.update(kwargs) - with ops.name_scope('cyclegan_x2y_train'): - train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, - **saved_params) - with ops.name_scope('cyclegan_y2x_train'): - train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, - **saved_params) - return namedtuples.GANTrainOps( - (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), - (train_ops_x2y.discriminator_train_op, - train_ops_y2x.discriminator_train_op), - training_util.get_or_create_global_step().assign_add(1)) - - # Create global step increment op. - global_step = training_util.get_or_create_global_step() - global_step_inc = global_step.assign_add(1) - - # Get generator and discriminator update ops. We split them so that update - # ops aren't accidentally run multiple times. For now, throw an error if - # there are update ops that aren't associated with either the generator or - # the discriminator. Might modify the `kwargs` dictionary. - gen_update_ops, dis_update_ops = _get_update_ops( - kwargs, model.generator_scope.name, model.discriminator_scope.name, - check_for_unused_update_ops) - - # Get the sync hooks if these are needed. - sync_hooks = [] - - generator_global_step = None - if isinstance(generator_optimizer, - sync_replicas_optimizer.SyncReplicasOptimizer): - # TODO(joelshor): Figure out a way to get this work without including the - # dummy global step in the checkpoint. - # WARNING: Making this variable a local variable causes sync replicas to - # hang forever. - generator_global_step = variable_scope.get_variable( - 'dummy_global_step_generator', - shape=[], - dtype=global_step.dtype.base_dtype, - initializer=init_ops.zeros_initializer(), - trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - gen_update_ops += [generator_global_step.assign(global_step)] - sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief)) - with ops.name_scope('generator_train'): - gen_train_op = training.create_train_op( - total_loss=loss.generator_loss, - optimizer=generator_optimizer, - variables_to_train=model.generator_variables, - global_step=generator_global_step, - update_ops=gen_update_ops, - **kwargs) - - discriminator_global_step = None - if isinstance(discriminator_optimizer, - sync_replicas_optimizer.SyncReplicasOptimizer): - # See comment above `generator_global_step`. - discriminator_global_step = variable_scope.get_variable( - 'dummy_global_step_discriminator', - shape=[], - dtype=global_step.dtype.base_dtype, - initializer=init_ops.zeros_initializer(), - trainable=False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - dis_update_ops += [discriminator_global_step.assign(global_step)] - sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief)) - with ops.name_scope('discriminator_train'): - disc_train_op = training.create_train_op( - total_loss=loss.discriminator_loss, - optimizer=discriminator_optimizer, - variables_to_train=model.discriminator_variables, - global_step=discriminator_global_step, - update_ops=dis_update_ops, - **kwargs) - - return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc, - sync_hooks) - - -# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive -# Image Compression` (https://arxiv.org/abs/1705.05823) -class RunTrainOpsHook(session_run_hook.SessionRunHook): - """A hook to run train ops a fixed number of times.""" - - def __init__(self, train_ops, train_steps): - """Run train ops a certain number of times. - - Args: - train_ops: A train op or iterable of train ops to run. - train_steps: The number of times to run the op(s). - """ - if not isinstance(train_ops, (list, tuple)): - train_ops = [train_ops] - self._train_ops = train_ops - self._train_steps = train_steps - - def before_run(self, run_context): - for _ in range(self._train_steps): - run_context.session.run(self._train_ops) - - -def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): - """Returns a hooks function for sequential GAN training. - - Args: - train_steps: A `GANTrainSteps` tuple that determines how many generator and - discriminator training steps to take. - - Returns: - A function that takes a GANTrainOps tuple and returns a list of hooks. - """ - - def get_hooks(train_ops): - generator_hook = RunTrainOpsHook(train_ops.generator_train_op, - train_steps.generator_train_steps) - discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, - train_steps.discriminator_train_steps) - return [generator_hook, discriminator_hook] + list(train_ops.train_hooks) - - return get_hooks - - -def _num_joint_steps(train_steps): - g_steps = train_steps.generator_train_steps - d_steps = train_steps.discriminator_train_steps - # Get the number of each type of step that should be run. - num_d_and_g_steps = min(g_steps, d_steps) - num_g_steps = g_steps - num_d_and_g_steps - num_d_steps = d_steps - num_d_and_g_steps - - return num_d_and_g_steps, num_g_steps, num_d_steps - - -def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): - """Returns a hooks function for joint GAN training. - - When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON - ALL OPTIMIZERS TO AVOID RACE CONDITIONS. - - The order of steps taken is: - 1) Combined generator and discriminator steps - 2) Generator only steps, if any remain - 3) Discriminator only steps, if any remain - - **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates - for the generator and discriminator simultaneously whenever possible. This - reduces the number of `tf.compat.v1.Session` calls, and can also change the - training - semantics. - - To illustrate the difference look at the following example: - - `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause - `get_sequential_train_hooks` to make 8 session calls: - 1) 3 generator steps - 2) 5 discriminator steps - - In contrast, `get_joint_train_steps` will make 5 session calls: - 1) 3 generator + discriminator steps - 2) 2 discriminator steps - - Args: - train_steps: A `GANTrainSteps` tuple that determines how many generator and - discriminator training steps to take. - - Returns: - A function that takes a GANTrainOps tuple and returns a list of hooks. - """ - num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps) - - def get_hooks(train_ops): - g_op = train_ops.generator_train_op - d_op = train_ops.discriminator_train_op - - joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps) - g_hook = RunTrainOpsHook(g_op, num_g_steps) - d_hook = RunTrainOpsHook(d_op, num_d_steps) - - return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks) - - return get_hooks - - -# TODO(joelshor): This function currently returns the global step. Find a -# good way for it to return the generator, discriminator, and final losses. -def gan_train(train_ops, - logdir, - get_hooks_fn=get_sequential_train_hooks(), - master='', - is_chief=True, - scaffold=None, - hooks=None, - chief_only_hooks=None, - save_checkpoint_secs=600, - save_summaries_steps=100, - config=None): - """A wrapper around `contrib.training.train` that uses GAN hooks. - - Args: - train_ops: A GANTrainOps named tuple. - logdir: The directory where the graph and checkpoints are saved. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. - master: The URL of the master. - is_chief: Specifies whether or not the training is being run by the primary - replica during replica training. - scaffold: An tf.compat.v1.train.Scaffold instance. - hooks: List of `tf.estimator.SessionRunHook` callbacks which are run inside - the training loop. - chief_only_hooks: List of `tf.estimator.SessionRunHook` instances which are - run inside the training loop for the chief trainer only. - save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved - using a default checkpoint saver. If `save_checkpoint_secs` is set to - `None`, then the default checkpoint saver isn't used. - save_summaries_steps: The frequency, in number of global steps, that the - summaries are written to disk using a default summary saver. If - `save_summaries_steps` is set to `None`, then the default summary saver - isn't used. - config: An instance of `tf.compat.v1.ConfigProto`. - - Returns: - Output of the call to `training.train`. - """ - new_hooks = get_hooks_fn(train_ops) - if hooks is not None: - hooks = list(hooks) + list(new_hooks) - else: - hooks = new_hooks - return training.train( - train_ops.global_step_inc_op, - logdir, - master=master, - is_chief=is_chief, - scaffold=scaffold, - hooks=hooks, - chief_only_hooks=chief_only_hooks, - save_checkpoint_secs=save_checkpoint_secs, - save_summaries_steps=save_summaries_steps, - config=config) - - -def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): - """Returns a thin wrapper around slim.learning.train_step, for GANs. - - This function is to provide support for the Supervisor. For new code, please - use `MonitoredSession` and `get_sequential_train_hooks`. - - Args: - train_steps: A `GANTrainSteps` tuple that determines how many generator and - discriminator training steps to take. - - Returns: - A function that can be used for `train_step_fn` for GANs. - """ - - def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs): - """A thin wrapper around slim.learning.train_step, for GANs. - - Args: - sess: A Tensorflow session. - train_ops: A GANTrainOps tuple of train ops to run. - global_step: The global step. - train_step_kwargs: Dictionary controlling `train_step` behavior. - - Returns: - A scalar final loss and a bool whether or not the train loop should stop. - """ - # Only run `should_stop` at the end, if required. Make a local copy of - # `train_step_kwargs`, if necessary, so as not to modify the caller's - # dictionary. - should_stop_op, train_kwargs = None, train_step_kwargs - if 'should_stop' in train_step_kwargs: - should_stop_op = train_step_kwargs['should_stop'] - train_kwargs = train_step_kwargs.copy() - del train_kwargs['should_stop'] - - # Run generator training steps. - gen_loss = 0 - for _ in range(train_steps.generator_train_steps): - cur_gen_loss, _ = slim_learning.train_step(sess, - train_ops.generator_train_op, - global_step, train_kwargs) - gen_loss += cur_gen_loss - - # Run discriminator training steps. - dis_loss = 0 - for _ in range(train_steps.discriminator_train_steps): - cur_dis_loss, _ = slim_learning.train_step( - sess, train_ops.discriminator_train_op, global_step, train_kwargs) - dis_loss += cur_dis_loss - - sess.run(train_ops.global_step_inc_op) - - # Run the `should_stop` op after the global step has been incremented, so - # that the `should_stop` aligns with the proper `global_step` count. - if should_stop_op is not None: - should_stop = sess.run(should_stop_op) - else: - should_stop = False - - return gen_loss + dis_loss, should_stop - - return sequential_train_steps - - -# Helpers - - -def _convert_tensor_or_l_or_d(tensor_or_l_or_d): - """Convert input, list of inputs, or dictionary of inputs to Tensors.""" - if isinstance(tensor_or_l_or_d, (list, tuple)): - return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] - elif isinstance(tensor_or_l_or_d, dict): - return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} - else: - return ops.convert_to_tensor(tensor_or_l_or_d) - - -def _validate_distributions(distributions_l, noise_l): - if not isinstance(distributions_l, (tuple, list)): - raise ValueError('`predicted_distributions` must be a list. Instead, found ' - '%s.' % type(distributions_l)) - if len(distributions_l) != len(noise_l): - raise ValueError('Length of `predicted_distributions` %i must be the same ' - 'as the length of structured noise %i.' % - (len(distributions_l), len(noise_l))) - - -def _validate_acgan_discriminator_outputs(discriminator_output): - try: - a, b = discriminator_output - except (TypeError, ValueError): - raise TypeError( - 'A discriminator function for ACGAN must output a tuple ' - 'consisting of (discrimination logits, classification logits).') - return a, b - - -def _generate_stargan_random_domain_target(batch_size, num_domains): - """Generate random domain label. - - Args: - batch_size: (int) Number of random domain label. - num_domains: (int) Number of domains representing with the label. - - Returns: - Tensor of shape (batch_size, num_domains) representing random label. - """ - domain_idx = random_ops.random_uniform([batch_size], - minval=0, - maxval=num_domains, - dtype=dtypes.int32) - - return array_ops.one_hot(domain_idx, num_domains) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py deleted file mode 100644 index 841f25cd7f1..00000000000 --- a/tensorflow/contrib/gan/python/train_test.py +++ /dev/null @@ -1,1144 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for gan.python.train.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from absl.testing import parameterized -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables as variables_lib -from tensorflow.contrib.gan.python import namedtuples -from tensorflow.contrib.gan.python import train -from tensorflow.contrib.gan.python.features.python import random_tensor_pool -from tensorflow.contrib.slim.python.slim import learning as slim_learning -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import categorical -from tensorflow.python.platform import test -from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import coordinator -from tensorflow.python.training import gradient_descent -from tensorflow.python.training import sync_replicas_optimizer -from tensorflow.python.training import training_util - - -def generator_model(inputs): - return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs - - -class Generator(object): - - def __call__(self, inputs): - return generator_model(inputs) - - -def infogan_generator_model(inputs): - return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs[0] - - -class InfoGANGenerator(object): - - def __call__(self, inputs): - return infogan_generator_model(inputs) - - -def discriminator_model(inputs, _): - return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs - - -class Discriminator(object): - - def __call__(self, inputs, _): - return discriminator_model(inputs, _) - - -def infogan_discriminator_model(inputs, _): - return (variable_scope.get_variable('dummy_d', initializer=2.0) * inputs, - [categorical.Categorical([1.0])]) - - -class InfoGANDiscriminator(object): - - def __call__(self, inputs, _): - return infogan_discriminator_model(inputs, _) - - -def acgan_discriminator_model(inputs, _, num_classes=10): - return ( - discriminator_model(inputs, _), - array_ops.one_hot( - # TODO(haeusser): infer batch size from input - random_ops.random_uniform( - [3], maxval=num_classes, dtype=dtypes.int32), - num_classes)) - - -class ACGANDiscriminator(object): - - def __call__(self, inputs, _, num_classes=10): - return ( - discriminator_model(inputs, _), - array_ops.one_hot( - # TODO(haeusser): infer batch size from input - random_ops.random_uniform( - [3], maxval=num_classes, dtype=dtypes.int32), - num_classes)) - - -def stargan_generator_model(inputs, _): - """Dummy generator for StarGAN.""" - - return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs - - -class StarGANGenerator(object): - - def __call__(self, inputs, _): - return stargan_generator_model(inputs, _) - - -def stargan_discriminator_model(inputs, num_domains): - """Differentiable dummy discriminator for StarGAN.""" - - hidden = layers.flatten(inputs) - - output_src = math_ops.reduce_mean(hidden, axis=1) - - output_cls = layers.fully_connected( - inputs=hidden, - num_outputs=num_domains, - activation_fn=None, - normalizer_fn=None, - biases_initializer=None) - return output_src, output_cls - - -class StarGANDiscriminator(object): - - def __call__(self, inputs, num_domains): - return stargan_discriminator_model(inputs, num_domains) - - -def get_gan_model(): - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - pass - with variable_scope.variable_scope('discriminator') as dis_scope: - pass - return namedtuples.GANModel( - generator_inputs=None, - generated_data=None, - generator_variables=None, - generator_scope=gen_scope, - generator_fn=generator_model, - real_data=array_ops.ones([1, 2, 3]), - discriminator_real_outputs=array_ops.ones([1, 2, 3]), - discriminator_gen_outputs=array_ops.ones([1, 2, 3]), - discriminator_variables=None, - discriminator_scope=dis_scope, - discriminator_fn=discriminator_model) - - -def get_callable_gan_model(): - ganmodel = get_gan_model() - return ganmodel._replace( - generator_fn=Generator(), discriminator_fn=Discriminator()) - - -def create_gan_model(): - return train.gan_model( - generator_model, - discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2])) - - -def create_callable_gan_model(): - return train.gan_model( - Generator(), - Discriminator(), - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2])) - - -def get_infogan_model(): - return namedtuples.InfoGANModel( - *get_gan_model(), - structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])], - discriminator_and_aux_fn=infogan_discriminator_model) - - -def get_callable_infogan_model(): - return namedtuples.InfoGANModel( - *get_callable_gan_model(), - structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])], - discriminator_and_aux_fn=infogan_discriminator_model) - - -def create_infogan_model(): - return train.infogan_model( - infogan_generator_model, - infogan_discriminator_model, - real_data=array_ops.zeros([1, 2]), - unstructured_generator_inputs=[], - structured_generator_inputs=[random_ops.random_normal([1, 2])]) - - -def create_callable_infogan_model(): - return train.infogan_model( - InfoGANGenerator(), - InfoGANDiscriminator(), - real_data=array_ops.zeros([1, 2]), - unstructured_generator_inputs=[], - structured_generator_inputs=[random_ops.random_normal([1, 2])]) - - -def get_acgan_model(): - return namedtuples.ACGANModel( - *get_gan_model(), - one_hot_labels=array_ops.one_hot([0, 1, 2], 10), - discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10), - discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10)) - - -def get_callable_acgan_model(): - return namedtuples.ACGANModel( - *get_callable_gan_model(), - one_hot_labels=array_ops.one_hot([0, 1, 2], 10), - discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10), - discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10)) - - -def create_acgan_model(): - return train.acgan_model( - generator_model, - acgan_discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2]), - one_hot_labels=array_ops.one_hot([0, 1, 2], 10)) - - -def create_callable_acgan_model(): - return train.acgan_model( - Generator(), - ACGANDiscriminator(), - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2]), - one_hot_labels=array_ops.one_hot([0, 1, 2], 10)) - - -def get_cyclegan_model(): - return namedtuples.CycleGANModel( - model_x2y=get_gan_model(), - model_y2x=get_gan_model(), - reconstructed_x=array_ops.ones([1, 2, 3]), - reconstructed_y=array_ops.zeros([1, 2, 3])) - - -def get_callable_cyclegan_model(): - return namedtuples.CycleGANModel( - model_x2y=get_callable_gan_model(), - model_y2x=get_callable_gan_model(), - reconstructed_x=array_ops.ones([1, 2, 3]), - reconstructed_y=array_ops.zeros([1, 2, 3])) - - -def create_cyclegan_model(): - return train.cyclegan_model( - generator_model, - discriminator_model, - data_x=array_ops.zeros([1, 2]), - data_y=array_ops.ones([1, 2])) - - -def create_callable_cyclegan_model(): - return train.cyclegan_model( - Generator(), - Discriminator(), - data_x=array_ops.zeros([1, 2]), - data_y=array_ops.ones([1, 2])) - - -def get_stargan_model(): - """Similar to get_gan_model().""" - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - pass - with variable_scope.variable_scope('discriminator') as dis_scope: - pass - return namedtuples.StarGANModel( - input_data=array_ops.ones([1, 2, 2, 3]), - input_data_domain_label=array_ops.ones([1, 2]), - generated_data=array_ops.ones([1, 2, 2, 3]), - generated_data_domain_target=array_ops.ones([1, 2]), - reconstructed_data=array_ops.ones([1, 2, 2, 3]), - discriminator_input_data_source_predication=array_ops.ones([1]), - discriminator_generated_data_source_predication=array_ops.ones([1]), - discriminator_input_data_domain_predication=array_ops.ones([1, 2]), - discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), - generator_variables=None, - generator_scope=gen_scope, - generator_fn=stargan_generator_model, - discriminator_variables=None, - discriminator_scope=dis_scope, - discriminator_fn=stargan_discriminator_model) - - -def get_callable_stargan_model(): - model = get_stargan_model() - return model._replace( - generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator()) - - -def create_stargan_model(): - return train.stargan_model( - stargan_generator_model, stargan_discriminator_model, - array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2])) - - -def create_callable_stargan_model(): - return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(), - array_ops.ones([1, 2, 2, 3]), - array_ops.ones([1, 2])) - - -def get_sync_optimizer(): - return sync_replicas_optimizer.SyncReplicasOptimizer( - gradient_descent.GradientDescentOptimizer(learning_rate=1.0), - replicas_to_aggregate=1) - - -class GANModelTest(test.TestCase, parameterized.TestCase): - """Tests for `gan_model`.""" - - @parameterized.named_parameters( - ('gan', get_gan_model, namedtuples.GANModel), - ('callable_gan', get_callable_gan_model, namedtuples.GANModel), - ('infogan', get_infogan_model, namedtuples.InfoGANModel), - ('callable_infogan', get_callable_infogan_model, - namedtuples.InfoGANModel), - ('acgan', get_acgan_model, namedtuples.ACGANModel), - ('callable_acgan', get_callable_acgan_model, namedtuples.ACGANModel), - ('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel), - ('callable_cyclegan', get_callable_cyclegan_model, - namedtuples.CycleGANModel), - ('stargan', get_stargan_model, namedtuples.StarGANModel), - ('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel) - ) - def test_output_type(self, create_fn, expected_tuple_type): - """Test that output type is as expected.""" - self.assertIsInstance(create_fn(), expected_tuple_type) - - def test_no_shape_check(self): - - def dummy_generator_model(_): - return (None, None) - - def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument - return 1 - - with self.assertRaisesRegexp(AttributeError, 'object has no attribute'): - train.gan_model( - dummy_generator_model, - dummy_discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=array_ops.zeros([1]), - check_shapes=True) - train.gan_model( - dummy_generator_model, - dummy_discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=array_ops.zeros([1]), - check_shapes=False) - - -class StarGANModelTest(test.TestCase): - """Tests for `stargan_model`.""" - - @staticmethod - def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains): - input_tensor_list = [] - label_tensor_list = [] - for _ in range(num_domains): - input_tensor_list.append( - random_ops.random_uniform((batch_size, img_size, img_size, c_size))) - domain_idx = random_ops.random_uniform( - [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) - label_tensor_list.append(array_ops.one_hot(domain_idx, num_domains)) - return input_tensor_list, label_tensor_list - - def test_generate_stargan_random_domain_target(self): - batch_size = 8 - domain_numbers = 3 - - target_tensor = train._generate_stargan_random_domain_target( - batch_size, domain_numbers) - - with self.cached_session() as sess: - targets = sess.run(target_tensor) - self.assertTupleEqual((batch_size, domain_numbers), targets.shape) - for target in targets: - self.assertEqual(1, np.sum(target)) - self.assertEqual(1, np.max(target)) - - def test_stargan_model_output_type(self): - batch_size = 2 - img_size = 16 - c_size = 3 - num_domains = 5 - - input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( - batch_size, img_size, c_size, num_domains) - model = train.stargan_model( - generator_fn=stargan_generator_model, - discriminator_fn=stargan_discriminator_model, - input_data=input_tensor, - input_data_domain_label=label_tensor) - - self.assertIsInstance(model, namedtuples.StarGANModel) - self.assertTrue(isinstance(model.discriminator_variables, list)) - self.assertTrue(isinstance(model.generator_variables, list)) - self.assertIsInstance(model.discriminator_scope, - variable_scope.VariableScope) - self.assertTrue(model.generator_scope, variable_scope.VariableScope) - self.assertTrue(callable(model.discriminator_fn)) - self.assertTrue(callable(model.generator_fn)) - - def test_stargan_model_generator_output(self): - batch_size = 2 - img_size = 16 - c_size = 3 - num_domains = 5 - - input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( - batch_size, img_size, c_size, num_domains) - model = train.stargan_model( - generator_fn=stargan_generator_model, - discriminator_fn=stargan_discriminator_model, - input_data=input_tensor, - input_data_domain_label=label_tensor) - - with self.test_session(use_gpu=True) as sess: - - sess.run(variables.global_variables_initializer()) - - input_data, generated_data, reconstructed_data = sess.run( - [model.input_data, model.generated_data, model.reconstructed_data]) - self.assertTupleEqual( - (batch_size * num_domains, img_size, img_size, c_size), - input_data.shape) - self.assertTupleEqual( - (batch_size * num_domains, img_size, img_size, c_size), - generated_data.shape) - self.assertTupleEqual( - (batch_size * num_domains, img_size, img_size, c_size), - reconstructed_data.shape) - - def test_stargan_model_discriminator_output(self): - batch_size = 2 - img_size = 16 - c_size = 3 - num_domains = 5 - - input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor( - batch_size, img_size, c_size, num_domains) - model = train.stargan_model( - generator_fn=stargan_generator_model, - discriminator_fn=stargan_discriminator_model, - input_data=input_tensor, - input_data_domain_label=label_tensor) - - with self.test_session(use_gpu=True) as sess: - - sess.run(variables.global_variables_initializer()) - - disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([ - model.discriminator_input_data_source_predication, - model.discriminator_generated_data_source_predication - ]) - self.assertEqual(1, len(disc_input_data_source_pred.shape)) - self.assertEqual(batch_size * num_domains, - disc_input_data_source_pred.shape[0]) - self.assertEqual(1, len(disc_gen_data_source_pred.shape)) - self.assertEqual(batch_size * num_domains, - disc_gen_data_source_pred.shape[0]) - - input_label, disc_input_label, gen_label, disc_gen_label = sess.run([ - model.input_data_domain_label, - model.discriminator_input_data_domain_predication, - model.generated_data_domain_target, - model.discriminator_generated_data_domain_predication - ]) - self.assertTupleEqual((batch_size * num_domains, num_domains), - input_label.shape) - self.assertTupleEqual((batch_size * num_domains, num_domains), - disc_input_label.shape) - self.assertTupleEqual((batch_size * num_domains, num_domains), - gen_label.shape) - self.assertTupleEqual((batch_size * num_domains, num_domains), - disc_gen_label.shape) - - -class GANLossTest(test.TestCase, parameterized.TestCase): - """Tests for `gan_loss`.""" - - @parameterized.named_parameters( - ('gan', get_gan_model), - ('callable_gan', get_callable_gan_model), - ('infogan', get_infogan_model), - ('callable_infogan', get_callable_infogan_model), - ('acgan', get_acgan_model), - ('callable_acgan', get_callable_acgan_model), - ) - def test_output_type(self, get_gan_model_fn): - """Test output type.""" - loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) - self.assertIsInstance(loss, namedtuples.GANLoss) - self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) - - @parameterized.named_parameters( - ('cyclegan', create_cyclegan_model), - ('callable_cyclegan', create_callable_cyclegan_model), - ) - def test_cyclegan_output_type(self, get_gan_model_fn): - loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True) - self.assertIsInstance(loss, namedtuples.CycleGANLoss) - self.assertNotEmpty(ops.get_collection(ops.GraphKeys.SUMMARIES)) - - @parameterized.named_parameters( - ('gan', create_gan_model, False), - ('gan_one_sided', create_gan_model, True), - ('callable_gan', create_callable_gan_model, False), - ('callable_gan_one_sided', create_callable_gan_model, True), - ('infogan', create_infogan_model, False), - ('infogan_one_sided', create_infogan_model, True), - ('callable_infogan', create_callable_infogan_model, False), - ('callable_infogan_one_sided', create_callable_infogan_model, True), - ('acgan', create_acgan_model, False), - ('acgan_one_sided', create_acgan_model, True), - ('callable_acgan', create_callable_acgan_model, False), - ('callable_acgan_one_sided', create_callable_acgan_model, True), - ) - def test_grad_penalty(self, create_gan_model_fn, one_sided): - """Test gradient penalty option.""" - model = create_gan_model_fn() - loss = train.gan_loss(model) - loss_gp = train.gan_loss( - model, - gradient_penalty_weight=1.0, - gradient_penalty_one_sided=one_sided) - self.assertIsInstance(loss_gp, namedtuples.GANLoss) - - # Check values. - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - loss_gen_np, loss_gen_gp_np = sess.run( - [loss.generator_loss, loss_gp.generator_loss]) - loss_dis_np, loss_dis_gp_np = sess.run( - [loss.discriminator_loss, loss_gp.discriminator_loss]) - - self.assertEqual(loss_gen_np, loss_gen_gp_np) - self.assertLess(loss_dis_np, loss_dis_gp_np) - - @parameterized.named_parameters( - ('infogan', get_infogan_model), - ('callable_infogan', get_callable_infogan_model), - ) - def test_mutual_info_penalty(self, create_gan_model_fn): - """Test mutual information penalty option.""" - train.gan_loss( - create_gan_model_fn(), - mutual_information_penalty_weight=constant_op.constant(1.0)) - - @parameterized.named_parameters( - ('gan', get_gan_model), - ('callable_gan', get_callable_gan_model), - ('infogan', get_infogan_model), - ('callable_infogan', get_callable_infogan_model), - ('acgan', get_acgan_model), - ('callable_acgan', get_callable_acgan_model), - ) - def test_regularization_helper(self, get_gan_model_fn): - """Test regularization loss.""" - # Evaluate losses without regularization. - no_reg_loss = train.gan_loss(get_gan_model_fn()) - with self.test_session(use_gpu=True): - no_reg_loss_gen_np = no_reg_loss.generator_loss.eval() - no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval() - - with ops.name_scope(get_gan_model_fn().generator_scope.name): - ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, - constant_op.constant(3.0)) - with ops.name_scope(get_gan_model_fn().discriminator_scope.name): - ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES, - constant_op.constant(2.0)) - - # Check that losses now include the correct regularization values. - reg_loss = train.gan_loss(get_gan_model_fn()) - with self.test_session(use_gpu=True): - reg_loss_gen_np = reg_loss.generator_loss.eval() - reg_loss_dis_np = reg_loss.discriminator_loss.eval() - - self.assertEqual(3.0, reg_loss_gen_np - no_reg_loss_gen_np) - self.assertEqual(2.0, reg_loss_dis_np - no_reg_loss_dis_np) - - @parameterized.named_parameters( - ('notcallable', create_acgan_model), - ('callable', create_callable_acgan_model), - ) - def test_acgan(self, create_gan_model_fn): - """Test that ACGAN models work.""" - model = create_gan_model_fn() - loss = train.gan_loss(model) - loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0) - loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0) - self.assertIsInstance(loss, namedtuples.GANLoss) - self.assertIsInstance(loss_ac_gen, namedtuples.GANLoss) - self.assertIsInstance(loss_ac_dis, namedtuples.GANLoss) - - # Check values. - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([ - loss.generator_loss, loss_ac_gen.generator_loss, - loss_ac_dis.generator_loss - ]) - loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([ - loss.discriminator_loss, loss_ac_gen.discriminator_loss, - loss_ac_dis.discriminator_loss - ]) - - self.assertLess(loss_gen_np, loss_dis_np) - self.assertTrue(np.isscalar(loss_ac_gen_gen_np)) - self.assertTrue(np.isscalar(loss_ac_dis_gen_np)) - self.assertTrue(np.isscalar(loss_ac_gen_dis_np)) - self.assertTrue(np.isscalar(loss_ac_dis_dis_np)) - - @parameterized.named_parameters( - ('notcallable', create_cyclegan_model), - ('callable', create_callable_cyclegan_model), - ) - def test_cyclegan(self, create_gan_model_fn): - """Test that CycleGan models work.""" - model = create_gan_model_fn() - loss = train.cyclegan_loss(model) - self.assertIsInstance(loss, namedtuples.CycleGANLoss) - - # Check values. - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np, - loss_y2x_dis_np) = sess.run([ - loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss, - loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss - ]) - - self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np) - self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np) - self.assertTrue(np.isscalar(loss_x2y_gen_np)) - self.assertTrue(np.isscalar(loss_x2y_dis_np)) - self.assertTrue(np.isscalar(loss_y2x_gen_np)) - self.assertTrue(np.isscalar(loss_y2x_dis_np)) - - @parameterized.named_parameters( - ('notcallable', create_stargan_model), - ('callable', create_callable_stargan_model), - ) - def test_stargan(self, create_gan_model_fn): - - model = create_gan_model_fn() - model_loss = train.stargan_loss(model) - - self.assertIsInstance(model_loss, namedtuples.GANLoss) - - with self.cached_session() as sess: - - sess.run(variables.global_variables_initializer()) - - gen_loss, disc_loss = sess.run( - [model_loss.generator_loss, model_loss.discriminator_loss]) - - self.assertTrue(np.isscalar(gen_loss)) - self.assertTrue(np.isscalar(disc_loss)) - - @parameterized.named_parameters( - ('gan', create_gan_model), - ('callable_gan', create_callable_gan_model), - ('infogan', create_infogan_model), - ('callable_infogan', create_callable_infogan_model), - ('acgan', create_acgan_model), - ('callable_acgan', create_callable_acgan_model), - ) - def test_tensor_pool(self, create_gan_model_fn): - """Test tensor pool option.""" - model = create_gan_model_fn() - tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) - loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) - self.assertIsInstance(loss, namedtuples.GANLoss) - - # Check values. - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - for _ in range(10): - sess.run([loss.generator_loss, loss.discriminator_loss]) - - def test_discriminator_only_sees_pool(self): - """Checks that discriminator only sees pooled values.""" - def checker_gen_fn(_): - return constant_op.constant(0.0) - model = train.gan_model( - checker_gen_fn, - discriminator_model, - real_data=array_ops.zeros([]), - generator_inputs=random_ops.random_normal([])) - def tensor_pool_fn(_): - return (random_ops.random_uniform([]), random_ops.random_uniform([])) - def checker_dis_fn(inputs, _): - """Discriminator that checks that it only sees pooled Tensors.""" - self.assertFalse(constant_op.is_constant(inputs)) - return inputs - model = model._replace( - discriminator_fn=checker_dis_fn) - train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) - - def test_doesnt_crash_when_in_nested_scope(self): - with variable_scope.variable_scope('outer_scope'): - gan_model = train.gan_model( - generator_model, - discriminator_model, - real_data=array_ops.zeros([1, 2]), - generator_inputs=random_ops.random_normal([1, 2])) - - # This should work inside a scope. - train.gan_loss(gan_model, gradient_penalty_weight=1.0) - - # This should also work outside a scope. - train.gan_loss(gan_model, gradient_penalty_weight=1.0) - - -class TensorPoolAdjusteModelTest(test.TestCase): - - def _check_tensor_pool_adjusted_model_outputs( - self, tensor1, tensor2, pool_size): - history_values = [] - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - for i in range(2 * pool_size): - t1, t2 = sess.run([tensor1, tensor2]) - history_values.append(t1) - if i < pool_size: - # For [0, pool_size), the pool is not full, tensor1 should be equal - # to tensor2 as the pool. - self.assertAllEqual(t1, t2) - else: - # For [pool_size, ?), the pool is full, tensor2 must be equal to some - # historical values of tensor1 (which is previously stored in the - # pool). - self.assertTrue(any((v == t2).all() for v in history_values)) - - def _make_new_model_and_check(self, model, pool_size): - pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) - new_model = train._tensor_pool_adjusted_model(model, pool_fn) - # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) - self.assertIsNot(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) - - return new_model - - def test_tensor_pool_adjusted_model_gan(self): - """Test `_tensor_pool_adjusted_model` for gan model.""" - pool_size = 5 - model = create_gan_model() - new_model = self._make_new_model_and_check(model, pool_size) - - # Check values. - self._check_tensor_pool_adjusted_model_outputs( - model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, - pool_size) - - def test_tensor_pool_adjusted_model_infogan(self): - """Test _tensor_pool_adjusted_model for infogan model.""" - pool_size = 5 - model = create_infogan_model() - new_model = self._make_new_model_and_check(model, pool_size) - - # Check values. - self.assertIsNot(new_model.predicted_distributions, - model.predicted_distributions) - self._check_tensor_pool_adjusted_model_outputs( - model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, - pool_size) - - def test_tensor_pool_adjusted_model_acgan(self): - """Test _tensor_pool_adjusted_model for acgan model.""" - pool_size = 5 - model = create_acgan_model() - new_model = self._make_new_model_and_check(model, pool_size) - - # Check values. - self.assertIsNot(new_model.discriminator_gen_classification_logits, - model.discriminator_gen_classification_logits) - self._check_tensor_pool_adjusted_model_outputs( - model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, - pool_size) - - -class GANTrainOpsTest(test.TestCase, parameterized.TestCase): - """Tests for `gan_train_ops`.""" - - @parameterized.named_parameters( - ('gan', create_gan_model), - ('callable_gan', create_callable_gan_model), - ('infogan', create_infogan_model), - ('callable_infogan', create_callable_infogan_model), - ('acgan', create_acgan_model), - ('callable_acgan', create_callable_acgan_model), - ) - def test_output_type(self, create_gan_model_fn): - model = create_gan_model_fn() - loss = train.gan_loss(model) - - g_opt = gradient_descent.GradientDescentOptimizer(1.0) - d_opt = gradient_descent.GradientDescentOptimizer(1.0) - train_ops = train.gan_train_ops( - model, - loss, - g_opt, - d_opt, - summarize_gradients=True, - colocate_gradients_with_ops=True) - - self.assertIsInstance(train_ops, namedtuples.GANTrainOps) - - # Make sure there are no training hooks populated accidentally. - self.assertEmpty(train_ops.train_hooks) - - # TODO(joelshor): Add a test to check that custom update op is run. - @parameterized.named_parameters( - ('gan', create_gan_model, False), - ('gan_provideupdates', create_gan_model, True), - ('callable_gan', create_callable_gan_model, False), - ('callable_gan_provideupdates', create_callable_gan_model, True), - ('infogan', create_infogan_model, False), - ('infogan_provideupdates', create_infogan_model, True), - ('callable_infogan', create_callable_infogan_model, False), - ('callable_infogan_provideupdates', create_callable_infogan_model, True), - ('acgan', create_acgan_model, False), - ('acgan_provideupdates', create_acgan_model, True), - ('callable_acgan', create_callable_acgan_model, False), - ('callable_acgan_provideupdates', create_callable_acgan_model, True), - ) - def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): - model = create_gan_model_fn() - loss = train.gan_loss(model) - - # Add generator and discriminator update ops. - with variable_scope.variable_scope(model.generator_scope): - gen_update_count = variable_scope.get_variable('gen_count', initializer=0) - gen_update_op = gen_update_count.assign_add(1) - ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op) - with variable_scope.variable_scope(model.discriminator_scope): - dis_update_count = variable_scope.get_variable('dis_count', initializer=0) - dis_update_op = dis_update_count.assign_add(1) - ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op) - - # Add an update op outside the generator and discriminator scopes. - if provide_update_ops: - kwargs = { - 'update_ops': [ - constant_op.constant(1.0), gen_update_op, dis_update_op - ] - } - else: - ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0)) - kwargs = {} - - g_opt = gradient_descent.GradientDescentOptimizer(1.0) - d_opt = gradient_descent.GradientDescentOptimizer(1.0) - - with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): - train.gan_train_ops( - model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs) - train_ops = train.gan_train_ops( - model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) - - with self.test_session(use_gpu=True) as sess: - sess.run(variables.global_variables_initializer()) - self.assertEqual(0, gen_update_count.eval()) - self.assertEqual(0, dis_update_count.eval()) - - train_ops.generator_train_op.eval() - self.assertEqual(1, gen_update_count.eval()) - self.assertEqual(0, dis_update_count.eval()) - - train_ops.discriminator_train_op.eval() - self.assertEqual(1, gen_update_count.eval()) - self.assertEqual(1, dis_update_count.eval()) - - @parameterized.named_parameters( - ('gan', create_gan_model, False), - ('callable_gan', create_callable_gan_model, False), - ('infogan', create_infogan_model, False), - ('callable_infogan', create_callable_infogan_model, False), - ('acgan', create_acgan_model, False), - ('callable_acgan', create_callable_acgan_model, False), - ('gan_canbeint32', create_gan_model, True), - ) - def test_sync_replicas(self, create_gan_model_fn, create_global_step): - model = create_gan_model_fn() - loss = train.gan_loss(model) - num_trainable_vars = len(variables_lib.get_trainable_variables()) - - if create_global_step: - gstep = variable_scope.get_variable( - 'custom_gstep', dtype=dtypes.int32, initializer=0, trainable=False) - ops.add_to_collection(ops.GraphKeys.GLOBAL_STEP, gstep) - - g_opt = get_sync_optimizer() - d_opt = get_sync_optimizer() - train_ops = train.gan_train_ops( - model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt) - self.assertIsInstance(train_ops, namedtuples.GANTrainOps) - # No new trainable variables should have been added. - self.assertLen(variables_lib.get_trainable_variables(), num_trainable_vars) - - # Sync hooks should be populated in the GANTrainOps. - self.assertLen(train_ops.train_hooks, 2) - for hook in train_ops.train_hooks: - self.assertIsInstance( - hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) - sync_opts = [hook._sync_optimizer for hook in train_ops.train_hooks] - self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) - - g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) - d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) - - # Check that update op is run properly. - global_step = training_util.get_or_create_global_step() - with self.test_session(use_gpu=True) as sess: - variables.global_variables_initializer().run() - variables.local_variables_initializer().run() - - g_opt.chief_init_op.run() - d_opt.chief_init_op.run() - - gstep_before = global_step.eval() - - # Start required queue runner for SyncReplicasOptimizer. - coord = coordinator.Coordinator() - g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) - d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) - - g_sync_init_op.run() - d_sync_init_op.run() - - train_ops.generator_train_op.eval() - # Check that global step wasn't incremented. - self.assertEqual(gstep_before, global_step.eval()) - - train_ops.discriminator_train_op.eval() - # Check that global step wasn't incremented. - self.assertEqual(gstep_before, global_step.eval()) - - coord.request_stop() - coord.join(g_threads + d_threads) - - @parameterized.named_parameters( - ('is_chief', True), - ('is_not_chief', False), - ) - def test_is_chief_in_train_hooks(self, is_chief): - """Make sure is_chief is propagated correctly to sync hooks.""" - model = create_gan_model() - loss = train.gan_loss(model) - g_opt = get_sync_optimizer() - d_opt = get_sync_optimizer() - train_ops = train.gan_train_ops( - model, - loss, - g_opt, - d_opt, - is_chief=is_chief, - summarize_gradients=True, - colocate_gradients_with_ops=True) - - self.assertLen(train_ops.train_hooks, 2) - for hook in train_ops.train_hooks: - self.assertIsInstance( - hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) - is_chief_list = [hook._is_chief for hook in train_ops.train_hooks] - self.assertListEqual(is_chief_list, [is_chief, is_chief]) - - -class GANTrainTest(test.TestCase, parameterized.TestCase): - """Tests for `gan_train`.""" - - def _gan_train_ops(self, generator_add, discriminator_add): - step = training_util.create_global_step() - # Increment the global count every time a train op is run so we can count - # the number of times they're run. - # NOTE: `use_locking=True` is required to avoid race conditions with - # joint training. - train_ops = namedtuples.GANTrainOps( - generator_train_op=step.assign_add(generator_add, use_locking=True), - discriminator_train_op=step.assign_add( - discriminator_add, use_locking=True), - global_step_inc_op=step.assign_add(1)) - return train_ops - - @parameterized.named_parameters( - ('gan', create_gan_model), - ('callable_gan', create_callable_gan_model), - ('infogan', create_infogan_model), - ('callable_infogan', create_callable_infogan_model), - ('acgan', create_acgan_model), - ('callable_acgan', create_callable_acgan_model), - ) - def test_run_helper(self, create_gan_model_fn): - random_seed.set_random_seed(1234) - model = create_gan_model_fn() - loss = train.gan_loss(model) - - g_opt = gradient_descent.GradientDescentOptimizer(1.0) - d_opt = gradient_descent.GradientDescentOptimizer(1.0) - train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) - - final_step = train.gan_train( - train_ops, - logdir='', - hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) - self.assertTrue(np.isscalar(final_step)) - self.assertEqual(2, final_step) - - @parameterized.named_parameters( - ('seq_train_steps', train.get_sequential_train_hooks), - ('efficient_seq_train_steps', train.get_joint_train_hooks), - ) - def test_multiple_steps(self, get_hooks_fn_fn): - """Test multiple train steps.""" - train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) - train_steps = namedtuples.GANTrainSteps( - generator_train_steps=3, discriminator_train_steps=4) - final_step = train.gan_train( - train_ops, - get_hooks_fn=get_hooks_fn_fn(train_steps), - logdir='', - hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) - - self.assertTrue(np.isscalar(final_step)) - self.assertEqual(1 + 3 * 10 + 4 * 100, final_step) - - def test_supervisor_run_gan_model_train_ops_multiple_steps(self): - step = training_util.create_global_step() - train_ops = namedtuples.GANTrainOps( - generator_train_op=constant_op.constant(3.0), - discriminator_train_op=constant_op.constant(2.0), - global_step_inc_op=step.assign_add(1)) - train_steps = namedtuples.GANTrainSteps( - generator_train_steps=3, discriminator_train_steps=4) - - final_loss = slim_learning.train( - train_op=train_ops, - logdir='', - global_step=step, - number_of_steps=1, - train_step_fn=train.get_sequential_train_steps(train_steps)) - self.assertTrue(np.isscalar(final_loss)) - self.assertEqual(17.0, final_loss) - - @parameterized.named_parameters( - ('gan', create_gan_model), - ('callable_gan', create_callable_gan_model), - ('infogan', create_infogan_model), - ('callable_infogan', create_callable_infogan_model), - ('acgan', create_acgan_model), - ('callable_acgan', create_callable_acgan_model), - ) - def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): - model = create_gan_model_fn() - loss = train.gan_loss(model) - - g_opt = get_sync_optimizer() - d_opt = get_sync_optimizer() - train_ops = train.gan_train_ops( - model, - loss, - g_opt, - d_opt, - summarize_gradients=True, - colocate_gradients_with_ops=True) - - sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) - self.assertLen(sequential_train_hooks, 4) - sync_opts = [ - hook._sync_optimizer for hook in sequential_train_hooks if - isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] - self.assertLen(sync_opts, 2) - self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) - - joint_train_hooks = train.get_joint_train_hooks()(train_ops) - self.assertLen(joint_train_hooks, 5) - sync_opts = [ - hook._sync_optimizer for hook in joint_train_hooks if - isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] - self.assertLen(sync_opts, 2) - self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) - - -class PatchGANTest(test.TestCase, parameterized.TestCase): - """Tests that functions work on PatchGAN style output.""" - - @parameterized.named_parameters( - ('gan', create_gan_model), - ('callable_gan', create_callable_gan_model), - ('infogan', create_infogan_model), - ('callable_infogan', create_callable_infogan_model), - ('acgan', create_acgan_model), - ('callable_acgan', create_callable_acgan_model), - ) - def test_patchgan(self, create_gan_model_fn): - """Ensure that patch-based discriminators work end-to-end.""" - random_seed.set_random_seed(1234) - model = create_gan_model_fn() - loss = train.gan_loss(model) - - g_opt = gradient_descent.GradientDescentOptimizer(1.0) - d_opt = gradient_descent.GradientDescentOptimizer(1.0) - train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) - - final_step = train.gan_train( - train_ops, - logdir='', - hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) - self.assertTrue(np.isscalar(final_step)) - self.assertEqual(2, final_step) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index 0e8a493e15e..1eead8bff44 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -3,7 +3,7 @@ # For platform specific build config load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library_cc", ) diff --git a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc index c0b40194faf..4988ce6d2fe 100644 --- a/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_collective_executor_mgr.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/unbounded_work_queue.h" namespace tensorflow { @@ -65,12 +66,12 @@ class RecvBufCall : public CancellableCall { class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { public: - CollectiveRemoteAccessDistributed(const DeviceMgr* dev_mgr, - DeviceResolverInterface* dev_resolver, - WorkerCacheInterface* worker_cache, - int64 step_id, - RemoteMemoryManager* remote_memory_manager) - : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id), + CollectiveRemoteAccessDistributed( + const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver, + std::shared_ptr work_queue, + WorkerCacheInterface* worker_cache, int64 step_id, + RemoteMemoryManager* remote_memory_manager) + : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, work_queue, step_id), worker_cache_(worker_cache), remote_memory_manager_(remote_memory_manager) {} @@ -152,7 +153,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { CollectiveExecutor* GdrCollectiveExecutorMgr::Create(int64 step_id) { CollectiveRemoteAccessDistributed* rma = new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(), - worker_cache_, step_id, + work_queue_, worker_cache_, step_id, remote_memory_manager_); return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_, &gpu_ring_order_); diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 4744a9ee9a8..51f6201005a 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -163,7 +163,7 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { recv_args, step_id_, parsed.FullKey()); // Record "call" in active_ so that it can be aborted cleanly. - RegisterCall(call); + RegisterCall(call, recv_args); // RendezvousMgr already aborted, shouldn't send RPC call any more if (!call->status().ok()) { diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index f4bed99e2dc..0683a90610b 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -22,7 +22,6 @@ py_library( "util.py", ], srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", @@ -46,7 +45,6 @@ py_library( name = "match", srcs = ["tests/match.py"], srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", "//tensorflow/python:framework_ops", @@ -59,7 +57,6 @@ py_test( srcs = ["tests/util_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", "//tensorflow/python:client_testlib", @@ -73,7 +70,6 @@ py_test( srcs = ["tests/select_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", "//tensorflow/python:client_testlib", @@ -87,7 +83,6 @@ py_test( srcs = ["tests/match_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":match", "//tensorflow/python:client_testlib", @@ -101,7 +96,6 @@ py_test( srcs = ["tests/subgraph_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", "//tensorflow/python:client_testlib", @@ -115,7 +109,6 @@ py_test( srcs = ["tests/reroute_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", ":match", @@ -130,7 +123,6 @@ py_test( srcs = ["tests/edit_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", ":match", @@ -145,7 +137,6 @@ py_test( srcs = ["tests/transform_test.py"], python_version = "PY2", srcs_version = "PY2AND3", - tags = ["no_oss"], # b/133250576, deps = [ ":graph_editor_py", ":match", diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 4b53d182f34..543c1da7e33 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -19,11 +19,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import re from six import iteritems from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops as tf_array_ops +from tensorflow.python.util.compat import collections_abc __all__ = [ "make_list_of_op", @@ -157,7 +157,7 @@ def transform_tree(tree, fn, iterable_type=tuple): res = tree.__new__(type(tree), (transform_tree(child, fn) for child in tree)) return res - elif isinstance(tree, collections.Sequence): + elif isinstance(tree, collections_abc.Sequence): res = tree.__new__(type(tree)) res.__init__(transform_tree(child, fn) for child in tree) return res diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc index 2bf6097d013..243c2a40298 100644 --- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc +++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc @@ -31,12 +31,13 @@ class SequenceFileReader { new io::BufferedInputStream(file, kSequenceFileBufferSize)) {} Status ReadHeader() { - string version; + tstring version; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &version)); - if (version.substr(0, 3) != "SEQ" || version[3] != 6) { + StringPiece version_view(version); + if (version_view.substr(0, 3) != "SEQ" || version[3] != 6) { return errors::InvalidArgument( "sequence file header must starts with `SEQ6`, received \"", - version.substr(0, 3), static_cast(version[3]), "\""); + version_view.substr(0, 3), static_cast(version[3]), "\""); } TF_RETURN_IF_ERROR(ReadString(&key_class_name_)); TF_RETURN_IF_ERROR(ReadString(&value_class_name_)); @@ -50,7 +51,7 @@ class SequenceFileReader { "' is currently not supported"); } - string buffer; + tstring buffer; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(2, &buffer)); compression_ = buffer[0]; block_compression_ = buffer[1]; @@ -84,12 +85,12 @@ class SequenceFileReader { return Status::OK(); } - Status ReadRecord(string* key, string* value) { + Status ReadRecord(tstring* key, tstring* value) { uint32 length = 0; TF_RETURN_IF_ERROR(ReadUInt32(&length)); if (length == static_cast(-1)) { // Sync marker. - string sync_marker; + tstring sync_marker; TF_RETURN_IF_ERROR( input_stream_->ReadNBytes(kSyncMarkerSize, &sync_marker)); if (sync_marker != sync_marker_) { @@ -114,7 +115,7 @@ class SequenceFileReader { return Status::OK(); } - Status ReadString(string* value) { + Status ReadString(tstring* value) { int64 length = 0; TF_RETURN_IF_ERROR(ReadVInt(&length)); if (value == nullptr) { @@ -124,7 +125,7 @@ class SequenceFileReader { } Status ReadUInt32(uint32* value) { - string buffer; + tstring buffer; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(4, &buffer)); *value = ((static_cast(buffer[0]) << 24) | static_cast(buffer[1]) << 16) | @@ -134,7 +135,7 @@ class SequenceFileReader { } Status ReadVInt(int64* value) { - string buffer; + tstring buffer; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(1, &buffer)); if (buffer[0] >= -112) { *value = static_cast(buffer[0]); @@ -167,12 +168,12 @@ class SequenceFileReader { private: std::unique_ptr input_stream_; - string key_class_name_; - string value_class_name_; - string sync_marker_; + tstring key_class_name_; + tstring value_class_name_; + tstring sync_marker_; bool compression_; bool block_compression_; - string compression_codec_class_name_; + tstring compression_codec_class_name_; TF_DISALLOW_COPY_AND_ASSIGN(SequenceFileReader); }; class SequenceFileDatasetOp : public DatasetOpKernel { @@ -198,7 +199,7 @@ class SequenceFileDatasetOp : public DatasetOpKernel { std::vector filenames; filenames.reserve(filenames_tensor->NumElements()); for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat()(i)); + filenames.push_back(filenames_tensor->flat()(i)); } *output = new Dataset(ctx, filenames, output_types_); @@ -233,6 +234,8 @@ class SequenceFileDatasetOp : public DatasetOpKernel { return "SequenceFileDatasetOp::Dataset"; } + Status CheckExternalState() const override { return Status::OK(); } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, @@ -256,17 +259,17 @@ class SequenceFileDatasetOp : public DatasetOpKernel { do { // We are currently processing a file, so try to read the next record. if (reader_) { - string key, value; + tstring key, value; Status status = reader_->ReadRecord(&key, &value); if (!errors::IsOutOfRange(status)) { TF_RETURN_IF_ERROR(status); Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); - key_tensor.scalar()() = key; + key_tensor.scalar()() = std::move(key); out_tensors->emplace_back(std::move(key_tensor)); Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); - value_tensor.scalar()() = value; + value_tensor.scalar()() = std::move(value); out_tensors->emplace_back(std::move(value_tensor)); *end_of_sequence = false; diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc index 4218ec05f2c..41c9a8b1f49 100644 --- a/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.cc @@ -73,7 +73,7 @@ Status BinaryObjectParser::Parse(uint8_t** ptr, } case STRING: { out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({})); - out_tensors->back().scalar()() = ParseString(ptr); + out_tensors->back().scalar()() = ParseString(ptr); break; } case DATE: { @@ -150,7 +150,7 @@ Status BinaryObjectParser::Parse(uint8_t** ptr, out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({length})); for (int32_t i = 0; i < length; i++) - out_tensors->back().vec()(i) = ParseString(ptr); + out_tensors->back().vec()(i) = ParseString(ptr); break; } case DATE_ARR: { diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc index ce8972f1e7f..67a84b99cff 100644 --- a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.cc @@ -379,7 +379,7 @@ Status IgniteDatasetIterator::LoadNextPage() { Status IgniteDatasetIterator::ReceivePage(int32_t page_size) { remainder_ = page_size; - page_ = std::unique_ptr(new uint8_t[remainder_]); + page_ = std::unique_ptr(new uint8_t[remainder_]); ptr_ = page_.get(); uint64 start = Env::Default()->NowMicros(); diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h index 5868c2cb67f..2e5051105a9 100644 --- a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_iterator.h @@ -74,7 +74,7 @@ class IgniteDatasetIterator : public DatasetIterator { mutex mutex_; - std::unique_ptr page_; + std::unique_ptr page_; uint8_t* ptr_; }; diff --git a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc index e3593ac6c7a..c28dbeae079 100644 --- a/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc +++ b/tensorflow/contrib/ignite/kernels/dataset/ignite_dataset_ops.cc @@ -71,8 +71,8 @@ class IgniteDatasetOp : public DatasetOpKernel { using DatasetOpKernel::DatasetOpKernel; void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - string cache_name = ""; - string host = ""; + tstring cache_name = ""; + tstring host = ""; int32 port = -1; bool local = false; int32 part = -1; @@ -96,17 +96,17 @@ class IgniteDatasetOp : public DatasetOpKernel { const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD"); if (env_cache_name) { - cache_name = string(env_cache_name); + cache_name = env_cache_name; } else { - OP_REQUIRES_OK(ctx, data::ParseScalarArgument(ctx, "cache_name", - &cache_name)); + OP_REQUIRES_OK(ctx, data::ParseScalarArgument(ctx, "cache_name", + &cache_name)); } if (env_host) { - host = string(env_host); + host = env_host; } else { OP_REQUIRES_OK(ctx, - data::ParseScalarArgument(ctx, "host", &host)); + data::ParseScalarArgument(ctx, "host", &host)); } if (env_port) { @@ -145,13 +145,13 @@ class IgniteDatasetOp : public DatasetOpKernel { ctx, data::ParseScalarArgument(ctx, "page_size", &page_size)); } - if (env_username) username = string(env_username); + if (env_username) username = env_username; - if (env_password) password = string(env_password); + if (env_password) password = env_password; - if (env_certfile) certfile = string(env_certfile); + if (env_certfile) certfile = env_certfile; - if (env_keyfile) keyfile = string(env_keyfile); + if (env_keyfile) keyfile = env_keyfile; if (env_cert_password) cert_password = string(env_cert_password); diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 200e3476a9e..4b14b9e08cf 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -111,9 +111,6 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], - tags = [ - "notap", # b/136286905 - ], ) tf_custom_op_library( diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc index 93722896233..b9d615613cc 100644 --- a/tensorflow/contrib/image/kernels/segmentation_ops.cc +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -128,7 +128,7 @@ struct ImageConnectedComponentsFunctor { // Connected components (arguably) make sense for number, bool, and string types TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS); TF_CALL_bool(REGISTER_IMAGE_CONNECTED_COMPONENTS); -TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_tstring(REGISTER_IMAGE_CONNECTED_COMPONENTS); #undef REGISTER_IMAGE_CONNECTED_COMPONENTS // TODO(ringwalt): Implement on GPU. We probably want to stick to the original diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 05ba9155c40..96f6af2ac51 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -506,7 +506,7 @@ def connected_components(images): # constructing multiple additional large tensors. components_flat = array_ops.reshape(components, [-1]) unique_ids, id_index = array_ops.unique(components_flat) - id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0] + id_is_zero = array_ops.where_v2(math_ops.equal(unique_ids, 0))[:, 0] # Map each nonzero id to consecutive values. nonzero_consecutive_ids = math_ops.range( array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1 diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 2b0bcf64019..dfc6af3e558 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -48,7 +48,7 @@ def single_image_random_dot_stereograms(depth_values, corrupt the encode 3-D data within the image. Based upon [this - paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). + paper](https://www.cs.waikato.ac.nz/~ihw/papers/94-HWT-SI-IHW-SIRDS-paper.pdf). This outputs a SIRDS image as picture_out.png: diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 777399184e8..4fd9e2c5b95 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -12,7 +12,7 @@ load( "tf_kernel_library", ) load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_kernel_tests_linkstatic", ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc index 886f6798150..d5da76a753f 100644 --- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc +++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc @@ -30,7 +30,7 @@ class ObtainNextOp : public OpKernel { const Tensor* list; OP_REQUIRES_OK(ctx, ctx->input("list", &list)); int64 num_elements = list->NumElements(); - auto list_flat = list->flat(); + auto list_flat = list->flat(); // Allocate output. Tensor* output_tensor = nullptr; @@ -48,7 +48,7 @@ class ObtainNextOp : public OpKernel { *pos = (*pos + 1) % num_elements; // Assign value to output. - output_tensor->scalar()() = list_flat(*pos); + output_tensor->scalar()() = list_flat(*pos); } }; diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index bb0d4c178dc..a3875bb4a19 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -33,15 +33,15 @@ class KafkaDatasetOp : public DatasetOpKernel { std::vector topics; topics.reserve(topics_tensor->NumElements()); for (int i = 0; i < topics_tensor->NumElements(); ++i) { - topics.push_back(topics_tensor->flat()(i)); + topics.push_back(topics_tensor->flat()(i)); } std::string servers = ""; OP_REQUIRES_OK( - ctx, data::ParseScalarArgument(ctx, "servers", &servers)); + ctx, data::ParseScalarArgument(ctx, "servers", &servers)); std::string group = ""; - OP_REQUIRES_OK( - ctx, data::ParseScalarArgument(ctx, "group", &group)); + OP_REQUIRES_OK(ctx, + data::ParseScalarArgument(ctx, "group", &group)); bool eof = false; OP_REQUIRES_OK(ctx, data::ParseScalarArgument(ctx, "eof", &eof)); int64 timeout = -1; @@ -128,9 +128,9 @@ class KafkaDatasetOp : public DatasetOpKernel { if (message->err() == RdKafka::ERR_NO_ERROR) { // Produce the line as output. Tensor line_tensor(cpu_allocator(), DT_STRING, {}); - line_tensor.scalar()() = - std::string(static_cast(message->payload()), - message->len()); + line_tensor.scalar()().assign( + static_cast(message->payload()), + message->len()); out_tensors->emplace_back(std::move(line_tensor)); *end_of_sequence = false; // Sync offset diff --git a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc index 8919d5efedf..88d1aa1bd22 100644 --- a/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc +++ b/tensorflow/contrib/kinesis/kernels/kinesis_dataset_ops.cc @@ -148,11 +148,11 @@ class KinesisDatasetOp : public DatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { std::string stream = ""; - OP_REQUIRES_OK( - ctx, data::ParseScalarArgument(ctx, "stream", &stream)); + OP_REQUIRES_OK(ctx, + data::ParseScalarArgument(ctx, "stream", &stream)); std::string shard = ""; - OP_REQUIRES_OK( - ctx, data::ParseScalarArgument(ctx, "shard", &shard)); + OP_REQUIRES_OK(ctx, + data::ParseScalarArgument(ctx, "shard", &shard)); bool read_indefinitely = true; OP_REQUIRES_OK(ctx, data::ParseScalarArgument( ctx, "read_indefinitely", &read_indefinitely)); diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py index 1783a07fac9..3a257d81887 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py @@ -21,11 +21,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import functools import re from tensorflow.python.util import tf_inspect +from tensorflow.python.util.compat import collections_abc # used for register_type_abbreviation and _type_repr below. _TYPE_ABBREVIATIONS = {} @@ -114,7 +114,7 @@ class Sequence(_SingleArgumentType): """ def __instancecheck__(self, instance): - return (isinstance(instance, collections.Sequence) and + return (isinstance(instance, collections_abc.Sequence) and all(isinstance(x, self._type) for x in instance)) @@ -130,9 +130,9 @@ class Collection(_SingleArgumentType): """ def __instancecheck__(self, instance): - return (isinstance(instance, collections.Iterable) and - isinstance(instance, collections.Sized) and - isinstance(instance, collections.Container) and + return (isinstance(instance, collections_abc.Iterable) and + isinstance(instance, collections_abc.Sized) and + isinstance(instance, collections_abc.Container) and all(isinstance(x, self._type) for x in instance)) @@ -157,7 +157,7 @@ class Mapping(_TwoArgumentType): def __instancecheck__(self, instance): key_type, value_type = self._types # pylint: disable=unbalanced-tuple-unpacking - return (isinstance(instance, collections.Mapping) and + return (isinstance(instance, collections_abc.Mapping) and all(isinstance(k, key_type) for k in instance.keys()) and all(isinstance(k, value_type) for k in instance.values())) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py index b0961e5b3a2..394254cbd90 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -41,11 +41,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util.compat import collections_abc # pylint: disable=invalid-name # Types coercible to Axis.labels -# We use this instead of collections.Sequence to exclude strings. +# We use this instead of collections_abc.Sequence to exclude strings. LabelsLike = tc.Union(np.ndarray, range, list, tuple) # Types coercible to a tf.compat.v1.Dimension @@ -195,7 +196,7 @@ def as_axis(axis_data): return axis -class Axes(collections.Mapping): +class Axes(collections_abc.Mapping): """Axis names and indices for a tensor. It is an ordered mapping, with keys given by axis name and values given @@ -719,7 +720,7 @@ def transpose(labeled_tensor, axis_order=None, name=None): @tc.accepts(LabeledTensorLike, tc.Collection( tc.Union(string_types, - tc.Tuple(string_types, collections.Hashable))), + tc.Tuple(string_types, collections_abc.Hashable))), tc.Optional(string_types)) def expand_dims(labeled_tensor, axes, name=None): """Insert dimensions of size 1. @@ -1055,7 +1056,7 @@ def align(labeled_tensor_0, labeled_tensor_1, name=None): @tc.returns(types.FunctionType) -@tc.accepts(string_types, collections.Callable) +@tc.accepts(string_types, collections_abc.Callable) def define_unary_op(op_name, elementwise_function): """Define a unary operation for labeled tensors. @@ -1124,7 +1125,7 @@ sigmoid = define_unary_op('sigmoid', math_ops.sigmoid) @tc.returns(types.FunctionType) -@tc.accepts(string_types, collections.Callable) +@tc.accepts(string_types, collections_abc.Callable) def define_binary_op(op_name, elementwise_function): """Define a binary operation that broadcasts labeled tensors. diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index a04e3772799..35ab141a18f 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import types import numpy as np @@ -34,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics from tensorflow.python.ops import random_ops from tensorflow.python.training import input # pylint: disable=redefined-builtin +from tensorflow.python.util.compat import collections_abc @tc.returns(core.LabeledTensor) @@ -52,7 +52,7 @@ def _gather_1d_on_axis(labeled_tensor, indexer, axis, name=None): @tc.returns(core.LabeledTensor) @tc.accepts(core.LabeledTensorLike, tc.Mapping(string_types, - tc.Union(slice, collections.Hashable, list)), + tc.Union(slice, collections_abc.Hashable, list)), tc.Optional(string_types)) def select(labeled_tensor, selection, name=None): """Slice out a subset of the tensor. @@ -111,8 +111,8 @@ def select(labeled_tensor, selection, name=None): slices[axis_name] = slice(start, stop) # Needs to be after checking for slices, since slice objects claim to be - # instances of collections.Hashable but hash() on them fails. - elif isinstance(value, collections.Hashable): + # instances of collections_abc.Hashable but hash() on them fails. + elif isinstance(value, collections_abc.Hashable): slices[axis_name] = axis.index(value) elif isinstance(value, list): @@ -400,7 +400,7 @@ def rename_axis(labeled_tensor, existing_name, new_name, name=None): @tc.returns(tc.List(core.LabeledTensor)) -@tc.accepts(string_types, collections.Callable, int, bool, +@tc.accepts(string_types, collections_abc.Callable, int, bool, tc.Collection(core.LabeledTensorLike), bool, tc.Optional(string_types)) def _batch_helper(default_name, @@ -606,7 +606,7 @@ def random_crop(labeled_tensor, shape_map, seed=None, name=None): # TODO(shoyer): Allow the user to select the axis over which to map. @tc.returns(core.LabeledTensor) -@tc.accepts(collections.Callable, core.LabeledTensorLike, +@tc.accepts(collections_abc.Callable, core.LabeledTensorLike, tc.Optional(string_types)) def map_fn(fn, labeled_tensor, name=None): """Map on the list of tensors unpacked from labeled_tensor. @@ -661,7 +661,7 @@ def map_fn(fn, labeled_tensor, name=None): @tc.returns(core.LabeledTensor) -@tc.accepts(collections.Callable, core.LabeledTensorLike, +@tc.accepts(collections_abc.Callable, core.LabeledTensorLike, core.LabeledTensorLike, tc.Optional(string_types)) def foldl(fn, labeled_tensor, initial_value, name=None): """Left fold on the list of tensors unpacked from labeled_tensor. @@ -754,7 +754,7 @@ def squeeze(labeled_tensor, axis_names=None, name=None): # pylint: disable=invalid-name ReduceAxis = tc.Union(string_types, - tc.Tuple(string_types, collections.Hashable)) + tc.Tuple(string_types, collections_abc.Hashable)) ReduceAxes = tc.Optional(tc.Union(ReduceAxis, tc.Collection(ReduceAxis))) # pylint: enable=invalid-name @@ -876,7 +876,7 @@ def matmul(a, b, name=None): @tc.returns(types.FunctionType) -@tc.accepts(string_types, collections.Callable) +@tc.accepts(string_types, collections_abc.Callable) def define_reduce_op(op_name, reduce_fn): """Define a reduction op for labeled tensors. diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 8e410006c16..6010b072418 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -77,6 +77,8 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", visibility = [ "//learning/brain:__subpackages__", + "//learning/lib/ami/simple_ml/link_other_ml_tools/tensorflow:__subpackages__", + "//storage/d/analysis/prefetch:__pkg__", "//tensorflow:__subpackages__", "//tensorflow_model_optimization:__subpackages__", "//third_party/py/tf_slim:__subpackages__", @@ -154,6 +156,7 @@ cuda_py_test( "//tensorflow/python:variables", "//tensorflow/python/ops/losses:losses", ], + xla_enable_strict_auto_jit = False, ) py_test( diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index ee4b0373ef7..3fe4bd48748 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -78,16 +78,16 @@ template <> int64 SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) - return Fingerprint64(values_.vec().data()[start + n]); + return Fingerprint64(values_.vec().data()[start + n]); return values_.vec().data()[start + n]; } // InternalType is string or StringPiece when using StringCrosser. template <> -string SparseTensorColumn::Feature(int64 batch, int64 n) const { +tstring SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; if (DT_STRING == values_.dtype()) - return values_.vec().data()[start + n]; + return values_.vec().data()[start + n]; return std::to_string(values_.vec().data()[start + n]); } @@ -95,7 +95,7 @@ template <> StringPiece SparseTensorColumn::Feature(int64 batch, int64 n) const { const int64 start = feature_start_indices_[batch]; - return values_.vec().data()[start + n]; + return values_.vec().data()[start + n]; } // A column that is backed by a dense tensor. @@ -118,21 +118,21 @@ class DenseTensorColumn : public ColumnInterface { template <> int64 DenseTensorColumn::Feature(int64 batch, int64 n) const { if (DT_STRING == tensor_.dtype()) - return Fingerprint64(tensor_.matrix()(batch, n)); + return Fingerprint64(tensor_.matrix()(batch, n)); return tensor_.matrix()(batch, n); } // Internal type is string or StringPiece when using StringCrosser. template <> -string DenseTensorColumn::Feature(int64 batch, int64 n) const { - if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); +tstring DenseTensorColumn::Feature(int64 batch, int64 n) const { + if (DT_STRING == tensor_.dtype()) return tensor_.matrix()(batch, n); return std::to_string(tensor_.matrix()(batch, n)); } template <> StringPiece DenseTensorColumn::Feature(int64 batch, int64 n) const { - return tensor_.matrix()(batch, n); + return tensor_.matrix()(batch, n); } // Updates Output tensors with sparse crosses. @@ -310,7 +310,7 @@ struct CrossTraits; template struct CrossTraits { typedef StringCrosser Crosser; - typedef OutputUpdater Updater; + typedef OutputUpdater Updater; }; template <> @@ -598,20 +598,20 @@ class SparseFeatureCrossOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") .Device(DEVICE_CPU) - .TypeConstraint("out_type") - .TypeConstraint("internal_type"), + .TypeConstraint("out_type") + .TypeConstraint("internal_type"), SparseFeatureCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") .Device(DEVICE_CPU) - .TypeConstraint("out_type") + .TypeConstraint("out_type") .TypeConstraint("internal_type"), - SparseFeatureCrossOp); + SparseFeatureCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") .Device(DEVICE_CPU) .TypeConstraint("out_type") - .TypeConstraint("internal_type"), + .TypeConstraint("internal_type"), SparseFeatureCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") @@ -624,20 +624,20 @@ REGISTER_KERNEL_BUILDER(Name("SparseFeatureCross") // crosses features. REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") .Device(DEVICE_CPU) - .TypeConstraint("out_type") - .TypeConstraint("internal_type"), + .TypeConstraint("out_type") + .TypeConstraint("internal_type"), SparseFeatureCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") .Device(DEVICE_CPU) - .TypeConstraint("out_type") + .TypeConstraint("out_type") .TypeConstraint("internal_type"), - SparseFeatureCrossOp); + SparseFeatureCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") .Device(DEVICE_CPU) .TypeConstraint("out_type") - .TypeConstraint("internal_type"), + .TypeConstraint("internal_type"), SparseFeatureCrossOp); REGISTER_KERNEL_BUILDER(Name("SparseFeatureCrossV2") diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index e47a52a7072..385dcc0d80a 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -155,6 +155,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation from tensorflow.python.util import nest +from tensorflow.python.util.compat import collections_abc # Imports the core `InputLayer` symbol in contrib during development. InputLayer = fc_core.InputLayer # pylint: disable=invalid-name @@ -1403,7 +1404,7 @@ def shared_embedding_columns(sparse_id_columns, least one element of `sparse_id_columns` is not a `SparseColumn` or a `WeightedSparseColumn`. """ - if (not isinstance(sparse_id_columns, collections.Sequence) or + if (not isinstance(sparse_id_columns, collections_abc.Sequence) or isinstance(sparse_id_columns, six.string_types)): raise TypeError( "sparse_id_columns must be a non-string sequence (ex: list or tuple) " diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 65e8d75e5c5..d48edc027a2 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -25,6 +25,7 @@ py_library( srcs_version = "PY2AND3", visibility = [ "//learning/brain:__subpackages__", + "//storage/d/analysis/prefetch:__pkg__", "//tensorflow:__subpackages__", "//video/youtube/personalization:__subpackages__", ], diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 99f22d182cd..a15bbce515b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -19,12 +19,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import os import numpy as np import six +from tensorflow.python.util.compat import collections_abc + def _pprint(d): return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()]) @@ -55,7 +56,7 @@ class _BaseEstimator(object): for key in param_names: value = getattr(self, key, None) - if isinstance(value, collections.Callable): + if isinstance(value, collections_abc.Callable): continue # XXX: should we rather test if instance of estimator? diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 5ce5c02cc63..fcabbf69425 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -162,7 +162,7 @@ class ModelFnOps( loss_shape = loss.get_shape() if loss_shape.num_elements() not in (None, 1): raise ValueError('Loss must be scalar: %s.' % loss) - if not loss_shape.is_compatible_with(tensor_shape.scalar()): + if not loss_shape.is_compatible_with(tensor_shape.TensorShape([])): loss = array_ops.reshape(loss, []) # Validate predictions. diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc index 720c74e3de5..f35453f267e 100644 --- a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc +++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc @@ -36,7 +36,7 @@ class DecodeLibsvmOp : public OpKernel { void Compute(OpKernelContext* ctx) override { const Tensor* input_tensor; OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); - const auto& input_flat = input_tensor->flat(); + const auto& input_flat = input_tensor->flat(); Tensor* label_tensor; OP_REQUIRES_OK( diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index fa8dad938d7..8e75fcb666a 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -74,7 +74,7 @@ HOST_GENDIR := $(MAKEFILE_DIR)/gen/host_obj/ # Settings for the host compiler. HOST_CXX := $(CC_PREFIX) gcc -HOST_CXXFLAGS := --std=c++11 +HOST_CXXFLAGS := --std=c++14 HOST_LDOPTS := ifeq ($(HAS_GEN_HOST_PROTOC),true) HOST_LDOPTS += -L$(MAKEFILE_DIR)/gen/protobuf-host/lib @@ -185,7 +185,7 @@ ifneq ($(TARGET),ANDROID) OPTFLAGS += -march=native endif -CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG $(OPTFLAGS) +CXXFLAGS := --std=c++14 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG $(OPTFLAGS) LDFLAGS := \ -L/usr/local/lib DEPFLAGS = -MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td @@ -416,7 +416,7 @@ $(MARCH_OPTION) \ ifeq ($(BUILD_FOR_TEGRA),1) NVCC := $(JETPACK)/cuda/bin/nvcc - NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DANDROID_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++11 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX11 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3 + NVCCFLAGS := -x=cu -D__CUDACC__ -DNVCC -DANDROID_TEGRA -ccbin $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ --std c++14 --expt-relaxed-constexpr -m64 -gencode arch=compute_53,\"code=sm_53\" -gencode arch=compute_62,\"code=sm_62\" -DEIGEN_AVOID_STL_ARRAY -DTENSORFLOW_USE_EIGEN_THREADPOOL -DLANG_CXX14 -DEIGEN_HAS_C99_MATH -DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=5.3 CXXFLAGS4NVCC =\ -DIS_SLIM_BUILD \ -DANDROID_TEGRA \ @@ -433,7 +433,7 @@ $(MARCH_OPTION) \ -DANDROID_TEGRA \ -DEIGEN_AVOID_STL_ARRAY \ -DEIGEN_HAS_C99_MATH \ --DLANG_CXX11 -DTENSORFLOW_USE_EIGEN_THREADPOOL -DTF_EXTRA_CUDA_CAPABILITIES=5.3 +-DLANG_CXX14 -DTENSORFLOW_USE_EIGEN_THREADPOOL -DTF_EXTRA_CUDA_CAPABILITIES=5.3 INCLUDES += \ -Itensorflow/core/kernels \ @@ -655,8 +655,7 @@ $(wildcard tensorflow/core/util/*/*.cc) \ $(wildcard tensorflow/contrib/makefile/downloads/double_conversion/double-conversion/*.cc) \ tensorflow/core/profiler/internal/profiler_interface.cc \ tensorflow/core/profiler/internal/traceme_recorder.cc \ -tensorflow/core/profiler/lib/profiler_session.cc \ -tensorflow/core/profiler/lib/traceme.cc \ +$(wildcard tensorflow/core/profiler/lib/*.cc) \ tensorflow/core/util/version_info.cc # Remove duplicates (for version_info.cc) CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) @@ -860,7 +859,7 @@ $(OBJDIR)%.o: %.cc | $(PBT_GEN_FILES) $(OBJDIR)%.o: %.c @mkdir -p $(dir $@) @mkdir -p $(dir $(DEPDIR)$*) - $(CXX) $(patsubst --std=c++11,--std=c99, $(CXXFLAGS)) -x c $(DEPFLAGS) \ + $(CXX) $(patsubst --std=c++14,--std=c99, $(CXXFLAGS)) -x c $(DEPFLAGS) \ $(INCLUDES) -c $< -o $@ @mv -f $(DEPDIR)/$*.Td $(DEPDIR)/$*.d diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 1293e59cbcb..7ace5d970ac 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -87,9 +87,11 @@ need to install the standalone toolchain, however. Assign your NDK location to $NDK_ROOT: ```bash -export NDK_ROOT=/absolute/path/to/NDK/android-ndk-rxxx/ +export NDK_ROOT=/absolute/path/to/NDK/android-ndk-r14b ``` +Note : libtensorflow-core.a cannot be compiled with any ndk version above r14b. + Download the graph if you haven't already: ```bash diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 1feca44f6e5..6cf1145021c 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -27,9 +27,9 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'https://bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -GEMMLOWP_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" -NSYNC_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +NSYNC_URL="$(grep -o 'https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" # Note: The protobuf repo needs to be cloned due to its submodules. # These variables contain the GitHub repo and the sha, from `tensorflow/workspace.bzl`, @@ -37,7 +37,7 @@ NSYNC_URL="$(grep -o 'http://mirror.tensorflow.org/github.com/google/nsync/.*tar readonly PROTOBUF_REPO="https://github.com/protocolbuffers/protobuf.git" readonly PROTOBUF_TAG="$(grep -o 'https://github.com/protocolbuffers/protobuf/archive/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1 | awk '{print substr($0, index($0, "archive") + 8, index($0, "tar") - index($0, "archive") - 9) }')" -# TODO (yongtang): Replace the following with 'http://mirror.tensorflow.org/github.com/google/re2/.*tar\.gz' once +# TODO (yongtang): Replace the following with 'https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/re2/.*tar\.gz' once # the archive has been propagated in mirror.tensorflow.org. RE2_URL="$(grep -o 'https://github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft2d\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" @@ -46,8 +46,8 @@ ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_ CUB_URL="$(grep -o 'https.*cub/archive.*zip' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" # Required for TensorFlow Lite Flex runtime. -FARMHASH_URL="http://mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" -FLATBUFFERS_URL="http://mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz" +FARMHASH_URL="https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" +FLATBUFFERS_URL="https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v1.11.0.tar.gz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -140,7 +140,7 @@ replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#s replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" # TODO(satok): Remove this once protobuf/autogen.sh is fixed. -replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#http://download.tensorflow.org/deps/gmock-1.7.0.zip#' \ +replace_by_sed 's#https://googlemock.googlecode.com/files/gmock-1.7.0.zip#https://storage.googleapis.com/download.tensorflow.org/deps/gmock-1.7.0.zip#' \ "${DOWNLOADS_DIR}/protobuf/autogen.sh" cat "third_party/eigen3/gebp_neon.patch" | patch "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/products/GeneralBlockPanelKernel.h" diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt index d7ad266f678..95f2d186dc5 100644 --- a/tensorflow/contrib/makefile/proto_text_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt @@ -30,13 +30,9 @@ tensorflow/core/lib/random/distribution_sampler.cc tensorflow/core/lib/random/random.cc tensorflow/core/lib/random/simple_philox.cc tensorflow/core/lib/random/weighted_picker.cc -tensorflow/core/lib/strings/numbers.cc tensorflow/core/lib/strings/ordered_code.cc tensorflow/core/lib/strings/proto_text_util.cc -tensorflow/core/lib/strings/scanner.cc -tensorflow/core/lib/strings/str_util.cc tensorflow/core/lib/strings/strcat.cc -tensorflow/core/lib/strings/stringprintf.cc tensorflow/core/lib/wav/wav_io.cc tensorflow/core/platform/cpu_info.cc tensorflow/core/platform/default/logging.cc @@ -44,9 +40,9 @@ tensorflow/core/platform/default/mutex.cc tensorflow/core/platform/default/tracing.cc tensorflow/core/platform/denormal.cc tensorflow/core/platform/env.cc -tensorflow/core/platform/env_time.cc tensorflow/core/platform/file_system.cc tensorflow/core/platform/file_system_helper.cc +tensorflow/core/platform/numbers.cc tensorflow/core/platform/posix/env.cc tensorflow/core/platform/posix/env_time.cc tensorflow/core/platform/posix/error.cc @@ -55,7 +51,10 @@ tensorflow/core/platform/posix/port.cc tensorflow/core/platform/posix/posix_file_system.cc tensorflow/core/platform/protobuf.cc tensorflow/core/platform/protobuf_util.cc +tensorflow/core/platform/scanner.cc tensorflow/core/platform/setround.cc +tensorflow/core/platform/stringprintf.cc +tensorflow/core/platform/str_util.cc tensorflow/core/platform/tensor_coding.cc tensorflow/core/platform/tracing.cc tensorflow/tools/proto_text/gen_proto_text_functions.cc diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index e284353f2b0..73e19c0814a 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -129,6 +129,7 @@ tensorflow/core/kernels/function_ops.cc tensorflow/core/kernels/fused_batch_norm_op.cc tensorflow/core/kernels/fused_eigen_output_kernels.cc tensorflow/core/kernels/gather_functor.cc +tensorflow/core/kernels/gather_functor_batched.cc tensorflow/core/kernels/gather_nd_op.cc tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc @@ -245,6 +246,7 @@ tensorflow/core/kernels/slice_op_cpu_impl_4.cc tensorflow/core/kernels/slice_op_cpu_impl_5.cc tensorflow/core/kernels/slice_op_cpu_impl_6.cc tensorflow/core/kernels/slice_op_cpu_impl_7.cc +tensorflow/core/kernels/slice_op_cpu_impl_8.cc tensorflow/core/kernels/softmax_op.cc tensorflow/core/kernels/softplus_op.cc tensorflow/core/kernels/softsign_op.cc @@ -273,6 +275,7 @@ tensorflow/core/kernels/strided_slice_op_inst_4.cc tensorflow/core/kernels/strided_slice_op_inst_5.cc tensorflow/core/kernels/strided_slice_op_inst_6.cc tensorflow/core/kernels/strided_slice_op_inst_7.cc +tensorflow/core/kernels/strided_slice_op_inst_8.cc tensorflow/core/kernels/string_join_op.cc tensorflow/core/kernels/string_util.cc tensorflow/core/kernels/tensor_array.cc @@ -297,6 +300,46 @@ tensorflow/core/kernels/variable_ops.cc tensorflow/core/kernels/where_op.cc tensorflow/core/kernels/xent_op.cc tensorflow/core/kernels/xsmm_conv2d.cc +tensorflow/core/kernels/data/batch_dataset_op.cc +tensorflow/core/kernels/data/cache_dataset_ops.cc +tensorflow/core/kernels/data/cache_ops.cc +tensorflow/core/kernels/data/captured_function.cc +tensorflow/core/kernels/data/concatenate_dataset_op.cc +tensorflow/core/kernels/data/dataset_utils.cc +tensorflow/core/kernels/data/filter_dataset_op.cc +tensorflow/core/kernels/data/flat_map_dataset_op.cc +tensorflow/core/kernels/data/generator_dataset_op.cc +tensorflow/core/kernels/data/interleave_dataset_op.cc +tensorflow/core/kernels/data/iterator_ops.cc +tensorflow/core/kernels/data/map_dataset_op.cc +tensorflow/core/kernels/data/map_defun_op.cc +tensorflow/core/kernels/data/model_dataset_op.cc +tensorflow/core/kernels/data/multi_device_iterator_ops.cc +tensorflow/core/kernels/data/name_utils.cc +tensorflow/core/kernels/data/optional_ops.cc +tensorflow/core/kernels/data/optional_ops.cu.cc +tensorflow/core/kernels/data/padded_batch_dataset_op.cc +tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +tensorflow/core/kernels/data/parallel_map_dataset_op.cc +tensorflow/core/kernels/data/parallel_map_iterator.cc +tensorflow/core/kernels/data/prefetch_autotuner.cc +tensorflow/core/kernels/data/prefetch_dataset_op.cc +tensorflow/core/kernels/data/random_seed_ops.cc +tensorflow/core/kernels/data/range_dataset_op.cc +tensorflow/core/kernels/data/repeat_dataset_op.cc +tensorflow/core/kernels/data/shard_dataset_op.cc +tensorflow/core/kernels/data/shuffle_dataset_op.cc +tensorflow/core/kernels/data/single_threaded_executor.cc +tensorflow/core/kernels/data/skip_dataset_op.cc +tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +tensorflow/core/kernels/data/stats_utils.cc +tensorflow/core/kernels/data/take_dataset_op.cc +tensorflow/core/kernels/data/tensor_dataset_op.cc +tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +tensorflow/core/kernels/data/unbounded_thread_pool.cc +tensorflow/core/kernels/data/window_dataset.cc +tensorflow/core/kernels/data/window_dataset_op.cc +tensorflow/core/kernels/data/zip_dataset_op.cc tensorflow/core/ops/array_grad.cc tensorflow/core/ops/array_ops.cc tensorflow/core/ops/audio_ops.cc diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 352b2d61084..765c93b06e5 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -102,4 +102,5 @@ cuda_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", ], + xla_enable_strict_auto_jit = False, ) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index eae04c7ba3e..e46263b48a6 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1161,8 +1161,9 @@ def streaming_dynamic_auc(labels, and performing the final calculation using all of the concatenated values. Args: - labels: A `Tensor` of ground truth labels with the same shape as `labels` - and with values of 0 or 1 whose values are castable to `int64`. + labels: A `Tensor` of ground truth labels with the same shape as + `predictions` and with values of 0 or 1 whose values are castable to + `int64`. predictions: A `Tensor` of predictions whose values are castable to `float64`. Will be flattened into a 1-D `Tensor`. curve: The name of the curve for which to compute AUC, 'ROC' for the @@ -3640,7 +3641,8 @@ def streaming_concat(values, next_shape = array_ops.stack([next_size] + fixed_shape) new_value = array_ops.zeros(next_shape, dtype=values.dtype) old_value = array.value() - assign_op = state_ops.assign(array, new_value, validate_shape=False) + with ops.control_dependencies([old_value]): + assign_op = state_ops.assign(array, new_value, validate_shape=False) with ops.control_dependencies([assign_op]): copy_op = array[:size].assign(old_value[:size]) # return value needs to be the same dtype as no_op() for cond diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index aec07241e7a..906bebe3b82 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -1734,9 +1735,10 @@ class StreamingAUCTest(test.TestCase): predictions = constant_op.constant( [1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) - _, update_op = metrics.streaming_auc(predictions, labels) - sess.run(variables.local_variables_initializer()) - self.assertRaises(errors_impl.InvalidArgumentError, update_op.eval) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + r'predictions must be in \[0, 1\]'): + _, _ = metrics.streaming_auc(predictions, labels) + # Error detected statically; no need to run the op. def testAllCorrect(self): self.allCorrectAsExpected('ROC') @@ -6718,6 +6720,7 @@ class StreamingConcatTest(test.TestCase): def setUp(self): ops.reset_default_graph() + variable_scope.enable_resource_variables() def testVars(self): metrics.streaming_concat(values=array_ops.ones((10,))) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 388384a492f..30375c7f56e 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -172,9 +172,11 @@ def get_pruning_hparams(): nbins: integer number of bins to use for histogram computation block_height: integer - number of rows in a block (defaults to 1) + number of rows in a block (defaults to 1), can be -1 in which + case it is set to the size of the corresponding weight tensor. block_width: integer - number of cols in a block (defaults to 1) + number of cols in a block (defaults to 1), can be -1 in which + case it is set to the size of the corresponding weight tensor. block_pooling_function: string Whether to perform average (AVG) or max (MAX) pooling in the block (default: AVG) @@ -489,6 +491,10 @@ class Pruning(object): if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]: return self._update_mask(weights, threshold) + for i in range(2): + if block_dims[i] == -1: + block_dims[i] = squeezed_weights.get_shape()[i] + if self._block_pooling_function not in ['AVG', 'MAX']: raise ValueError('Unknown pooling function for block sparsity: %s' % self._block_pooling_function) diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 58080ad050d..1a925caab96 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -129,7 +129,7 @@ class PruningTest(test.TestCase): mask_val = new_mask.eval() self.assertAllEqual(mask_val, expected_mask) - def testBlockMasking(self): + def testBlockMaskingWithNonnegativeBlockDimensions(self): param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] weights_avg = constant_op.constant( @@ -146,6 +146,25 @@ class PruningTest(test.TestCase): self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg, expected_mask) + def testBlockMaskingWithNegativeBlockDimensions(self): + param_list = ["block_height=1", "block_width=-1", "threshold_decay=0"] + + weights_avg = constant_op.constant([[0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2], + [0.3, 0.3, 0.3, 0.3], + [0.3, 0.3, 0.4, 0.4]]) + weights_max = constant_op.constant([[0.1, 0.0, 0.1, 0.0], + [0.0, 0.1, 0.0, 0.2], + [0.3, 0.0, 0.3, 0.0], + [0.0, -0.3, 0.0, 0.4]]) + expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], + [1., 1., 1., 1.], [1., 1., 1., 1.]] + + self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, + expected_mask) + self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg, + expected_mask) + def testBlockMaskingWithHigherDimensions(self): param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] diff --git a/tensorflow/contrib/mpi/BUILD b/tensorflow/contrib/mpi/BUILD deleted file mode 100644 index 23f90cf77ef..00000000000 --- a/tensorflow/contrib/mpi/BUILD +++ /dev/null @@ -1,93 +0,0 @@ -# Description: -# MPI based communication interfaces and implementations for TensorFlow. - -package(default_visibility = [ - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -# For platform specific build config -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_proto_library_cc", -) - -tf_proto_library_cc( - name = "mpi_msg_proto", - srcs = ["mpi_msg.proto"], - cc_api_version = 2, - protodeps = ["//tensorflow/core:worker_proto"], - visibility = [ - "//tensorflow:__subpackages__", - ], -) - -cc_library( - name = "mpi_utils", - srcs = ["mpi_utils.cc"], - hdrs = ["mpi_utils.h"], - deps = [ - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/mpi", - ], -) - -cc_library( - name = "mpi_rendezvous_mgr", - srcs = ["mpi_rendezvous_mgr.cc"], - hdrs = ["mpi_rendezvous_mgr.h"], - deps = [ - ":mpi_msg_proto_cc", - ":mpi_utils", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:protos_cc", - "//tensorflow/core:worker_proto_cc", - "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", - "//tensorflow/core/distributed_runtime:recent_request_ids", - "//tensorflow/core/distributed_runtime:request_id", - "//tensorflow/core/distributed_runtime:session_mgr", - "//tensorflow/core/distributed_runtime:tensor_coding", - "//tensorflow/core/distributed_runtime:worker_env", - "//third_party/mpi", - ], -) - -cc_library( - name = "mpi_server_lib", - srcs = ["mpi_server_lib.cc"], - hdrs = ["mpi_server_lib.h"], - linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel - deps = [ - ":mpi_rendezvous_mgr", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - ], - alwayslink = 1, -) diff --git a/tensorflow/contrib/mpi/README.md b/tensorflow/contrib/mpi/README.md deleted file mode 100644 index 75cb8230483..00000000000 --- a/tensorflow/contrib/mpi/README.md +++ /dev/null @@ -1,94 +0,0 @@ -## How to compile and use MPI-enabled TensorFlow - -1. Follow the regular TF compilation instructions. During configure step, if you want MPI support, answer yes to this question: - - ```Do you wish to build TensorFlow with MPI support [y/N]``` - -2. To turn on the MPI connection, add the protocol "grpc+mpi" in the server definition: - - ```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+mpi') # default protocol is 'grpc'``` - -## Overview - -By using this protocol TensorFlow can take advantage of the high performance networking primitives that are offered via the MPI API. This enables TensorFlow to take advantage of high performance low latency networks such as Infiniband. These changes are largely transparent to the user who only has to change the offered protocol and launch the script using the 'mpirun' launcher. For example: - ```mpirun -np 2 python my_neuralnet.py ``` - - - - - -## Runtime options - -The following environment variables can be set to modify the behavior at runtime: - -**MPI_DISABLED=[0,1]** - -This environment variable allows you to disable the MPI path before launch (e.g. for performance or correctness testing). - -**MPI_OPTIMAL_PATH=[0,1]** - -When set to 0 it will use the default path where tensors are encoded to ProtoText before being copied to a remote process. When set to 1 a more optimal path will be taken where only the tensor description is encoded while the actual tensor data is transferred directly from the source buffer to the destination buffer. -This path is disabled by default as it requires that the MPI library can directly access the pointer to the data. For CPU backed buffers this is no problem, however for GPU backed buffers this requires MPI libraries that are built with CUDA support (CUDA Aware). When using non-CUDA aware MPI libraries and GPU buffers you will get segmentation faults. - - - -## Known problems - -For certain complex neural nets the implementation sometimes crashes inside the MPI libraries. This seems to be related to memory allocations/routines that register the memory for the Infiniband transfers. (The crashes do not happen when all MPI processes are within the same physical machine). - -**MVAPICH** -- The problem manifests itself with a segmentation fault inside a memory copy routine and during startup you will get the following warning: "WARNING: Error in initializing MVAPICH2 ptmalloc library. Continuing without InfiniBand registration cache support." - -**OpenMPI** -- With OpenMPI corrupt data will be received resulting in an assertion or the MPI library will print an error and exit. The error is "Attempt to free memory that is still in use by an ongoing MPI communication. MPI job will now abort." - -## Implementation details - - -The implementation takes over the responsibility for sending and receiving tensors between separate processes. This is facilitated by TensorFlow's ability to support different protocols. In this particular implementation, the standard gRPC library is used for all administrative operations while the MPI functions take over the tensor exchanges. On the sending side the tensors are placed in the standard waiting tables and nothing is changed there. On the receiving side the RecvFromRemoteAsync function is newly implemented and instead of requesting the data via gRPC the data is now requested via MPI calls. - -To this end once the code is loaded a dedicated thread will be launched that handles all MPI operations. This thread will loop through a set of operations: - -* Send requests placed on the request queue to the sending process -Once a request for a tensor is received two callbacks are created. The first one is to request the tensor and the second one is executed once the requested data has arrived. To this end the request is placed in a queue and will be sent once the MPI thread services the queue. This sending is done using non-blocking MPI_Isend operations. - -* Send tensor data in response to a request call -Once a request has arrived from a remote process the request is forwarded to the original TensorFlow code which looks up the tensor in the waiting table. Once the tensor has been found a callback is executed which places the found tensor on the sendQueue for the MPI thread. Once the sendQueue is served the tensor data will be send using non-blocking send operations (MP_Isend) to the remote process. - -* Receive tensor request -The MPI thread will check if there are any incoming tensor request messages on the communication lines using MPI_Iprobe. Once a request has been received it will be passed on to the standard TensorFlow code and eventually will be placed on the sendQueue. - -* Receive tensor -At some point after a request has been sent the remote process will transmit the tensor. This tensor will be received and we look-up the callback that is associated with this tensor in our request table and execute the callback on the received data. - - -In the implementation all send operations are non-blocking, all probe operations are non-blocking and all receive-operations are blocking. The receive-operations are only executed after the probe has determined that there is something to receive. -The MPI processes identify each other using an MPI process ID. The TensorFlow gRPC processes identify each other using a name. During launch we create a mapping between the TensorFlow process name and the MPI process ID to allow the processes to communicate with the correct destinations when using MPI operations. - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/contrib/mpi/mpi_msg.proto b/tensorflow/contrib/mpi/mpi_msg.proto deleted file mode 100644 index 36f1504901c..00000000000 --- a/tensorflow/contrib/mpi/mpi_msg.proto +++ /dev/null @@ -1,19 +0,0 @@ - -syntax = "proto3"; - -package tensorflow; -option cc_enable_arenas = true; - -import "tensorflow/core/protobuf/worker.proto"; - - -message MPIRecvTensorResponse { - RecvTensorResponse response = 1; - bool singleSend = 2; - string key = 3; - int64 step_id = 4; - uint64 checksum = 5; -} - - - diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc deleted file mode 100644 index c2e1edb1366..00000000000 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.cc +++ /dev/null @@ -1,321 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#include "tensorflow/contrib/mpi/mpi_rendezvous_mgr.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_mgr.h" -#include "tensorflow/core/common_runtime/gpu/gpu_util.h" -#include "tensorflow/core/distributed_runtime/session_mgr.h" -#include "tensorflow/core/distributed_runtime/tensor_coding.h" -#include "tensorflow/core/framework/allocator.h" - -namespace tensorflow { - -MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env) - : BaseRendezvousMgr(env), - worker_env_2(env), - use_optimal_transfer_(false), - recv_tensor_recent_request_ids_(100000) { - const char* mpienv = getenv("MPI_OPTIMAL_PATH"); - if (mpienv && mpienv[0] == '1') { - LOG(INFO) << "MPI Optimal copy path enabled (Requires CUDA-Aware MPI when " - "using GPUs)\n"; - use_optimal_transfer_ = true; - } - - // extract worker-name - auto parsed = env->local_devices[0]->parsed_name(); - const std::string task_id = - strings::StrCat(parsed.job, ":", parsed.replica, ":", parsed.task); - - mpiutils_ = new MPIUtils(task_id); - background_thread_ = - std::thread(&MPIRendezvousMgr::MPIBackgroundThread, this); -} - -BaseRemoteRendezvous* MPIRendezvousMgr::Create(int64 step_id, - const WorkerEnv* worker_env) { - return new MPIRemoteRendezvous(worker_env, step_id, mpiutils_, this); -} - -void MPIRemoteRendezvous::RecvFromRemoteAsync( - const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, - DoneCallback done) { - Status s = Status::OK(); - MPIRequestTensorCall* rendezvous_call = new MPIRequestTensorCall(); - - VLOG(2) << "MPI User requested " << parsed.FullKey() - << " @ step: " << step_id_; - - std::string src_task = strings::StrCat( - parsed.src.job, ":", parsed.src.replica, ":", parsed.src.task); - const int dst = mpiutils_->GetSourceID(src_task); - - Device* dst_device; - if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); - CHECK(s.ok()) << "Device lookup failed"; - } else { - done(s, Args(), recv_args, Tensor{}, false); - return; - } - - // Set properties of the request object and create the request function - rendezvous_call->Init(parsed, step_id_); - - std::function request_call = [parsed, dst, rendezvous_call]() { - // Use MPI_Alloc_mem here to force allocation inside MPI thread - // this is not optimal, but prevents memory corruption and segmentation - // faults during inter-server transfers... - MPI_CHECK(MPI_Alloc_mem(rendezvous_call->request_buffer_size_, - MPI_INFO_NULL, &rendezvous_call->request_buffer_)); - rendezvous_call->req_.SerializeToArray( - rendezvous_call->request_buffer_, - rendezvous_call->request_buffer_size_); - MPI_CHECK(MPI_Isend(rendezvous_call->request_buffer_, - rendezvous_call->request_buffer_size_, MPI_CHAR, dst, - TAG_REQTENSOR, MPI_COMM_WORLD, - &rendezvous_call->mpi_request_)); - }; - - // Create the function which is called when the Tensor is send by remote - const int64 temp1 = step_id_; - rendezvous_call->recv_call_ = - [this, parsed, recv_args, done, dst, temp1, - rendezvous_call](MPIRecvTensorResponse mpi_response) { - Status s; - Device* dst_device; - if (s.ok()) { - s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); - CHECK(s.ok()) << "Device lookup failed"; - } - - VLOG(3) << "MPI Received tensor " << parsed.FullKey() - << " @ step: " << temp1 - << " single-send: " << mpi_response.singlesend(); - - Tensor val; - if (mpi_response.singlesend()) { - dst_device->MakeTensorFromProto(mpi_response.response().tensor(), - recv_args.alloc_attrs, &val); - } else { - TensorResponse tr; - tr.InitAlloc(dst_device, recv_args.alloc_attrs); - tr.InitPartial(mpi_response.response(), AllocationAttributes()); - const size_t nBytes = tr.tensor().TotalBytes(); - void* data = const_cast(DMAHelper::base(&tr.tensor())); - MPI_Status status; - MPI_CHECK(MPI_Recv(data, static_cast(nBytes), MPI_BYTE, dst, - TAG_SENDTENSOR2, MPI_COMM_WORLD, &status)); - val = std::move(tr.tensor()); - } - - done(s, Args(), recv_args, val, mpi_response.response().is_dead()); - }; - - MPIRendezvousMgr* mgr = - reinterpret_cast(this->rendezvous_mgr_); - mgr->QueueRequest(string(parsed.FullKey()), step_id_, std::move(request_call), - rendezvous_call); -} - -MPIRemoteRendezvous::~MPIRemoteRendezvous() {} - -/* - * Add the request for one of our Tensors by a remote process - * to the local send/table. The here created callback will - * be called once the Tensor data has arrived and is - * ready to be send to the remote requester. - */ -void MPIRendezvousMgr::AddRequest(RecvTensorRequest request, - const int mpi_dst) { - TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique( - request.request_id(), "RecvTensor (MPIRendezvousMgr)", request)); - const int64 step_id = request.step_id(); - const std::string& key = request.rendezvous_key(); - Rendezvous::ParsedKey parsed; - TF_CHECK_OK(Rendezvous::ParseKey(key, &parsed)); - - MPIRecvTensorCallBack send_cb = [this, mpi_dst, parsed]( - const Status& status, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& val, bool is_dead, - MPISendTensorCall* mpi_send_call) { - // TODO(jbedorf) this should be a loop over max size - CHECK(mpi_send_call->mRes_.ByteSize() < INT_MAX) - << "Buffer too large for single transfer"; - MPI_CHECK(MPI_Alloc_mem(mpi_send_call->mRes_.ByteSize(), MPI_INFO_NULL, - &mpi_send_call->send_buffer_)); - mpi_send_call->mRes_.SerializeToArray(mpi_send_call->send_buffer_, - mpi_send_call->mRes_.ByteSize()); - - MPI_CHECK(MPI_Isend(mpi_send_call->send_buffer_, - static_cast(mpi_send_call->mRes_.ByteSize()), - MPI_CHAR, mpi_dst, TAG_SENDTENSOR, MPI_COMM_WORLD, - &(mpi_send_call->msg1_))); - MPI_CHECK(MPI_Test(&mpi_send_call->msg1_, &mpi_send_call->done1_, - MPI_STATUS_IGNORE)); - - if (!mpi_send_call->mRes_.singlesend()) { - const int tensor_size = static_cast(val.TotalBytes()); - void* temp = const_cast(DMAHelper::base(&val)); - - // If the MPI library is not GPU aware there should be a data transfer - // here to get the data on the host. - // if(src_dev->tensorflow_gpu_device_info()) //memcpy to send_buffer2_ - - // TODO(jbedorf) this should be a loop over max size - MPI_CHECK(MPI_Isend(temp, tensor_size, MPI_CHAR, mpi_dst, TAG_SENDTENSOR2, - MPI_COMM_WORLD, &mpi_send_call->msg2_)); - mpi_send_call->done2_ = 0; - } - return mpi_send_call; - }; - - // Wrapper around the read callback to place the callback on our queue - Rendezvous::DoneCallback done_cb = - [this, parsed, step_id, send_cb]( - const Status& status, const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { - if (!status.ok()) { - CHECK(status.ok()) - << "RecvLocalAsync was not ok, key: " << parsed.FullKey() - << " step: " << step_id - << " error message: " << status.error_message(); - return; - } - - VLOG(3) << "MPI Sending tensor " << parsed.FullKey() - << " @ step: " << step_id << std::endl; - - auto mpi_send_call = new MPISendTensorCall(); - mpi_send_call->Init(parsed, step_id, is_dead); - - Device* src_dev = nullptr; - Status s = this->worker_env_2->device_mgr->LookupDevice( - parsed.src_device, &src_dev); - CHECK(s.ok()) << "src device not found"; - - // Control if shape and data should be send together or if we can - // optimize it in two different transfers, thereby reducing memory - // copies - bool doOptimalTransfer = true; - if (!DataTypeCanUseMemcpy(val.dtype())) doOptimalTransfer = false; - if (val.TotalBytes() < 1024) doOptimalTransfer = false; - - doOptimalTransfer = doOptimalTransfer && use_optimal_transfer_; - - if (doOptimalTransfer) { - // First send the Tensor description and in a follow up transfer the - // data - mpi_send_call->mRes_.mutable_response()->mutable_tensor()->set_dtype( - val.dtype()); - val.shape().AsProto(mpi_send_call->mRes_.mutable_response() - ->mutable_tensor() - ->mutable_tensor_shape()); - mpi_send_call->mRes_.set_singlesend(false); - } else { - // Send the Tensor description and data in a single transfer - if (src_dev->tensorflow_gpu_device_info() && - (!send_args.alloc_attrs.on_host())) { - Notification n; - GPUUtil::SetProtoFromGPU( - val, src_dev, send_args.device_context, - mpi_send_call->mRes_.mutable_response()->mutable_tensor(), - is_dead, [&n, &s](const Status& s_) { - s = s_; - n.Notify(); - }); - n.WaitForNotification(); - } else { - val.AsProtoTensorContent( - mpi_send_call->mRes_.mutable_response()->mutable_tensor()); - } - } - - std::function res = std::bind( - send_cb, status, send_args, recv_args, val, is_dead, mpi_send_call); - - SendQueueEntry req(string(parsed.FullKey()), std::move(res)); - - this->QueueSendRequest(req); - - // Wait for the notification that indicates the tensor has been - // successfully transmitted to the remote process. Only needed if we - // have not parsed the tensor to proto - if (doOptimalTransfer) mpi_send_call->n_.WaitForNotification(); - }; // done_cb - - worker_env_2->compute_pool->Schedule([this, step_id, parsed, done_cb]() { - this->RecvLocalAsync(step_id, parsed, done_cb); - }); -} - -void MPIRendezvousMgr::MPIBackgroundThread() { - std::list> active_sends; - - while (1) { - MPI_Status status; - - // Check for incoming Tensor requests - RecvTensorRequest request; - if (ProbeForData(TAG_REQTENSOR, &status, &request)) { - this->AddRequest(request, status.MPI_SOURCE); - } - - // Check for incoming Tensor reply - MPIRecvTensorResponse mRes; - if (ProbeForData(TAG_SENDTENSOR, &status, &mRes)) { - const int64 step_id = mRes.step_id(); - std::string key = mRes.key(); - - std::shared_ptr call; - GetRecvCall(step_id, key, &call); - call->recv_call_(mRes); - RemoveRecvCall(step_id, key); - } - - // Remove sends that have been completed - active_sends.remove_if( - [](std::unique_ptr& i) { return i->IsFinished(); }); - - // send a Tensor request - RequestQueueEntry req; - if (GetRequest(&req)) req.second(); - - // Send a Tensor response - SendQueueEntry send; - if (GetResponse(&send)) { - std::unique_ptr p(send.second()); - active_sends.push_back(std::move(p)); - } - - // std::this_thread::sleep_for(std::chrono::microseconds(1)); - } -} - -} // namespace tensorflow -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h b/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h deleted file mode 100644 index 90140fcab31..00000000000 --- a/tensorflow/contrib/mpi/mpi_rendezvous_mgr.h +++ /dev/null @@ -1,255 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_MPI_MPI_RENDEZVOUS_MGR_H_ -#define TENSORFLOW_CONTRIB_MPI_MPI_RENDEZVOUS_MGR_H_ - -#ifdef TENSORFLOW_USE_MPI - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "tensorflow/contrib/mpi/mpi_msg.pb.h" -#include "tensorflow/contrib/mpi/mpi_utils.h" -#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/recent_request_ids.h" -#include "tensorflow/core/distributed_runtime/request_id.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/protobuf/worker.pb.h" - -#define TAG_REQTENSOR 1010 -#define TAG_SENDTENSOR 2020 -#define TAG_SENDTENSOR2 3030 - -namespace tensorflow { - -class MPISendTensorCall { - public: - char* send_buffer_; - char* send_buffer2_; - - MPI_Request msg1_; - MPI_Request msg2_; - int done1_; // Int instead of bool for simpler IsFinished logic - int done2_; - MPIRecvTensorResponse mRes_; - Notification n_; - - MPISendTensorCall() - : send_buffer_(nullptr), send_buffer2_(nullptr), done1_(1), done2_(1) {} - - ~MPISendTensorCall() { - MPI_CHECK(MPI_Wait(&msg1_, MPI_STATUS_IGNORE)); - n_.Notify(); - MPI_CHECK(MPI_Free_mem(send_buffer_)); - // delete[] send_buffer_; - delete[] send_buffer2_; - } - - MPISendTensorCall(MPISendTensorCall&&) = delete; - - void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id, - const bool is_dead) { - mRes_.set_key(string(parsed.FullKey())); - mRes_.set_step_id(step_id); - mRes_.mutable_response()->set_is_dead(is_dead); - mRes_.mutable_response()->set_send_start_micros( - Env::Default()->NowMicros()); - mRes_.set_singlesend(true); - } - - bool IsFinished() { - MPI_Status status; - if (!done1_) MPI_CHECK(MPI_Test(&msg1_, &done1_, &status)); - if (!done2_) MPI_CHECK(MPI_Test(&msg2_, &done2_, &status)); - return done1_ && done2_; - } -}; - -class MPIRequestTensorCall { - public: - Rendezvous::DoneCallback done_; - RecvTensorRequest req_; - MPI_Request mpi_request_; - char* request_buffer_; - size_t request_buffer_size_; - std::function recv_call_; - - MPIRequestTensorCall() : request_buffer_(nullptr) {} - ~MPIRequestTensorCall() { - MPI_CHECK(MPI_Wait(&mpi_request_, MPI_STATUS_IGNORE)); - // delete[] request_buffer_; - MPI_CHECK(MPI_Free_mem(request_buffer_)); - } - - void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id) { - req_.set_step_id(step_id); - req_.set_rendezvous_key(parsed.FullKey().data(), parsed.FullKey().size()); - req_.set_request_id(GetUniqueRequestId()); - request_buffer_size_ = req_.ByteSize(); - // request_buffer_ = new char[request_buffer_size_]; - // req_.SerializeToArray(request_buffer_, request_buffer_size_); - } -}; - -class MPIRemoteRendezvous : public BaseRemoteRendezvous { - public: - MPIRemoteRendezvous(const WorkerEnv* env, int64 step_id, const MPIUtils* util, - BaseRendezvousMgr* mgr_) - : BaseRemoteRendezvous(env, step_id), - mpiutils_(util), - rendezvous_mgr_(mgr_) {} - - protected: - void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, - const Rendezvous::Args& args, - DoneCallback done) override; - - private: - ~MPIRemoteRendezvous() override; - - const MPIUtils* mpiutils_; - BaseRendezvousMgr* rendezvous_mgr_; - - TF_DISALLOW_COPY_AND_ASSIGN(MPIRemoteRendezvous); -}; - -class MPIRendezvousMgr : public BaseRendezvousMgr { - public: - explicit MPIRendezvousMgr(const WorkerEnv* env); - ~MPIRendezvousMgr() { - delete mpiutils_; - fprintf(stderr, "Delete MPIRendezvousMgr \n"); - // TODO(jbedorf) stop background_thread_ - MPI_CHECK(MPI_Finalize()); - } - - void QueueRequest(std::string key, int64 step_id, - std::function request_call, - MPIRequestTensorCall* rCall) { - mutex_lock l(mrq_); - request_queue_.push(RequestQueueEntry(key, std::move(request_call))); - const std::string key_id = strings::StrCat(key, "_", step_id); - recv_tensor_map_[key_id] = std::shared_ptr(rCall); - } - - protected: - BaseRemoteRendezvous* Create(int64 step_id, - const WorkerEnv* worker_env) override; - - private: - typedef std::function - MPIRecvTensorCallBack; - - typedef std::pair> RequestQueueEntry; - typedef std::pair> - SendQueueEntry; - - const WorkerEnv* worker_env_2; - std::thread background_thread_; - MPIUtils* mpiutils_; - bool use_optimal_transfer_; - - mutex msq_; - mutex mrq_; - - std::queue send_queue_ GUARDED_BY(msq_); - std::queue request_queue_ GUARDED_BY(mrq_); - std::map> recv_tensor_map_ - GUARDED_BY(mrq_); - - RecentRequestIds recv_tensor_recent_request_ids_; - - void AddRequest(RecvTensorRequest, const int); - void MPIBackgroundThread(); - - void QueueSendRequest(SendQueueEntry req) { - mutex_lock l(msq_); - send_queue_.push(req); - } - - void GetRecvCall(const int64 step_id, const std::string& key, - std::shared_ptr* call) { - mutex_lock l(mrq_); - - const std::string key_id = strings::StrCat(key, "_", step_id); - if (recv_tensor_map_.find(key_id) == recv_tensor_map_.end()) { - LOG(FATAL) << "Key/step not found in recv_tensor_map_, step: " << step_id - << " key: " << key << std::endl; - } - *call = recv_tensor_map_[key_id]; - } - - void RemoveRecvCall(const int64 step_id, const std::string& key) { - mutex_lock l(mrq_); - const std::string key_id = strings::StrCat(key, "_", step_id); - recv_tensor_map_.erase(key_id); - } - - bool GetRequest(RequestQueueEntry* req) { - mutex_lock l(mrq_); - if (!request_queue_.empty()) { - *req = request_queue_.front(); - request_queue_.pop(); - return true; - } - return false; - } - - bool GetResponse(SendQueueEntry* send) { - mutex_lock l(msq_); - if (!send_queue_.empty()) { - *send = send_queue_.front(); - send_queue_.pop(); - return true; - } - return false; - } - - template - int ProbeForData(const int tag, MPI_Status* status, T* obj) { - int flag = 0, msg_size = 0; - MPI_Message msg; - // Receive the message, probe as size is variable - MPI_CHECK( - MPI_Improbe(MPI_ANY_SOURCE, tag, MPI_COMM_WORLD, &flag, &msg, status)); - if (flag) { - MPI_CHECK(MPI_Get_count(status, MPI_CHAR, &msg_size)); - MPI_Status stat2; - std::vector request_buffer_(msg_size); - MPI_Mrecv(&request_buffer_[0], msg_size, MPI_CHAR, &msg, &stat2); - bool res = obj->ParseFromArray(&request_buffer_[0], msg_size); - CHECK(res) << "Failed to parse incomming message"; - } - return flag; - } - - TF_DISALLOW_COPY_AND_ASSIGN(MPIRendezvousMgr); -}; // MPIRendezvousMgr -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI -#endif // TENSORFLOW_CONTRIB_MPI_MPI_RENDEZVOUS_MGR_H_ diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc deleted file mode 100644 index e44e10af081..00000000000 --- a/tensorflow/contrib/mpi/mpi_server_lib.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#include "tensorflow/contrib/mpi/mpi_server_lib.h" - -#include -#include - -#include "grpc/support/alloc.h" - -#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/env.h" - -namespace tensorflow { - -namespace { -// static utility function -RendezvousMgrInterface* NewMPIRendezvousMgr(const WorkerEnv* env) { - // Runtime check to disable the MPI path - const char* mpienv = getenv("MPI_DISABLED"); - if (mpienv && mpienv[0] == '1') { - LOG(INFO) << "MPI path disabled by environment variable\n"; - return new RpcRendezvousMgr(env); - } else { - return new MPIRendezvousMgr(env); - } -} - -} // namespace - -MPIServer::MPIServer(const ServerDef& server_def, Env* env) - : GrpcServer(server_def, env) {} - -MPIServer::~MPIServer() { - TF_CHECK_OK(Stop()); - TF_CHECK_OK(Join()); -} - -Status MPIServer::Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendezvous_mgr_func) { - GrpcServerOptions opts; - opts.service_func = service_func; - opts.rendezvous_mgr_func = rendezvous_mgr_func; - Status s = GrpcServer::Init(opts); - return s; -} - -Status MPIServer::Start() { - Status s = GrpcServer::Start(); - return s; -} - -Status MPIServer::Join() { - Status s = GrpcServer::Join(); - return s; -} - -/* static */ -Status MPIServer::Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server) { - std::unique_ptr ret(new MPIServer(server_def, Env::Default())); - ServiceInitFunction service_func = nullptr; - TF_RETURN_IF_ERROR(ret->Init(service_func, NewMPIRendezvousMgr)); - *out_server = std::move(ret); - return Status::OK(); -} - -namespace { - -class MPIServerFactory : public ServerFactory { - public: - bool AcceptsOptions(const ServerDef& server_def) override { - return server_def.protocol() == "grpc+mpi"; - } - - Status NewServer(const ServerDef& server_def, - std::unique_ptr* out_server) override { - return MPIServer::Create(server_def, Env::Default(), out_server); - } -}; - -// Registers a `ServerFactory` for `MPIServer` instances. -class MPIServerRegistrar { - public: - MPIServerRegistrar() { - gpr_allocation_functions alloc_fns; - alloc_fns.malloc_fn = port::Malloc; - alloc_fns.realloc_fn = port::Realloc; - alloc_fns.free_fn = port::Free; - gpr_set_allocation_functions(alloc_fns); - ServerFactory::Register("MPI_SERVER", new MPIServerFactory()); - } -}; -static MPIServerRegistrar registrar; - -} // namespace -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi/mpi_server_lib.h b/tensorflow/contrib/mpi/mpi_server_lib.h deleted file mode 100644 index 736f6922a15..00000000000 --- a/tensorflow/contrib/mpi/mpi_server_lib.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_MPI_MPI_SERVER_LIB_H_ -#define TENSORFLOW_CONTRIB_MPI_MPI_SERVER_LIB_H_ - -#ifdef TENSORFLOW_USE_MPI - -#include - -#include "tensorflow/contrib/mpi/mpi_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" - -namespace tensorflow { - -class MPIServer : public GrpcServer { - protected: - MPIServer(const ServerDef& server_def, Env* env); - - public: - static Status Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server); - - // Destruction is only supported in the factory method. Clean - // shutdown is not currently implemented for this server type. - ~MPIServer() override; - - // Implementations of ServerInterface methods. - Status Start() override; - Status Join() override; - - protected: - Status Init(ServiceInitFunction service_func, - RendezvousMgrCreationFunction rendezvous_mgr_func); - Status ChannelCacheFactory(const ServerDef& server_def, - GrpcChannelCache** channel_cache); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI -#endif // TENSORFLOW_CONTRIB_MPI_MPI_SERVER_LIB_H_ diff --git a/tensorflow/contrib/mpi/mpi_utils.cc b/tensorflow/contrib/mpi/mpi_utils.cc deleted file mode 100644 index 8184b856264..00000000000 --- a/tensorflow/contrib/mpi/mpi_utils.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#include "tensorflow/contrib/mpi/mpi_utils.h" -namespace tensorflow { - -#define max_worker_name_length 128 - -MPIUtils::MPIUtils(const std::string& worker_name) { - InitMPI(); - // Connect the MPI process IDs to the worker names that are used by TF. - // Gather the names of all the active processes (name can't be longer than - // 128 bytes) - int proc_id = 0, number_of_procs = 1; - char my_name[max_worker_name_length]; - MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &proc_id)); - MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs)); - - CHECK(worker_name.size() < max_worker_name_length) - << "Specified worker name is too long."; - snprintf(my_name, max_worker_name_length, worker_name.c_str()); - std::vector worker_names(number_of_procs * max_worker_name_length); - MPI_CHECK(MPI_Allgather(my_name, max_worker_name_length, MPI_CHAR, - &worker_names[0], max_worker_name_length, MPI_CHAR, - MPI_COMM_WORLD)); - - if (proc_id == 0) LOG(INFO) << "MPI process-ID to gRPC server name map: \n"; - for (int i = 0; i < number_of_procs; i++) { - name_to_id_[std::string(&worker_names[i * 128])] = i; - if (proc_id == 0) - LOG(INFO) << "Process: " << i - << "\tgRPC-name: " << std::string(&worker_names[i * 128]) - << std::endl; - } -} - -void MPIUtils::InitMPI() { - // Initialize the MPI environment if that hasn't been done - int flag = 0; - MPI_CHECK(MPI_Initialized(&flag)); - if (!flag) { - int proc_id = 0, number_of_procs = 1, len = -1; - char my_host_name[max_worker_name_length]; - // MPI_CHECK(MPI_Init_thread(0, 0, MPI_THREAD_MULTIPLE, &flag)); - MPI_CHECK(MPI_Init(0, 0)); - MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &proc_id)); - MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &number_of_procs)); - MPI_CHECK(MPI_Get_processor_name(my_host_name, &len)); - fprintf(stderr, - "MPI Environment initialized. Process id: %d Total processes: %d " - "|| Hostname: %s \n", - proc_id, number_of_procs, my_host_name); - } -} - -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi/mpi_utils.h b/tensorflow/contrib/mpi/mpi_utils.h deleted file mode 100644 index 4091925fc0d..00000000000 --- a/tensorflow/contrib/mpi/mpi_utils.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_MPI_MPI_UTILS_H_ -#define TENSORFLOW_CONTRIB_MPI_MPI_UTILS_H_ - -#ifdef TENSORFLOW_USE_MPI - -#include -#include -#include - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/lib/strings/str_util.h" - -// Skip MPI C++ bindings support, this matches the usage in other places -#define OMPI_SKIP_MPICXX -#include "third_party/mpi/mpi.h" -#define MPI_CHECK(cmd) \ - do { \ - int mpi_errno = cmd; \ - if (MPI_SUCCESS != mpi_errno) { \ - fprintf(stderr, "[%s:%d] MPI call failed with %d \n", __FILE__, \ - __LINE__, mpi_errno); \ - exit(EXIT_FAILURE); \ - } \ - assert(MPI_SUCCESS == mpi_errno); \ - } while (false) - -namespace tensorflow { -class MPIUtils { - public: - explicit MPIUtils(const std::string& worker_name); - - const int GetSourceID(const std::string& task_id) const { - auto it = name_to_id_.find(task_id); - if (it == name_to_id_.end()) { - LOG(FATAL) << "Failed to convert worker name to MPI index: " << task_id; - } - return it->second; - } - - private: - void InitMPI(); - - std::map name_to_id_; -}; -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI -#endif // TENSORFLOW_CONTRIB_MPI_MPI_UTILS_H_ diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD deleted file mode 100644 index 5e848c9e7cf..00000000000 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ /dev/null @@ -1,128 +0,0 @@ -# Ops that communicate with other processes via MPI. - -package(default_visibility = [ - "//tensorflow:__subpackages__", -]) - -licenses(["notice"]) # Apache 2.0 - -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_additional_mpi_lib_defines", - "tf_proto_library_cc", -) - -tf_proto_library_cc( - name = "mpi_message_proto", - srcs = ["mpi_message.proto"], - cc_api_version = 2, - protodeps = ["//tensorflow/core:protos_all"], - visibility = [ - "//tensorflow:__subpackages__", - ], -) - -cc_library( - name = "mpi_defines", - defines = tf_additional_mpi_lib_defines(), -) - -load( - "//tensorflow:tensorflow.bzl", - "tf_custom_op_library", - "tf_custom_op_py_library", - "tf_gen_op_libs", - "tf_gen_op_wrapper_py", - "tf_kernel_library", - "tf_py_test", -) - -tf_custom_op_library( - name = "python/ops/_mpi_ops.so", - srcs = [ - "kernels/mpi_ops.cc", - "kernels/ring.cc", - "kernels/ring.h", - "ops/mpi_ops.cc", - ], - gpu_srcs = [ - "kernels/ring.cu.cc", - "kernels/ring.h", - ], - deps = [ - ":mpi_defines", - ":mpi_message_proto_cc", - "//third_party/mpi", - ], -) - -tf_kernel_library( - name = "mpi_ops_kernels", - srcs = [ - "kernels/mpi_ops.cc", - "kernels/ring.cc", - ], - hdrs = [ - "kernels/ring.h", - ], - gpu_srcs = [ - "kernels/ring.cu.cc", - ], - deps = [ - ":mpi_defines", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:gpu_headers_lib", - "//tensorflow/core:lib", - "//tensorflow/core:proto_text", - "//tensorflow/core:stream_executor", - ], - # TODO: Include? alwayslink = 1, -) - -tf_gen_op_libs( - op_lib_names = ["mpi_ops"], -) - -tf_gen_op_wrapper_py( - name = "mpi_ops", - deps = [":mpi_ops_op_lib"], -) - -tf_custom_op_py_library( - name = "mpi_collectives_py", - srcs = [ - "__init__.py", - "python/ops/mpi_ops.py", - ], - dso = [ - ":python/ops/_mpi_ops.so", - ], - kernels = [ - ":mpi_ops_kernels", - ":mpi_ops_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":mpi_ops", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:device", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:util", - ], -) - -tf_py_test( - name = "mpi_ops_test", - srcs = ["mpi_ops_test.py"], - additional_deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/python:platform", - ], - data = [ - ":python/ops/_mpi_ops.so", - ], - tags = ["manual"], -) diff --git a/tensorflow/contrib/mpi_collectives/README.md b/tensorflow/contrib/mpi_collectives/README.md deleted file mode 100644 index c5e1a8c37e3..00000000000 --- a/tensorflow/contrib/mpi_collectives/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# MPI TensorFlow integration - -Tensorflow MPI integration allows communicating between different TensorFlow -processes using MPI. This enables training across multiple nodes and GPUs -using high-speed interconnects. diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py deleted file mode 100644 index 52029cbc36a..00000000000 --- a/tensorflow/contrib/mpi_collectives/__init__.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# pylint: disable=g-short-docstring-punctuation -"""## Communicating Between Processes with MPI - -TensorFlow natively provides inter-device communication through send and -receive ops and inter-node communication through Distributed TensorFlow, based -on the same send and receive abstractions. On HPC clusters where Infiniband or -other high-speed node interconnects are available, these can end up being -insufficient for synchronous data-parallel training (without asynchronous -gradient descent). This module implements a variety of MPI ops which can take -advantage of hardware-specific MPI libraries for efficient communication. - -In order to use this module, TensorFlow must be built with an MPI library, -which can be provided to the `./configure` script at build time. As a user of -TensorFlow, you will need to build TensorFlow yourself to select the MPI -library to use; to do so, follow the [instructions for building TensorFlow from -source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources). - -### Utility Ops - -In addition to reductions and gathers, this module provides utility operations -for detecting the running MPI configuration. - -Example: - -```python -import tensorflow.contrib.mpi_collectives as mpi - -# Use `mpi.Session` instead of `tf.Session` -with mpi.Session() as session: - rank = session.run(mpi.rank()) - print("My MPI Rank:", rank) - - if rank == 0: - print("MPI Size:", session.run(mpi.size())) -``` - -@@init -@@size -@@rank -@@local_rank - -### Ring Allreduce and Allgather - -When summing or averaging tensors across many processes, communication can -easily become a bottleneck. A naive implementation will send all the tensor -values to the same process, perform the reduction, and then broadcast the -values back to all other processes, effectively creating a synchronous -parameter server in one process. However, the process responsible for -performing the reduction will have to receive and send a massive amount of data -which scales with the number of processes *and* the number of parameters in the -model. - -Instead of centralizing the reduction and having one primary reducer, we can -implement a distributed allreduce or allgather. A bandwidth-optimal allreduce -will end up sending 2(N - 1) values for every value in the input tensor, -and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce -requires at least (N - 1) sends between the different nodes, and a broadcast of -the result also requires (N - 1) sends, for a total of 2 (N - 1); these two -steps cannot be combined in a clever way to reduce the number of required -sends.) This module implements bandwidth-optimal ring allreduce and ring -allgather operations using MPI; by choosing a hardware-appropriate MPI -implementation (such as OpenMPI with CUDA-IPC support), you can train large -models with synchronous gradient descent with minimal communication overhead. - -In addition to the `allreduce` and `allgather` functions, a convenience -`DistributedOptimizer` wrapper is provided to simplify using these functions -for reducing model gradients. - -Example: - -```python -import tensorflow as tf -from tensorflow.contrib import mpi_collectives as mpi - -# Construct a simple linear regression model to optimize -W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32) -B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32) -inputs = tf.placeholder("Inputs", shape=[None, 20]) -outputs = tf.placeholder("Outputs", shape=[None, 1]) -loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs) - -# Training using MPI allreduce with DistributedOptimizer -optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer()) -train = optimizer.minimize(loss) - -# Average loss over all ranks, for printing. -# Do not pass this to an optimizer! -avg_loss = mpi.allreduce(loss) - -# On different ranks, feed different input data. -with mpi.Session() as session: - rank = session.run(mpi.rank()) - batch_inputs, batch_outputs = construct_batch_for_rank(rank) - feed_dict = {inputs: batch_inputs, outputs: batch_outputs} - _, l = session.run([train, avg_loss], feed_dict=feed_dict) - print("Average Loss:", l) -``` - -[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms -for Clusters of Workstations". - -@@Session -@@DistributedOptimizer -@@allreduce -@@allgather -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import init -from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import size -from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import rank -from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import local_rank -from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import allgather -from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import _allreduce - - -def allreduce(tensor, average=True): - """Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices. - - Arguments: - tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce. - The shape of the input must be identical across all ranks. - average: If True, computes the average over all ranks. - Otherwise, computes the sum over all ranks. - - This function performs a bandwidth-optimal ring allreduce on the input - tensor. If the input is an tf.IndexedSlices, the function instead does an - allgather on the values and the indices, effectively doing an allreduce on - the represented tensor. - """ - if isinstance(tensor, tf.IndexedSlices): - # For IndexedSlices, do two allgathers intead of an allreduce. - mpi_size = tf.cast(size(), tensor.values.dtype) - values = allgather(tensor.values) - indices = allgather(tensor.indices) - - # To make this operation into an average, divide all gathered values by - # the MPI size. - new_values = tf.div(values, mpi_size) if average else values - return tf.IndexedSlices(new_values, indices, - dense_shape=tensor.dense_shape) - else: - mpi_size = tf.cast(size(), tensor.dtype) - summed_tensor = _allreduce(tensor) - new_tensor = (tf.div(summed_tensor, mpi_size) - if average else summed_tensor) - return new_tensor - - -class DistributedOptimizer(tf.train.Optimizer): - """An optimizer that wraps another tf.Optimizer, using an MPI allreduce to - average gradient values before applying gradients to model weights.""" - - def __init__(self, optimizer, name=None, use_locking=False): - """Construct a new DistributedOptimizer, which uses another optimizer - under the hood for computing single-process gradient values and - applying gradient updates after the gradient values have been averaged - across all the MPI ranks. - - Args: - optimizer: Optimizer to use for computing gradients and applying updates. - name: Optional name prefix for the operations created when applying - gradients. Defaults to "Distributed" followed by the provided - optimizer type. - use_locking: Whether to use locking when updating variables. See - Optimizer.__init__ for more info. - """ - if name is None: - name = "Distributed{}".format(type(optimizer).__name__) - - self._optimizer = optimizer - super(DistributedOptimizer, self).__init__( - name=name, use_locking=use_locking) - - def compute_gradients(self, *args, **kwargs): - """Compute gradients of all trainable variables. - - See Optimizer.compute_gradients() for more info. - - In DistributedOptimizer, compute_gradients() is overridden to also - allreduce the gradients before returning them. - """ - gradients = (super(DistributedOptimizer, self) - .compute_gradients(*args, **kwargs)) - return [(allreduce(gradient), var) for (gradient, var) in gradients] - - def _apply_dense(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._apply_dense(*args, **kwargs) - - def _apply_sparse(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._apply_sparse(*args, **kwargs) - - def _apply_sparse_duplicate_indices(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._apply_sparse_duplicate_indices(*args, - **kwargs) - - def _prepare(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._prepare(*args, **kwargs) - - def _create_slots(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._create_slots(*args, **kwargs) - - def _valid_dtypes(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._valid_dtypes(*args, **kwargs) - - def _finish(self, *args, **kwargs): - """Calls this same method on the underlying optimizer.""" - return self._optimizer._finish(*args, **kwargs) - - -class Session(tf.Session): - """A class for running TensorFlow operations, with copies of the same graph - running distributed across different MPI nodes. - - The primary difference between `tf.Session` and - `tf.contrib.mpi_collectives.Session` is that the MPI `Session` ensures that - the `Session` options are correct for use with `tf.contrib.mpi`, and - initializes MPI immediately upon the start of the session. - """ - - def __init__(self, target='', graph=None, config=None): - """Creates a new TensorFlow MPI session. - - Unlike a normal `tf.Session`, an MPI Session may only use a single GPU, - which must be specified in advance before the session is initialized. - In addition, it only uses a single graph evaluation thread, and - initializes MPI immediately upon starting. - - If no `graph` argument is specified when constructing the session, - the default graph will be launched in the session. If you are - using more than one graph (created with `tf.Graph()` in the same - process, you will have to use different sessions for each graph, - but each graph can be used in multiple sessions. In this case, it - is often clearer to pass the graph to be launched explicitly to - the session constructor. - - Args: - target: (Optional.) The execution engine to connect to. - graph: (Optional.) The `Graph` to be launched (described above). - config: (Optional.) A `ConfigProto` protocol buffer with configuration - options for the session. - """ - super(Session, self).__init__(target, graph, config=config) - - # Initialize MPI on the relevant device. - # TODO: Move this to library load and eliminate mpi.Session() - if graph is None: - graph = tf.get_default_graph() - with graph.as_default(): - self.run(init()) diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc deleted file mode 100644 index e4b0c2c6541..00000000000 --- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc +++ /dev/null @@ -1,1132 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/mutex.h" - -#define EIGEN_USE_THREADS - -#if GOOGLE_CUDA -#include -#include "tensorflow/stream_executor/stream.h" -#endif - -#include "tensorflow/stream_executor/lib/statusor.h" - -#define OMPI_SKIP_MPICXX -#include "third_party/mpi/mpi.h" -#include "tensorflow/contrib/mpi_collectives/kernels/ring.h" -#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" - -/* - * MPI Allreduce and Allgather Ops for TensorFlow. - * - * TensorFlow natively provides inter-device communication through send and - * receive ops and inter-node communication through Distributed TensorFlow, - * based on the same send and receive abstractions. These end up being - * insufficient for synchronous data-parallel training on HPC clusters where - * Infiniband or other high-speed interconnects are available. This module - * implements MPI ops for allgather and allreduce, which do bandwidth-optimal - * gathers and reductions and can take advantage of hardware-optimized - * communication libraries through the MPI implementation. - * - * The primary logic of the allreduce and allgather are in RingAllgather() and - * RingAllreduce(). The background thread which facilitates MPI operations is - * run in BackgroundThreadLoop(). The provided MPI ops are: - * – MPIInit: - * Initialize MPI on a given device (CPU or GPU). - * Should only be run on a single device in every process. - * – MPISize: - * Get the number of MPI processes in the global communicator. - * – MPIRank: - * Get the rank of the current MPI process in the global communicator. - * – MPILocalRank: - * Get the local rank of the current MPI process within its node. - * – MPIAllreduce: - * Perform an allreduce on a Tensor, returning the sum - * across all MPI processes in the global communicator. - * – MPIAllgather: - * Perform an allgather on a Tensor, returning the concatenation of - * the tensor on the first dimension across all MPI processes in the - * global communicator. - * - */ - -template -using StatusOr = stream_executor::port::StatusOr; - -using CPUDevice = Eigen::ThreadPoolDevice; -using GPUDevice = Eigen::GpuDevice; - -namespace tensorflow { -namespace contrib { -namespace mpi_collectives { - -// Make sure template specializations are generated in the ring.cu.cc and the -// ring.cc file, not in this file. -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, - Tensor*, Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, - Tensor*, Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); - -namespace { - -// Return true if the templated type is GPUDevice, otherwise false. -template -bool IsGPUDevice(); -template <> -bool IsGPUDevice() { - return true; -}; -template <> -bool IsGPUDevice() { - return false; -}; - -// A callback to call after the MPI communication completes. Since the -// allreduce and allgather ops are asynchronous, this callback is what resumes -// computation after the reduction is completed. -typedef std::function)> CommunicationDoneCallback; - -struct CollectiveOpRecord { - // The rank performing this piece of the op - int rank; - - // The name of the op/tensor to be reduced - std::string name; - - // The op's kernel context - OpKernelContext* context; - - // Data type of the op - DataType dtype; - - // The input tensor - const Tensor* in_t; - - // Allgather: Vector of per-rank first-dimension sizes - std::vector sizes_vec; - - // The temp tensor for intermediate results - Tensor temp_t; - - // The output tensor - Tensor* out_t; - - // Whether to run this op on the gpu - bool on_gpu; - - // The callback to call after the op has completed - CommunicationDoneCallback callback; -}; - -// Table storing Tensors to be reduced, keyed by unique name. -// This table contains everything necessary to do the reduction -typedef std::unordered_map TensorTable; - -// Table for storing Tensor metadata on rank zero. This is used for error -// checking and size calculations, as well as determining when a reduction is -// ready to be done (when all nodes are ready to do it). -typedef std::unordered_map > MessageTable; - -// The global state required for the MPI ops. -// -// MPI is a library that stores a lot of global per-program state and often -// requires running on a single thread. As a result, we have to have a single -// background thread responsible for all MPI operations, and communicate with -// that background thread through global state. -struct MPIGlobalState { - // An atomic boolean which is set to true when MPI is initialized. - // This ensures that MPI_Init is never called twice. - std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT; - - // Condition variable to wait for initialization - condition_variable cv; - - // Whether MPI_Init has been completed on the background thread. - bool initialization_done = false; - - // Whether MPI_Init succeeded on the background thread. - Status init_status; - - // A mutex that needs to be used whenever MPI operations touch - // shared structures. - mutex mu; - - // Tensors waiting to be allreduced or allgathered. - TensorTable tensor_table; - - // Queue of MPI requests waiting to be sent to the coordinator node. - std::queue message_queue; - - // Background thread running MPI communication. - std::thread background_thread; - - // Whether the background thread should shutdown. - bool shut_down = false; - - // Only exists on the coordinator node (rank zero). Maintains a count of - // how many nodes are ready to allreduce every tensor (keyed by tensor - // name). - std::unique_ptr message_table; - - // The MPI rank, local rank, and size. - int rank = 0; - int local_rank = 0; - int size = 1; - - // The device that MPI was initialized on. (-1 for no GPU) - int device = -1; - - // The CUDA stream used for data transfers and within-allreduce operations. - // A naive implementation would use the TensorFlow StreamExecutor CUDA - // stream. However, the allreduce and allgather require doing memory copies - // and kernel executions (for accumulation of values on the GPU). However, - // the subsequent operations must wait for those operations to complete, - // otherwise MPI (which uses its own stream internally) will begin the data - // transfers before the CUDA calls are complete. In order to wait for those - // CUDA operations, if we were using the TensorFlow stream, we would have - // to synchronize that stream; however, other TensorFlow threads may be - // submitting more work to that stream, so synchronizing on it can cause - // the allreduce to be delayed, waiting for compute totally unrelated to it - // in other parts of the graph. Overlaying memory transfers and compute - // during backpropagation is crucial for good performance, so we cannot use - // the TensorFlow stream, and must use our own stream. -#if GOOGLE_CUDA - cudaStream_t stream; - std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT; -#endif - - ~MPIGlobalState() { - // Make sure that the destructor of the background thread is safe to - // call. If a thread is still joinable (not detached or complete) its - // destructor cannot be called. - if (background_thread.joinable()) { - shut_down = true; - background_thread.join(); - } - } -}; - -// All the MPI state that must be stored globally per-process. -static MPIGlobalState mpi_global; - -// For clarify in argument lists. -#define RANK_ZERO 0 - -// A tag used for all coordinator messaging. -#define TAG_NOTIFY 1 - -// Store the MPIRequest for a name, and return whether the total count of -// MPIRequests for that tensor is now equal to the MPI size (and thus we are -// ready to reduce the tensor). -bool IncrementTensorCount(std::unique_ptr& message_table, - MPIRequest msg, int mpi_size) { - auto name = msg.tensor_name(); - auto table_iter = message_table->find(name); - if (table_iter == message_table->end()) { - message_table->emplace(name, std::vector({msg})); - table_iter = message_table->find(name); - } else { - table_iter->second.push_back(msg); - } - - int count = table_iter->second.size(); - return count == mpi_size; -} - -// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse -// instructing all ranks to start the reduction to all ranks. The MPIResponse -// also contains error messages in case the submitted MPIRequests were not -// valid (for example, contained mismatched shapes or types). -// -// Constructing the MPIResponse, thus, requires a whole lot of error checking. -MPIResponse ConstructMPIResponse(std::unique_ptr& message_table, - std::string name) { - bool error = false; - auto it = message_table->find(name); - assert(it != message_table->end()); - - std::vector requests = it->second; - assert(requests.size() > 0); - - std::ostringstream error_message_stream; - - // Check that all data types being reduced or gathered are identical - auto data_type = requests[0].tensor_type(); - for (unsigned int i = 1; i < requests.size(); i++) { - auto request_type = requests[i].tensor_type(); - if (data_type != request_type) { - error = true; - error_message_stream << "Mismatched data types: One rank had type " - << DataType_Name(data_type) - << ", but another rank had type " - << DataType_Name(request_type) << "."; - break; - } - } - - // Check that all requested operations are the same - auto message_type = requests[0].request_type(); - for (unsigned int i = 1; i < requests.size(); i++) { - if (error) { - break; - } - - auto request_type = requests[i].request_type(); - if (message_type != request_type) { - error = true; - error_message_stream << "Mismatched MPI operations: One rank did an " - << message_type << ", but another rank did an " - << request_type << "."; - break; - } - } - - // If we are doing an allreduce, check that all tensor shapes - // are identical - if (message_type == MPIRequest::ALLREDUCE) { - TensorShape tensor_shape = requests[0].tensor_shape(); - for (unsigned int i = 1; i < requests.size(); i++) { - if (error) { - break; - } - - TensorShape request_shape = requests[i].tensor_shape(); - if (tensor_shape != request_shape) { - error = true; - error_message_stream << "Mismatched allreduce tensor shapes: " - << "One rank reduced a tensor of shape " - << tensor_shape.DebugString() - << ", but another rank sent a tensor of shape " - << request_shape.DebugString() << "."; - break; - } - } - } - - // If we are doing an allgather, make sure all but the first dimension are - // the same. The first dimension may be different and the output tensor is - // the sum of the first dimension. Collect the sizes by rank. - if (message_type == MPIRequest::ALLGATHER) { - TensorShape tensor_shape = requests[0].tensor_shape(); - - if (tensor_shape.dims() == 0) { - error = true; - error_message_stream << "Rank zero tried to gather a rank-zero tensor."; - } - - for (unsigned int i = 1; i < requests.size(); i++) { - if (error) { - break; - } - - TensorShape request_shape = requests[i].tensor_shape(); - if (tensor_shape.dims() != request_shape.dims()) { - error = true; - error_message_stream << "Mismatched allgather tensor shapes: " - << "One rank gathered a tensor of rank " - << tensor_shape.dims() - << ", but another rank sent a tensor of rank " - << request_shape.dims() << "."; - break; - } - - for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) { - if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) { - error = true; - error_message_stream - << "Mismatched allgather tensor shapes: " - << "One rank gathered a tensor with dimension " << dim - << " equal to " << tensor_shape.dim_size(dim) - << ", but another rank sent a tensor with dimension " << dim - << " equal to " << request_shape.dim_size(dim) << "."; - break; - } - } - } - } - - MPIResponse response; - response.set_tensor_name(name); - if (error) { - std::string error_message = error_message_stream.str(); - response.set_response_type(MPIResponse::ERROR); - response.set_error_message(error_message); - } else { - auto response_type = MPIResponse::ERROR; - if (message_type == MPIRequest::ALLREDUCE) { - response_type = MPIResponse::ALLREDUCE; - } else { - response_type = MPIResponse::ALLGATHER; - } - response.set_response_type(response_type); - } - - // Clear all queued up requests for this name. They are now taken care of - // by the constructed MPI response. - message_table->erase(it); - - return response; -} - -// Process an MPIResponse by doing a reduction, a gather, or raising an error. -void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) { - OpKernelContext* context; - const Tensor* input_tensor; - std::vector sizes_vec; - Tensor temp_tensor; - Tensor* output_tensor; - CommunicationDoneCallback callback; - bool on_gpu; - { - // Lock on the tensor table. - mutex_lock guard(mpi_global.mu); - - // We should never fail at finding this key in the tensor table. - auto name = response.tensor_name(); - auto iter = tensor_table.find(name); - assert(iter != tensor_table.end()); - - assert(response.response_type() == MPIResponse::ALLREDUCE || - response.response_type() == MPIResponse::ALLGATHER || - response.response_type() == MPIResponse::ERROR); - - CollectiveOpRecord record = iter->second; - context = record.context; - input_tensor = record.in_t; - sizes_vec = record.sizes_vec; - temp_tensor = record.temp_t; - output_tensor = record.out_t; - on_gpu = record.on_gpu; - callback = record.callback; - - // Clear the tensor table of this tensor and its callbacks; the rest of - // this function takes care of it. - tensor_table.erase(iter); - } - - // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't - // link to non-existent symbols. -#if GOOGLE_CUDA -#define GPU_DEVICE_IF_CUDA GPUDevice -#else -#define GPU_DEVICE_IF_CUDA CPUDevice -#endif - - Status status; - auto dtype = input_tensor->dtype(); - if (response.response_type() == MPIResponse::ALLGATHER) { - if (dtype == DT_FLOAT) { - status = on_gpu ? RingAllgather( - context, input_tensor, sizes_vec, output_tensor) - : RingAllgather( - context, input_tensor, sizes_vec, output_tensor); - } else if (dtype == DT_INT32) { - status = on_gpu ? RingAllgather( - context, input_tensor, sizes_vec, output_tensor) - : RingAllgather(context, input_tensor, - sizes_vec, output_tensor); - } else if (dtype == DT_INT64) { - status = on_gpu ? RingAllgather( - context, input_tensor, sizes_vec, output_tensor) - : RingAllgather( - context, input_tensor, sizes_vec, output_tensor); - } else { - status = errors::Unknown("Invalid tensor type for MPI allgather."); - } - } else if (response.response_type() == MPIResponse::ALLREDUCE) { - if (dtype == DT_FLOAT) { - status = on_gpu ? RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor) - : RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor); - } else if (dtype == DT_INT32) { - status = on_gpu ? RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor) - : RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor); - } else if (dtype == DT_INT64) { - status = on_gpu ? RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor) - : RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor); - } else { - status = errors::Unknown("Invalid tensor type for MPI allreduce."); - } - } else if (response.response_type() == MPIResponse::ERROR) { - status = errors::FailedPrecondition(response.error_message()); - } - - if (status.ok()) { - callback(StatusOr(*output_tensor)); - } else { - callback(StatusOr(status)); - } -} - -// The MPI background thread loop coordinates all the MPI processes and the -// tensor reductions. The design of the communicator mechanism is limited by a -// few considerations: -// -// 1. Some MPI implementations require all MPI calls to happen from a -// single thread. Since TensorFlow may use several threads for graph -// processing, this means we must have our own dedicated thread for -// dealing with MPI. -// 2. We want to gracefully handle errors, when MPI processes do not -// properly agree upon what should happen (such as mismatched types or -// shapes). To do so requires the MPI processes to know about the shapes -// and types of the relevant tensors on the other processes. -// 3. The MPI reductions and gathers should be able to happen in parallel -// with other ongoing operations. Since MPI uses an internal -// (inaccessible) GPU stream separate from the TF GPUDevice streams, we -// cannot explicitly synchronize memcpys or kernels with it. As a result, -// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper -// ordering of memcpys and kernels with respect to TF streams. -// 4. NOTE: We cannot guarantee that all the MPI processes reduce their -// tensors in the same order. Thus, there must be a way to ensure the -// reduction memcpys and kernels occur for correct tensors across all -// ranks at the same time. We choose to use a coordinator (rank ID 0) to -// gather and trigger the reduction operations that are ready to execute. -// -// The coordinator currently follows a master-worker paradigm. Rank zero acts -// as the master (the "coordinator"), whereas all other ranks are simply -// workers. Each rank runs its own background thread which progresses in ticks. -// In each tick, the following actions happen: -// -// a) The workers send any available MPIRequests to the coordinator. These -// MPIRequests indicate what the worker would like to do (i.e. which -// tensor they would like to gather or reduce, as well as their shape and -// type). They repeat this for every tensor that they would like to -// operate on after that tensor's collective op has executed ComputeAsync. -// -// b) The workers send an empty "DONE" message to the coordinator to -// indicate that there are no more tensors they wish to operate on. -// -// c) The coordinator receives the MPIRequests from the workers, as well -// as from its own TensorFlow ops, and stores them in a request table. The -// coordinator continues to receive MPIRequest messages until it has -// received MPI_SIZE number of empty "DONE" messages. -// -// d) The coordinator finds all tensors that are ready to be reduced, -// gathered, or all operations that result in an error. For each of those, -// it sends an MPIResponse to all the workers. When no more MPIResponses -// are available, it sends a "DONE" response to the workers. If the -// process is being shutdown, it instead sends a "SHUTDOWN" response. -// -// e) The workers listen for MPIResponse messages, processing each one by -// doing the required reduce or gather, until they receive a "DONE" -// response from the coordinator. At that point, the tick ends. -// If instead of "DONE" they receive "SHUTDOWN", they exit their -// background loop. -// TODO: Use the global mpi_global state variable instead of a local one -void BackgroundThreadLoop() { -#if GOOGLE_CUDA - // Set the device, so that this thread uses the same GPU context as the - // calling thread. - // TODO: Ensure that this is operating correctly. The background thread - // needs to be able to control all GPUs that the rank has access to, and - // might be more than 1 GPU. Tensors could be resident in any of the - // GPUs, so the background thread's accumulate and copy kernels might need - // to correctly set the device and it might be necessary for the background - // thread to manage multiple streams. - cudaSetDevice(mpi_global.device); - cudaStreamCreate(&mpi_global.stream); -#endif - - // Initialize MPI. This must happen on the background thread, since not all - // MPI implementations support being called from multiple threads. - auto init_result = MPI_Init(NULL, NULL); - if (init_result != MPI_SUCCESS) { - mpi_global.init_status = - errors::Unknown("Could not initialize MPI; MPI_Init() failed."); - mpi_global.initialization_done = true; - mpi_global.cv.notify_all(); - return; - } else { - mpi_global.init_status = Status::OK(); - } - - // Get MPI rank to determine if we are rank zero. - int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - bool is_coordinator = rank == 0; - - // Get MPI size to determine how many tensors to wait for before reducing. - int size; - MPI_Comm_size(MPI_COMM_WORLD, &size); - - // Determine local rank by querying the local communicator. - MPI_Comm local_comm; - MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, - &local_comm); - int local_rank; - MPI_Comm_rank(local_comm, &local_rank); - - mpi_global.rank = rank; - mpi_global.local_rank = local_rank; - mpi_global.size = size; - mpi_global.initialization_done = true; - - // Notify calling thread that initialization is complete - mpi_global.cv.notify_all(); - - // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD! - // Initialize the tensor count table. No tensors are available yet. - if (is_coordinator) { - mpi_global.message_table = - std::unique_ptr(new MessageTable()); - } - - // The coordinator sends a SHUTDOWN message to trigger shutdown. - bool should_shut_down = false; - do { - // TODO: Eliminate the need for thread sleep by making all activity - // depend on other activity (e.g. condition or MPI waits). - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - - // Copy the data structures from global state under this lock. - // However, don't keep the lock for the rest of the loop, so that - // enqueued stream callbacks can continue. - std::queue message_queue; - { - mutex_lock guard(mpi_global.mu); - while (!mpi_global.message_queue.empty()) { - MPIRequest message = mpi_global.message_queue.front(); - mpi_global.message_queue.pop(); - message_queue.push(message); - } - } - - // Collect all tensors that are ready to be reduced. Record them in the - // tensor count table (rank zero) or send them to rank zero to be - // recorded (everyone else). - std::vector ready_to_reduce; - while (!message_queue.empty()) { - // Pop the first available message message - MPIRequest message = message_queue.front(); - message_queue.pop(); - - if (is_coordinator) { - bool reduce = - IncrementTensorCount(mpi_global.message_table, message, size); - if (reduce) { - ready_to_reduce.push_back(message.tensor_name()); - } - } else { - std::string encoded_message; - message.SerializeToString(&encoded_message); - MPI_Send(encoded_message.c_str(), encoded_message.length() + 1, - MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); - } - } - - // Rank zero has put all its own tensors in the tensor count table. - // Now, it should count all the tensors that are coming from other - // ranks at this tick. It should keep getting tensors until it gets a - // DONE message from all the other ranks. - if (is_coordinator) { - // Count of DONE messages. Keep receiving messages until the number - // of messages is equal to the number of processes. Initialize to - // one since the coordinator is effectively done. - int completed_ranks = 1; - while (completed_ranks != size) { - MPI_Status status; - MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); - - // Find number of characters in message (including zero byte). - int source_rank = status.MPI_SOURCE; - int msg_length; - MPI_Get_count(&status, MPI_BYTE, &msg_length); - - // If the length is zero, this is a DONE message. - if (msg_length == 0) { - completed_ranks++; - MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD, - &status); - continue; - } - - // Get tensor name from MPI into an std::string. - char* buffer = new char[msg_length]; - MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY, - MPI_COMM_WORLD, &status); - std::string received_data(buffer); - delete[] buffer; - - MPIRequest received_message; - received_message.ParseFromString(received_data); - auto received_name = received_message.tensor_name(); - - bool reduce = IncrementTensorCount(mpi_global.message_table, - received_message, size); - if (reduce) { - ready_to_reduce.push_back(received_name); - } - } - - // At this point, rank zero should have a fully updated tensor - // count table and should know all the tensors that need to be - // reduced or gathered, and everyone else should have sent all - // their information to rank zero. We can now do reductions and - // gathers; rank zero will choose which ones and in what order, - // and will notify the other ranks before doing each reduction. - for (int i = 0; i < ready_to_reduce.size(); i++) { - // Notify all nodes which tensor we'd like to reduce now - auto name = ready_to_reduce[i]; - MPIResponse response = - ConstructMPIResponse(mpi_global.message_table, name); - - std::string encoded_response; - response.SerializeToString(&encoded_response); - for (int r = 1; r < size; r++) { - MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, - MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); - } - - // Perform the reduction. All nodes should end up performing - // the same reduction. - PerformCollectiveOp(mpi_global.tensor_table, response); - } - - // Notify all nodes that we are done with the reductions for this - // tick. - MPIResponse done_response; - should_shut_down = mpi_global.shut_down; - done_response.set_response_type( - mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE); - std::string encoded_response; - done_response.SerializeToString(&encoded_response); - for (int r = 1; r < size; r++) { - MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, - MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); - } - } else { - // Notify the coordinator that this node is done sending messages. - // A DONE message is encoded as a zero-length message. - MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); - - // Receive names for tensors to reduce from rank zero. Once we - // receive a empty DONE message, stop waiting for more names. - while (true) { - MPI_Status status; - MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); - - // Find number of characters in message (including zero byte). - int msg_length; - MPI_Get_count(&status, MPI_BYTE, &msg_length); - - // Get tensor name from MPI into an std::string. - char* buffer = new char[msg_length]; - MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD, - &status); - std::string received_message(buffer); - delete[] buffer; - - MPIResponse response; - response.ParseFromString(received_message); - if (response.response_type() == MPIResponse::DONE) { - // No more messages this tick - break; - } else if (response.response_type() == MPIResponse::SHUTDOWN) { - // No more messages this tick, and the background thread - // should shut down - should_shut_down = true; - break; - } else { - // Process the current message - PerformCollectiveOp(mpi_global.tensor_table, response); - } - } - } - } while (!should_shut_down); - - MPI_Finalize(); -} - -// Initialize MPI and start the MPI background thread. Ensure that this is -// only done once no matter how many times this function is called. -Status InitializeMPIOnce(bool gpu) { - // Ensure MPI is only initialized once. - if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status; - - mpi_global.device = -1; -#if GOOGLE_CUDA - if (gpu) { - cudaGetDevice(&mpi_global.device); - } -#endif - - // Start the MPI background thread, which assumes MPI is initialized - // TODO: Change this to a Tensorflow thread - mpi_global.background_thread = std::thread(BackgroundThreadLoop); - - // Wait to ensure that the background thread has finished initializing MPI - mutex_lock guard(mpi_global.mu); - mpi_global.cv.wait(guard); - if (!mpi_global.initialization_done) { - mpi_global.init_status = - errors::Unknown("Failed to wait for MPI initialization."); - } - - return mpi_global.init_status; -} - -// Check that MPI is initialized. -Status IsMPIInitialized() { - if (!mpi_global.initialization_done) { - return errors::FailedPrecondition( - "MPI has not been initialized; use tf.contrib.mpi.Session."); - } - return Status::OK(); -} - -// This function (called from the callback set up in MPIAll*Op::ComputeAsync) -// only adds the op's record into the local op queue (to track the op's -// progress), and sends a message to the coordinator indicating that this rank -// is ready to begin. The MPI background thread will handle the MPI message. -void EnqueueTensorCollective(CollectiveOpRecord record, - MPIRequest::RequestType rtype) { - const Tensor* input_tensor = record.in_t; - MPIRequest message; - message.set_request_rank(record.rank); - message.set_tensor_name(record.name); - message.set_tensor_type(record.dtype); - message.set_request_type(rtype); - input_tensor->shape().AsProto(message.mutable_tensor_shape()); - - mutex_lock guard(mpi_global.mu); - mpi_global.tensor_table.emplace(record.name, record); - mpi_global.message_queue.push(message); -} - -} // namespace - -#if GOOGLE_CUDA -cudaStream_t CudaStreamForMPI() { return mpi_global.stream; } -#endif - -// Op to initialize MPI in the current process. The settings used in the -// configuration are the same that must be used for all future MPI ops. -template -class MPIInitOp : public OpKernel { - public: - explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - bool on_gpu = IsGPUDevice(); - OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu)); - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU), - MPIInitOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), - MPIInitOp); -#endif - -// Op to get the current MPI Size. -template -class MPISizeOp : public OpKernel { - public: - explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - OP_REQUIRES_OK(context, IsMPIInitialized()); - - // Write integer to output tensor - Tensor* output; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - - auto flat = output->flat(); - flat(0) = mpi_global.size; - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU), - MPISizeOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), - MPISizeOp); -#endif - -// Op to get the current MPI Rank. -template -class MPIRankOp : public OpKernel { - public: - explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - OP_REQUIRES_OK(context, IsMPIInitialized()); - - // Write integer to output tensor - Tensor* output; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - - auto flat = output->flat(); - flat(0) = mpi_global.rank; - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU), - MPIRankOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), - MPIRankOp); -#endif - -// Op to get the current local MPI Rank. -template -class MPILocalRankOp : public OpKernel { - public: - explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - OP_REQUIRES_OK(context, IsMPIInitialized()); - - // Write integer to output tensor - Tensor* output; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - - auto flat = output->flat(); - flat(0) = mpi_global.local_rank; - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU), - MPILocalRankOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER( - Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"), - MPILocalRankOp); -#endif - -template -class MPIAllreduceOp : public AsyncOpKernel { - public: - explicit MPIAllreduceOp(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - - // Although this op is handled asynchronously, the ComputeAsync call is - // very inexpensive. It only sets up a CollectiveOpRecord and places it - // in the table for the background thread to handle. Thus, we do not need - // a TF pool thread to perform the op. - bool IsExpensive() override { return false; } - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { - OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); - const Tensor* input_tensor = &context->input(0); - Tensor* output_tensor; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output(0, input_tensor->shape(), &output_tensor), - done); - - // Record allocated on stack so op can fail without memory leak - CollectiveOpRecord record; - record.name = name(); - record.context = context; - record.in_t = input_tensor; - record.out_t = output_tensor; - record.on_gpu = IsGPUDevice(); - record.dtype = input_tensor->dtype(); - - const size_t temp_size = - (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size; - TensorShape temp_shape; - temp_shape.AddDim(temp_size); - OP_REQUIRES_OK_ASYNC(context, - context->allocate_temp(input_tensor->dtype(), - temp_shape, &record.temp_t), - done); - - auto allreduce_done_callback = [done, context](StatusOr status) { - context->SetStatus(status.status()); - done(); - }; - record.callback = allreduce_done_callback; - - auto allreduce_launch_callback = [record] { - EnqueueTensorCollective(record, MPIRequest::ALLREDUCE); - }; - - // If we are on a CPU, our device context will be null and we can't - // get a stream to enqueue this on. On a CPU this op is called when the - // data is already available, so we can just immediately do the - // allreduce; we don't have to wait for the data to get populated. -#if GOOGLE_CUDA - auto device_context = context->op_device_context(); - if (device_context == nullptr) { - allreduce_launch_callback(); - } else { - auto stream = device_context->stream(); - stream->ThenDoHostCallback(allreduce_launch_callback); - } -#else - allreduce_launch_callback(); -#endif - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU), - MPIAllreduceOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), - MPIAllreduceOp); -#endif - -template -class MPIAllgatherOp : public AsyncOpKernel { - public: - explicit MPIAllgatherOp(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - - // Although this op is handled asynchronously, the ComputeAsync call is - // very inexpensive. It only sets up a CollectiveOpRecord and places it - // in the table for the background thread to handle. Thus, we do not need - // a TF pool thread to perform the op. - bool IsExpensive() override { return false; } - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { - OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); - const Tensor* input_tensor = &context->input(0); - const Tensor* sizing_tensor = &context->input(1); - - // Record allocated on stack so op can fail without memory leak - CollectiveOpRecord record; - record.name = name(); - record.context = context; - record.in_t = input_tensor; - record.on_gpu = IsGPUDevice(); - - // Construct the output size from the sizing tensor - size_t output_first_dim = 0; - if (sizing_tensor->shape().dims() == 0) { - // 0-dim sizing_tensor implies that the op is just gathering - // a single element from each rank - output_first_dim = mpi_global.size; - for (int i = 0; i < mpi_global.size; i++) { - record.sizes_vec.push_back(1); - } - } else { - // Collect the total output tensor sizing from the sizing tensor - // NOTE: The sizing tensor is forced to be placed on the CPU by - // declaring the input as HostMemory, so it is valid to read it here. - const int64* sizing_array = - (const int64*)sizing_tensor->tensor_data().data(); - for (int i = 0; i < mpi_global.size; i++) { - record.sizes_vec.push_back(sizing_array[i]); - output_first_dim += sizing_array[i]; - } - } - - TensorShape output_shape; - output_shape.AddDim(output_first_dim); - for (int i = 1; i < input_tensor->shape().dims(); i++) { - output_shape.AddDim(input_tensor->shape().dim_size(i)); - } - - Tensor* output_tensor; - OP_REQUIRES_OK_ASYNC( - context, context->allocate_output(0, output_shape, &output_tensor), - done); - - record.out_t = output_tensor; - record.dtype = input_tensor->dtype(); - - auto allgather_done_callback = [done, context](StatusOr status) { - context->SetStatus(status.status()); - done(); - }; - record.callback = allgather_done_callback; - - auto allgather_launch_callback = [record] { - EnqueueTensorCollective(record, MPIRequest::ALLGATHER); - }; - - // If we are on a CPU, our device context will be null and we can't - // get a stream to enqueue this on. On a CPU this op is called when the - // data is already available, so we can just immediately do the - // allgather; we don't have to wait for the data to get populated. -#if GOOGLE_CUDA - auto device_context = context->op_device_context(); - if (device_context == nullptr) { - allgather_launch_callback(); - } else { - auto stream = device_context->stream(); - stream->ThenDoHostCallback(allgather_launch_callback); - } -#else - allgather_launch_callback(); -#endif - } -}; - -REGISTER_KERNEL_BUILDER( - Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"), - MPIAllgatherOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER( - Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"), - MPIAllgatherOp); -#endif - -} // namespace mpi_collectives -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cc deleted file mode 100644 index 8970ceb1a20..00000000000 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#define EIGEN_USE_THREADS - -#include "tensorflow/contrib/mpi_collectives/kernels/ring.h" - -namespace tensorflow { -namespace contrib { -namespace mpi_collectives { - -using CPUDevice = Eigen::ThreadPoolDevice; - -extern template MPI_Datatype MPIType(); -extern template MPI_Datatype MPIType(); -extern template MPI_Datatype MPIType(); -extern template DataType TensorFlowDataType(); -extern template DataType TensorFlowDataType(); -extern template DataType TensorFlowDataType(); - -// Generate all necessary specializations for RingAllreduce. -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); -template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); - -// Generate all necessary specializations for RingAllgather. -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); - -// Copy data on a CPU using a straight-forward memcpy. -template <> -void CopyTensorData(void* dst, void* src, size_t size) { - std::memcpy(dst, src, size); -}; - -// Accumulate values on a CPU. -#define GENERATE_ACCUMULATE(type) \ - template <> \ - void AccumulateTensorData(type * dst, type * src, \ - size_t size) { \ - for (unsigned int i = 0; i < size; i++) { \ - dst[i] += src[i]; \ - } \ - }; -GENERATE_ACCUMULATE(int); -GENERATE_ACCUMULATE(long long); -GENERATE_ACCUMULATE(float); -#undef GENERATE_ACCUMULATE - -} // namespace mpi_collectives -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc deleted file mode 100644 index 572e19cb904..00000000000 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow/contrib/mpi_collectives/kernels/ring.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" -#include "tensorflow/core/util/gpu_launch_config.h" - -namespace tensorflow { -namespace contrib { -namespace mpi_collectives { - -using CPUDevice = Eigen::ThreadPoolDevice; - -template <> -MPI_Datatype MPIType() { - return MPI_FLOAT; -}; -template <> -MPI_Datatype MPIType() { - return MPI_INT; -}; -template <> -MPI_Datatype MPIType() { - return MPI_LONG_LONG; -}; - -template <> -DataType TensorFlowDataType() { - return DT_FLOAT; -}; -template <> -DataType TensorFlowDataType() { - return DT_INT32; -}; -template <> -DataType TensorFlowDataType() { - return DT_INT64; -}; - -// Generate all necessary specializations for RingAllreduce. -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); -template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); - -// Generate all necessary specializations for RingAllgather. -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); - -// Synchronously copy data on the GPU, using a different stream than the default -// and than TensorFlow to avoid synchronizing on operations unrelated to the -// allreduce. -template <> -void CopyTensorData(void* dst, void* src, size_t size) { - auto stream = CudaStreamForMPI(); - cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); - cudaStreamSynchronize(stream); -}; - -// Elementwise accumulation kernel for GPU. -template -__global__ void elemwise_accum(T* out, const T* in, const size_t N) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - out[i] += in[i]; - } -} - -// Synchronously accumulate tensors on the GPU, using a different stream than -// the default and than TensorFlow to avoid synchronizing on operations -// unrelated to the allreduce. -#define GENERATE_ACCUMULATE(type) \ - template <> \ - void AccumulateTensorData(type * dst, type * src, \ - size_t size) { \ - auto stream = CudaStreamForMPI(); \ - TF_CHECK_OK(GpuLaunchKernel(elemwise_accum, 32, 256, 0, stream, dst, \ - src, size)); \ - cudaStreamSynchronize(stream); \ - }; -GENERATE_ACCUMULATE(int); -GENERATE_ACCUMULATE(long long); -GENERATE_ACCUMULATE(float); -#undef GENERATE_ACCUMULATE - -} // namespace mpi_collectives -} // namespace contrib -} // namespace tensorflow -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/kernels/ring.h b/tensorflow/contrib/mpi_collectives/kernels/ring.h deleted file mode 100644 index c001615d3ff..00000000000 --- a/tensorflow/contrib/mpi_collectives/kernels/ring.h +++ /dev/null @@ -1,327 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_MPI_H_ -#define TENSORFLOW_CONTRIB_MPI_H_ - -#ifdef TENSORFLOW_USE_MPI - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/tensor_types.h" - -#if GOOGLE_CUDA -#include "cuda_runtime.h" -#endif - -// Needed to avoid header issues with C++-supporting MPI implementations -#define OMPI_SKIP_MPICXX -#include "third_party/mpi/mpi.h" - -#define TAG_TENSOR 12 - -namespace tensorflow { -namespace contrib { -namespace mpi_collectives { - -using CPUDevice = Eigen::ThreadPoolDevice; -using GPUDevice = Eigen::GpuDevice; - -// Convert from templated types to values we can pass to MPI. -template -MPI_Datatype MPIType(); - -// Convert from templated types to TensorFlow data types. -template -DataType TensorFlowDataType(); - -#define MPI_REQUIRES_OK(MPI_STATUS) \ - if ((MPI_STATUS) != MPI_SUCCESS) { \ - return errors::Unknown("MPI operation failed unexpectedly."); \ - } - -// Copy data from one tensor to another tensor. -// This uses a custom CUDA stream on GPU, which is necessary to overlay the -// backpropagation computations with the allreduce. -template -void CopyTensorData(void* destination, void* source, size_t size); - -// Add a tensor into another tensor, accumulating in place. -// This uses a custom CUDA stream on GPU, which is necessary to overlay the -// backpropagation computations with the allreduce. -template -void AccumulateTensorData(T* destination, T* source, size_t size); - -// We need to get the right stream for doing CUDA memory transfers and -// operations, which is possibly different from the standard TensorFlow stream. -#if GOOGLE_CUDA -cudaStream_t CudaStreamForMPI(); -#endif - -/* Perform a ring allreduce on the data. Allocate the necessary output tensor - * and store it in the output parameter. - * - * Assumes that all MPI processes are doing an allreduce of the same tensor, - * with the same dimensions. - * - * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the - * allreduce, the nodes involved are arranged in a ring: - * - * .--0--. - * / \ - * 3 1 - * \ / - * *--2--* - * - * Each node always sends to the next clockwise node in the ring, and receives - * from the previous one. - * - * The allreduce is done in two parts: a scatter-reduce and an allgather. In - * the scatter reduce, a reduction is done, so that each node ends up with a - * chunk of the final output tensor which has contributions from all other - * nodes. In the allgather, those chunks are distributed among all the nodes, - * so that all nodes have the entire output tensor. - * - * Both of these operations are done by dividing the input tensor into N - * evenly sized chunks (where N is the number of nodes in the ring). - * - * The scatter-reduce is done in N-1 steps. In the ith step, node j will send - * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to - * its existing data for that chunk. For example, in the first iteration with - * the ring depicted above, you will have the following transfers: - * - * Segment 0: Node 0 --> Node 1 - * Segment 1: Node 1 --> Node 2 - * Segment 2: Node 2 --> Node 3 - * Segment 3: Node 3 --> Node 0 - * - * In the second iteration, you'll have the following transfers: - * - * Segment 0: Node 1 --> Node 2 - * Segment 1: Node 2 --> Node 3 - * Segment 2: Node 3 --> Node 0 - * Segment 3: Node 0 --> Node 1 - * - * After this iteration, Node 2 has 3 of the four contributions to Segment 0. - * The last iteration has the following transfers: - * - * Segment 0: Node 2 --> Node 3 - * Segment 1: Node 3 --> Node 0 - * Segment 2: Node 0 --> Node 1 - * Segment 3: Node 1 --> Node 2 - * - * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 - * has the fully accumulated Segment 1; and so on. The scatter-reduce is - * complete. - * - * Next, the allgather distributes these fully accumulated chunks across all - * nodes. Communication proceeds in the same ring, once again in N-1 steps. At - * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). - * For example, at the first iteration, the following transfers will occur: - * - * Segment 0: Node 3 --> Node 0 - * Segment 1: Node 0 --> Node 1 - * Segment 2: Node 1 --> Node 2 - * Segment 3: Node 2 --> Node 3 - * - * After the first iteration, Node 0 will have a fully accumulated Segment 0 - * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its - * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. - * After this has continued for N - 1 iterations, all nodes will have a the - * fully accumulated tensor. - * - * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the - * allgather. Each send will contain K / N bytes, if there are K bytes in the - * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N - * bytes of data, and the performance of the allreduce (assuming no latency in - * connections) is constrained by the slowest interconnect between the nodes. - * - */ -template -Status RingAllreduce(OpKernelContext* context, const Tensor* input, - Tensor* temp, Tensor* output) { - // Acquire MPI size and rank - int n, r; - MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); - MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); - - T* buffer = (T*)output->tensor_data().data(); - - CopyTensorData((void*)buffer, (void*)input->tensor_data().data(), - output->tensor_data().size()); - - // Calculate segment sizes and segment ends - const size_t elements_to_reduce = input->NumElements(); - const size_t segment_size = elements_to_reduce / n; - std::vector segment_sizes(n, segment_size); - - const size_t residual = elements_to_reduce % n; - for (size_t i = 0; i < residual; ++i) { - segment_sizes[i]++; - } - - std::vector segment_starts(n); - segment_starts[0] = 0; - for (size_t i = 1; i < segment_starts.size(); ++i) { - segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1]; - } - - assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce); - - T* segment_recv = (T*)temp->tensor_data().data(); - - // Receive from your left neighbor with wrap-around - const size_t recv_from = ((r - 1) + n) % n; - - // Send to your right neighbor with wrap-around - const size_t send_to = (r + 1) % n; - - MPI_Status recv_status; - MPI_Request recv_req; - - // Now start ring. At every step, for every rank, we iterate through - // segments with wraparound and send and recv from our neighbors and reduce - // locally. At the i'th iteration, rank r, sends segment (r-i) and receives - // segment (r-i-1). - for (int i = 0; i < n - 1; i++) { - const size_t send_seg_id = ((r - i) + n) % n; - const size_t recv_seg_id = ((r - i - 1) + n) % n; - - T* segment_send = &(buffer[segment_starts[send_seg_id]]); - - MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id], - MPIType(), recv_from, TAG_TENSOR, - MPI_COMM_WORLD, &recv_req)); - - MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id], - MPIType(), send_to, TAG_TENSOR, - MPI_COMM_WORLD)); - - T* segment_update = &(buffer[segment_starts[recv_seg_id]]); - - // Wait for recv to complete before reduction - MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status)); - - const size_t recv_seg_size = segment_sizes[recv_seg_id]; - AccumulateTensorData(segment_update, segment_recv, - recv_seg_size); - } - - // Now start pipelined ring allgather. At every step, for every rank, we - // iterate through segments with wraparound and send and recv from our - // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and - // receives segment (r-i). - for (size_t i = 0; i < n - 1; ++i) { - const size_t send_seg_id = ((r - i + 1) + n) % n; - const size_t recv_seg_id = ((r - i) + n) % n; - - // Segment to send - at every iteration we send segment (r-i+1) - T* segment_send = &(buffer[segment_starts[send_seg_id]]); - - // Segment to recv - at every iteration we receive segment (r-i) - T* segment_recv = &(buffer[segment_starts[recv_seg_id]]); - - MPI_REQUIRES_OK(MPI_Sendrecv( - segment_send, segment_sizes[send_seg_id], MPIType(), send_to, - TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType(), - recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); - } - - return Status::OK(); -} - -// Perform a ring allgather on a Tensor. Other ranks may allgather with a -// tensor which differs in the first dimension only; all other dimensions must -// be the same. -// -// For more information on the ring allgather, read the documentation for the -// ring allreduce, which includes a ring allgather. -template -Status RingAllgather(OpKernelContext* context, const Tensor* input, - const std::vector& sizes, Tensor* output) { - // Acquire MPI size and rank - int n, r; - MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); - MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); - - assert(sizes.size() == n); - assert(input->dim_size(0) == sizes[r]); - - // Compute number of elements in every "row". We can't compute number of - // elements in every chunks, because those chunks are variable length. - size_t elements_per_row = 1; - for (int i = 1; i < input->shape().dims(); i++) { - elements_per_row *= input->dim_size(i); - } - - // Copy data from input tensor to correct place in output tensor. - std::vector segment_starts(n); - segment_starts[0] = 0; - for (int i = 1; i < n; i++) { - segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; - } - size_t offset = segment_starts[r]; - - // Copy data to the right offset for this rank. - T* buffer = (T*)output->tensor_data().data(); - CopyTensorData((void*)(buffer + offset), - (void*)input->tensor_data().data(), - elements_per_row * sizes[r] * sizeof(T)); - - // Receive from your left neighbor with wrap-around - const size_t recv_from = ((r - 1) + n) % n; - - // Send to your right neighbor with wrap-around - const size_t send_to = (r + 1) % n; - - // Perform a ring allgather. At every step, for every rank, we iterate - // through segments with wraparound and send and recv from our neighbors. - // At the i'th iteration, rank r, sends segment (r-i) and receives segment - // (r-1-i). - MPI_Status recv_status; - for (size_t i = 0; i < n - 1; ++i) { - const size_t send_seg_id = ((r - i) + n) % n; - const size_t recv_seg_id = ((r - i - 1) + n) % n; - - // Segment to send - at every iteration we send segment (r-i) - size_t offset_send = segment_starts[send_seg_id]; - size_t rows_send = sizes[send_seg_id]; - T* segment_send = &(buffer[offset_send]); - - // Segment to recv - at every iteration we receive segment (r-1-i) - size_t offset_recv = segment_starts[recv_seg_id]; - size_t rows_recv = sizes[recv_seg_id]; - T* segment_recv = &(buffer[offset_recv]); - - MPI_REQUIRES_OK(MPI_Sendrecv( - segment_send, elements_per_row * rows_send, MPIType(), send_to, - TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType(), - recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); - } - - return Status::OK(); -} - -} // namespace mpi_collectives -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI - -#undef TENSORFLOW_CONTRIB_MPI_H_ -#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py b/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py deleted file mode 100644 index c23dd33d579..00000000000 --- a/tensorflow/contrib/mpi_collectives/mpi_allgather_test.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import numpy as np -import tensorflow as tf -import tensorflow.contrib.mpi_collectives as mpi -from tensorflow.python.platform import test - - -average_allgather = False - - -class AllgatherTest(test.TestCase): - def checkAllgather(self, num_ranks, all_gathered, local_gathered): - # Ensure that indices match. - all_gat_ind = np.sort(all_gathered.indices) - loc_gat_ind = np.sort(local_gathered.indices) - assert(len(loc_gat_ind) == len(all_gat_ind)) - for i in range(len(loc_gat_ind)): - assert(loc_gat_ind[i] == all_gat_ind[i]) - - # For each index, verify same values. - local_checked = [] - for i in range(len(local_gathered.indices)): - local_checked.append(False) - for i in range(len(all_gathered.indices)): - all_index = all_gathered.indices[i] - # TODO(jthestness): Make this lookup quicker using sorting. - loc_index = -1 - for j in range(len(local_gathered.indices)): - if local_gathered.indices[j] == all_index and not local_checked[j]: - loc_index = j - local_checked[j] = True - break - assert(loc_index >= 0) - correct_output = local_gathered.values[loc_index][0] - if average_allgather: - correct_output = correct_output / float(num_ranks) - assert(all_gathered.values[i][0] == correct_output) - - - def test_mpi_allgather(self): - # Get MPI rank - my_rank = int(os.environ['PMI_RANK']) - num_ranks = int(os.environ['PMI_SIZE']) - - indices_per_rank = 100 - tensor_width = 10 - - # Create IndexedSlices for each rank, some with overlapping indices. - to_gather_indices = [] - to_gather_values = [] - to_gather = [] - for rank_id in range(num_ranks): - indices = [] - values = [] - my_multiple = rank_id + 1 - current_index = my_multiple - for i in range(indices_per_rank): - indices.append(current_index) - ones_tensor = tf.ones([tensor_width]) - values.append(tf.multiply(ones_tensor, - tf.fill(ones_tensor.get_shape(), - float(current_index)))) - current_index += my_multiple - concat_ind = tf.stack(indices) - concat_vals = tf.stack(values) - to_gather_indices.append(concat_ind) - to_gather_values.append(concat_vals) - to_gather.append(tf.IndexedSlices(concat_vals, concat_ind)) - - # Collect the local IndexedSlices (indices and values) to create - # correct IndexedSlices output. - correct_gather_indices = tf.concat(to_gather_indices, 0) - correct_gather_values = tf.concat(to_gather_values, 0) - correct_gather = tf.IndexedSlices(correct_gather_values, - correct_gather_indices) - - all_gather = mpi.allreduce(to_gather[my_rank], average_allgather) - - # NOTE: This assumes that device IDs are numbered the same as ranks. - gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) - config = tf.ConfigProto(gpu_options=gpu_options) - - # MPI Session to test allgather. - with mpi.Session(config=config) as sess: - sess.run(tf.global_variables_initializer()) - - all_gathered, local_gathered = sess.run([all_gather, correct_gather]) - - # Compare all_gathered with local_gathered. - self.checkAllgather(num_ranks, all_gathered, local_gathered) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py deleted file mode 100644 index 001f9170bc0..00000000000 --- a/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import numpy as np -import tensorflow as tf -import tensorflow.contrib.mpi_collectives as mpi -from tensorflow.python.platform import test - - -average_allreduce = False -max_wrong_count = -1 - - -class AllreduceTest(test.TestCase): - def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red, - our_correct): - # Find reduced/allreduced indices that are wrong and print all the - # values from output, slices, reduced, allreduced, so we can debug - # which is incorrect: - wrong_count = 0 - red_dims = out_loc_red.shape - assert(len(red_dims) == 2) - for i in range(red_dims[0]): - for j in range(red_dims[1]): - suffix = "" - if out_loc_red[i][j] != my_correct[i][j] or \ - out_all_red[i][j] != our_correct[i][j]: - suffix = "WRONG" - wrong_count += 1 - print("{}\t{}\t{}\t{}\t{}\t{}" - .format(my_rank, i, j, out_loc_red[i][j], - out_all_red[i][j], suffix), flush=True) - if max_wrong_count > 0 and wrong_count >= max_wrong_count: - return - - def test_mpi_allreduce(self): - # Get MPI rank - my_rank = int(os.environ['PMI_RANK']) - num_ranks = int(os.environ['PMI_SIZE']) - - stages = 13 - batch_size = 1331 - hidden_size = batch_size - out_size = batch_size - - # Input placeholder (batch_size x hidden) - init to 1s - inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size), - name="Input") - - # Large matrices (hidden x out_dim) - init random - weights = [] - for i in range(stages): - initer = tf.constant_initializer(pow(2.0, i + 1.0)) - weights.append(tf.get_variable("weights_{}".format(i), - shape=(hidden_size, out_size), - dtype=tf.float32, - initializer=initer)) - - # Calculate output through dependent allreduces - stage_input = inputs - for i in range(stages): - inter_output = tf.add(stage_input, weights[i], - name="add_red_{}".format(i)) - stage_input = mpi.allreduce(inter_output, - average=average_allreduce) - - all_reduced = stage_input - - # Local reduced output for verification - local_input = inputs - for i in range(stages): - inter_output = tf.add(local_input, weights[i], - name="addin_loc_{}".format(i)) - my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)), - dtype=tf.float32, name="loc_redr_{}".format(i)) - for r in range(num_ranks): - my_reducer = tf.add(my_reducer, inter_output, - name="add_loc_{}_{}".format(i, r)) - if average_allreduce: - local_input = tf.div(my_reducer, num_ranks, - name="div_loc_{}".format(i)) - else: - local_input = my_reducer - - local_reduced = local_input - - # NOTE: This assumes that device IDs are numbered the same as ranks - gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) - config = tf.ConfigProto(gpu_options=gpu_options) - - # MPI Session to test allreduce - with mpi.Session(config=config) as sess: - sess.run(tf.global_variables_initializer()) - - input_feed = np.ones((batch_size, hidden_size), dtype=np.float32) - our_output = input_feed[0][0] - spread_var = 100 - input_feed = input_feed + my_rank * spread_var - my_output = input_feed[0][0] - for i in range(stages): - curr_feed = my_output + pow(2.0, i + 1.0) - my_output = curr_feed * num_ranks + 1 - curr_our_feed = our_output + pow(2.0, i + 1.0) - if i == 0: - sum_ranks = num_ranks * (num_ranks - 1) / 2 - our_output = curr_our_feed * num_ranks + \ - spread_var * sum_ranks - else: - our_output = curr_our_feed * num_ranks - - print("rank {}: My output is {}".format(my_rank, my_output)) - my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) - my_correct = my_correct + my_output - print("rank {}: Our output is {}".format(my_rank, our_output)) - our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) - our_correct = our_correct + our_output - - for i in range(1000): - if i % 100 == 0: - print("{}: iter {}".format(my_rank, i), flush=True) - feed_dict = {inputs: input_feed} - out_all_red, out_loc_red \ - = sess.run([all_reduced, local_reduced], - feed_dict=feed_dict) - - if not np.allclose(out_loc_red, my_correct) or \ - not np.allclose(out_all_red, our_correct): - print("Test incorrect on iter {}".format(i), flush=True) - self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red, - our_correct) - assert(np.allclose(out_loc_red, my_correct) and - np.allclose(out_all_red, our_correct)) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/mpi_collectives/mpi_message.proto b/tensorflow/contrib/mpi_collectives/mpi_message.proto deleted file mode 100644 index afbce981ae1..00000000000 --- a/tensorflow/contrib/mpi_collectives/mpi_message.proto +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto3"; - -package tensorflow.contrib.mpi_collectives; - -import "tensorflow/core/framework/tensor_shape.proto"; -import "tensorflow/core/framework/types.proto"; - -// An MPIRequest is a message sent from a rank greater than zero to the -// coordinator (rank zero), informing the coordinator of an operation that -// the rank wants to do and the tensor that it wants to apply the operation to. -message MPIRequest { - enum RequestType { - ALLREDUCE = 0; - ALLGATHER = 1; - } - - // The request rank is necessary to create a consistent ordering of results, - // for example in the allgather where the order of outputs should be sorted - // by rank. - int32 request_rank = 1; - RequestType request_type = 2; - DataType tensor_type = 3; - string tensor_name = 4; - TensorShapeProto tensor_shape = 5; -}; - -// An MPIResponse is a message sent from the coordinator (rank zero) to a rank -// greater than zero, informing the rank of an operation should be performed -// now. If the operation requested would result in an error (for example, due -// to a type or shape mismatch), then the MPIResponse can contain an error and -// an error message instead. Finally, an MPIResponse can be a DONE message (if -// there are no more tensors to reduce on this tick of the background loop) or -// SHUTDOWN if all MPI processes should shut down. -message MPIResponse { - enum ResponseType { - ALLREDUCE = 0; - ALLGATHER = 1; - ERROR = 2; - DONE = 3; - SHUTDOWN = 4; - } - - // Empty if the type is DONE or SHUTDOWN. - ResponseType response_type = 1; - string tensor_name = 2; - - // Empty unless response_type is ERROR. - string error_message = 3; -}; diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/mpi_ops.cc deleted file mode 100644 index 475297ca921..00000000000 --- a/tensorflow/contrib/mpi_collectives/mpi_ops.cc +++ /dev/null @@ -1,1236 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/mutex.h" - -#define EIGEN_USE_THREADS - -#if GOOGLE_CUDA -#include -#include "tensorflow/stream_executor/stream.h" -#endif - -#include "tensorflow/stream_executor/lib/statusor.h" - -#define OMPI_SKIP_MPICXX -#include "third_party/mpi/mpi.h" -#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" -#include "tensorflow/contrib/mpi_collectives/ring.h" - -/* - * MPI Allreduce and Allgather Ops for TensorFlow. - * - * TensorFlow natively provides inter-device communication through send and - * receive ops and inter-node communication through Distributed TensorFlow, - * based on the same send and receive abstractions. These end up being - * insufficient for synchronous data-parallel training on HPC clusters where - * Infiniband or other high-speed interconnects are available. This module - * implements MPI ops for allgather and allreduce, which do bandwidth-optimal - * gathers and reductions and can take advantage of hardware-optimized - * communication libraries through the MPI implementation. - * - * The primary logic of the allreduce and allgather are in RingAllgather() and - * RingAllreduce(). The background thread which facilitates MPI operations is - * run in BackgroundThreadLoop(). The provided MPI ops are: - * – MPIInit: - * Initialize MPI on a given device (CPU or GPU). - * Should only be run on a single device in every process. - * – MPISize: - * Get the number of MPI processes in the global communicator. - * – MPIRank: - * Get the rank of the current MPI process in the global communicator. - * – MPILocalRank: - * Get the local rank of the current MPI process within its node. - * – MPIAllreduce: - * Perform an allreduce on a Tensor, returning the sum - * across all MPI processes in the global communicator. - * – MPIAllgather: - * Perform an allgather on a Tensor, returning the concatenation of - * the tensor on the first dimension across all MPI processes in the - * global communicator. - * - */ - -template -using StatusOr = se::port::StatusOr; - -using CPUDevice = Eigen::ThreadPoolDevice; -using GPUDevice = Eigen::GpuDevice; - -namespace tensorflow { -namespace contrib { -namespace mpi { - -// Make sure template specializations are generated in the ring.cu.cc and the -// ring.cc file, not in this file. -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, - Tensor*, Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, - Tensor*, Tensor*); -extern template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -extern template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); -extern template Status RingAllgather( - OpKernelContext*, const Tensor*, const std::vector&, Tensor*); - -namespace { - -// Return true if the templated type is GPUDevice, otherwise false. -template -bool IsGPUDevice(); -template <> -bool IsGPUDevice() { - return true; -}; -template <> -bool IsGPUDevice() { - return false; -}; - -// A callback to call after the MPI communication completes. Since the -// allreduce and allgather ops are asynchronous, this callback is what resumes -// computation after the reduction is completed. -typedef std::function)> CommunicationDoneCallback; - -struct CollectiveOpRecord { - // The rank performing this piece of the op - int rank; - - // The name of the op/tensor to be reduced - std::string name; - - // The op's kernel context - OpKernelContext* context; - - // Data type of the op - DataType dtype; - - // The input tensor - const Tensor* in_t; - - // Allgather: Vector of per-rank first-dimension sizes - std::vector sizes_vec; - - // The temp tensor for intermediate results - Tensor temp_t; - - // The output tensor - Tensor* out_t; - - // Whether to run this op on the gpu - bool on_gpu; - - // The callback to call after the op has completed - CommunicationDoneCallback callback; -}; - -// Table storing Tensors to be reduced, keyed by unique name. -// This table contains everything necessary to do the reduction -typedef std::unordered_map TensorTable; - -// Table for storing Tensor metadata on rank zero. This is used for error -// checking and size calculations, as well as determining when a reduction is -// ready to be done (when all nodes are ready to do it). -typedef std::unordered_map > MessageTable; - -// The global state required for the MPI ops. -// -// MPI is a library that stores a lot of global per-program state and often -// requires running on a single thread. As a result, we have to have a single -// background thread responsible for all MPI operations, and communicate with -// that background thread through global state. -struct MPIGlobalState { - // An atomic boolean which is set to true when MPI is initialized. - // This ensures that MPI_Init is never called twice. - std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT; - - // Condition variable to wait for initialization - condition_variable cv; - - // Whether MPI_Init has been completed on the background thread. - bool initialization_done = false; - - // Whether MPI_Init succeeded on the background thread. - Status init_status; - - // A mutex that needs to be used whenever MPI operations touch - // shared structures. - mutex mu; - - // Tensors waiting to be allreduced or allgathered. - TensorTable tensor_table; - - // Queue of MPI requests waiting to be sent to the coordinator node. - std::queue message_queue; - - // Background thread running MPI communication. - std::thread background_thread; - - // Whether the background thread should shutdown. - bool shut_down = false; - - // Only exists on the coordinator node (rank zero). Maintains a count of - // how many nodes are ready to allreduce every tensor (keyed by tensor - // name). - std::unique_ptr message_table; - - // The MPI rank, local rank, and size. - int rank = 0; - int local_rank = 0; - int size = 1; - - // The device that MPI was initialized on. (-1 for no GPU) - int device = -1; - - // The CUDA stream used for data transfers and within-allreduce operations. - // A naive implementation would use the TensorFlow StreamExecutor CUDA - // stream. However, the allreduce and allgather require doing memory copies - // and kernel executions (for accumulation of values on the GPU). However, - // the subsequent operations must wait for those operations to complete, - // otherwise MPI (which uses its own stream internally) will begin the data - // transfers before the CUDA calls are complete. In order to wait for those - // CUDA operations, if we were using the TensorFlow stream, we would have - // to synchronize that stream; however, other TensorFlow threads may be - // submitting more work to that stream, so synchronizing on it can cause - // the allreduce to be delayed, waiting for compute totally unrelated to it - // in other parts of the graph. Overlaying memory transfers and compute - // during backpropagation is crucial for good performance, so we cannot use - // the TensorFlow stream, and must use our own stream. -#if GOOGLE_CUDA - cudaStream_t stream; - std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT; -#endif - - ~MPIGlobalState() { - // Make sure that the destructor of the background thread is safe to - // call. If a thread is still joinable (not detached or complete) its - // destructor cannot be called. - if (background_thread.joinable()) { - shut_down = true; - background_thread.join(); - } - } -}; - -// All the MPI state that must be stored globally per-process. -static MPIGlobalState mpi_global; - -// For clarify in argument lists. -#define RANK_ZERO 0 - -// A tag used for all coordinator messaging. -#define TAG_NOTIFY 1 - -// Store the MPIRequest for a name, and return whether the total count of -// MPIRequests for that tensor is now equal to the MPI size (and thus we are -// ready to reduce the tensor). -bool IncrementTensorCount(std::unique_ptr& message_table, - MPIRequest msg, int mpi_size) { - auto name = msg.tensor_name(); - auto table_iter = message_table->find(name); - if (table_iter == message_table->end()) { - message_table->emplace(name, std::vector({msg})); - table_iter = message_table->find(name); - } else { - table_iter->second.push_back(msg); - } - - int count = table_iter->second.size(); - return count == mpi_size; -} - -// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse -// instructing all ranks to start the reduction to all ranks. The MPIResponse -// also contains error messages in case the submitted MPIRequests were not -// valid (for example, contained mismatched shapes or types). -// -// Constructing the MPIResponse, thus, requires a whole lot of error checking. -MPIResponse ConstructMPIResponse(std::unique_ptr& message_table, - std::string name) { - bool error = false; - auto it = message_table->find(name); - assert(it != message_table->end()); - - std::vector requests = it->second; - assert(requests.size() > 0); - - std::ostringstream error_message_stream; - - // Check that all data types being reduced or gathered are identical - auto data_type = requests[0].tensor_type(); - for (unsigned int i = 1; i < requests.size(); i++) { - auto request_type = requests[i].tensor_type(); - if (data_type != request_type) { - error = true; - error_message_stream << "Mismatched data types: One rank had type " - << DataType_Name(data_type) - << ", but another rank had type " - << DataType_Name(request_type) << "."; - break; - } - } - - // Check that all requested operations are the same - auto message_type = requests[0].request_type(); - for (unsigned int i = 1; i < requests.size(); i++) { - if (error) { - break; - } - - auto request_type = requests[i].request_type(); - if (message_type != request_type) { - error = true; - error_message_stream << "Mismatched MPI operations: One rank did an " - << message_type << ", but another rank did an " - << request_type << "."; - break; - } - } - - // If we are doing an allreduce, check that all tensor shapes - // are identical - if (message_type == MPIRequest::ALLREDUCE) { - TensorShape tensor_shape = requests[0].tensor_shape(); - for (unsigned int i = 1; i < requests.size(); i++) { - if (error) { - break; - } - - TensorShape request_shape = requests[i].tensor_shape(); - if (tensor_shape != request_shape) { - error = true; - error_message_stream << "Mismatched allreduce tensor shapes: " - << "One rank reduced a tensor of shape " - << tensor_shape.DebugString() - << ", but another rank sent a tensor of shape " - << request_shape.DebugString() << "."; - break; - } - } - } - - // If we are doing an allgather, make sure all but the first dimension are - // the same. The first dimension may be different and the output tensor is - // the sum of the first dimension. Collect the sizes by rank. - if (message_type == MPIRequest::ALLGATHER) { - TensorShape tensor_shape = requests[0].tensor_shape(); - - if (tensor_shape.dims() == 0) { - error = true; - error_message_stream << "Rank zero tried to gather a rank-zero tensor."; - } - - for (unsigned int i = 1; i < requests.size(); i++) { - if (error) { - break; - } - - TensorShape request_shape = requests[i].tensor_shape(); - if (tensor_shape.dims() != request_shape.dims()) { - error = true; - error_message_stream << "Mismatched allgather tensor shapes: " - << "One rank gathered a tensor of rank " - << tensor_shape.dims() - << ", but another rank sent a tensor of rank " - << request_shape.dims() << "."; - break; - } - - for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) { - if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) { - error = true; - error_message_stream - << "Mismatched allgather tensor shapes: " - << "One rank gathered a tensor with dimension " << dim - << " equal to " << tensor_shape.dim_size(dim) - << ", but another rank sent a tensor with dimension " << dim - << " equal to " << request_shape.dim_size(dim) << "."; - break; - } - } - } - } - - MPIResponse response; - response.set_tensor_name(name); - if (error) { - std::string error_message = error_message_stream.str(); - response.set_response_type(MPIResponse::ERROR); - response.set_error_message(error_message); - } else { - auto response_type = MPIResponse::ERROR; - if (message_type == MPIRequest::ALLREDUCE) { - response_type = MPIResponse::ALLREDUCE; - } else { - response_type = MPIResponse::ALLGATHER; - } - response.set_response_type(response_type); - } - - // Clear all queued up requests for this name. They are now taken care of - // by the constructed MPI response. - message_table->erase(it); - - return response; -} - -// Process an MPIResponse by doing a reduction, a gather, or raising an error. -void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) { - OpKernelContext* context; - const Tensor* input_tensor; - std::vector sizes_vec; - Tensor temp_tensor; - Tensor* output_tensor; - CommunicationDoneCallback callback; - bool on_gpu; - { - // Lock on the tensor table. - mutex_lock guard(mpi_global.mu); - - // We should never fail at finding this key in the tensor table. - auto name = response.tensor_name(); - auto iter = tensor_table.find(name); - assert(iter != tensor_table.end()); - - assert(response.response_type() == MPIResponse::ALLREDUCE || - response.response_type() == MPIResponse::ALLGATHER || - response.response_type() == MPIResponse::ERROR); - - CollectiveOpRecord record = iter->second; - context = record.context; - input_tensor = record.in_t; - sizes_vec = record.sizes_vec; - temp_tensor = record.temp_t; - output_tensor = record.out_t; - on_gpu = record.on_gpu; - callback = record.callback; - - // Clear the tensor table of this tensor and its callbacks; the rest of - // this function takes care of it. - tensor_table.erase(iter); - } - - // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't - // link to non-existent symbols. -#if GOOGLE_CUDA -#define GPU_DEVICE_IF_CUDA GPUDevice -#else -#define GPU_DEVICE_IF_CUDA CPUDevice -#endif - - Status status; - auto dtype = input_tensor->dtype(); - if (response.response_type() == MPIResponse::ALLGATHER) { - if (dtype == DT_FLOAT) { - status = on_gpu ? RingAllgather( - context, input_tensor, sizes_vec, output_tensor) - : RingAllgather( - context, input_tensor, sizes_vec, output_tensor); - } else if (dtype == DT_INT32) { - status = on_gpu ? RingAllgather( - context, input_tensor, sizes_vec, output_tensor) - : RingAllgather(context, input_tensor, - sizes_vec, output_tensor); - } else if (dtype == DT_INT64) { - status = on_gpu ? RingAllgather( - context, input_tensor, sizes_vec, output_tensor) - : RingAllgather( - context, input_tensor, sizes_vec, output_tensor); - } else { - status = errors::Unknown("Invalid tensor type for MPI allgather."); - } - } else if (response.response_type() == MPIResponse::ALLREDUCE) { - if (dtype == DT_FLOAT) { - status = on_gpu ? RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor) - : RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor); - } else if (dtype == DT_INT32) { - status = on_gpu ? RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor) - : RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor); - } else if (dtype == DT_INT64) { - status = on_gpu ? RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor) - : RingAllreduce( - context, input_tensor, &temp_tensor, output_tensor); - } else { - status = errors::Unknown("Invalid tensor type for MPI allreduce."); - } - } else if (response.response_type() == MPIResponse::ERROR) { - status = errors::FailedPrecondition(response.error_message()); - } - - if (status.ok()) { - callback(StatusOr(*output_tensor)); - } else { - callback(StatusOr(status)); - } -} - -// The MPI background thread loop coordinates all the MPI processes and the -// tensor reductions. The design of the communicator mechanism is limited by a -// few considerations: -// -// 1. Some MPI implementations require all MPI calls to happen from a -// single thread. Since TensorFlow may use several threads for graph -// processing, this means we must have our own dedicated thread for -// dealing with MPI. -// 2. We want to gracefully handle errors, when MPI processes do not -// properly agree upon what should happen (such as mismatched types or -// shapes). To do so requires the MPI processes to know about the shapes -// and types of the relevant tensors on the other processes. -// 3. The MPI reductions and gathers should be able to happen in parallel -// with other ongoing operations. Since MPI uses an internal -// (inaccessible) GPU stream separate from the TF GPUDevice streams, we -// cannot explicitly synchronize memcpys or kernels with it. As a result, -// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper -// ordering of memcpys and kernels with respect to TF streams. -// 4. NOTE: We cannot guarantee that all the MPI processes reduce their -// tensors in the same order. Thus, there must be a way to ensure the -// reduction memcpys and kernels occur for correct tensors across all -// ranks at the same time. We choose to use a coordinator (rank ID 0) to -// gather and trigger the reduction operations that are ready to execute. -// -// The coordinator currently follows a master-worker paradigm. Rank zero acts -// as the master (the "coordinator"), whereas all other ranks are simply -// workers. Each rank runs its own background thread which progresses in ticks. -// In each tick, the following actions happen: -// -// a) The workers send any available MPIRequests to the coordinator. These -// MPIRequests indicate what the worker would like to do (i.e. which -// tensor they would like to gather or reduce, as well as their shape and -// type). They repeat this for every tensor that they would like to -// operate on after that tensor's collective op has executed ComputeAsync. -// -// b) The workers send an empty "DONE" message to the coordinator to -// indicate that there are no more tensors they wish to operate on. -// -// c) The coordinator receives the MPIRequests from the workers, as well -// as from its own TensorFlow ops, and stores them in a request table. The -// coordinator continues to receive MPIRequest messages until it has -// received MPI_SIZE number of empty "DONE" messages. -// -// d) The coordinator finds all tensors that are ready to be reduced, -// gathered, or all operations that result in an error. For each of those, -// it sends an MPIResponse to all the workers. When no more MPIResponses -// are available, it sends a "DONE" response to the workers. If the -// process is being shutdown, it instead sends a "SHUTDOWN" response. -// -// e) The workers listen for MPIResponse messages, processing each one by -// doing the required reduce or gather, until they receive a "DONE" -// response from the coordinator. At that point, the tick ends. -// If instead of "DONE" they receive "SHUTDOWN", they exit their -// background loop. -// TODO: Use the global mpi_global state variable instead of a local one -void BackgroundThreadLoop() { -#if GOOGLE_CUDA - // Set the device, so that this thread uses the same GPU context as the - // calling thread. - // TODO: Ensure that this is operating correctly. The background thread - // needs to be able to control all GPUs that the rank has access to, and - // might be more than 1 GPU. Tensors could be resident in any of the - // GPUs, so the background thread's accumulate and copy kernels might need - // to correctly set the device and it might be necessary for the background - // thread to manage multiple streams. - cudaSetDevice(mpi_global.device); - cudaStreamCreate(&mpi_global.stream); -#endif - - // Initialize MPI. This must happen on the background thread, since not all - // MPI implementations support being called from multiple threads. - auto init_result = MPI_Init(NULL, NULL); - if (init_result != MPI_SUCCESS) { - mpi_global.init_status = - errors::Unknown("Could not initialize MPI; MPI_Init() failed."); - mpi_global.initialization_done = true; - mpi_global.cv.notify_all(); - return; - } else { - mpi_global.init_status = Status::OK(); - } - - // Get MPI rank to determine if we are rank zero. - int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); - bool is_coordinator = rank == 0; - - // Get MPI size to determine how many tensors to wait for before reducing. - int size; - MPI_Comm_size(MPI_COMM_WORLD, &size); - - // Determine local rank by querying the local communicator. - MPI_Comm local_comm; - MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL, - &local_comm); - int local_rank; - MPI_Comm_rank(local_comm, &local_rank); - - mpi_global.rank = rank; - mpi_global.local_rank = local_rank; - mpi_global.size = size; - mpi_global.initialization_done = true; - - // Notify calling thread that initialization is complete - mpi_global.cv.notify_all(); - - // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD! - // Initialize the tensor count table. No tensors are available yet. - if (is_coordinator) { - mpi_global.message_table = - std::unique_ptr(new MessageTable()); - } - - // The coordinator sends a SHUTDOWN message to trigger shutdown. - bool should_shut_down = false; - do { - // TODO: Eliminate the need for thread sleep by making all activity - // depend on other activity (e.g. condition or MPI waits). - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - - // Copy the data structures from global state under this lock. - // However, don't keep the lock for the rest of the loop, so that - // enqueued stream callbacks can continue. - std::queue message_queue; - { - mutex_lock guard(mpi_global.mu); - while (!mpi_global.message_queue.empty()) { - MPIRequest message = mpi_global.message_queue.front(); - mpi_global.message_queue.pop(); - message_queue.push(message); - } - } - - // Collect all tensors that are ready to be reduced. Record them in the - // tensor count table (rank zero) or send them to rank zero to be - // recorded (everyone else). - std::vector ready_to_reduce; - while (!message_queue.empty()) { - // Pop the first available message message - MPIRequest message = message_queue.front(); - message_queue.pop(); - - if (is_coordinator) { - bool reduce = - IncrementTensorCount(mpi_global.message_table, message, size); - if (reduce) { - ready_to_reduce.push_back(message.tensor_name()); - } - } else { - std::string encoded_message; - message.SerializeToString(&encoded_message); - MPI_Send(encoded_message.c_str(), encoded_message.length() + 1, - MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); - } - } - - // Rank zero has put all its own tensors in the tensor count table. - // Now, it should count all the tensors that are coming from other - // ranks at this tick. It should keep getting tensors until it gets a - // DONE message from all the other ranks. - if (is_coordinator) { - // Count of DONE messages. Keep receiving messages until the number - // of messages is equal to the number of processes. Initialize to - // one since the coordinator is effectively done. - int completed_ranks = 1; - while (completed_ranks != size) { - MPI_Status status; - MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status); - - // Find number of characters in message (including zero byte). - int source_rank = status.MPI_SOURCE; - int msg_length; - MPI_Get_count(&status, MPI_BYTE, &msg_length); - - // If the length is zero, this is a DONE message. - if (msg_length == 0) { - completed_ranks++; - MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD, - &status); - continue; - } - - // Get tensor name from MPI into an std::string. - char* buffer = new char[msg_length]; - MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY, - MPI_COMM_WORLD, &status); - std::string received_data(buffer); - delete[] buffer; - - MPIRequest received_message; - received_message.ParseFromString(received_data); - auto received_name = received_message.tensor_name(); - - bool reduce = IncrementTensorCount(mpi_global.message_table, - received_message, size); - if (reduce) { - ready_to_reduce.push_back(received_name); - } - } - - // At this point, rank zero should have a fully updated tensor - // count table and should know all the tensors that need to be - // reduced or gathered, and everyone else should have sent all - // their information to rank zero. We can now do reductions and - // gathers; rank zero will choose which ones and in what order, - // and will notify the other ranks before doing each reduction. - for (int i = 0; i < ready_to_reduce.size(); i++) { - // Notify all nodes which tensor we'd like to reduce now - auto name = ready_to_reduce[i]; - MPIResponse response = - ConstructMPIResponse(mpi_global.message_table, name); - - std::string encoded_response; - response.SerializeToString(&encoded_response); - for (int r = 1; r < size; r++) { - MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, - MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); - } - - // Perform the reduction. All nodes should end up performing - // the same reduction. - PerformCollectiveOp(mpi_global.tensor_table, response); - } - - // Notify all nodes that we are done with the reductions for this - // tick. - MPIResponse done_response; - should_shut_down = mpi_global.shut_down; - done_response.set_response_type( - mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE); - std::string encoded_response; - done_response.SerializeToString(&encoded_response); - for (int r = 1; r < size; r++) { - MPI_Send(encoded_response.c_str(), encoded_response.length() + 1, - MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD); - } - } else { - // Notify the coordinator that this node is done sending messages. - // A DONE message is encoded as a zero-length message. - MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD); - - // Receive names for tensors to reduce from rank zero. Once we - // receive a empty DONE message, stop waiting for more names. - while (true) { - MPI_Status status; - MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status); - - // Find number of characters in message (including zero byte). - int msg_length; - MPI_Get_count(&status, MPI_BYTE, &msg_length); - - // Get tensor name from MPI into an std::string. - char* buffer = new char[msg_length]; - MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD, - &status); - std::string received_message(buffer); - delete[] buffer; - - MPIResponse response; - response.ParseFromString(received_message); - if (response.response_type() == MPIResponse::DONE) { - // No more messages this tick - break; - } else if (response.response_type() == MPIResponse::SHUTDOWN) { - // No more messages this tick, and the background thread - // should shut down - should_shut_down = true; - break; - } else { - // Process the current message - PerformCollectiveOp(mpi_global.tensor_table, response); - } - } - } - } while (!should_shut_down); - - MPI_Finalize(); -} - -// Initialize MPI and start the MPI background thread. Ensure that this is -// only done once no matter how many times this function is called. -Status InitializeMPIOnce(bool gpu) { - // Ensure MPI is only initialized once. - if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status; - - mpi_global.device = -1; -#if GOOGLE_CUDA - if (gpu) { - cudaGetDevice(&mpi_global.device); - } -#endif - - // Start the MPI background thread, which assumes MPI is initialized - // TODO: Change this to a Tensorflow thread - mpi_global.background_thread = std::thread(BackgroundThreadLoop); - - // Wait to ensure that the background thread has finished initializing MPI - mutex_lock guard(mpi_global.mu); - mpi_global.cv.wait(guard); - if (!mpi_global.initialization_done) { - mpi_global.init_status = - errors::Unknown("Failed to wait for MPI initialization."); - } - - return mpi_global.init_status; -} - -// Check that MPI is initialized. -Status IsMPIInitialized() { - if (!mpi_global.initialization_done) { - return errors::FailedPrecondition( - "MPI has not been initialized; use tf.contrib.mpi.Session."); - } - return Status::OK(); -} - -// This function (called from the callback set up in MPIAll*Op::ComputeAsync) -// only adds the op's record into the local op queue (to track the op's -// progress), and sends a message to the coordinator indicating that this rank -// is ready to begin. The MPI background thread will handle the MPI message. -void EnqueueTensorCollective(CollectiveOpRecord record, - MPIRequest::RequestType rtype) { - const Tensor* input_tensor = record.in_t; - MPIRequest message; - message.set_request_rank(record.rank); - message.set_tensor_name(record.name); - message.set_tensor_type(record.dtype); - message.set_request_type(rtype); - input_tensor->shape().AsProto(message.mutable_tensor_shape()); - - mutex_lock guard(mpi_global.mu); - mpi_global.tensor_table.emplace(record.name, record); - mpi_global.message_queue.push(message); -} - -} // namespace - -#if GOOGLE_CUDA -cudaStream_t CudaStreamForMPI() { return mpi_global.stream; } -#endif - -// Op to initialize MPI in the current process. The settings used in the -// configuration are the same that must be used for all future MPI ops. -template -class MPIInitOp : public OpKernel { - public: - explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - bool on_gpu = IsGPUDevice(); - OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu)); - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU), - MPIInitOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), - MPIInitOp); -#endif - -REGISTER_OP("MPIInit").Doc(R"doc( -Initialize MPI for the current process. - -If this is run on a GPU, then that GPU must be used for all future MPI -operations. If it is run on CPU, then all future MPI operations must also -run on CPU. -)doc"); - -// Op to get the current MPI Size. -template -class MPISizeOp : public OpKernel { - public: - explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - OP_REQUIRES_OK(context, IsMPIInitialized()); - - // Write integer to output tensor - Tensor* output; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - - auto flat = output->flat(); - flat(0) = mpi_global.size; - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU), - MPISizeOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), - MPISizeOp); -#endif - -REGISTER_OP("MPISize") - .Output("size: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the number of running MPI processes. - -More precisely, returns the number of MPI processes in the group associated -with the MPI_COMM_WORLD communicator. - -size: Size of the MPI group. -)doc"); - -// Op to get the current MPI Rank. -template -class MPIRankOp : public OpKernel { - public: - explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - OP_REQUIRES_OK(context, IsMPIInitialized()); - - // Write integer to output tensor - Tensor* output; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - - auto flat = output->flat(); - flat(0) = mpi_global.rank; - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU), - MPIRankOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), - MPIRankOp); -#endif - -REGISTER_OP("MPIRank") - .Output("rank: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the index of the current process in the MPI group. - -More precisely, returns the rank of the calling process in the MPI_COMM_WORLD -communicator. - -rank: Rank of the calling process. -)doc"); - -// Op to get the current local MPI Rank. -template -class MPILocalRankOp : public OpKernel { - public: - explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - OP_REQUIRES_OK(context, IsMPIInitialized()); - - // Write integer to output tensor - Tensor* output; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({}), &output)); - - auto flat = output->flat(); - flat(0) = mpi_global.local_rank; - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU), - MPILocalRankOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER( - Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"), - MPILocalRankOp); -#endif - -REGISTER_OP("MPILocalRank") - .Output("rank: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the index of the current process in the node it is on. - -More precisely, returns the rank of the calling process in communicator that -only spans the MPI processes running on that node. - -rank: Rank of the calling process on the node it is on. -)doc"); - -template -class MPIAllreduceOp : public AsyncOpKernel { - public: - explicit MPIAllreduceOp(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - - // Although this op is handled asynchronously, the ComputeAsync call is - // very inexpensive. It only sets up a CollectiveOpRecord and places it - // in the table for the background thread to handle. Thus, we do not need - // a TF pool thread to perform the op. - bool IsExpensive() override { return false; } - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { - OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); - const Tensor* input_tensor = &context->input(0); - Tensor* output_tensor; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output(0, input_tensor->shape(), &output_tensor), - done); - - // Record allocated on stack so op can fail without memory leak - CollectiveOpRecord record; - record.name = name(); - record.context = context; - record.in_t = input_tensor; - record.out_t = output_tensor; - record.on_gpu = IsGPUDevice(); - record.dtype = input_tensor->dtype(); - - const size_t temp_size = - (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size; - TensorShape temp_shape; - temp_shape.AddDim(temp_size); - OP_REQUIRES_OK_ASYNC(context, - context->allocate_temp(input_tensor->dtype(), - temp_shape, &record.temp_t), - done); - - auto allreduce_done_callback = [done, context](StatusOr status) { - context->SetStatus(status.status()); - done(); - }; - record.callback = allreduce_done_callback; - - auto allreduce_launch_callback = [record] { - EnqueueTensorCollective(record, MPIRequest::ALLREDUCE); - }; - - // If we are on a CPU, our device context will be null and we can't - // get a stream to enqueue this on. On a CPU this op is called when the - // data is already available, so we can just immediately do the - // allreduce; we don't have to wait for the data to get populated. -#if GOOGLE_CUDA - auto device_context = context->op_device_context(); - if (device_context == nullptr) { - allreduce_launch_callback(); - } else { - auto stream = device_context->stream(); - stream->ThenDoHostCallback(allreduce_launch_callback); - } -#else - allreduce_launch_callback(); -#endif - } -}; - -REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU), - MPIAllreduceOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), - MPIAllreduceOp); -#endif - -REGISTER_OP("MPIAllreduce") - .Attr("T: {int32, int64, float32}") - .Input("tensor: T") - .Output("sum: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); - }) - .Doc(R"doc( -Perform an MPI Allreduce on a tensor. All other processes that do a reduction -on a tensor with the same name must have the same dimension for that tensor. -Tensors are reduced with other tensors that have the same node name for the -allreduce. - -Arguments - tensor: A tensor to reduce. - -Output - sum: A tensor with the same shape as `tensor`, summed across all - MPI processes. -)doc"); - -template -class MPIAllgatherOp : public AsyncOpKernel { - public: - explicit MPIAllgatherOp(OpKernelConstruction* context) - : AsyncOpKernel(context) {} - - // Although this op is handled asynchronously, the ComputeAsync call is - // very inexpensive. It only sets up a CollectiveOpRecord and places it - // in the table for the background thread to handle. Thus, we do not need - // a TF pool thread to perform the op. - bool IsExpensive() override { return false; } - - void ComputeAsync(OpKernelContext* context, DoneCallback done) override { - OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done); - const Tensor* input_tensor = &context->input(0); - const Tensor* sizing_tensor = &context->input(1); - - // Record allocated on stack so op can fail without memory leak - CollectiveOpRecord record; - record.name = name(); - record.context = context; - record.in_t = input_tensor; - record.on_gpu = IsGPUDevice(); - - // Construct the output size from the sizing tensor - size_t output_first_dim = 0; - if (sizing_tensor->shape().dims() == 0) { - // 0-dim sizing_tensor implies that the op is just gathering - // a single element from each rank - output_first_dim = mpi_global.size; - for (int i = 0; i < mpi_global.size; i++) { - record.sizes_vec.push_back(1); - } - } else { - // Collect the total output tensor sizing from the sizing tensor - // NOTE: The sizing tensor is forced to be placed on the CPU by - // declaring the input as HostMemory, so it is valid to read it here. - const int64* sizing_array = - (const int64*)sizing_tensor->tensor_data().data(); - for (int i = 0; i < mpi_global.size; i++) { - record.sizes_vec.push_back(sizing_array[i]); - output_first_dim += sizing_array[i]; - } - } - - TensorShape output_shape; - output_shape.AddDim(output_first_dim); - for (int i = 1; i < input_tensor->shape().dims(); i++) { - output_shape.AddDim(input_tensor->shape().dim_size(i)); - } - - Tensor* output_tensor; - OP_REQUIRES_OK_ASYNC( - context, context->allocate_output(0, output_shape, &output_tensor), - done); - - record.out_t = output_tensor; - record.dtype = input_tensor->dtype(); - - auto allgather_done_callback = [done, context](StatusOr status) { - context->SetStatus(status.status()); - done(); - }; - record.callback = allgather_done_callback; - - auto allgather_launch_callback = [record] { - EnqueueTensorCollective(record, MPIRequest::ALLGATHER); - }; - - // If we are on a CPU, our device context will be null and we can't - // get a stream to enqueue this on. On a CPU this op is called when the - // data is already available, so we can just immediately do the - // allgather; we don't have to wait for the data to get populated. -#if GOOGLE_CUDA - auto device_context = context->op_device_context(); - if (device_context == nullptr) { - allgather_launch_callback(); - } else { - auto stream = device_context->stream(); - stream->ThenDoHostCallback(allgather_launch_callback); - } -#else - allgather_launch_callback(); -#endif - } -}; - -REGISTER_OP("MPIAllgather") - .Attr("T: {int32, int64, float32}") - .Attr("S: {int64}") - .Input("tensor: T") - .Input("sizes: S") - .Output("gathered: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle output; - TF_RETURN_IF_ERROR( - c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); - c->set_output(0, output); - return Status::OK(); - }) - .Doc(R"doc( -Perform an MPI Allgather on a tensor. All other processes that do a gather on a -tensor with the same name must have the same rank for that tensor, and have the -same dimension on all but the first dimension. - -Arguments - tensor: A tensor to gather. - sizes: A tensor containing the first-dimension sizes of tensors to be - gathered from other ranks - -Output - gathered: A tensor with the same shape as `tensor` except for the first - dimension, which is the sum of dimensions in `sizes`. -)doc"); - -REGISTER_KERNEL_BUILDER( - Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"), - MPIAllgatherOp); -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER( - Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"), - MPIAllgatherOp); -#endif - -} // namespace mpi -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/mpi_ops.py deleted file mode 100644 index bd7096d9cee..00000000000 --- a/tensorflow/contrib/mpi_collectives/mpi_ops.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""Inter-process communication using MPI.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from tensorflow.python.framework import errors -from tensorflow.python.framework import load_library -from tensorflow.python.framework import ops -from tensorflow.python.platform import resource_loader -from tensorflow.python.platform import tf_logging as logging - - -def _load_library(name, op_list=None): - """Loads a .so file containing the specified operators. - - Args: - name: The name of the .so file to load. - op_list: A list of names of operators that the library should have. If None - then the .so file's contents will not be verified. - - Raises: - NameError if one of the required ops is missing. - """ - try: - filename = resource_loader.get_path_to_datafile(name) - library = load_library.load_op_library(filename) - for expected_op in (op_list or []): - for lib_op in library.OP_LIST.op: - if lib_op.name == expected_op: - break - else: - raise NameError('Could not find operator %s in dynamic library %s' % - (expected_op, name)) - return library - except errors.NotFoundError: - logging.warning('%s file could not be loaded.', name) - - -MPI_LIB = _load_library( - 'mpi_collectives.so', - ['MPISize', 'MPIRank', 'MPILocalRank', 'MPIAllgather', 'MPIAllreduce']) - - -def size(name=None): - """An op which returns the number of MPI processes. - - This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the - size of the global communicator. - - Returns: - An integer scalar containing the number of MPI processes. - """ - return MPI_LIB.mpi_size(name=name) - - -ops.NotDifferentiable('MPISize') - - -def rank(name=None): - """An op which returns the MPI rank of the calling process. - - This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the - rank of the current process in the global communicator. - - Returns: - An integer scalar with the MPI rank of the calling process. - """ - return MPI_LIB.mpi_rank(name=name) - - -ops.NotDifferentiable('MPIRank') - - -def init(name=None): - """An op which initializes MPI on the device on which it is run. - - All future MPI ops must be run on the same device that the `init` op was run - on. - """ - return MPI_LIB.mpi_init(name=name) - - -ops.NotDifferentiable('MPIInit') - - -def local_rank(name=None): - """An op which returns the local MPI rank of the calling process, within the - node that it is running on. For example, if there are seven processes running - on a node, their local ranks will be zero through six, inclusive. - - This is equivalent to running `MPI_Comm_rank(...)` on a new communicator - which only includes processes on the same node. - - Returns: - An integer scalar with the local MPI rank of the calling process. - """ - return MPI_LIB.mpi_local_rank(name=name) - - -ops.NotDifferentiable('MPILocalRank') - - -def _allreduce(tensor, name=None): - """An op which sums an input tensor over all the MPI processes. - - The reduction operation is keyed by the name of the op. The tensor type and - shape must be the same on all MPI processes for a given name. The reduction - will not start until all processes are ready to send and receive the tensor. - - Returns: - A tensor of the same shape and type as `tensor`, summed across all - processes. - """ - return MPI_LIB.mpi_allreduce(tensor, name=name) - - -ops.NotDifferentiable('MPIAllreduce') - - -def allgather(tensor, name=None): - """An op which concatenates the input tensor with the same input tensor on - all other MPI processes. - - The concatenation is done on the first dimension, so the input tensors on the - different processes must have the same rank and shape, except for the first - dimension, which is allowed to be different. - - Returns: - A tensor of the same type as `tensor`, concatenated on dimension zero - across all processes. The shape is identical to the input shape, except for - the first dimension, which may be greater and is the sum of all first - dimensions of the tensors in different MPI processes. - """ - # Specify that first allgather is to collect the tensor gather sizes, - # indicated by passing in a scalar (0-D tensor) of value 0 - sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') - my_size = tf.slice( - tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') - if name is None: - name = 'allgather' - sizing_name = '{}_sizing'.format(name) - sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name) - return MPI_LIB.mpi_allgather(tensor, sizes, name=name) - - -ops.NotDifferentiable('MPIAllgather') diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops_test.py b/tensorflow/contrib/mpi_collectives/mpi_ops_test.py deleted file mode 100644 index 48e5c0a0c70..00000000000 --- a/tensorflow/contrib/mpi_collectives/mpi_ops_test.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -"""Tests for tensorflow.contrib.mpi_collectives.mpi_ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path -import itertools - -import tensorflow as tf - -import tensorflow.contrib.mpi_collectives as mpi - - -def mpi_env_rank_and_size(): - """Get MPI rank and size from environment variables and return them as a - tuple of integers. - - Most MPI implementations have an `mpirun` or `mpiexec` command that will - run an MPI executable and set up all communication necessary between the - different processors. As part of that set up, they will set environment - variables that contain the rank and size of the MPI_COMM_WORLD - communicator. We can read those environment variables from Python in order - to ensure that `mpi.rank()` and `mpi.size()` return the expected values. - - Since MPI is just a standard, not an implementation, implementations - typically choose their own environment variable names. This function tries - to support several different implementation, but really it only needs to - support whatever implementation we want to use for the TensorFlow test - suite. - - If this is not running under MPI, then defaults of rank zero and size one - are returned. (This is appropriate because when you call MPI_Init in an - application not started with mpirun, it will create a new independent - communicator with only one process in it.) - """ - rank_env = "PMI_RANK OMPI_COMM_WORLD_RANK".split() - size_env = "PMI_SIZE OMPI_COMM_WORLD_SIZE".split() - - for rank_var, size_var in zip(rank_env, size_env): - rank = os.environ.get(rank_var) - size = os.environ.get(size_var) - if rank is not None and size is not None: - return int(rank), int(size) - - # Default to rank zero and size one if there are no environment variables - return 0, 1 - - -class MPITests(tf.test.TestCase): - """ - Tests for MPI ops in tensorflow.contrib.mpi_collectives. - """ - - def test_mpi_rank(self): - """Test that the rank returned by mpi.rank() is correct.""" - true_rank, _ = mpi_env_rank_and_size() - with self.test_session() as session: - rank = session.run(mpi.rank()) - self.assertEqual(true_rank, rank) - - def test_mpi_size(self): - """Test that the size returned by mpi.size() is correct.""" - _, true_size = mpi_env_rank_and_size() - with self.test_session() as session: - size = session.run(mpi.size()) - self.assertEqual(true_size, size) - - def test_mpi_allreduce_cpu(self): - """Test on CPU that the allreduce correctly sums 1D, 2D, 3D tensors.""" - with self.test_session() as session: - size = session.run(mpi.size()) - - dtypes = [tf.int32, tf.float32] - dims = [1, 2, 3] - for dtype, dim in itertools.product(dtypes, dims): - tf.set_random_seed(1234) - tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) - summed = mpi.allreduce(tensor, average=False) - multiplied = tensor * size - max_difference = tf.reduce_max(tf.abs(summed - multiplied)) - - # Threshold for floating point equality depends on number of - # ranks, since we're comparing against precise multiplication. - if size <= 3: - threshold = 0 - elif size < 10: - threshold = 1e-4 - elif size < 15: - threshold = 5e-4 - else: - break - - diff = session.run(max_difference) - self.assertTrue(diff <= threshold, - "mpi.allreduce produces incorrect results") - - def test_mpi_allreduce_gpu(self): - """Test that the allreduce works on GPUs. - - This test will crash badly if used with an MPI implementation that does - not support GPU memory transfers directly, as it will call MPI_Send on - a GPU data pointer.""" - # Only do this test if there are GPUs available. - if not tf.test.is_gpu_available(cuda_only=True): - return - - no_gpus = tf.GPUOptions(visible_device_list="") - cpu_config = tf.ConfigProto(gpu_options=no_gpus) - with self.test_session(config=cpu_config) as session: - local_rank = session.run(mpi.local_rank()) - - one_gpu = tf.GPUOptions(visible_device_list=str(local_rank)) - gpu_config = tf.ConfigProto(gpu_options=one_gpu) - with self.test_session(config=gpu_config) as session: - size = session.run(mpi.size()) - - dtype = tf.float32 - dim = 3 - with tf.device("/gpu:0"): - tf.set_random_seed(1234) - tensor = tf.random_uniform([17] * dim, -100, 100, dtype=dtype) - summed = mpi.allreduce(tensor, average=False) - multiplied = tensor * size - max_difference = tf.reduce_max(tf.abs(summed - multiplied)) - - # Threshold for floating point equality depends on number of - # ranks, since we're comparing against precise multiplication. - if size <= 3: - threshold = 0 - elif size < 10: - threshold = 1e-4 - elif size < 15: - threshold = 5e-4 - else: - return - - diff = session.run(max_difference) - self.assertTrue(diff <= threshold, - "mpi.allreduce on GPU produces incorrect results") - - def test_mpi_allreduce_error(self): - """Test that the allreduce raises an error if different ranks try to - send tensors of different rank or dimension.""" - with self.test_session() as session: - rank = session.run(mpi.rank()) - size = session.run(mpi.size()) - - # This test does not apply if there is only one worker. - if size == 1: - return - - # Same rank, different dimension - tf.set_random_seed(1234) - dims = [17 + rank] * 3 - tensor = tf.random_uniform(dims, -1.0, 1.0) - with self.assertRaises(tf.errors.FailedPreconditionError): - session.run(mpi.allreduce(tensor)) - - # Same number of elements, different rank - tf.set_random_seed(1234) - if rank == 0: - dims = [17, 23 * 57] - else: - dims = [17, 23, 57] - tensor = tf.random_uniform(dims, -1.0, 1.0) - with self.assertRaises(tf.errors.FailedPreconditionError): - session.run(mpi.allreduce(tensor)) - - def test_mpi_allreduce_type_error(self): - """Test that the allreduce raises an error if different ranks try to - send tensors of different type.""" - with self.test_session() as session: - rank = session.run(mpi.rank()) - size = session.run(mpi.size()) - - # This test does not apply if there is only one worker. - if size == 1: - return - - # Same rank, different dimension - dims = [17] * 3 - tensor = tf.ones(dims, dtype=tf.int32 if rank % 2 == 0 else tf.float32) - with self.assertRaises(tf.errors.FailedPreconditionError): - session.run(mpi.allreduce(tensor)) - - def test_mpi_allgather(self): - """Test that the allgather correctly gathers 1D, 2D, 3D tensors.""" - with self.test_session() as session: - size = session.run(mpi.size()) - rank = session.run(mpi.rank()) - - dtypes = tf.int32, tf.float32 - dims = 1, 2, 3 - for dtype, dim in itertools.product(dtypes, dims): - tensor = tf.ones([17] * dim, dtype=dtype) * rank - gathered = mpi.allgather(tensor) - - gathered_tensor = session.run(gathered) - self.assertEqual(list(gathered_tensor.shape), - [17 * size] + [17] * (dim - 1)) - - for i in range(size): - rank_tensor = tf.slice(gathered_tensor, [i * 17] + [0] * (dim - 1), - [17] + [-1] * (dim - 1)) - self.assertEqual(list(rank_tensor.shape), [17] * dim) - self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))), - "mpi.allgather produces incorrect gathered tensor") - - def test_mpi_allgather_variable_size(self): - """Test that the allgather correctly gathers 1D, 2D, 3D tensors, - even if those tensors have different sizes along the first dim.""" - with self.test_session() as session: - size = session.run(mpi.size()) - rank = session.run(mpi.rank()) - - dtypes = tf.int32, tf.float32 - dims = 1, 2, 3 - for dtype, dim in itertools.product(dtypes, dims): - # Support tests up to MPI Size of 35 - if size > 35: - break - - tensor_sizes = [17, 32, 81, 12, 15, 23, 22] * 5 - tensor_sizes = tensor_sizes[:size] - - tensor = tf.ones([tensor_sizes[rank]] + [17] * (dim - 1), - dtype=dtype) * rank - gathered = mpi.allgather(tensor) - - gathered_tensor = session.run(gathered) - expected_size = sum(tensor_sizes) - self.assertEqual(list(gathered_tensor.shape), - [expected_size] + [17] * (dim - 1)) - - for i in range(size): - rank_size = [tensor_sizes[i]] + [17] * (dim - 1) - rank_tensor = tf.slice(gathered, - [sum(tensor_sizes[:i])] + [0] * (dim - 1), - rank_size) - self.assertEqual(list(rank_tensor.shape), rank_size) - self.assertTrue(session.run(tf.reduce_all(tf.equal(rank_tensor, i))), - "mpi.allgather produces incorrect gathered tensor") - - def test_mpi_allgather_error(self): - """Test that the allgather returns an error if any dimension besides - the first is different among the tensors being gathered.""" - with self.test_session() as session: - rank = session.run(mpi.rank()) - size = session.run(mpi.size()) - - # This test does not apply if there is only one worker. - if size == 1: - return - - tensor_size = [17] * 3 - tensor_size[1] = 10 * (rank + 1) - tensor = tf.ones(tensor_size, dtype=tf.float32) * rank - with self.assertRaises(tf.errors.FailedPreconditionError): - session.run(mpi.allgather(tensor)) - - def test_mpi_allgather_type_error(self): - """Test that the allgather returns an error if the types being gathered - differ among the processes""" - with self.test_session() as session: - rank = session.run(mpi.rank()) - size = session.run(mpi.size()) - - # This test does not apply if there is only one worker. - if size == 1: - return - - tensor_size = [17] * 3 - dtype = tf.int32 if rank % 2 == 0 else tf.float32 - tensor = tf.ones(tensor_size, dtype=dtype) * rank - with self.assertRaises(tf.errors.FailedPreconditionError): - session.run(mpi.allgather(tensor)) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc deleted file mode 100644 index 18e6bb61cff..00000000000 --- a/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { -namespace contrib { -namespace mpi_collectives { - -REGISTER_OP("MPIInit").Doc(R"doc( -Initialize MPI for the current process. - -If this is run on a GPU, then that GPU must be used for all future MPI -operations. If it is run on CPU, then all future MPI operations must also -run on CPU. -)doc"); - -REGISTER_OP("MPISize") - .Output("size: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the number of running MPI processes. - -More precisely, returns the number of MPI processes in the group associated -with the MPI_COMM_WORLD communicator. - -size: Size of the MPI group. -)doc"); - -REGISTER_OP("MPIRank") - .Output("rank: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the index of the current process in the MPI group. - -More precisely, returns the rank of the calling process in the MPI_COMM_WORLD -communicator. - -rank: Rank of the calling process. -)doc"); - -REGISTER_OP("MPILocalRank") - .Output("rank: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the index of the current process in the node it is on. - -More precisely, returns the rank of the calling process in communicator that -only spans the MPI processes running on that node. - -rank: Rank of the calling process on the node it is on. -)doc"); - -REGISTER_OP("MPIAllreduce") - .Attr("T: {int32, int64, float32}") - .Input("tensor: T") - .Output("sum: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); - }) - .Doc(R"doc( -Perform an MPI Allreduce on a tensor. All other processes that do a reduction -on a tensor with the same name must have the same dimension for that tensor. -Tensors are reduced with other tensors that have the same node name for the -allreduce. - -Arguments - tensor: A tensor to reduce. - -Output - sum: A tensor with the same shape as `tensor`, summed across all - MPI processes. -)doc"); - -REGISTER_OP("MPIAllgather") - .Attr("T: {int32, int64, float32}") - .Attr("S: {int64}") - .Input("tensor: T") - .Input("sizes: S") - .Output("gathered: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle output; - TF_RETURN_IF_ERROR( - c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); - c->set_output(0, output); - return Status::OK(); - }) - .Doc(R"doc( -Perform an MPI Allgather on a tensor. All other processes that do a gather on a -tensor with the same name must have the same rank for that tensor, and have the -same dimension on all but the first dimension. - -Arguments - tensor: A tensor to gather. - sizes: A tensor containing the first-dimension sizes of tensors to be - gathered from other ranks - -Output - gathered: A tensor with the same shape as `tensor` except for the first - dimension, which is the sum of dimensions in `sizes`. -)doc"); - -} // namespace mpi_collectives -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py deleted file mode 100644 index 2fbefef0d36..00000000000 --- a/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""Inter-process communication using MPI.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops -from tensorflow.contrib.util import loader -from tensorflow.python.framework import ops -from tensorflow.python.platform import resource_loader - -_mpi_ops_so = loader.load_op_library( - resource_loader.get_path_to_datafile('_mpi_ops.so')) - - -def size(name=None): - """An op which returns the number of MPI processes. - - This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the - size of the global communicator. - - Returns: - An integer scalar containing the number of MPI processes. - """ - return gen_mpi_ops.mpi_size(name=name) - - -ops.NotDifferentiable('MPISize') - - -def rank(name=None): - """An op which returns the MPI rank of the calling process. - - This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the - rank of the current process in the global communicator. - - Returns: - An integer scalar with the MPI rank of the calling process. - """ - return gen_mpi_ops.mpi_rank(name=name) - - -ops.NotDifferentiable('MPIRank') - - -def init(name=None): - """An op which initializes MPI on the device on which it is run. - - All future MPI ops must be run on the same device that the `init` op was run - on. - """ - return gen_mpi_ops.mpi_init(name=name) - - -ops.NotDifferentiable('MPIInit') - - -def local_rank(name=None): - """An op which returns the local MPI rank of the calling process, within the - node that it is running on. For example, if there are seven processes running - on a node, their local ranks will be zero through six, inclusive. - - This is equivalent to running `MPI_Comm_rank(...)` on a new communicator - which only includes processes on the same node. - - Returns: - An integer scalar with the local MPI rank of the calling process. - """ - return gen_mpi_ops.mpi_local_rank(name=name) - - -ops.NotDifferentiable('MPILocalRank') - - -def _allreduce(tensor, name=None): - """An op which sums an input tensor over all the MPI processes. - - The reduction operation is keyed by the name of the op. The tensor type and - shape must be the same on all MPI processes for a given name. The reduction - will not start until all processes are ready to send and receive the tensor. - - Returns: - A tensor of the same shape and type as `tensor`, summed across all - processes. - """ - return gen_mpi_ops.mpi_allreduce(tensor, name=name) - - -ops.NotDifferentiable('MPIAllreduce') - - -def allgather(tensor, name=None): - """An op which concatenates the input tensor with the same input tensor on - all other MPI processes. - - The concatenation is done on the first dimension, so the input tensors on the - different processes must have the same rank and shape, except for the first - dimension, which is allowed to be different. - - Returns: - A tensor of the same type as `tensor`, concatenated on dimension zero - across all processes. The shape is identical to the input shape, except for - the first dimension, which may be greater and is the sum of all first - dimensions of the tensors in different MPI processes. - """ - # Specify that first allgather is to collect the tensor gather sizes, - # indicated by passing in a scalar (0-D tensor) of value 0 - sizes_flag = tf.constant(0, dtype=tf.int64, name='size_flag_const') - my_size = tf.slice( - tf.shape(tensor, out_type=tf.int64), [0], [1], name='size_slice') - if name is None: - name = 'allgather' - sizing_name = '{}_sizing'.format(name) - sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name) - return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name) - - -ops.NotDifferentiable('MPIAllgather') diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/ring.cc deleted file mode 100644 index d93233eb210..00000000000 --- a/tensorflow/contrib/mpi_collectives/ring.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#define EIGEN_USE_THREADS - -#include "tensorflow/contrib/mpi_collectives/ring.h" - -namespace tensorflow { -namespace contrib { -namespace mpi { - -using CPUDevice = Eigen::ThreadPoolDevice; - -extern template MPI_Datatype MPIType(); -extern template MPI_Datatype MPIType(); -extern template MPI_Datatype MPIType(); -extern template DataType TensorFlowDataType(); -extern template DataType TensorFlowDataType(); -extern template DataType TensorFlowDataType(); - -// Generate all necessary specializations for RingAllreduce. -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); -template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); - -// Generate all necessary specializations for RingAllgather. -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); - -// Copy data on a CPU using a straight-forward memcpy. -template <> -void CopyTensorData(void* dst, void* src, size_t size) { - std::memcpy(dst, src, size); -}; - -// Accumulate values on a CPU. -#define GENERATE_ACCUMULATE(type) \ - template <> \ - void AccumulateTensorData(type * dst, type * src, \ - size_t size) { \ - for (unsigned int i = 0; i < size; i++) { \ - dst[i] += src[i]; \ - } \ - }; -GENERATE_ACCUMULATE(int); -GENERATE_ACCUMULATE(long long); -GENERATE_ACCUMULATE(float); -#undef GENERATE_ACCUMULATE - -} // namespace mpi -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/ring.cu.cc deleted file mode 100644 index 401d1caa514..00000000000 --- a/tensorflow/contrib/mpi_collectives/ring.cu.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef TENSORFLOW_USE_MPI - -#if GOOGLE_CUDA - -#define EIGEN_USE_GPU - -#include "tensorflow/contrib/mpi_collectives/ring.h" - -namespace tensorflow { -namespace contrib { -namespace mpi { - -using CPUDevice = Eigen::ThreadPoolDevice; - -template <> -MPI_Datatype MPIType() { - return MPI_FLOAT; -}; -template <> -MPI_Datatype MPIType() { - return MPI_INT; -}; -template <> -MPI_Datatype MPIType() { - return MPI_LONG_LONG; -}; - -template <> -DataType TensorFlowDataType() { - return DT_FLOAT; -}; -template <> -DataType TensorFlowDataType() { - return DT_INT32; -}; -template <> -DataType TensorFlowDataType() { - return DT_INT64; -}; - -// Generate all necessary specializations for RingAllreduce. -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); -template Status RingAllreduce(OpKernelContext*, - const Tensor*, Tensor*, - Tensor*); -template Status RingAllreduce(OpKernelContext*, const Tensor*, - Tensor*, Tensor*); - -// Generate all necessary specializations for RingAllgather. -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, - const Tensor*, - const std::vector&, - Tensor*); -template Status RingAllgather(OpKernelContext*, const Tensor*, - const std::vector&, - Tensor*); - -// Synchronously copy data on the GPU, using a different stream than the default -// and than TensorFlow to avoid synchronizing on operations unrelated to the -// allreduce. -template <> -void CopyTensorData(void* dst, void* src, size_t size) { - auto stream = CudaStreamForMPI(); - cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream); - cudaStreamSynchronize(stream); -}; - -// Elementwise accumulation kernel for GPU. -template -__global__ void elemwise_accum(T* out, const T* in, const size_t N) { - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - out[i] += in[i]; - } -} - -// Synchronously accumulate tensors on the GPU, using a different stream than -// the default and than TensorFlow to avoid synchronizing on operations -// unrelated to the allreduce. -#define GENERATE_ACCUMULATE(type) \ - template <> \ - void AccumulateTensorData(type * dst, type * src, \ - size_t size) { \ - auto stream = CudaStreamForMPI(); \ - TF_CHECK_OK(GpuLaunchKernel(elemwise_accum, 32, 256, 0, stream, dst, \ - src, size)); \ - cudaStreamSynchronize(stream); \ - }; -GENERATE_ACCUMULATE(int); -GENERATE_ACCUMULATE(long long); -GENERATE_ACCUMULATE(float); -#undef GENERATE_ACCUMULATE - -} // namespace mpi -} // namespace contrib -} // namespace tensorflow -#endif // GOOGLE_CUDA - -#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/ring.h deleted file mode 100644 index 9b5d52e1b64..00000000000 --- a/tensorflow/contrib/mpi_collectives/ring.h +++ /dev/null @@ -1,327 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_MPI_H_ -#define TENSORFLOW_CONTRIB_MPI_H_ - -#ifdef TENSORFLOW_USE_MPI - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/tensor_types.h" - -#if GOOGLE_CUDA -#include "cuda_runtime.h" -#endif - -// Needed to avoid header issues with C++-supporting MPI implementations -#define OMPI_SKIP_MPICXX -#include "third_party/mpi/mpi.h" - -#define TAG_TENSOR 12 - -namespace tensorflow { -namespace contrib { -namespace mpi { - -using CPUDevice = Eigen::ThreadPoolDevice; -using GPUDevice = Eigen::GpuDevice; - -// Convert from templated types to values we can pass to MPI. -template -MPI_Datatype MPIType(); - -// Convert from templated types to TensorFlow data types. -template -DataType TensorFlowDataType(); - -#define MPI_REQUIRES_OK(MPI_STATUS) \ - if ((MPI_STATUS) != MPI_SUCCESS) { \ - return errors::Unknown("MPI operation failed unexpectedly."); \ - } - -// Copy data from one tensor to another tensor. -// This uses a custom CUDA stream on GPU, which is necessary to overlay the -// backpropagation computations with the allreduce. -template -void CopyTensorData(void* destination, void* source, size_t size); - -// Add a tensor into another tensor, accumulating in place. -// This uses a custom CUDA stream on GPU, which is necessary to overlay the -// backpropagation computations with the allreduce. -template -void AccumulateTensorData(T* destination, T* source, size_t size); - -// We need to get the right stream for doing CUDA memory transfers and -// operations, which is possibly different from the standard TensorFlow stream. -#if GOOGLE_CUDA -cudaStream_t CudaStreamForMPI(); -#endif - -/* Perform a ring allreduce on the data. Allocate the necessary output tensor - * and store it in the output parameter. - * - * Assumes that all MPI processes are doing an allreduce of the same tensor, - * with the same dimensions. - * - * A ring allreduce is a bandwidth-optimal way to do an allreduce. To do the - * allreduce, the nodes involved are arranged in a ring: - * - * .--0--. - * / \ - * 3 1 - * \ / - * *--2--* - * - * Each node always sends to the next clockwise node in the ring, and receives - * from the previous one. - * - * The allreduce is done in two parts: a scatter-reduce and an allgather. In - * the scatter reduce, a reduction is done, so that each node ends up with a - * chunk of the final output tensor which has contributions from all other - * nodes. In the allgather, those chunks are distributed among all the nodes, - * so that all nodes have the entire output tensor. - * - * Both of these operations are done by dividing the input tensor into N - * evenly sized chunks (where N is the number of nodes in the ring). - * - * The scatter-reduce is done in N-1 steps. In the ith step, node j will send - * the (j - i)th chunk and receive the (j - i - 1)th chunk, adding it in to - * its existing data for that chunk. For example, in the first iteration with - * the ring depicted above, you will have the following transfers: - * - * Segment 0: Node 0 --> Node 1 - * Segment 1: Node 1 --> Node 2 - * Segment 2: Node 2 --> Node 3 - * Segment 3: Node 3 --> Node 0 - * - * In the second iteration, you'll have the following transfers: - * - * Segment 0: Node 1 --> Node 2 - * Segment 1: Node 2 --> Node 3 - * Segment 2: Node 3 --> Node 0 - * Segment 3: Node 0 --> Node 1 - * - * After this iteration, Node 2 has 3 of the four contributions to Segment 0. - * The last iteration has the following transfers: - * - * Segment 0: Node 2 --> Node 3 - * Segment 1: Node 3 --> Node 0 - * Segment 2: Node 0 --> Node 1 - * Segment 3: Node 1 --> Node 2 - * - * After this iteration, Node 3 has the fully accumulated Segment 0; Node 0 - * has the fully accumulated Segment 1; and so on. The scatter-reduce is - * complete. - * - * Next, the allgather distributes these fully accumulated chunks across all - * nodes. Communication proceeds in the same ring, once again in N-1 steps. At - * the ith step, node j will send chunk (j - i + 1) and receive chunk (j - i). - * For example, at the first iteration, the following transfers will occur: - * - * Segment 0: Node 3 --> Node 0 - * Segment 1: Node 0 --> Node 1 - * Segment 2: Node 1 --> Node 2 - * Segment 3: Node 2 --> Node 3 - * - * After the first iteration, Node 0 will have a fully accumulated Segment 0 - * (from Node 3) and Segment 1. In the next iteration, Node 0 will send its - * just-received Segment 0 onward to Node 1, and receive Segment 3 from Node 3. - * After this has continued for N - 1 iterations, all nodes will have a the - * fully accumulated tensor. - * - * Each node will do (N-1) sends for the scatter-reduce and (N-1) sends for the - * allgather. Each send will contain K / N bytes, if there are K bytes in the - * original tensor on every node. Thus, each node sends and receives 2K(N - 1)/N - * bytes of data, and the performance of the allreduce (assuming no latency in - * connections) is constrained by the slowest interconnect between the nodes. - * - */ -template -Status RingAllreduce(OpKernelContext* context, const Tensor* input, - Tensor* temp, Tensor* output) { - // Acquire MPI size and rank - int n, r; - MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); - MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); - - T* buffer = (T*)output->tensor_data().data(); - - CopyTensorData((void*)buffer, (void*)input->tensor_data().data(), - output->tensor_data().size()); - - // Calculate segment sizes and segment ends - const size_t elements_to_reduce = input->NumElements(); - const size_t segment_size = elements_to_reduce / n; - std::vector segment_sizes(n, segment_size); - - const size_t residual = elements_to_reduce % n; - for (size_t i = 0; i < residual; ++i) { - segment_sizes[i]++; - } - - std::vector segment_starts(n); - segment_starts[0] = 0; - for (size_t i = 1; i < segment_starts.size(); ++i) { - segment_starts[i] = segment_starts[i - 1] + segment_sizes[i - 1]; - } - - assert(segment_starts[n - 1] + segment_sizes[n - 1] == elements_to_reduce); - - T* segment_recv = (T*)temp->tensor_data().data(); - - // Receive from your left neighbor with wrap-around - const size_t recv_from = ((r - 1) + n) % n; - - // Send to your right neighbor with wrap-around - const size_t send_to = (r + 1) % n; - - MPI_Status recv_status; - MPI_Request recv_req; - - // Now start ring. At every step, for every rank, we iterate through - // segments with wraparound and send and recv from our neighbors and reduce - // locally. At the i'th iteration, rank r, sends segment (r-i) and receives - // segment (r-i-1). - for (int i = 0; i < n - 1; i++) { - const size_t send_seg_id = ((r - i) + n) % n; - const size_t recv_seg_id = ((r - i - 1) + n) % n; - - T* segment_send = &(buffer[segment_starts[send_seg_id]]); - - MPI_REQUIRES_OK(MPI_Irecv(segment_recv, segment_sizes[recv_seg_id], - MPIType(), recv_from, TAG_TENSOR, - MPI_COMM_WORLD, &recv_req)); - - MPI_REQUIRES_OK(MPI_Send(segment_send, segment_sizes[send_seg_id], - MPIType(), send_to, TAG_TENSOR, - MPI_COMM_WORLD)); - - T* segment_update = &(buffer[segment_starts[recv_seg_id]]); - - // Wait for recv to complete before reduction - MPI_REQUIRES_OK(MPI_Wait(&recv_req, &recv_status)); - - const size_t recv_seg_size = segment_sizes[recv_seg_id]; - AccumulateTensorData(segment_update, segment_recv, - recv_seg_size); - } - - // Now start pipelined ring allgather. At every step, for every rank, we - // iterate through segments with wraparound and send and recv from our - // neighbors. At the i'th iteration, rank r, sends segment (r-i+1) and - // receives segment (r-i). - for (size_t i = 0; i < n - 1; ++i) { - const size_t send_seg_id = ((r - i + 1) + n) % n; - const size_t recv_seg_id = ((r - i) + n) % n; - - // Segment to send - at every iteration we send segment (r-i+1) - T* segment_send = &(buffer[segment_starts[send_seg_id]]); - - // Segment to recv - at every iteration we receive segment (r-i) - T* segment_recv = &(buffer[segment_starts[recv_seg_id]]); - - MPI_REQUIRES_OK(MPI_Sendrecv( - segment_send, segment_sizes[send_seg_id], MPIType(), send_to, - TAG_TENSOR, segment_recv, segment_sizes[recv_seg_id], MPIType(), - recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); - } - - return Status::OK(); -} - -// Perform a ring allgather on a Tensor. Other ranks may allgather with a -// tensor which differs in the first dimension only; all other dimensions must -// be the same. -// -// For more information on the ring allgather, read the documentation for the -// ring allreduce, which includes a ring allgather. -template -Status RingAllgather(OpKernelContext* context, const Tensor* input, - const std::vector& sizes, Tensor* output) { - // Acquire MPI size and rank - int n, r; - MPI_REQUIRES_OK(MPI_Comm_size(MPI_COMM_WORLD, &n)); - MPI_REQUIRES_OK(MPI_Comm_rank(MPI_COMM_WORLD, &r)); - - assert(sizes.size() == n); - assert(input->dim_size(0) == sizes[r]); - - // Compute number of elements in every "row". We can't compute number of - // elements in every chunks, because those chunks are variable length. - size_t elements_per_row = 1; - for (int i = 1; i < input->shape().dims(); i++) { - elements_per_row *= input->dim_size(i); - } - - // Copy data from input tensor to correct place in output tensor. - std::vector segment_starts(n); - segment_starts[0] = 0; - for (int i = 1; i < n; i++) { - segment_starts[i] = segment_starts[i - 1] + elements_per_row * sizes[i - 1]; - } - size_t offset = segment_starts[r]; - - // Copy data to the right offset for this rank. - T* buffer = (T*)output->tensor_data().data(); - CopyTensorData((void*)(buffer + offset), - (void*)input->tensor_data().data(), - elements_per_row * sizes[r] * sizeof(T)); - - // Receive from your left neighbor with wrap-around - const size_t recv_from = ((r - 1) + n) % n; - - // Send to your right neighbor with wrap-around - const size_t send_to = (r + 1) % n; - - // Perform a ring allgather. At every step, for every rank, we iterate - // through segments with wraparound and send and recv from our neighbors. - // At the i'th iteration, rank r, sends segment (r-i) and receives segment - // (r-1-i). - MPI_Status recv_status; - for (size_t i = 0; i < n - 1; ++i) { - const size_t send_seg_id = ((r - i) + n) % n; - const size_t recv_seg_id = ((r - i - 1) + n) % n; - - // Segment to send - at every iteration we send segment (r-i) - size_t offset_send = segment_starts[send_seg_id]; - size_t rows_send = sizes[send_seg_id]; - T* segment_send = &(buffer[offset_send]); - - // Segment to recv - at every iteration we receive segment (r-1-i) - size_t offset_recv = segment_starts[recv_seg_id]; - size_t rows_recv = sizes[recv_seg_id]; - T* segment_recv = &(buffer[offset_recv]); - - MPI_REQUIRES_OK(MPI_Sendrecv( - segment_send, elements_per_row * rows_send, MPIType(), send_to, - TAG_TENSOR, segment_recv, elements_per_row * rows_recv, MPIType(), - recv_from, TAG_TENSOR, MPI_COMM_WORLD, &recv_status)); - } - - return Status::OK(); -} - -} // namespace mpi -} // namespace contrib -} // namespace tensorflow - -#endif // TENSORFLOW_USE_MPI - -#undef TENSORFLOW_CONTRIB_MPI_H_ -#endif // TENSORFLOW_CONTRIB_MPI_H_ diff --git a/tensorflow/contrib/nn/python/ops/alpha_dropout.py b/tensorflow/contrib/nn/python/ops/alpha_dropout.py index 2b64a78c223..ad9f223f302 100644 --- a/tensorflow/contrib/nn/python/ops/alpha_dropout.py +++ b/tensorflow/contrib/nn/python/ops/alpha_dropout.py @@ -19,12 +19,11 @@ from __future__ import print_function import numbers from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name @@ -61,7 +60,7 @@ def alpha_dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylin keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") - keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) + keep_prob.get_shape().assert_has_rank(0) # Do nothing if we know keep_prob == 1 if tensor_util.constant_value(keep_prob) == 1: diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer.py b/tensorflow/contrib/opt/python/training/lars_optimizer.py index bc18177b6d0..0c06f4d7f36 100644 --- a/tensorflow/contrib/opt/python/training/lars_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lars_optimizer.py @@ -113,28 +113,30 @@ class LARSOptimizer(optimizer.Optimizer): (g_norm + self._weight_decay * w_norm + self._epsilon)), 1.0), 1.0) scaled_lr = self._learning_rate * trust_ratio - return scaled_lr + # Add the weight regularization gradient + grad = grad + self._weight_decay * var + return scaled_lr, grad def _apply_dense(self, grad, var): - scaled_lr = self.compute_lr(grad, var) + scaled_lr, grad = self.compute_lr(grad, var) mom = self.get_slot(var, "momentum") return training_ops.apply_momentum( var, mom, - scaled_lr, - grad, + math_ops.cast(1.0, var.dtype.base_dtype), + grad * scaled_lr, self._momentum, use_locking=False, use_nesterov=self._use_nesterov) def _resource_apply_dense(self, grad, var): - scaled_lr = self.compute_lr(grad, var) + scaled_lr, grad = self.compute_lr(grad, var) mom = self.get_slot(var, "momentum") return training_ops.resource_apply_momentum( var.handle, mom.handle, - scaled_lr, - grad, + math_ops.cast(1.0, var.dtype.base_dtype), + grad * scaled_lr, self._momentum, use_locking=False, use_nesterov=self._use_nesterov) diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py index b76db763da0..8c135a21bc2 100644 --- a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py @@ -67,9 +67,10 @@ class LARSOptimizerTest(test.TestCase): g_norm = np.linalg.norm(grad_np.flatten(), ord=2) trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np) scaled_lr = lr_np * trust_ratio + grad_np = grad_np + wd_np * var_np - vel_np = m_np * vel_np + grad_np - var_np -= scaled_lr * vel_np + vel_np = m_np * vel_np + scaled_lr * grad_np + var_np -= vel_np self.assertAllClose(var_np, post_var) self.assertAllClose(vel_np, post_vel) @@ -115,9 +116,10 @@ class LARSOptimizerTest(test.TestCase): g_norm = np.linalg.norm(grad_np.flatten(), ord=2) trust_ratio = eeta * w_norm / (g_norm + wd_np * w_norm + ep_np) scaled_lr = lr_np * trust_ratio + grad_np = grad_np + wd_np * var_np - vel_np = m_np * vel_np + grad_np - var_np -= scaled_lr * vel_np + vel_np = m_np * vel_np + scaled_lr * grad_np + var_np -= vel_np self.assertAllClose(var_np, post_var) self.assertAllClose(vel_np, post_vel) diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer.py b/tensorflow/contrib/opt/python/training/nadam_optimizer.py index 960826407b6..046c6ee83fd 100644 --- a/tensorflow/contrib/opt/python/training/nadam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/nadam_optimizer.py @@ -24,14 +24,37 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import adam from tensorflow.python.training import training_ops +from tensorflow.python.util import deprecation class NadamOptimizer(adam.AdamOptimizer): """Optimizer that implements the Nadam algorithm. See [Dozat, T., 2015](http://cs229.stanford.edu/proj2015/054_report.pdf). + + WARNING: due to a known issue this optimizer does not use nesterov momentum + on TPUs or when using XLA in general. This is deprecated; instead prefer + tf.keras.optimizers.Nadam which does the right thing. """ + @deprecation.deprecated( + None, "WARNING: wrong behavior with XLA. Use tf.keras.optimizers.Nadam.") + def __init__( + self, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + use_locking=False, + name="Adam"): + super(NadamOptimizer, self).__init__( + learning_rate=learning_rate, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + use_locking=use_locking, + name=name) + def _apply_dense(self, grad, var): m = self.get_slot(var, "m") v = self.get_slot(var, "v") diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py index e2bcee51130..233503b911e 100644 --- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -356,10 +356,10 @@ class MomentumWOptimizer(DecoupledWeightDecayExtension, class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): """Optimizer that implements the Adam algorithm with weight decay. - This is an implementation of the AdamW optimizer described in "Fixing - Weight Decay Regularization in Adam" by Loshchilov & Hutter + This is an implementation of the AdamW optimizer described in ["Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter] (https://arxiv.org/abs/1711.05101) - ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + ([pdf](https://arxiv.org/pdf/1711.05101.pdf)). It computes the update step of `train.AdamOptimizer` and additionally decays the variable. Note that this is different from adding L2 regularization on diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index f61e28bbc7e..a90647deed0 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -39,7 +39,8 @@ _RELU_TYPES = {'Relu', 'Relu6'} _QUANTIZATION_OP = {'FakeQuantWithMinMaxVars'} _VALID_SRC_OP = {'Add', 'AddV2', 'Mul'} _INTERMEDIATE_OP = {'Add', 'AddV2', 'Mul'} -_PASS_THROUGH_OP = {'Reshape', 'Identity', 'BatchToSpaceND', 'SpaceToBatchND'} +_PASS_THROUGH_OP = {'Reshape', 'Identity', 'BatchToSpaceND', 'SpaceToBatchND', + 'MaxPool', 'Max'} _VALID_ACTIVATION_OP = {'Relu', 'Relu6'} diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD index c98ae649f3e..aeb2c67317e 100644 --- a/tensorflow/contrib/reduce_slice_ops/BUILD +++ b/tensorflow/contrib/reduce_slice_ops/BUILD @@ -1,7 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_kernel_tests_linkstatic") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_kernel_tests_linkstatic") package( licenses = ["notice"], # Apache 2.0 diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 4f8186c7394..78ea6374220 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -227,9 +227,6 @@ def _block_lstm(seq_len_max, # pylint: enable=invalid-name -_lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] - - @ops.RegisterGradient("LSTMBlockCell") def _LSTMBlockCellGrad(op, *grad): """Gradient for LSTMBlockCell.""" @@ -247,7 +244,7 @@ def _LSTMBlockCellGrad(op, *grad): if cell_size is None: raise ValueError("cell_size from `cs_prev` should not be None.") - (cs_prev_grad, dicfo, wci_grad, wcf_grad, + (cs_prev_grad, dgates, wci_grad, wcf_grad, wco_grad) = gen_rnn_ops.lstm_block_cell_grad( x=x, cs_prev=cs_prev, @@ -267,8 +264,8 @@ def _LSTMBlockCellGrad(op, *grad): h_grad=h_grad, use_peephole=op.get_attr("use_peephole")) - # Backprop from dicfo to xh. - xh_grad = math_ops.matmul(dicfo, w, transpose_b=True) + # Backprop from dgates to xh. + xh_grad = math_ops.matmul(dgates, w, transpose_b=True) x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size)) x_grad.get_shape().merge_with(x.get_shape()) @@ -277,13 +274,13 @@ def _LSTMBlockCellGrad(op, *grad): (batch_size, cell_size)) h_prev_grad.get_shape().merge_with(h_prev.get_shape()) - # Backprop from dicfo to w. + # Backprop from dgates to w. xh = array_ops.concat([x, h_prev], 1) - w_grad = math_ops.matmul(xh, dicfo, transpose_a=True) + w_grad = math_ops.matmul(xh, dgates, transpose_a=True) w_grad.get_shape().merge_with(w.get_shape()) - # Backprop from dicfo to b. - b_grad = nn_ops.bias_add_grad(dicfo) + # Backprop from dgates to b. + b_grad = nn_ops.bias_add_grad(dgates) b_grad.get_shape().merge_with(b.get_shape()) return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 75710ea4190..c0939c84c44 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1948,7 +1948,9 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ - super(PhasedLSTMCell, self).__init__(_reuse=reuse) + # We pass autocast=False because this layer can accept inputs of different + # dtypes, so we do not want to automatically cast them to the same dtype. + super(PhasedLSTMCell, self).__init__(_reuse=reuse, autocast=False) self._num_units = num_units self._use_peepholes = use_peepholes self._leak = leak diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD index a037be78387..f092af17a90 100644 --- a/tensorflow/contrib/rpc/BUILD +++ b/tensorflow/contrib/rpc/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static") package( default_visibility = ["//visibility:public"], diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD index 47413aa8692..db197d10cd8 100644 --- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD +++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD @@ -1,7 +1,7 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") # Placeholder for loading internal BUILD rule. package( diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index 6d8c50177d4..3f9400a6748 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -251,6 +251,7 @@ cuda_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + xla_enable_strict_auto_jit = False, ) cuda_py_test( diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py index 66a464dc218..824c8dad43d 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_v2_test.py @@ -149,7 +149,8 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): x_test = np.random.randint(vocab, size=(self.batch, self.timestep)) y = np.random.randn(self.batch, self.timestep) model = keras.models.Model([inputs, query, state], score) - model.compile("rmsprop", "mse") + # TODO(b/138592586): Run with single-execution-path + model.compile("rmsprop", "mse", experimental_run_tf_function=False) model.fit([x, self.query, self.state], (y, y)) y_ref = model.predict_on_batch([x_test, self.query, self.state]) @@ -159,6 +160,9 @@ class AttentionMechanismTest(test.TestCase, parameterized.TestCase): config, custom_objects={attention_cls.__name__: attention_cls}) loaded_model.set_weights(weights) + # TODO(b/138592586): Run with single-execution-path + loaded_model.compile("rmsprop", "mse", experimental_run_tf_function=False) + y = loaded_model.predict_on_batch([x_test, self.query, self.state]) self.assertAllClose(y_ref, y) @@ -405,11 +409,13 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): memory_sequence_length=self.encoder_sequence_length, normalize=True, dtype=dtype) - cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") - cell = wrapper.AttentionWrapper(cell, attention_mechanism) + cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid", + dtype=dtype) + cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype) sampler = sampler_py.TrainingSampler() - my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler, + dtype=dtype) final_outputs, final_state, _ = my_decoder( decoder_inputs, @@ -432,11 +438,13 @@ class AttentionWrapperV2Test(test.TestCase, parameterized.TestCase): scale=True, dtype=dtype, ) - cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid") - cell = wrapper.AttentionWrapper(cell, attention_mechanism) + cell = keras.layers.LSTMCell(self.units, recurrent_activation="sigmoid", + dtype=dtype) + cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype) sampler = sampler_py.TrainingSampler() - my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler) + my_decoder = basic_decoder.BasicDecoderV2(cell=cell, sampler=sampler, + dtype=dtype) final_outputs, final_state, _ = my_decoder( decoder_inputs, diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 6360d1cfdc1..343e5f4be69 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -407,8 +407,8 @@ class TestLargeBeamStep(test.TestCase): log_prob_neg_inf = array_ops.ones( [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf - log_probs = array_ops.where(log_prob_mask, log_prob_zeros, - log_prob_neg_inf) + log_probs = array_ops.where_v2(log_prob_mask, log_prob_zeros, + log_prob_neg_inf) return log_probs log_probs = get_probs() diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index a9215e88000..0e19d1e3205 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -2147,7 +2147,8 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): initial_cell_state=None, name=None, attention_layer=None, - attention_fn=None): + attention_fn=None, + dtype=None): """Construct the `AttentionWrapper`. **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in @@ -2224,6 +2225,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): (attention_mechanism, cell_output, attention_state, attention_layer) and outputs (attention, alignments, next_attention_state). If provided, the attention_layer_size should be the size of the outputs of attention_fn. + dtype: The cell dtype Raises: TypeError: `attention_layer_size` is not None and (`attention_mechanism` @@ -2232,7 +2234,7 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): is a list, and its length does not match that of `attention_layer_size`; if `attention_layer_size` and `attention_layer` are set simultaneously. """ - super(AttentionWrapper, self).__init__(name=name) + super(AttentionWrapper, self).__init__(name=name, dtype=dtype) rnn_cell_impl.assert_like_rnncell("cell", cell) if isinstance(attention_mechanism, (list, tuple)): self._is_multi = True diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 5e4f5f53cd7..737d6866283 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -10,7 +10,7 @@ load( "py_test", "tf_cc_test", ) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") # Placeholder: load("//tensorflow:tensorflow.bzl", "tf_portable_proto_lib") package( diff --git a/tensorflow/contrib/session_bundle/bundle_shim_test.cc b/tensorflow/contrib/session_bundle/bundle_shim_test.cc index 815beb73a02..121fc2239dd 100644 --- a/tensorflow/contrib/session_bundle/bundle_shim_test.cc +++ b/tensorflow/contrib/session_bundle/bundle_shim_test.cc @@ -47,11 +47,11 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle, const string& input_tensor_name, const string& output_tensor_name) { // Validate the half plus two behavior. - std::vector serialized_examples; + std::vector serialized_examples; for (float x : {0, 1, 2, 3}) { serialized_examples.push_back(MakeSerializedExample(x)); } - Tensor input = test::AsTensor(serialized_examples, TensorShape({4})); + Tensor input = test::AsTensor(serialized_examples, TensorShape({4})); std::vector outputs; TF_ASSERT_OK(saved_model_bundle.session->Run( diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc index a690d9b129a..996e4ce0b80 100644 --- a/tensorflow/contrib/session_bundle/session_bundle.cc +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -72,7 +72,7 @@ Status GetMetaGraphDefFromExport(const StringPiece export_dir, // Creates a string tensor. Tensor CreateStringTensor(const string& value) { Tensor tensor(DT_STRING, TensorShape({})); - tensor.scalar()() = value; + tensor.scalar()() = value; return tensor; } diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc index 9e4b1c72195..108806e3328 100644 --- a/tensorflow/contrib/session_bundle/session_bundle_test.cc +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -97,11 +97,11 @@ void CheckRegressionSignature(const Signatures& signatures, const string output_name = regression_signature.output().tensor_name(); // Validate the half plus two behavior. - std::vector serialized_examples; + std::vector serialized_examples; for (float x : {0, 1, 2, 3}) { serialized_examples.push_back(MakeSerializedExample(x)); } - Tensor input = test::AsTensor(serialized_examples, TensorShape({4})); + Tensor input = test::AsTensor(serialized_examples, TensorShape({4})); std::vector outputs; TF_ASSERT_OK( bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs)); @@ -146,13 +146,13 @@ void CheckSessionBundle(const string& export_path, ASSERT_EQ(2, path_outputs.size()); // Validate the two asset file tensors are set by the init_op and include the // base_path and asset directory. - test::ExpectTensorEqual( - test::AsTensor({io::JoinPath(asset_path, "hello1.txt")}, - TensorShape({})), + test::ExpectTensorEqual( + test::AsTensor({io::JoinPath(asset_path, "hello1.txt")}, + TensorShape({})), path_outputs[0]); - test::ExpectTensorEqual( - test::AsTensor({io::JoinPath(asset_path, "hello2.txt")}, - TensorShape({})), + test::ExpectTensorEqual( + test::AsTensor({io::JoinPath(asset_path, "hello2.txt")}, + TensorShape({})), path_outputs[1]); Signatures signatures; diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py index c457d44e07b..dec5cbc6d22 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader_test.py @@ -144,14 +144,16 @@ class ParallelReaderTest(test.TestCase): capacity=55, min_after_dequeue=28, dtypes=[dtypes_lib.string, dtypes_lib.string], - shapes=[tensor_shape.scalar(), tensor_shape.scalar()]) + shapes=[tensor_shape.TensorShape([]), + tensor_shape.TensorShape([])]) self._verify_read_up_to_out(shared_queue) def testReadUpToFromFIFOQueue(self): shared_queue = data_flow_ops.FIFOQueue( capacity=99, dtypes=[dtypes_lib.string, dtypes_lib.string], - shapes=[tensor_shape.scalar(), tensor_shape.scalar()]) + shapes=[tensor_shape.TensorShape([]), + tensor_shape.TensorShape([])]) self._verify_read_up_to_out(shared_queue) diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 8fca63292e6..381d5941e5a 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -443,11 +443,9 @@ class Image(ItemHandler): """Decodes a raw image.""" return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) - pred_fn_pairs = { - math_ops.logical_or( - math_ops.equal(image_format, 'raw'), - math_ops.equal(image_format, 'RAW')): decode_raw, - } + pred_fn_pairs = [(math_ops.logical_or( + math_ops.equal(image_format, 'raw'), + math_ops.equal(image_format, 'RAW')), decode_raw)] image = control_flow_ops.case( pred_fn_pairs, default=check_jpeg, exclusive=True) diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py index 5db4fe02b8e..aefc07696b9 100644 --- a/tensorflow/contrib/slim/python/slim/learning_test.py +++ b/tensorflow/contrib/slim/python/slim/learning_test.py @@ -197,7 +197,8 @@ class MultiplyGradientsTest(test.TestCase): gradient = constant_op.constant(self._grad_vec, dtype=dtypes.float32) variable = variables_lib.Variable(array_ops.zeros_like(gradient)) multiplier_flag = variables_lib.Variable(True) - tensor_multiplier = array_ops.where(multiplier_flag, self._multiplier, 1.0) + tensor_multiplier = array_ops.where_v2(multiplier_flag, self._multiplier, + 1.0) grad_to_var = (gradient, variable) gradient_multipliers = {variable: tensor_multiplier} diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index 69cbb120ef8..7bb73f5a415 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -9,7 +9,7 @@ load( "tf_py_test", ) load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_kernel_tests_linkstatic", ) diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index fdd7e1e1ee3..ca246f912be 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -8,7 +8,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static") package( default_visibility = ["//visibility:public"], diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py index 926e4dda916..a8a5b574691 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py +++ b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py @@ -17,8 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections - from tensorflow.contrib import layers from tensorflow.contrib.framework.python.ops import variables as framework_variables @@ -29,6 +27,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.training import adagrad +from tensorflow.python.util.compat import collections_abc class HybridModel(object): @@ -66,7 +65,7 @@ class HybridModel(object): # If this is a collection of layers, return the mean of their inference # results. - if isinstance(layer, collections.Iterable): + if isinstance(layer, collections_abc.Iterable): return math_ops.reduce_mean( array_ops.stack([l.inference_graph(data) for l in layer]), 0) # If this is a single layer, return its inference result. diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 94650fe108b..5f997c2fba0 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -52,7 +52,7 @@ class CreateTreeVariableOp : public OpKernel { auto* result = new DecisionTreeResource(param_proto_); if (!ParseProtoUnlimited(result->mutable_decision_tree(), - tree_config_t->scalar()())) { + tree_config_t->scalar()())) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument("Unable to parse tree config.")); @@ -85,7 +85,7 @@ class TreeSerializeOp : public OpKernel { Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(0, TensorShape(), &output_config_t)); - output_config_t->scalar()() = + output_config_t->scalar()() = decision_tree_resource->decision_tree().SerializeAsString(); } }; @@ -116,7 +116,7 @@ class TreeDeserializeOp : public OpKernel { decision_trees::Model* config = decision_tree_resource->mutable_decision_tree(); OP_REQUIRES(context, - ParseProtoUnlimited(config, tree_config_t->scalar()()), + ParseProtoUnlimited(config, tree_config_t->scalar()()), errors::InvalidArgument("Unable to parse tree config.")); decision_tree_resource->MaybeInitialize(); } @@ -224,7 +224,7 @@ class TreePredictionsV4Op : public OpKernel { : 0); OP_REQUIRES_OK(context, context->allocate_output(1, output_paths_shape, &output_tree_paths)); - auto out_paths = output_tree_paths->unaligned_flat(); + auto out_paths = output_tree_paths->unaligned_flat(); // TODO(gilberth): If this slows down inference too much, consider having // a filter that only serializes paths for the predicted label that we're diff --git a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc index b21a9179777..fcea240dee9 100644 --- a/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/reinterpret_string_to_float_op.cc @@ -38,7 +38,7 @@ float Convert(const string& in) { void Evaluate(const Tensor& input_data, Tensor output_data, int32 start, int32 end) { auto out_data = output_data.unaligned_flat(); - const auto in_data = input_data.unaligned_flat(); + const auto in_data = input_data.unaligned_flat(); for (int32 i = start; i < end; ++i) { out_data(i) = Convert(in_data(i)); diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index ede6e1abc9f..e4693cf68dc 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -56,7 +56,7 @@ class CreateFertileStatsVariableOp : public OpKernel { errors::InvalidArgument("Stats config must be a scalar.")); auto* result = new FertileStatsResource(param_proto_); FertileStats stats; - if (!ParseProtoUnlimited(&stats, stats_config_t->scalar()())) { + if (!ParseProtoUnlimited(&stats, stats_config_t->scalar()())) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument("Unable to parse stats config.")); @@ -98,7 +98,7 @@ class FertileStatsSerializeOp : public OpKernel { FertileStats stats; fertile_stats_resource->PackToProto(&stats); - output_config_t->scalar()() = stats.SerializeAsString(); + output_config_t->scalar()() = stats.SerializeAsString(); } private: @@ -128,9 +128,10 @@ class FertileStatsDeserializeOp : public OpKernel { // Deallocate all the previous objects on the resource. fertile_stats_resource->Reset(); FertileStats stats; - OP_REQUIRES(context, - ParseProtoUnlimited(&stats, stats_config_t->scalar()()), - errors::InvalidArgument("Unable to parse stats config.")); + OP_REQUIRES( + context, + ParseProtoUnlimited(&stats, stats_config_t->scalar()()), + errors::InvalidArgument("Unable to parse stats config.")); fertile_stats_resource->ExtractFromProto(stats); fertile_stats_resource->MaybeInitialize(); diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD index d205b255402..71bfa5bbb8c 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD +++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD @@ -1,7 +1,7 @@ # TensorFlow code for training random forests. load("//tensorflow:tensorflow.bzl", "tf_cc_test") -load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:default/build_config_root.bzl", "if_static") package( default_visibility = ["//visibility:public"], diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc index f4a7058ddb8..417cb6f7420 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc @@ -103,7 +103,7 @@ float CandidateGraphRunner::SplitScore() { void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) { std::vector outputs; RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs); - ParseProtoUnlimited(node, outputs[0].unaligned_flat()(0)); + ParseProtoUnlimited(node, outputs[0].unaligned_flat()(0)); const auto& oblique = split_.inequality_left_child_test().oblique(); auto* new_split = node->mutable_inequality_left_child_test()->mutable_oblique(); diff --git a/tensorflow/contrib/tensor_forest/proto/BUILD b/tensorflow/contrib/tensor_forest/proto/BUILD index efa696fffe6..702dbed7fc0 100644 --- a/tensorflow/contrib/tensor_forest/proto/BUILD +++ b/tensorflow/contrib/tensor_forest/proto/BUILD @@ -1,4 +1,4 @@ -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") package( default_visibility = ["//visibility:public"], diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index df10997d633..623e52ca0b6 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -461,7 +461,7 @@ class RandomForestGraphs(object): mask = math_ops.less( r, array_ops.ones_like(r) * self.params.bagging_fraction) - gather_indices = array_ops.squeeze(array_ops.where(mask), axis=[1]) + gather_indices = array_ops.squeeze(array_ops.where_v2(mask), axis=[1]) # TODO(thomaswc): Calculate out-of-bag data and labels, and store # them for use in calculating statistics later. tree_data = array_ops.gather(processed_dense_features, gather_indices) diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index e5efe4b16d8..801fe67b069 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -2,7 +2,7 @@ # TensorBoard module containing volatile or experimental code. # For platform specific build config -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") load("//tensorflow:tensorflow.bzl", "py_test") package( diff --git a/tensorflow/contrib/text/kernels/skip_gram_kernels.cc b/tensorflow/contrib/text/kernels/skip_gram_kernels.cc index 3cd0b5f72b5..198388599e8 100644 --- a/tensorflow/contrib/text/kernels/skip_gram_kernels.cc +++ b/tensorflow/contrib/text/kernels/skip_gram_kernels.cc @@ -128,7 +128,7 @@ class SkipGramGenerateCandidatesOp : public OpKernel { .TypeConstraint("T"), \ SkipGramGenerateCandidatesOp) -REGISTER_KERNEL(string); +REGISTER_KERNEL(tstring); REGISTER_KERNEL(int64); REGISTER_KERNEL(int32); REGISTER_KERNEL(int16); diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 94a51abb762..017d08f5f60 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -1,7 +1,6 @@ # Description: # contains parts of TensorFlow that are experimental or unstable and which are not supported. - -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library") load("//tensorflow:tensorflow.bzl", "py_test") package( @@ -174,7 +173,7 @@ py_test( py_test( name = "sampling_ops_test", - size = "small", + size = "medium", srcs = ["python/training/sampling_ops_test.py"], python_version = "PY2", srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 10f3f88f3eb..fddcf1e4f62 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -212,7 +212,7 @@ def bucket(tensors, else static_batch_size) bucket_shapes = [ - tensor_shape.vector(maybe_static_batch_size).concatenate(s) + tensor_shape.TensorShape([maybe_static_batch_size]).concatenate(s) for s in bucket_queues[0].shapes ] # top_queue is a PaddingFIFOQueue even if the bucket queues are regular FIFO @@ -222,7 +222,7 @@ def bucket(tensors, top_queue = data_flow_ops.PaddingFIFOQueue( capacity=capacity, dtypes=[dtypes.int32] + types, - shapes=[tensor_shape.scalar()] + bucket_shapes, + shapes=[tensor_shape.TensorShape([])] + bucket_shapes, shared_name=shared_name, name="top_queue") @@ -399,11 +399,11 @@ def bucket_by_sequence_length(input_length, conditions_c = math_ops.logical_and( math_ops.less_equal(buckets_min, input_length), math_ops.less(input_length, buckets_max)) - which_bucket = math_ops.reduce_min(array_ops.where(conditions_c)) + which_bucket = math_ops.reduce_min(array_ops.where_v2(conditions_c)) which_bucket = math_ops.cast(which_bucket, dtypes.int32) if shapes is not None: - shapes = [tensor_shape.scalar()] + shapes + shapes = [tensor_shape.TensorShape([])] + shapes _, dequeued = bucket( tensors=[input_length] + tensor_list, diff --git a/tensorflow/contrib/training/python/training/sampling_ops.py b/tensorflow/contrib/training/python/training/sampling_ops.py index 849b77d6095..257cc4fce21 100644 --- a/tensorflow/contrib/training/python/training/sampling_ops.py +++ b/tensorflow/contrib/training/python/training/sampling_ops.py @@ -417,7 +417,7 @@ def _calculate_acceptance_probabilities(init_probs, target_probs): ratio_l = target_probs / init_probs # Replace NaNs with 0s. - ratio_l = array_ops.where( + ratio_l = array_ops.where_v2( math_ops.is_nan(ratio_l), array_ops.zeros_like(ratio_l), ratio_l) # Calculate list of acceptance probabilities. diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index e44c4f8c0ef..02baf4e071e 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -594,7 +594,7 @@ class NextQueuedSequenceBatch(object): # unless we explicitly tie them to CPU. with ops.colocate_with(self._state_saver._capacity_queue.queue_ref): indices_where_not_done = array_ops.reshape( - array_ops.where( + array_ops.where_v2( math_ops.logical_not(self._state_saver._sequence_is_done)), [-1]) keeping_next_key = array_ops.gather( diff --git a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc index 096ca0f0cf9..1207a338f39 100644 --- a/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc +++ b/tensorflow/contrib/util/convert_graphdef_memmapped_format_test.cc @@ -98,10 +98,10 @@ TEST(ConvertGraphdefMemmappedFormatTest, NotSupportedTypesConvert) { constexpr int kTensorHeight = 100; const TensorShape kTestTensorShape({kTensorWidth, kTensorHeight}); Tensor test_tensor1(DT_STRING, kTestTensorShape); - test::FillFn(&test_tensor1, [](int) -> string { return "ABC"; }); + test::FillFn(&test_tensor1, [](int) -> string { return "ABC"; }); Tensor test_tensor2(DT_STRING, kTestTensorShape); - test::FillFn(&test_tensor2, [](int) -> string { return "XYZ"; }); + test::FillFn(&test_tensor2, [](int) -> string { return "XYZ"; }); auto root = Scope::NewRootScope().ExitOnError(); Output m = ops::Add(root, test_tensor1, test_tensor2); const string result_name = m.node()->name(); diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index fac783b7d5f..b0035269d40 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -5,7 +5,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_library") # For platform specific build config load( - "//tensorflow/core:platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_proto_library_cc", ) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ca158b3486b..aa607fa8257 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -7,7 +7,7 @@ # ":protos_all_cc" - exports all core TensorFlow protos # ":protos_all_py" - py_proto_library version (Google-internal) # ":lib" - exports the public non-test headers for: -# platform/: Platform-specific code and external dependencies +# //third_party/tensorflow/core/platform:: Platform-specific code and external dependencies # lib/: Low-level libraries that are not TensorFlow-specific # ":test" - test equivalent of ":lib". # This is currently public, but may be made internal in the @@ -104,7 +104,7 @@ load("//tensorflow:tensorflow.bzl", "tf_version_info_genrule") # For platform specific build config load( - ":platform/default/build_config.bzl", + "//tensorflow/core/platform:default/build_config.bzl", "tf_additional_all_protos", "tf_additional_cloud_kernel_deps", "tf_additional_cloud_op_deps", @@ -112,36 +112,25 @@ load( "tf_additional_cupti_wrapper_deps", "tf_additional_device_tracer_cuda_deps", "tf_additional_device_tracer_deps", - "tf_additional_device_tracer_srcs", "tf_additional_device_tracer_test_flags", "tf_additional_gdr_lib_defines", "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", - "tf_additional_lib_hdrs", - "tf_additional_lib_srcs", "tf_additional_libdevice_data", "tf_additional_libdevice_deps", - "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_monitoring_hdrs", - "tf_additional_monitoring_srcs", - "tf_additional_mpi_lib_defines", "tf_additional_numa_copts", "tf_additional_numa_deps", "tf_additional_numa_lib_defines", - "tf_additional_proto_hdrs", - "tf_additional_proto_srcs", "tf_additional_test_deps", - "tf_additional_test_srcs", "tf_additional_verbs_lib_defines", "tf_grpc_service_all", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", "tf_lib_proto_compiler_deps", "tf_lib_proto_parsing_deps", - "tf_platform_hdrs", - "tf_platform_srcs", "tf_proto_library", "tf_proto_library_cc", "tf_protos_all", @@ -151,10 +140,11 @@ load( "tf_pyclif_proto_library", ) load( - ":platform/default/build_config_root.bzl", + "//tensorflow/core/platform:default/build_config_root.bzl", "if_dynamic_kernels", "if_static", "tf_cuda_tests_tags", + "tf_gpu_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") @@ -179,6 +169,7 @@ package_group( name = "dependency_whitelist", packages = [ "//learning/freud/topic_models/tensorflow/...", + "//perftools/accelerators/xprof/api/...", "//quality/webanswers/brain/tokenization/custom_tf_ops/kernels/...", ], ) @@ -242,7 +233,7 @@ COMMON_PROTO_SRCS = [ ] ERROR_CODES_PROTO_SRCS = [ - "lib/core/error_codes.proto", + "//tensorflow/core/lib/core:error_codes.proto", ] # LINT.ThenChange(//tensorflow/core/android_proto_config.asciipb) @@ -277,11 +268,27 @@ tf_proto_library( make_default_target_header_only = True, protodeps = [ ":protos_all_proto", - ":error_codes_proto", + "//tensorflow/core/lib/core:error_codes_proto", ], visibility = ["//visibility:public"], ) +tf_generate_proto_text_sources( + name = "attr_value_proto_text", + srcs = [ + "framework/attr_value.proto", + "framework/resource_handle.proto", + "framework/tensor.proto", + "framework/tensor_shape.proto", + "framework/types.proto", + ], + srcs_relative_dir = "tensorflow/core/", + deps = [ + ":lib_internal", + ":protos_all_proto_cc", + ], +) + tf_jspb_proto_library( name = "protos_all_jspb_proto", visibility = ["//visibility:public"], @@ -321,45 +328,34 @@ tf_proto_library( visibility = ["//visibility:public"], ) -# Minimal lib to detect platform -cc_library( - name = "lib_platform", - hdrs = [ - "platform/platform.h", - ], -) - filegroup( name = "platform_base_hdrs", srcs = [ - "platform/byte_order.h", - "platform/cord.h", - "platform/env_time.h", - "platform/logging.h", - "platform/macros.h", - "platform/platform_strings.h", - "platform/types.h", + "//tensorflow/core/platform:byte_order.h", + "//tensorflow/core/platform:cord.h", + "//tensorflow/core/platform:env_time.h", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:platform_strings.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", ], visibility = ["//visibility:private"], ) cc_library( name = "platform_base", - srcs = tf_platform_hdrs([ - "integral_types.h", - "logging.h", - ]) + tf_platform_srcs([ - "logging.cc", - "env_time.cc", - ]) + [ - "platform/env_time.cc", - ], hdrs = [":platform_base_hdrs"], copts = tf_copts(), tags = ["avoid_dep"], visibility = [":__subpackages__"], deps = [ - ":lib_platform", + "//tensorflow/core/platform", + "//tensorflow/core/platform:byte_order", + "//tensorflow/core/platform:env_time", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", "//tensorflow/core/platform/default/build_config:base", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", @@ -379,13 +375,13 @@ cc_library( filegroup( name = "platform_port_hdrs", srcs = [ - "platform/cpu_info.h", - "platform/dynamic_annotations.h", - "platform/init_main.h", - "platform/mem.h", - "platform/mutex.h", - "platform/numa.h", - "platform/thread_annotations.h", + "//tensorflow/core/platform:cpu_info.h", + "//tensorflow/core/platform:dynamic_annotations.h", + "//tensorflow/core/platform:init_main.h", + "//tensorflow/core/platform:mem.h", + "//tensorflow/core/platform:mutex.h", + "//tensorflow/core/platform:numa.h", + "//tensorflow/core/platform:thread_annotations.h", ], visibility = ["//visibility:private"], ) @@ -394,24 +390,18 @@ filegroup( filegroup( name = "platform_port_internal_hdrs", srcs = [ - "platform/demangle.h", - "platform/host_info.h", - "platform/snappy.h", + "//tensorflow/core/platform:demangle.h", + "//tensorflow/core/platform:host_info.h", + "//tensorflow/core/platform:snappy.h", ], visibility = ["//visibility:private"], ) cc_library( name = "platform_port", - srcs = tf_platform_hdrs([ - "cpu_info.h", - "dynamic_annotations.h", - "thread_annotations.h", - "mutex.h", - ]) + tf_platform_srcs([ - "port.cc", - ]) + [ - "platform/cpu_info.cc", + srcs = [ + "//tensorflow/core/platform:cpu_info.cc", + "//tensorflow/core/platform:legacy_platform_port_srcs", ], hdrs = [ ":platform_port_hdrs", @@ -420,7 +410,7 @@ cc_library( copts = tf_copts() + tf_additional_numa_copts(), visibility = [":__subpackages__"], deps = [ - ":lib_platform", + "//tensorflow/core/platform:platform", ":platform_base", "@com_google_absl//absl/base", "//tensorflow/core/platform/default/build_config:port", @@ -431,7 +421,7 @@ cc_library( filegroup( name = "platform_protobuf_hdrs", srcs = [ - "platform/protobuf.h", + "//tensorflow/core/platform:protobuf.h", ], visibility = ["//visibility:private"], ) @@ -440,19 +430,18 @@ filegroup( filegroup( name = "platform_protobuf_internal_hdrs", srcs = [ - "platform/protobuf_internal.h", + "//tensorflow/core/platform:protobuf_internal.h", ], visibility = ["//visibility:private"], ) cc_library( name = "platform_protobuf", - srcs = tf_platform_hdrs([ - "protobuf.h", - ]) + [ - "platform/protobuf.cc", - "platform/protobuf_util.cc", - "lib/core/status.h", + srcs = [ + "//tensorflow/core/lib/core:legacy_lib_core_status_header", + "//tensorflow/core/platform:protobuf.cc", + "//tensorflow/core/platform:protobuf.h", + "//tensorflow/core/platform:protobuf_util.cc", ], hdrs = [ ":platform_protobuf_hdrs", @@ -461,9 +450,9 @@ cc_library( copts = tf_copts(), visibility = [":__subpackages__"], deps = [ - ":lib_platform", ":platform_base", ":platform_port", + "//tensorflow/core/platform", "//tensorflow/core/platform/default/build_config:protobuf", "@com_google_protobuf//:protobuf", ], @@ -473,7 +462,7 @@ cc_library( name = "grpc_services", srcs = [], hdrs = [ - "platform/grpc_services.h", + "//tensorflow/core/platform:grpc_services.h", ], copts = tf_copts(), visibility = ["//visibility:public"], @@ -482,8 +471,8 @@ cc_library( cc_library( name = "human_readable_json", - srcs = tf_platform_srcs(["human_readable_json.cc"]), - hdrs = ["platform/human_readable_json.h"], + srcs = ["//tensorflow/core/platform:legacy_human_readable_json_src"], + hdrs = ["//tensorflow/core/platform:human_readable_json.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ @@ -494,8 +483,8 @@ cc_library( cc_library( name = "logger", - srcs = ["platform/logger.cc"], - hdrs = ["platform/logger.h"], + srcs = ["//tensorflow/core/platform:logger.cc"], + hdrs = ["//tensorflow/core/platform:logger.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = [ @@ -509,9 +498,9 @@ cc_library( filegroup( name = "platform_env_hdrs", srcs = [ - "platform/env.h", - "platform/file_statistics.h", - "platform/file_system.h", + "//tensorflow/core/platform:env.h", + "//tensorflow/core/platform:file_statistics.h", + "//tensorflow/core/platform:file_system.h", ], visibility = ["//visibility:private"], ) @@ -520,21 +509,17 @@ filegroup( filegroup( name = "platform_env_internal_hdrs", srcs = [ - "platform/load_library.h", + "//tensorflow/core/platform:load_library.h", ], visibility = ["//visibility:private"], ) cc_library( name = "platform_env", - srcs = tf_platform_srcs([ - "env.cc", - "load_library.cc", - ]) + tf_platform_hdrs([ - "wide_char.h", - ]) + [ - "platform/env.cc", - "platform/file_system.cc", + srcs = [ + "//tensorflow/core/platform:env.cc", + "//tensorflow/core/platform:file_system.cc", + "//tensorflow/core/platform:legacy_platform_env_srcs", ], hdrs = [ ":platform_env_hdrs", @@ -546,13 +531,13 @@ cc_library( "//tensorflow/c:__subpackages__", ], deps = [ - ":error_codes_proto_cc", ":lib", ":lib_internal", - ":lib_platform", ":platform_base", ":platform_port", ":platform_protobuf", + "//tensorflow/core/lib/core:error_codes_proto_cc", + "//tensorflow/core/platform", "//tensorflow/core/platform/default/build_config:env", "//tensorflow/core/platform/default/build_config:port", ], @@ -561,19 +546,17 @@ cc_library( filegroup( name = "platform_file_system_hdrs", srcs = [ - "platform/file_system_helper.h", - "platform/null_file_system.h", + "//tensorflow/core/platform:file_system_helper.h", + "//tensorflow/core/platform:null_file_system.h", ], visibility = ["//visibility:private"], ) cc_library( name = "platform_file_system", - srcs = tf_platform_srcs([ - ]) + tf_platform_hdrs([ - "windows_file_system.h", - ]) + [ - "platform/file_system_helper.cc", + srcs = [ + "//tensorflow/core/platform:file_system_helper.cc", + "//tensorflow/core/platform:legacy_file_system_hdrs", ], hdrs = [ ":platform_file_system_hdrs", @@ -582,83 +565,71 @@ cc_library( visibility = [":__subpackages__"], deps = [ ":lib", - ":lib_platform", ":platform_env", + "//tensorflow/core/platform", ], ) -cc_library( - name = "platform_strings", - srcs = tf_platform_srcs([ - "platform/platform_strings.cc", - "platform/platform_strings_computed.h", - ]), - hdrs = [ - "platform/platform_strings.h", - ], - visibility = [":__subpackages__"], - deps = [":lib"], -) - filegroup( name = "platform_other_hdrs", srcs = [ - "platform/abi.h", - "platform/context.h", - "platform/cpu_feature_guard.h", - "platform/error.h", - "platform/fingerprint.h", - "platform/monitoring.h", - "platform/net.h", - "platform/notification.h", - "platform/prefetch.h", - "platform/profile_utils/android_armv7a_cpu_utils_helper.h", - "platform/profile_utils/clock_cycle_profiler.h", - "platform/profile_utils/cpu_utils.h", - "platform/profile_utils/i_cpu_utils_helper.h", - "platform/stacktrace.h", - "platform/stacktrace_handler.h", - "platform/strong_hash.h", - "platform/subprocess.h", + "//tensorflow/core/platform:abi.h", + "//tensorflow/core/platform:context.h", + "//tensorflow/core/platform:cpu_feature_guard.h", + "//tensorflow/core/platform:error.h", + "//tensorflow/core/platform:fingerprint.h", + "//tensorflow/core/platform:monitoring.h", + "//tensorflow/core/platform:net.h", + "//tensorflow/core/platform:notification.h", + "//tensorflow/core/platform:prefetch.h", + "//tensorflow/core/platform:profile_utils/android_armv7a_cpu_utils_helper.h", + "//tensorflow/core/platform:profile_utils/clock_cycle_profiler.h", + "//tensorflow/core/platform:profile_utils/cpu_utils.h", + "//tensorflow/core/platform:profile_utils/i_cpu_utils_helper.h", + "//tensorflow/core/platform:stacktrace.h", + "//tensorflow/core/platform:stacktrace_handler.h", + "//tensorflow/core/platform:strong_hash.h", + "//tensorflow/core/platform:subprocess.h", ] + tf_additional_monitoring_hdrs(), visibility = ["//visibility:private"], ) +tf_cc_test( + name = "platform_unbounded_work_queue_test", + srcs = ["//tensorflow/core/platform:unbounded_work_queue_test.cc"], + deps = [ + ":framework", + ":lib", + ":lib_internal", + ":lib_test_internal", + ":test", + ":test_main", + "@com_google_absl//absl/memory", + ], +) + # Headers that are not exported as part of ":lib". filegroup( name = "platform_other_internal_hdrs", srcs = [ - "platform/denormal.h", - "platform/setround.h", - "platform/tracing.h", + "//tensorflow/core/platform:denormal.h", + "//tensorflow/core/platform:setround.h", + "//tensorflow/core/platform:tracing.h", ], visibility = ["//visibility:private"], ) cc_library( name = "platform_other", - srcs = tf_platform_srcs([ - "subprocess.cc", - "net.cc", - "tracing.cc", - ]) + tf_platform_hdrs([ - "tracing.h", - "error.h", - "context.h", - "fingerprint.h", - "notification.h", - "stacktrace.h", - "strong_hash.h", - "subprocess.h", - "tracing_impl.h", - ]) + [ - "platform/cpu_feature_guard.cc", - "platform/setround.cc", - "platform/tracing.cc", - "platform/denormal.cc", - "platform/profile_utils/android_armv7a_cpu_utils_helper.cc", - "platform/profile_utils/clock_cycle_profiler.cc", - "platform/profile_utils/cpu_utils.cc", + srcs = [ + "//tensorflow/core/platform:cpu_feature_guard.cc", + "//tensorflow/core/platform:denormal.cc", + "//tensorflow/core/platform:legacy_platform_other_srcs", + "//tensorflow/core/platform:profile_utils/android_armv7a_cpu_utils_helper.cc", + "//tensorflow/core/platform:profile_utils/clock_cycle_profiler.cc", + "//tensorflow/core/platform:profile_utils/cpu_utils.cc", + "//tensorflow/core/platform:setround.cc", + "//tensorflow/core/platform:tracing.cc", ], hdrs = [ ":platform_other_hdrs", @@ -668,11 +639,14 @@ cc_library( visibility = [":__subpackages__"], deps = [ ":lib", - ":lib_platform", ":platform_base", ":platform_env", ":platform_port", ":platform_protobuf", + "//tensorflow/core/platform", + "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:annotation", + "//tensorflow/core/platform:stacktrace", "//tensorflow/core/platform/default/build_config:other", "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/platform/default/build_config:port", @@ -684,34 +658,42 @@ cc_library( # don't have to depend on lib/platformlib. cc_library( name = "lib_proto_parsing", - srcs = glob(tf_additional_proto_srcs()), + srcs = [ + "//tensorflow/core/platform:protobuf.cc", + ], hdrs = [ - "lib/core/errors.h", - "lib/core/status.h", - "lib/core/stringpiece.h", - "lib/strings/numbers.h", - "lib/strings/strcat.h", - "platform/init_main.h", - "platform/logging.h", - "platform/macros.h", - "platform/platform.h", - "platform/protobuf.h", - "platform/types.h", - "platform/windows/cpu_info.h", - "lib/bfloat16/bfloat16.h", - ] + tf_additional_proto_hdrs(), + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_proto_parsing_headers", + "//tensorflow/core/lib/strings:legacy_lib_proto_parsing_headers", + "//tensorflow/core/platform:init_main.h", + "//tensorflow/core/platform:legacy_proto_hdrs", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:protobuf.h", + "//tensorflow/core/platform:stringpiece.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", + ], copts = tf_copts(), deps = tf_lib_proto_parsing_deps() + [ ":platform_base", "@com_google_absl//absl/strings", "@double_conversion//:double-conversion", + "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/platform:cpu_info", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:platform", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/platform:types", ], ) cc_library( name = "lib_proto_compiler", hdrs = [ - "platform/protobuf_compiler.h", + "//tensorflow/core/platform:protobuf_compiler.h", ], copts = tf_copts(), deps = tf_lib_proto_compiler_deps() + [ @@ -726,26 +708,6 @@ cc_library( cc_library( name = "lib", hdrs = [ - "lib/bfloat16/bfloat16.h", - "lib/core/arena.h", - "lib/core/bitmap.h", - "lib/core/bits.h", - "lib/core/coding.h", - "lib/core/errors.h", - "lib/core/notification.h", - "lib/core/raw_coding.h", - "lib/core/status.h", - "lib/core/stringpiece.h", - "lib/core/threadpool.h", - "lib/core/threadpool_interface.h", - "lib/gtl/array_slice.h", - "lib/gtl/cleanup.h", - "lib/gtl/compactptrset.h", - "lib/gtl/flatmap.h", - "lib/gtl/flatset.h", - "lib/gtl/inlined_vector.h", - "lib/gtl/optional.h", - "lib/gtl/priority_queue_util.h", "lib/hash/crc32c.h", "lib/hash/hash.h", "lib/histogram/histogram.h", @@ -760,32 +722,31 @@ cc_library( "lib/io/table.h", "lib/io/table_builder.h", "lib/io/table_options.h", - "lib/math/math_util.h", "lib/monitoring/collected_metrics.h", "lib/monitoring/collection_registry.h", "lib/monitoring/counter.h", "lib/monitoring/gauge.h", "lib/monitoring/metric_def.h", "lib/monitoring/sampler.h", - "lib/random/distribution_sampler.h", - "lib/random/philox_random.h", - "lib/random/random_distributions.h", - "lib/random/simple_philox.h", - "lib/strings/numbers.h", - "lib/strings/proto_serialization.h", - "lib/strings/str_util.h", - "lib/strings/strcat.h", - "lib/strings/stringprintf.h", ":platform_base_hdrs", ":platform_env_hdrs", ":platform_file_system_hdrs", ":platform_other_hdrs", ":platform_port_hdrs", ":platform_protobuf_hdrs", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_core_headers", + "//tensorflow/core/lib/gtl:legacy_lib_gtl_headers", + "//tensorflow/core/lib/math:math_util.h", + "//tensorflow/core/lib/random:legacy_lib_random_headers", + "//tensorflow/core/lib/strings:legacy_lib_string_headers", ], visibility = ["//visibility:public"], deps = [ ":lib_internal", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/platform:stringprintf", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -797,7 +758,7 @@ cc_library( cc_library( name = "lib_experimental", hdrs = [ - "lib/core/threadpool_options.h", + "//tensorflow/core/lib/core:legacy_lib_core_threadpool_options_header", ], visibility = [ ":experimental_access", @@ -820,45 +781,13 @@ cc_library( ], ) -cc_library( - name = "abi", - srcs = ["platform/abi.cc"], - hdrs = ["platform/abi.h"], - deps = [":platform_base"], -) - -cc_library( - name = "stacktrace", - srcs = glob(["platform/*/stacktrace.h"]), - hdrs = ["platform/stacktrace.h"], - deps = [ - ":abi", - ":lib_platform", - "//tensorflow/core/platform/default/build_config:stacktrace", - ], -) - -cc_library( - name = "stacktrace_handler", - srcs = ["platform/stacktrace_handler.cc"], - hdrs = ["platform/stacktrace_handler.h"], - deps = [ - ":abi", - ":lib_platform", - ":stacktrace", - ], -) - -# Libraries that will eventually be moved into lib/core -# Note that stringpiece_test can't be place here yet, because we are -# required to use tf_cc_test, and that rule will change / into _ +# DEPRECATED: use platform:stringpiece instead. cc_library( name = "core_stringpiece", - hdrs = ["lib/core/stringpiece.h"], + hdrs = ["//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header"], copts = tf_copts(), deps = [ - ":platform_base", - "@com_google_absl//absl/strings", + "//tensorflow/core/platform:stringpiece", ], ) @@ -869,14 +798,15 @@ cc_library( name = "test", testonly = 1, srcs = [ - "platform/test.cc", "util/reporter.cc", - ] + tf_additional_test_srcs(), + "//tensorflow/core/platform:legacy_test_srcs", + "//tensorflow/core/platform:test.cc", + ], hdrs = [ - "lib/core/status_test_util.h", - "platform/test.h", - "platform/test_benchmark.h", "util/reporter.h", + "//tensorflow/core/lib/core:legacy_lib_core_status_test_util_header", + "//tensorflow/core/platform:test.h", + "//tensorflow/core/platform:test_benchmark.h", ], copts = tf_copts(), linkopts = select({ @@ -902,16 +832,16 @@ cc_library( name = "test_lite", testonly = 1, srcs = [ - "platform/test.cc", + "//tensorflow/core/platform:test.cc", ], hdrs = [ - "platform/test.h", - "platform/test_benchmark.h", + "//tensorflow/core/platform:test.h", + "//tensorflow/core/platform:test_benchmark.h", ], copts = tf_copts(), deps = [ - ":lib_platform", ":platform_base", + "//tensorflow/core/platform", "//tensorflow/core/platform/default/build_config:gtest", ], ) @@ -1162,35 +1092,39 @@ cc_library( cc_library( name = "framework_lite", - srcs = tf_additional_minimal_lib_srcs(), + srcs = [ + "//tensorflow/core/platform:legacy_minimal_lib_srcs", + ], hdrs = [ "framework/numeric_types.h", "framework/tensor_types.h", "framework/type_traits.h", - "lib/bfloat16/bfloat16.h", - "platform/byte_order.h", - "platform/default/dynamic_annotations.h", - "platform/default/integral_types.h", - "platform/default/logging.h", - "platform/default/mutex.h", - "platform/default/thread_annotations.h", - "platform/dynamic_annotations.h", - "platform/macros.h", - "platform/mutex.h", - "platform/platform.h", - "platform/prefetch.h", - "platform/protobuf.h", - "platform/thread_annotations.h", - "platform/types.h", - "platform/cpu_info.h", - ] + if_windows(["platform/windows/integral_types.h"]), + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/platform:byte_order.h", + "//tensorflow/core/platform:default/dynamic_annotations.h", + "//tensorflow/core/platform:default/integral_types.h", + "//tensorflow/core/platform:default/logging.h", + "//tensorflow/core/platform:default/mutex.h", + "//tensorflow/core/platform:dynamic_annotations.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:mutex.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:prefetch.h", + "//tensorflow/core/platform:protobuf.h", + "//tensorflow/core/platform:thread_annotations.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", + "//tensorflow/core/platform:cpu_info.h", + ] + if_windows(["//tensorflow/core/platform:windows/integral_types.h"]), visibility = ["//visibility:public"], deps = [ "@nsync//:nsync_cpp", ] + [ "//third_party/eigen3", + "//tensorflow/core/lib/bfloat16", "//tensorflow/core/platform/default/build_config:minimal", + "//tensorflow/core/platform:types", ], ) @@ -1384,6 +1318,36 @@ tf_gen_op_libs( "ragged_conversion_ops", "ragged_math_ops", ], + deps = [":ragged_to_dense_util"], +) + +cc_library( + name = "ragged_to_dense_util", + srcs = [ + "ops/ragged_to_dense_util.cc", + ], + hdrs = [ + "ops/ragged_to_dense_util.h", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "ragged_to_dense_util_test", + srcs = [ + "ops/ragged_to_dense_util_test.cc", + ], + deps = [ + ":ragged_to_dense_util", + ":test", + ":testlib", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_googletest//:gtest_main", + ], ) cc_library( @@ -1805,11 +1769,21 @@ filegroup( filegroup( name = "mobile_srcs_no_runtime", srcs = [ - ":protos_all_proto_text_srcs", - ":error_codes_proto_text_srcs", + ":attr_value_proto_text_srcs", "//tensorflow/core/platform/default/build_config:android_srcs", "//tensorflow/core/util/ctc:android_srcs", + "//tensorflow/core/platform:legacy_srcs_no_runtime", "//tensorflow/core/profiler:mobile_srcs", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/bfloat16:bfloat16.cc", + "//tensorflow/core/lib/core:legacy_lib_core_all_headers", + "//tensorflow/core/lib/core:legacy_lib_core_all_srcs", + "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", + "//tensorflow/core/lib/random:legacy_lib_random_all_headers", + "//tensorflow/core/lib/random:legacy_lib_random_all_srcs", + "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", + "//tensorflow/core/lib/strings:legacy_lib_strings_all_srcs", + "//tensorflow/core/lib/math:math_util.h", ] + glob( [ "client/**/*.cc", @@ -1817,8 +1791,6 @@ filegroup( "framework/**/*.cc", "lib/**/*.h", "lib/**/*.cc", - "platform/**/*.h", - "platform/**/*.cc", "public/**/*.h", "util/**/*.h", "util/**/*.cc", @@ -1839,22 +1811,6 @@ filegroup( "util/events_writer.*", "util/stats_calculator.*", "util/reporter.*", - "platform/**/cuda_libdevice_path.*", - "platform/**/logger.cc", - # Exclude env_time and logging to avoid collisions with - # :platform_base, a common dependency for downstream targets. - "platform/**/env_time.cc", - "platform/**/logging.cc", - "platform/default/test_benchmark.*", - "platform/cuda.h", - "platform/rocm.h", - "platform/google/**/*", - "platform/hadoop/**/*", - "platform/gif.h", - "platform/jpeg.h", - "platform/png.h", - "platform/stream_executor.*", - "platform/windows/**/*", "user_ops/**/*.cu.cc", "util/ctc/*.h", "util/ctc/*.cc", @@ -2109,7 +2065,7 @@ filegroup( filegroup( name = "android_test_srcs", # TODO(andrewharp/nhua): - # make more test-related sources portable e.g. "platform/test.cc", + # make more test-related sources portable e.g. "//tensorflow/core/platform:test.cc", srcs = [ ":framework/fake_input.cc", ":framework/fake_input.h", @@ -2117,10 +2073,10 @@ filegroup( ":framework/shape_inference_testutil.h", ":framework/tensor_testutil.cc", ":framework/tensor_testutil.h", - ":platform/test.cc", - ":platform/test.h", ":util/reporter.cc", ":util/reporter.h", + "//tensorflow/core/platform:test.cc", + "//tensorflow/core/platform:test.h", ], visibility = ["//visibility:public"], ) @@ -2133,9 +2089,9 @@ filegroup( ":framework/shape_inference_testutil.h", ":framework/tensor_testutil.cc", ":framework/tensor_testutil.h", - ":platform/test.h", ":util/reporter.cc", ":util/reporter.h", + "//tensorflow/core/platform:test.h", ], visibility = ["//visibility:public"], ) @@ -2298,6 +2254,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "protobuf/graph_debug_info_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/graph_debug_info.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "protobuf/meta_graph_pyclif", proto_lib = ":protos_all_cc", @@ -2400,36 +2363,32 @@ tf_proto_library_cc( ], ) -LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob( +LIB_INTERNAL_PRIVATE_HEADERS = [ + "framework/resource_handle.h", + "//tensorflow/core/platform:legacy_lib_internal_headers", + "//tensorflow/core/platform:scanner.h", + "//tensorflow/core/platform:str_util.h", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_core_all_headers", + "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", + "//tensorflow/core/lib/random:legacy_lib_random_all_headers", + "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", + "//tensorflow/core/lib/math:math_util.h", +] + glob( [ "lib/**/*.h", - "platform/*.h", - "platform/profile_utils/**/*.h", ], exclude = [ "**/*test*", "lib/gif/**/*", "lib/jpeg/**/*", "lib/png/**/*", - "platform/gif.h", - "platform/jpeg.h", - "platform/png.h", - "platform/**/cuda.h", - "platform/**/rocm.h", - "platform/**/stream_executor.h", ], ) -LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ - "lib/core/blocking_counter.h", - "lib/core/refcount.h", - "lib/gtl/edit_distance.h", - "lib/gtl/int_type.h", - "lib/gtl/iterator_range.h", - "lib/gtl/manual_constructor.h", - "lib/gtl/map_util.h", - "lib/gtl/stl_util.h", - "lib/gtl/top_n.h", +LIB_INTERNAL_PUBLIC_HEADERS = [ + "//tensorflow/core/lib/core:legacy_lib_internal_core_headers", + "//tensorflow/core/lib/gtl:legacy_lib_internal_public_gtl_headers", "lib/hash/hash.h", "lib/io/inputbuffer.h", "lib/io/iterator.h", @@ -2442,25 +2401,21 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "lib/monitoring/mobile_gauge.h", "lib/monitoring/mobile_sampler.h", "lib/png/png_io.h", - "lib/random/random.h", - "lib/random/random_distributions.h", - "lib/random/weighted_picker.h", - "lib/strings/base64.h", - "lib/strings/ordered_code.h", - "lib/strings/proto_text_util.h", - "lib/strings/proto_serialization.h", - "lib/strings/scanner.h", + "//tensorflow/core/lib/random:legacy_lib_internal_public_random_headers", + "//tensorflow/core/lib/strings:legacy_lib_internal_public_string_headers", "lib/wav/wav_io.h", - "platform/demangle.h", - "platform/denormal.h", - "platform/host_info.h", - "platform/platform.h", - "platform/monitoring.h", - "platform/protobuf_internal.h", - "platform/setround.h", - "platform/snappy.h", - "platform/tensor_coding.h", - "platform/tracing.h", + "//tensorflow/core/platform:demangle.h", + "//tensorflow/core/platform:denormal.h", + "//tensorflow/core/platform:host_info.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:monitoring.h", + "//tensorflow/core/platform:protobuf_internal.h", + "//tensorflow/core/platform:setround.h", + "//tensorflow/core/platform:snappy.h", + "//tensorflow/core/platform:tensor_coding.h", + "//tensorflow/core/platform:tracing.h", + "//tensorflow/core/platform:unbounded_work_queue.h", + "//tensorflow/core/platform:legacy_platform_lib_hdrs", "util/env_var.h", ] @@ -2469,7 +2424,6 @@ LIB_INTERNAL_DEFINES = ( tf_additional_lib_defines() + [ "TF_USE_SNAPPY", ] + tf_additional_verbs_lib_defines() + - tf_additional_mpi_lib_defines() + tf_additional_gdr_lib_defines() + tf_additional_numa_lib_defines() ) @@ -2490,6 +2444,7 @@ cc_library( ], }), deps = tf_additional_lib_deps() + [ + "//tensorflow/core/platform:annotation", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "//third_party/eigen3", @@ -2503,8 +2458,6 @@ cc_library( srcs = LIB_INTERNAL_PRIVATE_HEADERS + glob( [ "lib/**/*.cc", - "platform/*.cc", - "platform/profile_utils/**/*.cc", "util/env_var.cc", ], exclude = [ @@ -2514,46 +2467,34 @@ cc_library( "lib/gif/**/*", "lib/jpeg/**/*", "lib/png/**/*", - "platform/**/env_time.cc", - "platform/**/monitoring.cc", - "platform/**/cuda_libdevice_path.cc", - "platform/**/device_tracer.cc", - "platform/**/logger.cc", - "platform/**/logging.cc", - "platform/**/human_readable_json.cc", - "platform/abi.cc", - "platform/protobuf.cc", ], - ) + tf_additional_lib_srcs( - exclude = [ - "**/*test*", - "platform/**/cuda.h", - "platform/**/cuda_libdevice_path.cc", - "platform/**/rocm.h", - "platform/**/monitoring.cc", - "platform/**/stream_executor.h", - "platform/**/env_time.cc", - "platform/**/device_tracer.cc", - "platform/**/logger.cc", - "platform/**/logging.cc", - "platform/**/human_readable_json.cc", - "platform/abi.cc", - ] + - # Protobuf deps already included through the ":lib_proto_parsing" - # dependency. - tf_additional_proto_srcs(), - ) + tf_additional_monitoring_srcs(), + ) + [ + "//tensorflow/core/platform:legacy_monitoring_srcs", + "//tensorflow/core/platform:legacy_platform_lib_srcs", + "//tensorflow/core/platform:legacy_lib_internal_srcs", + "//tensorflow/core/lib/core:legacy_lib_core_all_srcs", + "//tensorflow/core/lib/random:legacy_lib_random_all_srcs", + "//tensorflow/core/lib/strings:legacy_lib_strings_all_srcs", + ], hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), defines = LIB_INTERNAL_DEFINES, deps = tf_additional_lib_deps() + [ + ":core_stringpiece", ":lib_hash_crc32c_accelerate_internal", ":lib_proto_parsing", - ":abi", - ":core_stringpiece", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "//third_party/eigen3", + "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:annotation", + "//tensorflow/core/platform:cpu_info", + "//tensorflow/core/platform:numbers", + "//tensorflow/core/platform:platform_strings", + "//tensorflow/core/platform:scanner", + "//tensorflow/core/platform:stringprintf", + "//tensorflow/core/platform:str_util", "//tensorflow/core/platform/default/build_config:platformlib", "@snappy", "@zlib_archive//:zlib", @@ -2575,7 +2516,7 @@ cc_library( name = "gif_internal", srcs = [ "lib/gif/gif_io.cc", - "platform/gif.h", + "//tensorflow/core/platform:gif.h", ], hdrs = ["lib/gif/gif_io.h"], copts = tf_copts(), @@ -2596,7 +2537,7 @@ cc_library( srcs = [ "lib/jpeg/jpeg_handle.cc", "lib/jpeg/jpeg_mem.cc", - "platform/jpeg.h", + "//tensorflow/core/platform:jpeg.h", ], hdrs = [ "lib/jpeg/jpeg_handle.h", @@ -2619,18 +2560,19 @@ cc_library( name = "png_internal", srcs = ["lib/png/png_io.cc"], hdrs = [ - "lib/bfloat16/bfloat16.h", - "lib/core/stringpiece.h", "lib/png/png_io.h", - "platform/byte_order.h", - "platform/cpu_info.h", - "platform/default/integral_types.h", - "platform/default/logging.h", - "platform/logging.h", - "platform/macros.h", - "platform/platform.h", - "platform/png.h", - "platform/types.h", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", + "//tensorflow/core/platform:byte_order.h", + "//tensorflow/core/platform:cpu_info.h", + "//tensorflow/core/platform:default/integral_types.h", + "//tensorflow/core/platform:default/logging.h", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:png.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", ], copts = tf_copts(), linkopts = select({ @@ -2651,13 +2593,14 @@ cc_library( cc_library( name = "tflite_portable_logging", hdrs = [ - "lib/bfloat16/bfloat16.h", - "platform/default/integral_types.h", - "platform/default/logging.h", - "platform/logging.h", - "platform/macros.h", - "platform/platform.h", - "platform/types.h", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/platform:default/integral_types.h", + "//tensorflow/core/platform:default/logging.h", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", ], copts = tf_copts(), linkopts = ["-ldl"], @@ -2672,26 +2615,30 @@ cc_library( srcs = if_android([ "lib/jpeg/jpeg_handle.cc", "lib/jpeg/jpeg_mem.cc", - "platform/jpeg.h", + "//tensorflow/core/platform:jpeg.h", ]), hdrs = [ - "lib/bfloat16/bfloat16.h", - "lib/core/stringpiece.h", "lib/jpeg/jpeg_handle.h", "lib/jpeg/jpeg_mem.h", - "platform/default/dynamic_annotations.h", - "platform/default/integral_types.h", - "platform/default/logging.h", - "platform/dynamic_annotations.h", - "platform/logging.h", - "platform/macros.h", - "platform/mem.h", - "platform/platform.h", - "platform/types.h", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", + "//tensorflow/core/platform:default/dynamic_annotations.h", + "//tensorflow/core/platform:default/integral_types.h", + "//tensorflow/core/platform:default/logging.h", + "//tensorflow/core/platform:dynamic_annotations.h", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:mem.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:stringpiece.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", ], copts = tf_copts(), linkopts = ["-ldl"], deps = [ + ":core_stringpiece", + "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform/default/build_config:jpeg", "//tensorflow/core/platform/default/build_config:logging", "@com_google_absl//absl/base:core_headers", @@ -2703,24 +2650,24 @@ cc_library( name = "android_gif_internal", srcs = if_android([ "lib/gif/gif_io.cc", - "platform/gif.h", - "lib/strings/strcat.h", - "lib/strings/numbers.h", + "//tensorflow/core/platform:gif.h", + "//tensorflow/core/lib/strings:legacy_lib_android_gif_internal_string_headers", ]), hdrs = [ - "lib/bfloat16/bfloat16.h", - "lib/core/stringpiece.h", "lib/gif/gif_io.h", - "lib/gtl/cleanup.h", - "platform/default/dynamic_annotations.h", - "platform/default/integral_types.h", - "platform/default/logging.h", - "platform/dynamic_annotations.h", - "platform/logging.h", - "platform/macros.h", - "platform/mem.h", - "platform/platform.h", - "platform/types.h", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", + "//tensorflow/core/lib/gtl:legacy_android_gif_internal_headers", + "//tensorflow/core/platform:default/dynamic_annotations.h", + "//tensorflow/core/platform:default/integral_types.h", + "//tensorflow/core/platform:default/logging.h", + "//tensorflow/core/platform:dynamic_annotations.h", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:mem.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", ], copts = tf_copts(), linkopts = ["-ldl"], @@ -2736,20 +2683,21 @@ cc_library( name = "android_png_internal", srcs = if_android([ "lib/png/png_io.cc", - "platform/png.h", + "//tensorflow/core/platform:png.h", ]), hdrs = [ - "lib/bfloat16/bfloat16.h", - "lib/core/stringpiece.h", "lib/png/png_io.h", - "platform/byte_order.h", - "platform/cpu_info.h", - "platform/default/integral_types.h", - "platform/default/logging.h", - "platform/logging.h", - "platform/macros.h", - "platform/platform.h", - "platform/types.h", + "//tensorflow/core/lib/bfloat16:bfloat16.h", + "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", + "//tensorflow/core/platform:byte_order.h", + "//tensorflow/core/platform:cpu_info.h", + "//tensorflow/core/platform:default/integral_types.h", + "//tensorflow/core/platform:default/logging.h", + "//tensorflow/core/platform:logging.h", + "//tensorflow/core/platform:macros.h", + "//tensorflow/core/platform:platform.h", + "//tensorflow/core/platform:tstring.h", + "//tensorflow/core/platform:types.h", ], copts = tf_copts(), linkopts = ["-ldl"], @@ -2760,59 +2708,19 @@ cc_library( ], ) -tf_proto_library( - name = "error_codes_proto", - srcs = ERROR_CODES_PROTO_SRCS, - cc_api_version = 2, - make_default_target_header_only = True, - provide_cc_alias = True, -) - -tf_generate_proto_text_sources( - name = "error_codes_proto_text", - srcs = ERROR_CODES_PROTO_SRCS, - protodeps = [], - srcs_relative_dir = "tensorflow/core/", - deps = [ - ":error_codes_proto_cc", - ":lib_internal", - ], -) - tf_proto_library( name = "protos_all_proto", srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, make_default_target_header_only = True, protodeps = [ - ":error_codes_proto", + "//tensorflow/core/lib/core:error_codes_proto", ], ) -tf_generate_proto_text_sources( - name = "protos_all_proto_text", - srcs = COMMON_PROTO_SRCS, - protodeps = ERROR_CODES_PROTO_SRCS, - srcs_relative_dir = "tensorflow/core/", - visibility = ["//visibility:public"], - deps = [ - ":error_codes_proto_text", - ":lib_internal", - ":protos_all_proto_cc", - ], -) - -cc_library( - name = "proto_text", - hdrs = [ - ":error_codes_proto_text_hdrs", - ":protos_all_proto_text_hdrs", - ], - deps = [ - ":lib", - ":lib_internal", - ":protos_all_cc", - ], +alias( + name = "error_codes_proto_cc", + actual = "//tensorflow/core/lib/core:error_codes_proto_cc", ) tf_version_info_genrule() @@ -2975,11 +2883,10 @@ tf_cuda_library( deps = [ ":allocator_registry_impl", ":allocator", + ":attr_value_proto_text", ":feature_util", ":lib", ":lib_internal", - ":protos_all_proto_text", - ":error_codes_proto_text", ":protos_all_cc", ":stats_calculator_portable", ":version_lib", @@ -3030,11 +2937,11 @@ cc_header_only_library( tf_cuda_library( name = "stream_executor", - srcs = ["platform/stream_executor.h"], + srcs = ["//tensorflow/core/platform:stream_executor.h"], hdrs = [ - "platform/cuda.h", - "platform/rocm.h", - "platform/stream_executor.h", + "//tensorflow/core/platform:cuda.h", + "//tensorflow/core/platform:rocm.h", + "//tensorflow/core/platform:stream_executor.h", ], deps = [ "//tensorflow/core/platform/default/build_config:stream_executor", @@ -3045,9 +2952,9 @@ tf_cuda_library( # and does not include any cuda dependencies. cc_library( name = "stream_executor_no_cuda", - srcs = ["platform/stream_executor.h"], + srcs = ["//tensorflow/core/platform:stream_executor.h"], hdrs = [ - "platform/stream_executor_no_cuda.h", + "//tensorflow/core/platform:stream_executor_no_cuda.h", ], visibility = ["//visibility:public"], deps = [ @@ -3121,7 +3028,6 @@ tf_cuda_library( ":framework_internal", ":lib", ":lib_internal", - ":proto_text", ":protos_all_cc", "//third_party/eigen3", "@com_google_absl//absl/container:flat_hash_map", @@ -3180,7 +3086,6 @@ tf_cuda_library( ":framework_internal", ":lib", ":lib_internal", - ":proto_text", ":protos_all_cc", "@com_google_absl//absl/container:flat_hash_set", "//third_party/eigen3", @@ -3328,7 +3233,6 @@ tf_cuda_library( ":framework_internal", ":lib", ":lib_internal", - ":proto_text", ":protos_all_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", @@ -3347,7 +3251,6 @@ tf_cuda_library( hdrs = CORE_CPU_LIB_HEADERS, deps = [ ":core_cpu_base", - ":proto_text", "//tensorflow/core/grappler:grappler_item", ] + if_static([":core_cpu_impl"]) + tf_protos_all() + tf_protos_grappler(), ) @@ -3357,7 +3260,6 @@ tf_cuda_library( hdrs = CORE_CPU_LIB_HEADERS, deps = [ ":core_cpu_base_no_ops", - ":proto_text", "//tensorflow/core/grappler:grappler_item", ] + tf_protos_all() + tf_protos_grappler(), ) @@ -3375,7 +3277,6 @@ tf_cuda_library( ":framework", ":graph", ":lib", - ":proto_text", ":protos_all_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -3411,6 +3312,7 @@ cc_library( ":lib", ":lib_internal", ":shared_counter", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -3427,7 +3329,7 @@ cc_library( cc_library( name = "regexp_internal", hdrs = [ - "platform/regexp.h", + "//tensorflow/core/platform:regexp.h", ], visibility = [ "//tensorflow/compiler:__subpackages__", @@ -3454,7 +3356,6 @@ tf_cuda_library( ":lib", ":lib_experimental", ":lib_internal", - ":proto_text", ":protos_all_cc", "//tensorflow/core/debug:debug_graph_utils", "//tensorflow/core/kernels:function_ops", @@ -3479,7 +3380,6 @@ cc_library( ":framework", ":lib", ":lib_internal", - ":proto_text", ":protos_all_cc", ], alwayslink = 1, @@ -3487,7 +3387,9 @@ cc_library( tf_cuda_library( name = "device_tracer", - srcs = tf_additional_device_tracer_srcs(), + srcs = [ + "//tensorflow/core/platform:legacy_device_tracer_srcs", + ], copts = tf_copts(), cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(), visibility = [ @@ -3705,7 +3607,6 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", - ":proto_text", "//third_party/eigen3", "@local_config_sycl//sycl", ], @@ -3719,11 +3620,11 @@ cc_library( name = "lib_test_internal", testonly = 1, hdrs = [ - "lib/gtl/manual_constructor.h", "lib/io/block.h", "lib/io/block_builder.h", "lib/io/format.h", - "lib/random/philox_random_test_utils.h", + "//tensorflow/core/lib/gtl:legacy_lib_test_internal_headers", + "//tensorflow/core/lib/random:legacy_lib_test_internal_headers", ], deps = [ ":lib", @@ -3762,7 +3663,7 @@ cc_library( cc_library( name = "test_main", testonly = 1, - srcs = ["platform/test_main.cc"], + srcs = ["//tensorflow/core/platform:test_main.cc"], copts = tf_copts(), linkopts = select({ "//tensorflow:windows": [], @@ -3783,16 +3684,16 @@ cc_library( cc_library( name = "test_lite_main", testonly = 1, - srcs = ["platform/test_main.cc"], + srcs = ["//tensorflow/core/platform:test_main.cc"], copts = tf_copts(), deps = [ # TODO(ahentz): we don't want to depend on "lib" here. It used to be # that "core_stringpiece" was enough but that recently changed and # we now need at least "str_util". ":lib", - ":lib_platform", - ":stacktrace_handler", ":test_lite", + "//tensorflow/core/platform", + "//tensorflow/core/platform:stacktrace_handler", "//tensorflow/core/platform/default/build_config:test_lite_main", ], alwayslink = 1, @@ -3802,25 +3703,6 @@ tf_cc_tests( name = "low_level_library_tests", size = "small", srcs = [ - "lib/core/arena_test.cc", - "lib/core/bitmap_test.cc", - "lib/core/blocking_counter_test.cc", - "lib/core/coding_test.cc", - "lib/core/notification_test.cc", - "lib/core/refcount_test.cc", - "lib/core/status_test.cc", - "lib/core/stringpiece_test.cc", - "lib/core/threadpool_test.cc", - "lib/gtl/cleanup_test.cc", - "lib/gtl/compactptrset_test.cc", - "lib/gtl/edit_distance_test.cc", - "lib/gtl/flatmap_test.cc", - "lib/gtl/flatset_test.cc", - "lib/gtl/int_type_test.cc", - "lib/gtl/iterator_range_test.cc", - "lib/gtl/manual_constructor_test.cc", - "lib/gtl/map_util_test.cc", - "lib/gtl/top_n_test.cc", "lib/hash/crc32c_test.cc", "lib/hash/hash_test.cc", "lib/histogram/histogram_test.cc", @@ -3834,33 +3716,31 @@ tf_cc_tests( "lib/io/snappy/snappy_buffers_test.cc", "lib/io/table_test.cc", "lib/io/zlib_buffers_test.cc", - "lib/math/math_util_test.cc", "lib/monitoring/collection_registry_test.cc", "lib/monitoring/counter_test.cc", "lib/monitoring/gauge_test.cc", "lib/monitoring/metric_def_test.cc", "lib/monitoring/sampler_test.cc", - "lib/random/distribution_sampler_test.cc", - "lib/random/philox_random_test.cc", - "lib/random/random_test.cc", - "lib/random/simple_philox_test.cc", - "lib/strings/base64_test.cc", - "lib/strings/numbers_test.cc", - "lib/strings/scanner_test.cc", - "lib/strings/str_util_test.cc", - "lib/strings/strcat_test.cc", - "lib/strings/stringprintf_test.cc", "lib/wav/wav_io_test.cc", - "platform/fingerprint_test.cc", - "platform/integral_types_test.cc", - "platform/logging_test.cc", - "platform/mutex_test.cc", - "platform/net_test.cc", - "platform/port_test.cc", - "platform/profile_utils/cpu_utils_test.cc", - "platform/stacktrace_handler_test.cc", - "platform/subprocess_test.cc", - "platform/vmodule_benchmark_test.cc", + "//tensorflow/core/lib/core:legacy_lib_core_all_tests", + "//tensorflow/core/lib/gtl:legacy_lib_gtl_tests", + "//tensorflow/core/lib/math:math_util_test.cc", + "//tensorflow/core/lib/random:legacy_lib_random_tests", + "//tensorflow/core/lib/strings:legacy_low_level_library_tests", + "//tensorflow/core/platform:fingerprint_test.cc", + "//tensorflow/core/platform:integral_types_test.cc", + "//tensorflow/core/platform:logging_test.cc", + "//tensorflow/core/platform:mutex_test.cc", + "//tensorflow/core/platform:net_test.cc", + "//tensorflow/core/platform:port_test.cc", + "//tensorflow/core/platform:profile_utils/cpu_utils_test.cc", + "//tensorflow/core/platform:scanner_test.cc", + "//tensorflow/core/platform:stacktrace_handler_test.cc", + "//tensorflow/core/platform:str_util_test.cc", + "//tensorflow/core/platform:stringpiece_test.cc", + "//tensorflow/core/platform:stringprintf_test.cc", + "//tensorflow/core/platform:subprocess_test.cc", + "//tensorflow/core/platform:vmodule_benchmark_test.cc", ], deps = [ ":core_cpu_internal", @@ -3870,6 +3750,10 @@ tf_cc_tests( ":protos_all_cc", ":test", ":test_main", + "//tensorflow/core/platform:scanner", + "//tensorflow/core/platform:str_util", + "//tensorflow/core/platform:stringpiece", + "//tensorflow/core/platform:stringprintf", "//third_party/eigen3", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -3879,7 +3763,7 @@ tf_cc_tests( tf_cc_test( name = "vmodule_test", - srcs = ["platform/vmodule_test.cc"], + srcs = ["//tensorflow/core/platform:vmodule_test.cc"], tags = ["optonly"], deps = [ ":lib", @@ -3894,7 +3778,7 @@ tf_cc_test( tf_cc_test( name = "lib_random_random_distributions_test", - srcs = ["lib/random/random_distributions_test.cc"], + srcs = ["//tensorflow/core/lib/random:legacy_lib_random_random_distributions_test"], tags = ["optonly"], deps = [ ":lib", @@ -3910,18 +3794,18 @@ tf_cc_test( tf_cc_test( name = "platform_strings_test", size = "small", - srcs = ["platform/platform_strings_test.cc"], + srcs = ["//tensorflow/core/platform:platform_strings_test.cc"], features = ["-dynamic_link_test_srcs"], # see go/dynamic_link_test_srcs deps = [ ":lib", - ":platform_strings", + "//tensorflow/core/platform:platform_strings", ], ) tf_cc_test( name = "platform_env_test", size = "small", - srcs = ["platform/env_test.cc"], + srcs = ["//tensorflow/core/platform:env_test.cc"], deps = [ ":lib", ":lib_internal", @@ -3936,7 +3820,7 @@ tf_cc_test( tf_cc_test( name = "platform_fake_python_env_test", size = "small", - srcs = ["platform/fake_python_env_test.cc"], + srcs = ["//tensorflow/core/platform:fake_python_env_test.cc"], args = [ "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", ], @@ -3959,7 +3843,7 @@ tf_cc_test( tf_cc_test( name = "platform_abi_test", size = "small", - srcs = ["platform/abi_test.cc"], + srcs = ["//tensorflow/core/platform:abi_test.cc"], deps = [ ":framework", ":lib", @@ -3975,7 +3859,7 @@ tf_cc_test( tf_cc_test( name = "platform_numa_test", size = "small", - srcs = ["platform/numa_test.cc"], + srcs = ["//tensorflow/core/platform:numa_test.cc"], tags = [ # This test will not pass unless it has access to all NUMA nodes # on the executing machine. @@ -3997,7 +3881,7 @@ tf_cc_test( tf_cc_test( name = "platform_setround_test", size = "small", - srcs = ["platform/setround_test.cc"], + srcs = ["//tensorflow/core/platform:setround_test.cc"], tags = [ "noasan", "noclang", @@ -4016,7 +3900,7 @@ tf_cc_test( tf_cc_test( name = "platform_file_system_test", size = "small", - srcs = ["platform/file_system_test.cc"], + srcs = ["//tensorflow/core/platform:file_system_test.cc"], deps = [ ":lib", ":lib_internal", @@ -4067,7 +3951,7 @@ tf_cc_test( tf_cc_test( name = "lib_strings_ordered_code_test", - srcs = ["lib/strings/ordered_code_test.cc"], + srcs = ["//tensorflow/core/lib/strings:legacy_strings_ordered_code_test"], extra_copts = ["$(STACK_FRAME_UNLIMITED)"], # Tests initialize large vectors deps = [ ":lib", @@ -4079,7 +3963,7 @@ tf_cc_test( tf_cc_test( name = "lib_strings_proto_serialization_test", - srcs = ["lib/strings/proto_serialization_test.cc"], + srcs = ["//tensorflow/core/lib/strings:legacy_strings_proto_serialization_test"], deps = [ ":lib", ":lib_internal", @@ -4094,7 +3978,7 @@ tf_cc_test( tf_cc_test( name = "lib_random_weighted_picker_test", size = "medium", - srcs = ["lib/random/weighted_picker_test.cc"], + srcs = ["//tensorflow/core/lib/random:legacy_lib_random_random_weighted_picker_test"], deps = [ ":lib", ":lib_internal", @@ -4586,6 +4470,20 @@ tf_cuda_cc_test( ], ) +tf_cc_test_gpu( + name = "rocm_rocdl_path_test", + size = "small", + srcs = ["//tensorflow/core/platform:rocm_rocdl_path_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_gpu_tests_tags(), + deps = [ + ":lib", + ":test", + ":test_main", + "//tensorflow/core/platform:rocm_rocdl_path", + ], +) + tf_cuda_only_cc_test( name = "util_gpu_kernel_helper_test", srcs = [ @@ -4658,7 +4556,7 @@ tf_cc_test( size = "small", srcs = ["common_runtime/constant_folding_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), - tags = tf_cuda_tests_tags(), + tags = tf_cuda_tests_tags() + ["no_rocm"], deps = [ ":core", ":core_cpu", @@ -4724,12 +4622,14 @@ tf_cuda_cc_test( size = "small", srcs = ["common_runtime/process_function_library_runtime_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), + tags = ["no_rocm"], deps = [ ":core_cpu", ":core_cpu_internal", ":framework", ":framework_internal", ":lib", + ":protos_all_cc", ":test", ":test_main", ":testlib", @@ -5328,7 +5228,7 @@ tf_cc_test( tf_cc_test_gpu( name = "device_tracer_test", size = "small", - srcs = ["platform/device_tracer_test.cc"], + srcs = ["//tensorflow/core/platform:device_tracer_test.cc"], args = ["--heap_check=local"] + tf_additional_device_tracer_test_flags(), linkstatic = tf_kernel_tests_linkstatic(), @@ -5534,11 +5434,19 @@ filegroup( testonly = 1, srcs = [ # A simple key-value store: + # 0 : 'b' + # 1 : 'b' + # ... + # 9 : 'b' + # Which is then overwritten with: # 0 : 'a' # 1 : 'b' # ... # 9 : 'j' "lib/lmdb/testdata/data.mdb", + # LMDB, being a memory-mapped database, uses a different file format on + # big-endian systems. + "lib/lmdb/testdata/data_bigendian.mdb", ], visibility = ["//visibility:public"], ) @@ -5552,10 +5460,12 @@ filegroup( cc_library( name = "cuda_libdevice_path", - srcs = tf_additional_libdevice_srcs(), - hdrs = ["platform/cuda_libdevice_path.h"], + srcs = [ + "//tensorflow/core/platform:legacy_libdevice_srcs", + ], copts = tf_copts(), data = tf_additional_libdevice_data(), + textual_hdrs = ["//tensorflow/core/platform:cuda_libdevice_path.h"], visibility = ["//visibility:public"], deps = [ ":lib", @@ -5569,9 +5479,9 @@ transitive_hdrs( ":core_cpu", ":framework", ":lib", - ":platform_strings", ":protos_all_cc", ":stream_executor", + "//tensorflow/core/platform:platform_strings", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousMemoryCache.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousMemoryCache.pbtxt new file mode 100644 index 00000000000..7c6d161b236 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousMemoryCache.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "AnonymousMemoryCache" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousRandomSeedGenerator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousRandomSeedGenerator.pbtxt new file mode 100644 index 00000000000..327a0682dc8 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousRandomSeedGenerator.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "AnonymousRandomSeedGenerator" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt new file mode 100644 index 00000000000..07366bfd367 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyAdagradV2.pbtxt @@ -0,0 +1,53 @@ +op { + graph_op_name: "ApplyAdagradV2" + visibility: HIDDEN + in_arg { + name: "var" + description: <