diff --git a/.gitignore b/.gitignore index be75938ec40..828bbe9bd33 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ Podfile.lock /tensorflow/contrib/lite/examples/ios/simple/data/*.txt /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite xcuserdata/** +/api_init_files_list.txt # Android .gradle diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3dad41a88c8..8669c25c452 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,16 @@ # Contributing guidelines +## Pull Request Checklist + +Before sending your pull requests, make sure you followed this list. + +- Read [contributing guidelines](CONTRIBUTING.md). +- Read [Code of Conduct](CODE_OF_CONDUCT.md). +- Ensure you have signed the [Contributor License Agreement (CLA)](https://cla.developers.google.com/). +- Check if my changes are consistent with the [guidelines](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#general-guidelines-and-philosophy-for-contribution). +- Changes are consistent with the [Coding Style](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#c-coding-style). +- Run [Unit Tests](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md#running-unit-tests). + ## How to become a contributor and submit your own code ### Contributor License Agreements diff --git a/README.md b/README.md index e1a50c87e26..6fb4486d0de 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ ----------------- -| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------|---------------| -| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | ![Build Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) +| **`Documentation`** | +|-----------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -40,15 +40,6 @@ environment to install the nightly TensorFlow build. We support CPU and GPU packages on Linux, Mac, and Windows. -**Individual whl files** -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/)) -* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) -* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) -([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) - #### *Try your first TensorFlow program* ```shell $ python @@ -82,6 +73,30 @@ The TensorFlow project strives to abide by generally accepted best practices in [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + +## Continuous build status + +### Official Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **Linux CPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Linux GPU** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/ubuntu-gpu-cc.png) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Linux XLA** | TBA | TBA | +| **MacOS** | ![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/macos-py2-cc.png) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows CPU** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [pypi](https://pypi.org/project/tf-nightly/) | +| **Windows GPU** | [![Status](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/badge/icon)](http://ci.tensorflow.org/job/tf-master-win-gpu-cmake/) | [pypi](https://pypi.org/project/tf-nightly-gpu/) | +| **Android** | [![Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) [build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/) | + + +### Community Supported Builds + +| Build Type | Status | Artifacts | +| --- | --- | --- | +| **IBM s390x** | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | TBA | +| **IBM ppc64le CPU** | [![Build Status](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/badge/icon)](http://powerci.osuosl.org/job/TensorFlow_Ubuntu_16.04_CPU/) | TBA | + + ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index e8459531748..84d9d52868e 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,62 @@ +# Release 1.8.0 + +## Major Features And Improvements +* Can now pass `tf.contrib.distribute.MirroredStrategy()` to `tf.estimator.RunConfig()` to run an Estimator model on multiple GPUs on one machine. +* Add `tf.contrib.data.prefetch_to_device()`, which supports prefetching to GPU memory. +* Added Gradient Boosted Trees as pre-made Estimators: BoostedTreesClassifier, BoostedTreesRegressor. +* Add 3rd generation pipeline config for Cloud TPUs which improves performance and usability. +* `tf.contrib.bayesflow` is moving out to it's own repo. +* Added `tf.contrib.{proto,rpc}` to allow generic proto parsing and RPC communication[1](#rpc-issue). + +## Bug Fixes and Other Changes +* `tf.data`: + * Add `tf.contrib.data.prefetch_to_device`, which enables prefetching dataset elements to GPU memory. + * Add `tf.contrib.data.AUTOTUNE`, which allows the tf.data runtime to automatically tune the prefetch buffer sizes based on your system and environment. + * Add `tf.contrib.data.make_csv_dataset` for building datasets of CSV files. +* Eager Execution: + * With eager execution Datasets can now be used as standard python iterators (`for batch in dataset:`). Both `Dataset.__iter__()` and `Dataset.make_one_shot_iterator()` can now be used to create iterators when eager execution is enabled. + * Automatic device placement has been enabled (i.e., use a GPU if available automatically, without requiring an explicit `with tf.device(“/gpu:0”)`) (Fixes #14133) + * `tf.GradientTape` has moved out of contrib. +* `tf.keras`: + * Added the fashion mnist dataset. + * New data preprocessing functions: `image/random_brightness`, `sequence/TimeseriesGenerator`, and `text/hashing_trick`. +* Accelerated Linear Algebra (XLA): + * Select and scatter in reference util and evaluator now use lexicographical order to break ties. +* TensorFlow Debugger (tfdbg) CLI: + * During tensor-filter operations, allow exclusion of nodes by regular expressions. + * Fix spurious background colors in some text terminals. +* `tf.contrib`: + * Add meta-distribution BatchReshape which reshapes batch dimensions. + * `tf.contrib.layers.recompute_grad` works for explicit gradient checkpointing on TPU. + * Add `tf.contrib.framework.argsort`. + * Allow `DNNBoostedTreeCombinedEstimator` to work with core versions of feature columns and losses. + * Add non-linear image warping ops: `tf.contrib.image.sparse_image_warp`, `tf.contrib.image.dense_image_warp`, and `tf.contrib.image.interpolate_spline`. + * Fix bug in `tf.contrib.opt.MultitaskOptimizerWrapper` where types of tensors were mismatched. +* Other: + * Low-level graph construction now calls the TensorFlow C API. This change should be invisible to most users, but can be disabled by setting the environment variable `TF_C_API_GRAPH_CONSTRUCTION=0` in this release. Future releases will remove the ability to disable this change. Please [file a bug](https://github.com/tensorflow/tensorflow/issues/new) if you find yourself using this escape hatch. + * Add description of shapes and a pointer to tutorial notebook in `tf.distributions.Distribution`. + * Update scatter operations: + * Add `tf.scatter_min` and `tf.scatter_max` + * Extend scatter operations to work with a scalar update parameter. + * Move cuDNN RNN ops to core for use in TensorFlow codebase only. + * Add `float64` support for `Conv2d`, `Conv2dBackpropInput`, and `Conv2dBackpropFilter`. + * Add `float64` support for `AvgPool`/`AvgPoolGrad`. + * Make graph name scope thread local so that they work correctly in multi-threaded environments. + * Update nsync synchronization library to avoid slow primitives on Linux. + * Removed need to put nsync/public on C include path when building custom ops. + * Add `tf.image.psnr`, `tf.image.ssim`, `tf.image.ssim_multiscale`, `tf.image.image_gradients`, `tf.image.sobel_edges`. + * Add links to https://js.tensorflow.org. + * Fix non-uniformity of orthogonal matrices. + * Fix bug where multi-image Estimator eval summaries were not displayed correctly. + +1 The cancellation logic of the RPC op contains a concurrency error. A fix has been submitted to master and will be part of the next release. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Aghasy, Alan Du, Alan Lee, Alan Yee, Alex Wiltschko, Animesh Karnewar, Ankit Gupta, Anton Matosov, Aris L, Ben Barsdell, Brent Yi, Brett Koonce, Carl Thomé, cbockman, Chikanaga Tomoyuki, Chris Tava, CéDric Deltheil, Dahan Gong, Dalmo Cirne, Daniel Erenrich, David Norman, DavidNorman, Edd Wilder-James, Fanjin Zeng, Felix Abecassis, fo40225, George Sterpu, Giovanni Terlingen, Gor Baghdasaryan, Guillaume Klein, Hanchen Li, Ilya Polenov, Jakub Kolodziejczyk, Jason Sadler, Jayaram Bobba, Jerry Liu, jinghuangintel, Jiongyan Zhang (张炯衍), Joel Shor, Jong Wook Kim, Julian Eisenschlos, Karl Lessard, Krish Ravindranath, Loo Rong Jie, Lukas Geiger, Luke Iwanski, Mahmoud Abuzaina, ManHyuk, Marvin Richter, Maximilian Mitchell, Mohammad Ashraf Bhuiyan, msofka, Mustafa Kasap, Nathan Burnham, Nathan Luehr, Naveen Marri, ngc92, nio1814, Oleg Zabluda, Ou Changkun, Panos Ipeirotis, Paul Van Eck, Peter Lee, Piotr Czapla, qjivy, Rholais Lii, Rodrigo Formigone, Russell Klopfer, ryantimjohn, Sang Han, SebastiáN RamíRez, shengfuintel, Siby Jose Plathottam, Silver Chan, Stanislaw Antol, Taehoon Lee, Tarang Chugh, Ted Chang, Thomas Bastiani, Xian Xu, Xiaoming (Jason) Cui, Yan Facai (颜发才), yaox12, Yashal Shakti Kanungo, Yong Tang, Yuan (Terry) Tang, Yuxin Wu, Ziyue(Louis) Lu + # Release 1.7.0 ## Major Features And Improvements @@ -177,7 +236,7 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 * Add `complex64` support to XLA compiler. * `bfloat` support is now added to XLA infrastructure. * Make `ClusterSpec` propagation work with XLA devices. - * Use a determinisitic executor to generate XLA graph. + * Use a deterministic executor to generate XLA graph. * `tf.contrib`: * `tf.contrib.distributions`: * Add `tf.contrib.distributions.Autoregressive`. diff --git a/SECURITY.md b/SECURITY.md index a5ce3a62ee2..01886b613e5 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -173,7 +173,7 @@ the progress being made towards a fix and announcement. In addition, please include the following information along with your report: * Your name and affiliation (if any). -* A description the technical details of the vulnerabilities. It is very +* A description of the technical details of the vulnerabilities. It is very important to let us know how we can reproduce your findings. * An explanation who can exploit this vulnerability, and what they gain when doing so -- write an attack scenario. This will help us evaluate your report diff --git a/WORKSPACE b/WORKSPACE index 11c5cdb2070..4ddfb9a3832 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657", - strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f", + sha256 = "a38539c5b5c358548e75b44141b4ab637bba7c4dc02b46b1f62a96d6433f56ae", + strip_prefix = "rules_closure-dbb96841cc0a5fb2664c37822803b06dab20c7d1", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/dbb96841cc0a5fb2664c37822803b06dab20c7d1.tar.gz", # 2018-04-13 ], ) diff --git a/configure.py b/configure.py index 8fb89791116..b6c32543cf7 100644 --- a/configure.py +++ b/configure.py @@ -226,8 +226,6 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --force_python=py%s' % python_major_version) - write_to_bazelrc('build --host_force_python=py%s' % python_major_version) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path @@ -500,10 +498,6 @@ def set_cc_opt_flags(environ_cp): if not is_ppc64le() and not is_windows(): write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') - # TODO(mikecase): Remove these default defines once we are able to get - # TF Lite targets building without them. - write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') - write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): """set TF_CUDA_CLANG action_env. @@ -847,8 +841,8 @@ def reformat_version_sequence(version_str, sequence_count): def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use, ' - 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION + 'Please specify the CUDA SDK version you want to use. ' + '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. @@ -1228,6 +1222,9 @@ def set_tf_cuda_compute_capabilities(environ_cp): ask_cuda_compute_capabilities, default_cuda_compute_capabilities) # Check whether all capabilities from the input is valid all_valid = True + # Remove all whitespace characters before splitting the string + # that users may insert by accident, as this will result in error + tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 18eeb281680..b86b277ac32 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -2097,7 +2097,7 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, for (int i = 0; i < size; ++i) { TensorId id = results.missing_unused_input_map_keys[i]; - tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names_data.push_back(std::string(id.first)); tf_results->missing_unused_key_names[i] = tf_results->missing_unused_key_names_data.back().c_str(); tf_results->missing_unused_key_indexes[i] = id.second; diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 9678ee926fc..95b04f9058a 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -184,6 +184,7 @@ library { return std::move(functions[0]); } +#if not defined(PLATFORM_WINDOWS) // On success, returns a set of TF_Function instances encoding a dataset // node stack that reads a Imagenet TFRecordFile dataset from `file_path`, and // sets `dataset_name` to the created dataset name. The returned functions must @@ -7076,7 +7077,9 @@ library { return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); #endif } +#endif +#if not defined(PLATFORM_WINDOWS) // On success, returns a set of TF_Function instances encoding a dataset // node stack that reads an MNIST file dataset from `file_path`, and // sets `dataset_name` to the created dataset name. The returned functions must @@ -8221,6 +8224,7 @@ library { return CreateFunctionsFromTextProto(func_def, &mutate_proto_func, status); #endif } +#endif // Adds the input functions to `graph`. On success, returns the created // IteratorGetNext node. @@ -8314,6 +8318,13 @@ TF_Operation* TF_MakeFakeIteratorGetNextWithDatasets(TF_Graph* graph, TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( TF_Graph* graph, const char* file_path, int batch_size, unsigned char is_mnist, TF_Status* status) { +#if defined(PLATFORM_WINDOWS) + // TODO(ashankar): get these functions working on Windows. + status->status = tensorflow::errors::Unimplemented( + "TF_MakeFileBasedIteratorGetNextWithDatasets in the experimental C API " + "is not implemented for Windows"); + return nullptr; +#else tensorflow::Status s; std::string dataset_name; @@ -8355,4 +8366,92 @@ TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( << graph->graph.ToGraphDefDebug().DebugString(); return getnext_node; +#endif +} + +TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, int tensor_id, + TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + VLOG(1) << "Dequeuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + } + + TF_Operation* dequeue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_dequeue_", tensor_id).c_str()); + if (dequeue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the dequeue node in the TF graph."); + return nullptr; + } + + VLOG(1) << "Running the dequeue op"; + TF_Output output{dequeue_op, 0}; + TF_Tensor* ret; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ nullptr, /*input_values*/ nullptr, /*ninputs*/ 0, + // output related parameters + /*outputs*/ &output, /*output_values*/ &ret, + /*noutputs*/ 1, + /*targets*/ nullptr, /*ntargets*/ 0, + /*run_metadata*/ nullptr, status); + if (VLOG_IS_ON(1) && status->status.ok()) { + tensorflow::Tensor tensor; + if (tensorflow::TF_TensorToTensor(ret, &tensor).ok()) { + VLOG(1) << "Dequeued tensor content: " << tensor.DebugString(); + } + } + return ret; +} + +void TF_EnqueueNamedTensor(TF_Session* session, int tensor_id, + TF_Tensor* tensor, TF_Status* status) { + assert(session); + { + tensorflow::mutex_lock c(session->graph->mu); + if (VLOG_IS_ON(1)) { + VLOG(1) << "Enqueuing named tensor with id " << tensor_id + << ", with input graph: " + << session->graph->graph.ToGraphDefDebug().DebugString(); + tensorflow::Tensor internal_tensor; + if (tensorflow::TF_TensorToTensor(tensor, &internal_tensor).ok()) { + VLOG(1) << "Enqueu'ing tensor content: " + << internal_tensor.DebugString(); + } + } + } + + TF_Operation* enqueue_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("fifo_queue_enqueue_", tensor_id).c_str()); + if (enqueue_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the enqueue node in the TF graph."); + return; + } + + TF_Operation* placeholder_op = TF_GraphOperationByName( + session->graph, + tensorflow::strings::StrCat("arg_tensor_enqueue_", tensor_id).c_str()); + if (placeholder_op == nullptr) { + status->status = tensorflow::errors::Internal( + "Unable to find the placeholder node as input to enqueue in the TF " + "graph."); + return; + } + + VLOG(1) << "Running the enqueue op"; + TF_Output input{placeholder_op, 0}; + TF_SessionRun(session, /*run_options*/ nullptr, + // input related parameters + /*inputs*/ &input, /*input_values*/ &tensor, /*ninputs*/ 1, + // output related parameters + /*outputs*/ nullptr, /*output_values*/ nullptr, /*noutputs*/ 0, + /*targets*/ &enqueue_op, /*ntargets*/ 1, + /*run_metadata*/ nullptr, status); + VLOG(1) << "Enqueuing is done."; } diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index 88cb173cd25..20bdace40f1 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -86,6 +86,35 @@ TF_CAPI_EXPORT extern TF_Operation* TF_MakeFileBasedIteratorGetNextWithDatasets( TF_Graph* graph, const char* file_path, int batch_size, unsigned char is_mnist, TF_Status* status); +// On success, dequeues a tensor from a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_dequeue_", to be executed by this API call. + +// Caller must call TF_DeleteTensor() over the returned tensor. If the queue is +// empty, this call is blocked. +// +// Tensors are enqueued via the corresponding TF enqueue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern TF_Tensor* TF_DequeueNamedTensor(TF_Session* session, + int tensor_id, + TF_Status* status); + +// On success, enqueues `tensor` into a TF-managed FifoQueue given by +// `tensor_id`, associated with `session`. There must be a graph node named +// "fifo_queue_enqueue_", to be executed by this API call. It reads +// from a placeholder node "arg_tensor_enqueue_". +// +// `tensor` is still owned by the caller. This call will be blocked if the queue +// has reached its capacity, and will be unblocked when the queued tensors again +// drop below the capacity due to dequeuing. +// +// Tensors are dequeued via the corresponding TF dequeue op. +// TODO(hongm): Add support for `timeout_ms`. +TF_CAPI_EXPORT extern void TF_EnqueueNamedTensor(TF_Session* session, + int tensor_id, + TF_Tensor* tensor, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index ca80db23ed3..577f10c5e69 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -1368,7 +1368,7 @@ TEST(CAPI, SavedModel) { } const tensorflow::string input_op_name = - tensorflow::ParseTensorName(input_name).first.ToString(); + std::string(tensorflow::ParseTensorName(input_name).first); TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); @@ -1376,7 +1376,7 @@ TEST(CAPI, SavedModel) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); const tensorflow::string output_op_name = - tensorflow::ParseTensorName(output_name).first.ToString(); + std::string(tensorflow::ParseTensorName(output_name).first); TF_Operation* output_op = TF_GraphOperationByName(graph, output_op_name.c_str()); ASSERT_TRUE(output_op != nullptr); @@ -1700,7 +1700,7 @@ TEST_F(CApiGradientsTest, OpWithNoGradientRegistered_NoGradInputs) { TestGradientsError(false); } -// REGISTER_OP for CApiTestAttributesTest test cases. +// REGISTER_OP for CApiAttributesTest test cases. // Registers two ops, each with a single attribute called 'v'. // The attribute in one op will have a type 'type', the other // will have list(type). diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index cd19cf8d624..c16aba666ee 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index b1f7bdaa542..74bc25a491a 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -125,7 +125,7 @@ CheckpointReader::BuildV2VarMaps() { const auto& slice_proto = entry.slices(i); CHECK(filtered_keys .insert(EncodeTensorNameSlice( - v2_reader_->key().ToString() /* full var's name */, + std::string(v2_reader_->key()) /* full var's name */, TensorSlice(slice_proto))) .second); } @@ -138,11 +138,11 @@ CheckpointReader::BuildV2VarMaps() { new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { - if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; + if (filtered_keys.count(std::string(v2_reader_->key())) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - string key = v2_reader_->key().ToString(); + string key = std::string(v2_reader_->key()); (*var_to_shape_map)[key] = TensorShape(entry.shape()); (*var_to_data_type_map)[key] = DataType(entry.dtype()); } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index a2d96357ac8..9ce781fab0b 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -24,14 +24,13 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", - "//tensorflow/core/common_runtime/eager:execute_node", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:copy_to_device_node", @@ -49,6 +48,18 @@ tf_cuda_library( ], "//conditions:default": [], }) + [ + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core:gpu_runtime", ], ) @@ -59,7 +70,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":c_api", - ":runtime", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", @@ -69,10 +79,23 @@ tf_cuda_library( "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/common_runtime/eager:attr_builder", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:eager_executor", + "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/common_runtime/eager:kernel_and_device", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/distributed_runtime:remote_device", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/eager:remote_tensor_handle", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", + "//tensorflow/core/distributed_runtime/rpc:grpc_worker_service", + "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", ], ) @@ -91,47 +114,7 @@ tf_cuda_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", - ], -) - -tf_cuda_library( - name = "runtime", - srcs = ["runtime.cc"], - hdrs = ["runtime.h"], - copts = tf_copts(), - visibility = ["//tensorflow:internal"], - deps = select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib_lite", - ], - "//conditions:default": [ - "//tensorflow/c:c_api", - "//tensorflow/core:core_cpu", - "//tensorflow/core/common_runtime/eager:kernel_and_device", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], - }), -) - -tf_cc_test( - name = "runtime_test", - srcs = ["runtime_test.cc"], - deps = [ - ":runtime", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:client_session", - "//tensorflow/cc:ops", - "//tensorflow/cc:scope", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime/rpc/eager:eager_grpc_server_lib", ], ) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index c96a38dec3e..216210c88c1 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #ifdef TENSORFLOW_EAGER_USE_XLA #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #endif // TENSORFLOW_EAGER_USE_XLA @@ -32,16 +31,22 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h" #include "tensorflow/core/common_runtime/eager/execute.h" -#include "tensorflow/core/common_runtime/eager/execute_node.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -72,6 +77,121 @@ string DeviceName(const tensorflow::Device* d) { std::atomic_int_fast64_t func_id_generator(0); #endif // TENSORFLOW_EAGER_USE_XLA +tensorflow::Status GetAllRemoteDevices( + const std::vector& remote_workers, + tensorflow::WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { + std::vector remote_devices; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + for (const string& remote_worker : remote_workers) { + tensorflow::Notification n; + tensorflow::NewRemoteDevices( + tensorflow::Env::Default(), worker_cache, remote_worker, + [&status, &n, &remote_devices]( + const tensorflow::Status& s, + std::vector* devices) { + status = s; + if (s.ok()) { + for (tensorflow::Device* d : *devices) { + remote_devices.push_back(d); + } + } + n.Notify(); + }); + n.WaitForNotification(); + } + std::unique_ptr remote_device_mgr( + new tensorflow::DeviceMgr(remote_devices)); + + TF_RETURN_IF_ERROR(status); + + *device_mgr = std::move(remote_device_mgr); + return tensorflow::Status::OK(); +} + +tensorflow::Status CreateRemoteContexts( + const std::vector& remote_workers, + tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, + tensorflow::gtl::FlatMap* remote_contexts) { + for (int i = 0; i < remote_workers.size(); i++) { + const string& remote_worker = remote_workers[i]; + + tensorflow::eager::CreateContextRequest request; + tensorflow::eager::CreateContextResponse response; + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, + &parsed_name)) { + return tensorflow::errors::InvalidArgument( + "Unable to parse ", remote_worker, " as a device name"); + } + request.mutable_server_def()->set_job_name(parsed_name.job); + request.mutable_server_def()->set_task_index(parsed_name.task); + request.set_async(async); + auto* eager_client = remote_eager_workers->GetClient(remote_worker); + if (eager_client == nullptr) { + return tensorflow::errors::Internal( + "Cannot find a client for the given target:", remote_worker); + } + tensorflow::Notification n; + tensorflow::Status status; + // TODO(nareshmodi) do this in parallel instead of serially. + eager_client->CreateContextAsync( + &request, &response, [&status, &n](const tensorflow::Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(status); + + remote_contexts->emplace(remote_worker, response.context_id()); + } + return tensorflow::Status::OK(); +} + +tensorflow::Status NewRemoteAwareTFE_Context(const TFE_ContextOptions* opts, + TFE_Context** ctx) { + string worker_name = tensorflow::strings::StrCat( + "/job:", opts->server_def.job_name(), + "/replica:0/task:", opts->server_def.task_index()); + std::unique_ptr server; + TF_RETURN_IF_ERROR( + tensorflow::eager::EagerGrpcServer::Create(opts->server_def, &server)); + + TF_RETURN_IF_ERROR(server->Start()); + + std::vector remote_workers; + server->master_env()->worker_cache->ListWorkers(&remote_workers); + remote_workers.erase( + std::remove(remote_workers.begin(), remote_workers.end(), worker_name), + remote_workers.end()); + + std::unique_ptr remote_device_mgr; + TF_RETURN_IF_ERROR(GetAllRemoteDevices( + remote_workers, server->master_env()->worker_cache, &remote_device_mgr)); + + std::shared_ptr channel_cache = + server->channel_cache(); + std::unique_ptr remote_eager_workers( + tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); + + // Initialize remote eager workers. + tensorflow::gtl::FlatMap remote_contexts; + TF_RETURN_IF_ERROR(CreateRemoteContexts(remote_workers, + remote_eager_workers.get(), + opts->async, &remote_contexts)); + + tensorflow::RemoteRendezvous* r = + server->worker_env()->rendezvous_mgr->Find(0); + + auto* device_mgr = server->worker_env()->device_mgr; + *ctx = new TFE_Context(opts->session_options.options, opts->policy, + opts->async, device_mgr, r, std::move(server), + std::move(remote_eager_workers), + std::move(remote_device_mgr), remote_contexts); + + return tensorflow::Status::OK(); +} } // namespace extern "C" { @@ -92,6 +212,15 @@ void TFE_ContextOptionsSetDevicePlacementPolicy( options->policy = policy; } +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status) { + if (!options->server_def.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Invalid tensorflow.ServerDef protocol buffer"); + } +} + TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, unsigned char async, TF_Status* status) { @@ -101,28 +230,35 @@ TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx, void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { + if (!opts->server_def.job_name().empty()) { + TFE_Context* ctx = nullptr; + status->status = NewRemoteAwareTFE_Context(opts, &ctx); + return ctx; + } + std::vector devices; status->status = tensorflow::DeviceFactory::AddDevices( opts->session_options.options, "/job:localhost/replica:0/task:0", &devices); - if (!status->status.ok()) { - return nullptr; - } + if (!status->status.ok()) return nullptr; std::unique_ptr device_mgr( new tensorflow::DeviceMgr(devices)); + tensorflow::Rendezvous* r = new tensorflow::IntraProcessRendezvous(device_mgr.get()); + return new TFE_Context(opts->session_options.options, opts->policy, opts->async, std::move(device_mgr), r); } -void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { - delete ctx; -} +void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { delete ctx; } TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { TF_DeviceList* list = new TF_DeviceList; - ctx->context.device_mgr()->ListDeviceAttributes(&list->response); + ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response); + if (ctx->context.remote_device_mgr()) { + ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response); + } return list; } @@ -220,9 +356,6 @@ TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { } return retval; } -} // extern "C" - -extern "C" { TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { @@ -242,21 +375,18 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, void TFE_DeleteOp(TFE_Op* op) { delete op; } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { - tensorflow::Device* d = nullptr; - if (device_name != nullptr && strlen(device_name) > 0) { - status->status = op->ctx->context.FindDeviceByName(device_name, &d); - } - op->device = d; + status->status = op->operation.SetDevice(device_name); } const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { - tensorflow::Device* device = - (op->device == nullptr) ? op->ctx->context.HostCPU() : op->device; + tensorflow::Device* device = (op->operation.Device() == nullptr) + ? op->operation.EagerContext()->HostCPU() + : op->operation.Device(); return device->name().c_str(); } void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { - op->use_xla = enable; + op->operation.SetUseXla(enable); #ifndef TENSORFLOW_EAGER_USE_XLA LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " "built with XLA support."; @@ -264,22 +394,20 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - h->handle->Ref(); - op->inputs.push_back(h->handle); - op->attrs.NumInputs(op->inputs.size()); + op->operation.AddInput(h->handle); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status) { TF_AttrType ret; - if (op->is_function()) { + if (op->operation.is_function()) { status->status = tensorflow::errors::Unimplemented( "TODO(apassos): Support for attributes for TensorFlow functions is not " "ready yet."); return TF_ATTR_INT; // The compiler requires that we return something. } - status->status = - tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); + status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(), + attr_name, &ret, is_list); return ret; } @@ -298,23 +426,24 @@ TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, } void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { - op->attrs.Set(attr_name, value); + op->operation.MutableAttrs()->Set(attr_name, value); } void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) { - op->attrs.Set(attr_name, static_cast(value)); + op->operation.MutableAttrs()->Set(attr_name, static_cast(value)); } void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) { - op->attrs.Set(attr_name, value); + op->operation.MutableAttrs()->Set(attr_name, value); } void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) { - op->attrs.Set(attr_name, (value == 0) ? false : true); + op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true); } void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) { - op->attrs.Set(attr_name, static_cast(value)); + op->operation.MutableAttrs()->Set(attr_name, + static_cast(value)); } void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, @@ -336,23 +465,24 @@ void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims, proto.add_dim()->set_size(dims[d]); } } - op->attrs.Set(attr_name, proto); + op->operation.MutableAttrs()->Set(attr_name, proto); } void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value) { tensorflow::AttrValue attr_value; tensorflow::NameAttrList* func = attr_value.mutable_func(); - func->set_name(value->name); - value->attrs.FillAttrValueMap(func->mutable_attr()); - op->attrs.Set(attr_name, attr_value); + func->set_name(value->operation.Name()); + value->operation.Attrs().FillAttrValueMap(func->mutable_attr()); + op->operation.MutableAttrs()->Set(attr_name, attr_value); } #define TFE_OP_SET_ATTR_LIST(fn, type) \ void fn(TFE_Op* op, const char* attr_name, const type* values, \ int num_values) { \ - op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice( \ - values, num_values)); \ + op->operation.MutableAttrs()->Set( \ + attr_name, \ + tensorflow::gtl::ArraySlice(values, num_values)); \ } TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*) TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) @@ -360,14 +490,14 @@ TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float) void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, const int64_t* values, int num_values) { - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice( - reinterpret_cast(values), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice( + reinterpret_cast(values), num_values)); } void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, const TF_DataType* values, int num_values) { - op->attrs.Set( + op->operation.MutableAttrs()->Set( attr_name, tensorflow::gtl::ArraySlice( reinterpret_cast(values), num_values)); @@ -379,8 +509,8 @@ void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, for (int i = 0; i < num_values; ++i) { b[i] = values[i]; } - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice(b.get(), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice(b.get(), num_values)); } void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, @@ -410,9 +540,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, } } } - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice( - proto.get(), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice( + proto.get(), num_values)); } void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, @@ -420,534 +550,25 @@ void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, std::unique_ptr funcs( new tensorflow::NameAttrList[num_values]); for (int i = 0; i < num_values; i++) { - funcs[i].set_name(value[i]->name); - value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr()); + funcs[i].set_name(value[i]->operation.Name()); + value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr()); } - op->attrs.Set(attr_name, - tensorflow::gtl::ArraySlice( - funcs.get(), num_values)); + op->operation.MutableAttrs()->Set( + attr_name, tensorflow::gtl::ArraySlice( + funcs.get(), num_values)); } -} // extern "C" - -namespace { - -// Initializes the step stats if needed. -void MaybeInitializeStepStats(tensorflow::StepStats* step_stats, - tensorflow::EagerContext* ctx) { - // Lazily initialize the RunMetadata with information about all devices if - // this is the first call. - while (step_stats->dev_stats_size() < ctx->devices()->size()) { - int device_idx = step_stats->dev_stats_size(); - auto* dev_stats = step_stats->add_dev_stats(); - dev_stats->set_device(ctx->devices()->at(device_idx)->name()); - } -} - -int StepStatsDeviceIndex(tensorflow::StepStats* step_stats, - tensorflow::EagerContext* ctx, - tensorflow::Device* device) { - // Find the current device's index. - if (device == nullptr) { - device = ctx->HostCPU(); - } - for (int i = 0; i < ctx->devices()->size(); ++i) { - if (ctx->devices()->at(i) == device || - ctx->devices()->at(i)->name() == device->name()) { - return i; - } - } - // TODO(apassos) do not fall back to host CPU if device is unknown. - return 0; -} - -tensorflow::Status ValidateInputTypeAndPlacement( - tensorflow::EagerContext* ctx, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel, tensorflow::RunMetadata* run_metadata) { - tensorflow::Device* host_device = ctx->HostCPU(); - const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); - if (memtypes.size() != op->inputs.size()) { - return tensorflow::errors::InvalidArgument( - "expected ", memtypes.size(), " inputs, got ", op->inputs.size()); - } - for (int i = 0; i < op->inputs.size(); ++i) { - const tensorflow::Device* expected_device = - memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device; - tensorflow::TensorHandle* handle = op->inputs[i]; - tensorflow::Device* handle_device = nullptr; - TF_RETURN_IF_ERROR(handle->Device(&handle_device)); - const tensorflow::Device* actual_device = - handle_device == nullptr ? host_device : handle_device; - if (expected_device != actual_device) { - switch (ctx->GetDevicePlacementPolicy()) { - case tensorflow::DEVICE_PLACEMENT_SILENT_FOR_INT32: - // TODO(xpan): See if we could bubble python related error up - // to python level. - if (handle->dtype == tensorflow::DT_INT32) { - // Note: enabling silent copies of int32 tensors to match behavior - // of graph mode. - break; - } - TF_FALLTHROUGH_INTENDED; - case tensorflow::DEVICE_PLACEMENT_EXPLICIT: - return tensorflow::errors::InvalidArgument( - "Tensors on conflicting devices:" - " cannot compute ", - op->name, " as input #", i, " was expected to be on ", - expected_device->name(), " but is actually on ", - actual_device->name(), " (operation running on ", - op_device->name(), ")", - " Tensors can be copied explicitly using .gpu() or .cpu() " - "methods," - " or transparently copied by using tf.enable_eager_execution(" - "device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors " - "between devices" - " may slow down your model"); - case tensorflow::DEVICE_PLACEMENT_WARN: - LOG(WARNING) << "before computing " << op->name << " input #" << i - << " was expected to be on " << expected_device->name() - << " but is actually on " << actual_device->name() - << " (operation running on " << op_device->name() - << "). This triggers a copy which can be a performance " - "bottleneck."; - break; - case tensorflow::DEVICE_PLACEMENT_SILENT: // Do nothing. - break; - } - // We are only here if the policy is warn or silent copies, so we should - // trigger a copy. - auto pre_time = tensorflow::Env::Default()->NowMicros(); - tensorflow::TensorHandle* copied_tensor = nullptr; - tensorflow::Status status = tensorflow::EagerCopyToDevice( - handle, ctx, expected_device->name().c_str(), &copied_tensor); - if (run_metadata != nullptr) { - auto* step_stats = run_metadata->mutable_step_stats(); - MaybeInitializeStepStats(step_stats, ctx); - // Record the sending on the source device for now. - int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device); - auto* dev_stats = step_stats->mutable_dev_stats(device_idx); - auto* node_stats = dev_stats->add_node_stats(); - node_stats->set_node_name("_Send"); - node_stats->set_all_start_micros(pre_time); - node_stats->set_op_end_rel_micros( - tensorflow::Env::Default()->NowMicros() - pre_time); - } - if (!status.ok()) { - if (copied_tensor != nullptr) copied_tensor->Unref(); - return tensorflow::errors::Internal( - "Failed copying input tensor from ", actual_device->name(), " to ", - expected_device->name(), " in order to run ", op->name, ": ", - status.error_message()); - } - handle->Unref(); - handle = copied_tensor; - op->inputs[i] = copied_tensor; - } - if (handle->dtype != kernel->input_type(i)) { - return tensorflow::errors::InvalidArgument( - "cannot compute ", op->name, " as input #", i, - " was expected to be a ", - tensorflow::DataTypeString(kernel->input_type(i)), - " tensor but is a ", tensorflow::DataTypeString(handle->dtype), - " tensor"); - } - } - return tensorflow::Status::OK(); -} - -tensorflow::Device* SelectDevice(const tensorflow::NodeDef& ndef, - TFE_Context* ctx, TF_Status* status) { - tensorflow::DeviceSet ds; - for (tensorflow::Device* d : *ctx->context.devices()) { - ds.AddDevice(d); - } - tensorflow::DeviceTypeVector final_devices; - status->status = tensorflow::SupportedDeviceTypesForNode( - ds.PrioritizedDeviceTypeList(), ndef, &final_devices); - if (!status->status.ok()) { - return nullptr; - } - if (final_devices.empty()) { - status->status = tensorflow::errors::Internal( - "Could not find valid device for node ", ndef.DebugString()); - return nullptr; - } - for (tensorflow::Device* d : *ctx->context.devices()) { - if (d->device_type() == final_devices[0].type_string()) { - return d; - } - } - status->status = tensorflow::errors::Unknown( - "Could not find a device for node ", ndef.DebugString()); - return nullptr; -} - - -#ifdef TENSORFLOW_EAGER_USE_XLA -// Synthesizes and returns a wrapper function over `op`, which must be a -// primitive op (e.g. matmul). -// -// The wrapper function conforms to the function signature expected by -// _XlaLaunchOp, with input params ordered by . For example, if the op has input params , they will be reordered to as the input params to the synthesized function. -// -// It populates `const_input_types`, `arg_input_types` and -// `op_input_to_func_input` based on the reordering results, that the caller can -// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets -// `status` accordingly. -const tensorflow::FunctionDef* OpToFunction( - TFE_Op* op, std::vector* const_input_types, - std::vector* arg_input_types, - tensorflow::gtl::FlatMap* op_input_to_func_input, - TF_Status* status) { - DCHECK(!op->is_function()); - - tensorflow::FunctionDef fdef; - - // Get the OpDef of the op we are trying to encapsulate. - TFE_Context* ctx = op->ctx; - const tensorflow::OpRegistrationData* op_data; - { - status->status = ctx->context.FindFunctionOpData(op->name, &op_data); - if (!status->status.ok()) { - return nullptr; - } - } - const tensorflow::OpDef& op_def = op_data->op_def; - - tensorflow::OpDef* signature = fdef.mutable_signature(); - - // Handle constant inputs. - const std::unordered_set const_inputs( - *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); - - // First add place holders for the input args, so that we can refer to them by - // position in the next loop. Also tally up the resource inputs. - int num_resource_inputs = 0; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { - ++num_resource_inputs; - } - signature->add_input_arg(); - } - - // Now we map the input params from `op_def` to `signature`, where the param - // ordering for `signature` is: . - int const_index = 0; - int arg_index = const_inputs.size(); - int resource_index = op_def.input_arg_size() - num_resource_inputs; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); - tensorflow::OpDef::ArgDef* func_input_arg = nullptr; - if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { - VLOG(1) << "For const input, mapping op input " << i << " to func input " - << const_index; - (*op_input_to_func_input)[i] = const_index; - func_input_arg = signature->mutable_input_arg(const_index++); - const_input_types->push_back( - static_cast(op->inputs[i]->dtype)); - } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { - VLOG(1) << "For resource input, mapping op input " << i - << " to func input " << resource_index; - (*op_input_to_func_input)[i] = resource_index; - func_input_arg = signature->mutable_input_arg(resource_index++); - } else { - VLOG(1) << "For arg input, mapping op input " << i << " to func input " - << arg_index; - (*op_input_to_func_input)[i] = arg_index; - func_input_arg = signature->mutable_input_arg(arg_index++); - arg_input_types->push_back( - static_cast(op->inputs[i]->dtype)); - } - - func_input_arg->set_name(op_input_arg.name()); - func_input_arg->set_type(op->inputs[i]->dtype); - } - VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); - - // Resources args are at the end of the function input params, and we should - // have iterated over all of them. - DCHECK_EQ(signature->input_arg_size(), resource_index); - - // Make the synthesized function's name unique. - signature->set_name(tensorflow::strings::StrCat( - op_def.name(), func_id_generator.fetch_add(1))); - - // Add the node def and set its input names to match op_def's names. - const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); - *fdef.add_node_def() = ndef; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); - } - VLOG(1) << "Added NodeDef: " << fdef.DebugString(); - - // Fix the output names and set output types. - for (int i = 0; i < op_def.output_arg_size(); ++i) { - tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); - const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); - const string& out_tensor_name = tensorflow::strings::StrCat( - ndef.name(), ":", op_def_arg.name(), ":", 0); - arg->set_name(op_def_arg.name()); - (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; - const string& type_attr = op_def_arg.type_attr(); - if (!type_attr.empty()) { - auto i = ndef.attr().find(type_attr); - if (i == ndef.attr().end()) { - status->status = tensorflow::errors::InvalidArgument( - tensorflow::strings::StrCat("Could not find attr ", type_attr, - " in NodeDef ", ndef.DebugString())); - return nullptr; - } - arg->set_type(i->second.type()); - } - } - VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); - - status->status = ctx->context.AddFunctionDef(fdef); - if (!status->status.ok()) return nullptr; - const auto ret = ctx->context.FindFunctionDef(signature->name()); - DCHECK(ret != nullptr); - return ret; -} - -// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed -// via XLA. -std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { - VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; - auto launch_op = - std::unique_ptr(TFE_NewOp(op->ctx, "_XlaLaunch", status)); - if (TF_GetCode(status) != TF_OK) return nullptr; - if (op->device) { - TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); - if (TF_GetCode(status) != TF_OK) return nullptr; - } - - const tensorflow::FunctionDef* fdef; - { - fdef = op->ctx->context.FindFunctionDef(op->name); - } - std::vector const_input_types; - std::vector arg_input_types; - tensorflow::gtl::FlatMap op_input_to_func_input; - if (fdef == nullptr) { - // See if this is a primitive op, and if so create a function for it, so - // that _XlaLaunchOp can access it. - fdef = OpToFunction(op, &const_input_types, &arg_input_types, - &op_input_to_func_input, status); - if (!status->status.ok()) return nullptr; - } else { - // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for - // functions, so we need to find another way to handle constant inputs. - for (int i = const_input_types.size(); - i < fdef->signature().input_arg_size(); ++i) { - VLOG(1) << "Adding Targs from input arg " << i; - const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); - arg_input_types.push_back(static_cast(arg.type())); - } - } - DCHECK(fdef != nullptr); - - // Copy inputs and their devices. - // Since input param reordering may have occurred between `op` and `launch_op` - // via `op_input_to_func_input`, adjust the actual inputs accordingly. - launch_op->inputs = op->inputs; - for (tensorflow::TensorHandle* h : launch_op->inputs) { - h->Ref(); - } - if (!op_input_to_func_input.empty()) { - DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); - for (int i = 0; i < op_input_to_func_input.size(); ++i) { - VLOG(1) << "mapping op input " << i << " to func input " - << op_input_to_func_input[i]; - - launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; - } - } - launch_op->attrs.NumInputs(op->inputs.size()); - - TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(), - const_input_types.size()); - - // Set Targs and Nresources attrs. - TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(), - arg_input_types.size()); - const int num_resource_inputs = fdef->signature().input_arg_size() - - const_input_types.size() - - arg_input_types.size(); - TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs); - - // Set Tresults attr. - std::vector tresults; - for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { - tresults.push_back(static_cast(arg.type())); - } - TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(), - tresults.size()); - - // Set function attr. - tensorflow::AttrValue attr_value; - tensorflow::NameAttrList* func = attr_value.mutable_func(); - func->set_name(fdef->signature().name()); - launch_op->attrs.Set("function", attr_value); - - return launch_op; -} -#endif // TENSORFLOW_EAGER_USE_XLA - -} // namespace - -extern "C" { void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - TFE_Context* ctx = op->ctx; - status->status = ctx->context.GetStatus(); + tensorflow::gtl::InlinedVector handle_retvals( + *num_retvals); + status->status = + tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals); if (!status->status.ok()) { return; } -#ifdef TENSORFLOW_EAGER_USE_XLA - std::unique_ptr xla_launch_op; - if (op->use_xla && op->name != "_XlaLaunch") { - xla_launch_op = BuildXlaLaunch(op, status); - if (!status->status.ok()) { - return; - } - op = xla_launch_op.get(); - } -#endif // TENSORFLOW_EAGER_USE_XLA - // Ensure all resource-touching ops run in the device the resource is, - // regardless of anything else that has been specified. This is identical to - // the graph mode behavior. - for (int i = 0; i < op->inputs.size(); ++i) { - tensorflow::Device* input_op_device = nullptr; - status->status = op->inputs[i]->OpDevice(&input_op_device); - if (!status->status.ok()) return; - VLOG(2) << "for op " << op->name << " input " << i << " " - << tensorflow::DataTypeString(op->inputs[i]->dtype) << " " - << (input_op_device == nullptr ? "cpu" : input_op_device->name()) - << " " << (op->device == nullptr ? "cpu" : op->device->name()); - if (op->inputs[i]->dtype == tensorflow::DT_RESOURCE && - (input_op_device != op->device || input_op_device == nullptr)) { - tensorflow::Device* d = - input_op_device == nullptr ? ctx->context.HostCPU() : input_op_device; - VLOG(1) << "Changing device of operation " << op->name << " to " - << d->name() << " because input #" << i - << " is a resource in this device."; - op->device = d; - } - } - tensorflow::Device* device = op->device; - - tensorflow::Fprint128 cache_key = - op->attrs.CacheKey(device == nullptr ? "unspecified" : device->name()); - tensorflow::KernelAndDevice* kernel = ctx->context.GetCachedKernel(cache_key); - if (kernel == nullptr) { - const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); - if (device == nullptr) { - device = SelectDevice(ndef, ctx, status); - if (!status->status.ok()) { - return; - } - } - CHECK(device != nullptr); - if (ctx->context.LogDevicePlacement()) { - LOG(INFO) << "Executing op " << ndef.op() << " in device " - << device->name(); - } - kernel = new tensorflow::KernelAndDevice(ctx->context.GetRendezvous()); - // Knowledge of the implementation of Init (and in-turn - // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def - // will be accessed, so grab on to the lock. - // See WARNING comment in Execute (before kernel->Run) - would be nice to - // rework to avoid this subtlety. - tensorflow::tf_shared_lock l(*ctx->context.FunctionsMu()); - status->status = tensorflow::KernelAndDevice::Init( - ndef, ctx->context.func_lib(device), kernel); - if (!status->status.ok()) { - delete kernel; - return; - } - // Update output_dtypes inside `kernel`. - const tensorflow::OpDef* op_def = nullptr; - const tensorflow::FunctionDef* function_def = - ctx->context.FuncLibDef()->Find(ndef.op()); - if (function_def != nullptr) { - op_def = &(function_def->signature()); - } - if (op_def == nullptr) { - status->status = OpDefForOp(ndef.op().c_str(), &op_def); - if (!status->status.ok()) { - return; - } - } - tensorflow::DataTypeVector input_dtypes; - status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, - kernel->mutable_output_dtypes()); - if (!status->status.ok()) { - return; - } - ctx->context.AddKernelToCache(cache_key, kernel); - } - const tensorflow::DataTypeVector& output_dtypes = kernel->output_dtypes(); - const int output_dtypes_size = output_dtypes.size(); - if (output_dtypes_size > *num_retvals) { - TF_SetStatus(status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat("Expecting ", output_dtypes.size(), - " outputs, but *num_retvals is ", - *num_retvals) - .c_str()); - return; - } - *num_retvals = output_dtypes_size; - if (device == nullptr) { - // TODO(apassos) debug how the assignment below might return a different - // device from the one requested above. - device = kernel->device(); - } - status->status = ValidateInputTypeAndPlacement( - &ctx->context, device, op, kernel->kernel(), - ctx->context.ShouldStoreMetadata() ? ctx->context.RunMetadataProto() - : nullptr); - if (!status->status.ok()) return; - std::unique_ptr maybe_stats; - if (ctx->context.ShouldStoreMetadata()) { - maybe_stats.reset(new tensorflow::NodeExecStats); - maybe_stats->set_node_name(op->name); - maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); - maybe_stats->set_op_start_rel_micros(0); - maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); - // TODO(apassos) track referenced tensors - } - if (ctx->context.Async()) { - // Note that for async mode, execution order will make sure that all - // input handles are ready before executing them. - // TODO(agarwal): Consider executing "cheap" kernels inline for performance. - tensorflow::gtl::InlinedVector handle_retvals( - *num_retvals); - tensorflow::uint64 id = op->ctx->context.NextId(); - for (int i = 0; i < *num_retvals; ++i) { - tensorflow::TensorHandle* h = - new tensorflow::TensorHandle(id, output_dtypes[i], &op->ctx->context); - retvals[i] = new TFE_TensorHandle(h); - handle_retvals[i] = h; - } - tensorflow::EagerNode* node = new tensorflow::ExecuteNode( - id, &op->ctx->context, op->device, op->inputs, kernel, - maybe_stats.release(), output_dtypes, handle_retvals); - ctx->context.ExecutorAdd(node); - } else { - // Execute checks if retvals[i] is nullptr or not to figure if it needs to - // allocate it. - std::vector handle_retvals(*num_retvals, - nullptr); - status->status = tensorflow::EagerExecute( - &op->ctx->context, op->device, op->inputs, kernel, maybe_stats.get(), - handle_retvals.data(), *num_retvals); - for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = new TFE_TensorHandle(handle_retvals[i]); - } + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = new TFE_TensorHandle(handle_retvals[i]); } } @@ -1090,10 +711,3 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } } } // namespace tensorflow - - -TFE_Op::~TFE_Op() { - for (tensorflow::TensorHandle* h : inputs) { - h->Unref(); - } -} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index c06ce84a8c5..574a097e0d6 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -81,6 +81,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetAsync(TFE_ContextOptions*, TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); +// A tensorflow.ServerDef specifies remote workers (in addition to the current +// workers name). Operations created on this context can then be executed on +// any of these remote workers by setting an appropriate device. +// +// If the following is set, all servers identified by the +// ServerDef must be up when the context is created. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetServerDef( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + // Destroy an options object. TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 05dc64f5217..2b8384d7203 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -28,14 +28,23 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/remote_device.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" +#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" +#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -45,12 +54,12 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/public/version.h" - struct TFE_ContextOptions { TF_SessionOptions session_options; // true if async execution is enabled. bool async = false; TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_SILENT}; + tensorflow::ServerDef server_def; }; struct TFE_Context { @@ -64,6 +73,23 @@ struct TFE_Context { default_policy), async, std::move(device_mgr), rendezvous) {} + explicit TFE_Context( + const tensorflow::SessionOptions& opts, + TFE_ContextDevicePlacementPolicy default_policy, bool async, + tensorflow::DeviceMgr* local_device_mgr, + tensorflow::Rendezvous* rendezvous, + std::unique_ptr server, + std::unique_ptr remote_eager_workers, + std::unique_ptr remote_device_mgr, + const tensorflow::gtl::FlatMap& + remote_contexts) + : context(opts, + static_cast( + default_policy), + async, local_device_mgr, rendezvous, std::move(server), + std::move(remote_eager_workers), std::move(remote_device_mgr), + remote_contexts) {} + tensorflow::EagerContext context; }; @@ -85,19 +111,9 @@ struct TFE_Op { // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a // primitive operation. TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + : operation(&ctx->context, op, t) {} - ~TFE_Op(); - - bool const is_function() const { return attr_types == nullptr; } - - TFE_Context* ctx; // Must outlive the TFE_Op. - const tensorflow::string name; - tensorflow::AttrBuilder attrs; - const tensorflow::AttrTypeMap* attr_types; - tensorflow::gtl::InlinedVector inputs; - tensorflow::Device* device; - bool use_xla = false; + tensorflow::EagerOperation operation; }; namespace tensorflow { diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 701175e4943..49646bb7359 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -23,7 +24,9 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tensorflow_server.pb.h" using tensorflow::string; @@ -220,6 +223,103 @@ TEST(CAPI, Context) { TF_DeleteStatus(status); } +tensorflow::ServerDef GetServerDef(int num_tasks) { + tensorflow::ServerDef server_def; + server_def.set_protocol("grpc"); + server_def.set_job_name("localhost"); + server_def.set_task_index(0); + tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); + tensorflow::JobDef* job_def = cluster_def->add_job(); + job_def->set_name("localhost"); + for (int i = 0; i < num_tasks; i++) { + int port = tensorflow::testing::PickUnusedPortOrDie(); + job_def->mutable_tasks()->insert( + {i, tensorflow::strings::StrCat("localhost:", port)}); + } + return server_def; +} + +void TestRemoteExecute(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + auto* h0_task1 = + TFE_TensorHandleCopyToDevice(h0_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + auto* h1_task1 = + TFE_TensorHandleCopyToDevice(h1_task0, ctx, remote_device_name, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_Op* matmul = MatMulOp(ctx, h0_task1, h1_task1); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(h0_task1); + TFE_DeleteTensorHandle(h1_task1); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } +TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 97c323b8722..dcc2357b71a 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -130,13 +130,15 @@ class GradientTape { } } - bool ShouldRecord(gtl::ArraySlice tensor_ids); + bool ShouldRecord(gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes); void Watch(int64 tensor_id); void RecordOperation(const string& op_type, gtl::ArraySlice output_tensors, gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, const std::function& backward_function_deleter); @@ -170,12 +172,32 @@ class GradientTape { // Template instantiations here +inline bool IsDtypeTrainable(DataType dtype) { + switch (dtype) { + case DT_HALF: + case DT_BFLOAT16: + case DT_FLOAT: + case DT_DOUBLE: + case DT_COMPLEX64: + case DT_COMPLEX128: + case DT_RESOURCE: + case DT_VARIANT: + return true; + default: + return false; + } +} + template bool GradientTape::ShouldRecord( - gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; + gtl::ArraySlice tensor_ids, + gtl::ArraySlice dtypes) { + CHECK_EQ(tensor_ids.size(), dtypes.size()); + for (int i = 0; i < tensor_ids.size(); ++i) { + if (tensor_tape_.find(tensor_ids[i]) != tensor_tape_.end()) { + if (IsDtypeTrainable(dtypes[i])) { + return true; + } } } return false; @@ -189,9 +211,11 @@ void GradientTape::Watch(int64 tensor_id) { template void GradientTape::RecordOperation( const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { + if (!ShouldRecord(input_tensor_id, input_dtypes)) { backward_function_deleter(); return; } @@ -380,49 +404,39 @@ Status InitialGradients(const VSpace& vspace, gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, - const gtl::FlatMap& tensor_usage_counts, gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; - if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { - if (!output_gradients.empty() && output_gradients[i] != nullptr) { - // TODO(apassos) figure out how to print debugging information here. - return errors::InvalidArgument( - "A gradient was provided for a tensor which is used as part of the " - "computation."); - } - } else { - if (output_gradients.empty() || output_gradients[i] == nullptr) { - auto tensor_it = tensor_tape.find(id); - if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { - auto op_it = op_tape.find(tensor_it->second); - if (op_it == op_tape.end()) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "failed to find operation producing a tensor"); + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; } - bool found = false; - for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { - if (op_it->second.output_tensor_info[j].id == id) { - found = true; - (*result)[id].push_back( - vspace.Ones(op_it->second.output_tensor_info[j].shape, - op_it->second.output_tensor_info[j].dtype)); - break; - } - } - if (!found) { - return errors::Internal( - "Internal state of the gradient tape is invalid: " - "none of operations outputs match expected tensor"); - } - } else { - // No record of the target tensor found on the tape, so no gradient - // needs to be computed from it. Do nothing. + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); } } else { - (*result)[id].push_back(output_gradients[i]); + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. } + } else { + (*result)[id].push_back(output_gradients[i]); } } return Status::OK(); @@ -451,8 +465,7 @@ Status GradientTape::ComputeGradient( InitialStack(state.op_tape, state.op_missing_tensor); gtl::FlatMap> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, - tensor_tape_, state.op_tape, - state.tensor_usage_counts, &gradients); + tensor_tape_, state.op_tape, &gradients); auto cleanup = [this, &state]() { if (!persistent_) { // Release all backprop functions diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 93155998b86..e18fdf6c57b 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) { session->extend_before_run = false; } -std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { +std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { Node* node = &output.oper->node; CppShapeInferenceResult::HandleData handle_data; handle_data.set_is_set(true); @@ -135,4 +135,30 @@ std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) { return result; } +void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (status->status.ok()) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index 2d4c8cd9ed7..4bcb5bde62c 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -55,9 +55,15 @@ void ExtendSession(TF_Session* session, TF_Status* status); // Returns the serialized CppShapeInferenceResult::HandleData proto for // `output` if its a resource tensor, or otherwise returns the empty string. -// TODO(b/74620627): remove when _USE_C_SHAPES is removed -std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); +std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output); +// Sets `output` based on `proto`, which should be a serialized +// CppShapeInferenceResult::HandleData proto. +// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string +// because I couldn't get SWIG to work otherwise. +void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output, + const void* proto, size_t proto_len, + TF_Status* status); } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d73121c7b70..d6a4f141b6b 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -440,7 +440,7 @@ string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } - return name.ToString(); + return std::string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index c143b978338..62a889181e7 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -220,7 +220,7 @@ std::unordered_set Scope::Impl::GetColocationConstraints( for (const string& entry : node_constraints) { StringPiece s(entry); if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) { - current_constraints.insert(s.ToString()); + current_constraints.insert(std::string(s)); } } } else { diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 6545e4ee3eb..ff348fadb24 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -385,6 +385,42 @@ Status MirrorPadGradGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad); +Status StridedSliceGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Input x = Shape(scope, op.input(0)); + Input begin = op.input(1); + Input end = op.input(2); + Input strides = op.input(3); + int64 begin_mask; + int64 end_mask; + int64 ellipsis_mask; + int64 new_axis_mask; + int64 shrink_axis_mask; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask)); + TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask)); + grad_outputs->push_back( + StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0], + StridedSliceGrad::BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask) + .NewAxisMask(new_axis_mask) + .ShrinkAxisMask(shrink_axis_mask))); + // No gradients returned for begin, end and strides + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 4a215fcc929..de3bd0fc9e2 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -354,5 +354,29 @@ TEST_F(ArrayGradTest, MirrorPadGradGrad_Symmetric) { RunTest(x, x_shape, y, y_shape); } +TEST_F(ArrayGradTest, StridedSliceGrad) { + TensorShape x_shape({6, 4, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + + // y = x[2:6:2, 1:3, 1:3] + auto y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1}); + // y.shape = [2, 2, 2]; + RunTest(x, x_shape, y, {2, 2, 2}); + + // y = x[2:6:2, 1:3, 1:3] + // begin_mask = 1<<1 (ignore begin_index = 1) + // end_mask = 1<<2 (ignore end_index = 2) + y = StridedSlice(scope_, x, {2, 1, 1}, {6, 3, 3}, {2, 1, 1}, + StridedSlice::BeginMask(1 << 1).EndMask(1 << 2)); + // y.shape = [2, 3, 3]; + RunTest(x, x_shape, y, {2, 3, 3}); + + // y = [tf.newaxis, 2:6:2, 1:3, 1:3] + y = StridedSlice(scope_, x, {0, 2, 1, 1}, {0, 6, 3, 3}, {1, 2, 1, 1}, + StridedSlice::NewAxisMask(1 << 0)); + // y.shape = [1, 2, 2, 2]; + RunTest(x, x_shape, y, {1, 2, 2, 2}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1b4c7c26880..fd7b6fe6625 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -31,7 +31,6 @@ using ops::AddN; using ops::BatchMatMul; using ops::Const; using ops::Div; -using ops::Greater; using ops::MatMul; using ops::Max; using ops::Maximum; @@ -46,7 +45,6 @@ using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; using ops::Sum; -using ops::Where3; // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 4ddddcb5863..23e9dc40d23 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/cc/tools/freeze_saved_model.h" +#include #include #include "tensorflow/core/framework/attr_value.pb.h" @@ -71,6 +72,15 @@ void GetNodeNameToNodeDefMap( } } +// Strips off the tensor part of the tensor_name to get the node_name. +const string GetNodeNameFromTensorName(string tensor_name) { + if (tensor_name[0] == '^') { + tensor_name.erase(0, 1); + } + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + return tensor_name_parts[0]; +} + // Gets the set of node names needed by `outputs` and the corresponding set of // variable nodes to convert. void GetReachableNodesAndVariables( @@ -83,10 +93,8 @@ void GetReachableNodesAndVariables( new std::unordered_set({"Variable", "VariableV2", "VarHandleOp"}); std::queue nodes_to_visit; - for (const string& tensor_name : outputs) { - // We need to strip off the tensor part to get the node name. - std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); - nodes_to_visit.push(tensor_name_parts[0]); + for (const string& output_tensor_name : outputs) { + nodes_to_visit.push(GetNodeNameFromTensorName(output_tensor_name)); } // We do a traversal backwards from the outputs specified in the MetaGraphDef. while (!nodes_to_visit.empty()) { @@ -100,8 +108,8 @@ void GetReachableNodesAndVariables( if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { variable_node_names->insert(node->name()); } - for (const string& input : node->input()) { - nodes_to_visit.push(input); + for (const string& input_tensor_name : node->input()) { + nodes_to_visit.push(GetNodeNameFromTensorName(input_tensor_name)); } } } diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index cd35fd3b95d..979b23c3fc5 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -351,6 +351,56 @@ TEST_F(FreezeTest, GraphDefWithNoVariables) { GraphDefEqual(frozen_graph_def, graph_def); } +TEST_F(FreezeTest, GraphDefWithMultiOutputOperation) { + // Tensors from operations with multiple outputs get tensor suffixes when used + // in input fields of following nodes, i.e. split:0, split:1. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), {10.0f, 10.0f}, {2}); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + OutputList split = ops::Split(scope.WithOpName("split"), axis, a, 2).output; + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), split[1], b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithControlDependency) { + // Inputs that are control dependencies get tensor prefixes, + // i.e. ^control_dependency. + // Test that we traverse those correctly. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output source = ops::Const(scope.WithOpName("source"), 10.0f, {}); + Output a = ops::Const(scope.WithOpName("a").WithControlDependencies(source), + {10.0f, 10.0f}, {2}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + TEST_F(FreezeTest, GraphDefWithoutDependentVariables) { TestFreezeGraphWithoutDependentVariables(false); } diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 19e6bf68e77..2119c8ec47f 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -214,7 +214,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:support", "@llvm//:target", ], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2cae85e8965..0025842aead 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -333,6 +333,20 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" : ""; + const string include_hlo_profile_printer_data_proto = + opts.gen_hlo_profile_printer_data + ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")" + : ""; + + // When HLO profiling is disabled we only forward declare the + // HloProfilePrinter protobuf. So we can only conditionally emit this code + // calling HloProfilePrinter::profile_counters_size. + const string assign_profile_counters_size = + opts.gen_hlo_profile_printer_data + ? "data->profile_counters_size = " + "data->hlo_profile_printer_data->profile_counters_size();" + : ""; + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -348,6 +362,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) {{INCLUDE_XLA_DATA_PROTO}} +{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}} #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" @@ -418,6 +433,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); + data->hlo_profile_printer_data = StaticHloProfilePrinterData(); + {{ASSIGN_PROFILE_COUNTERS_SIZE}} return data; }(); return *kStaticData; @@ -487,6 +504,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } + + // Metadata that can be used to pretty-print profile counters. + static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() { + static const xla::HloProfilePrinterData* kHloProfilePrinterData = + {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}; + return kHloProfilePrinterData; + } }; {{NS_END}} @@ -501,35 +525,41 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, + {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, + {"{{DECLS_FROM_OBJ_FILE}}", + str_util::Join(metadata_result.header_variable_decls, "\n")}, {"{{ENTRY}}", compile_result.entry_point}, + {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", + metadata_result.hlo_profile_printer_data_access_shim}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, + {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}", + include_hlo_profile_printer_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", + metadata_result.program_shape_access_shim}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, - {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}, - {"{{DECLS_FROM_OBJ_FILE}}", - str_util::Join(metadata_result.header_variable_decls, "\n")}, - {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", - metadata_result.program_shape_access_shim}}; + {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } -static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) { +static string CreateUniqueIdentifier(const CodegenOpts& opts, + StringPiece suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { strings::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape"); + strings::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -550,18 +580,31 @@ Status GenerateMetadata(const CodegenOpts& opts, // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives // a shim that evaluates to nullptr, which is what we want. + ProtobufToEmbed program_shape_protobuf{ + CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", + program_shape.get()}; + + ProtobufToEmbed hlo_profile_printer_data_protobuf{ + CreateUniqueIdentifier(opts, "HloProfilePrinterData"), + "xla::HloProfilePrinterData", + compile_result.aot->hlo_profile_printer_data()}; + TF_ASSIGN_OR_RETURN( - EmbeddedProtocolBuffer embedded_program_shape, - CreateEmbeddedProtocolBuffer(opts.target_triple, - CreateUniqueIdentifierForProgramShape(opts), - "xla::ProgramShape", program_shape.get())); + EmbeddedProtocolBuffers embedded_protobufs, + CreateEmbeddedProtocolBuffers( + opts.target_triple, + {program_shape_protobuf, hlo_profile_printer_data_protobuf})); metadata_result->program_shape_access_shim = - std::move(embedded_program_shape.cpp_shim_expression); + std::move(embedded_protobufs.cpp_shims[0].expression); + metadata_result->hlo_profile_printer_data_access_shim = + std::move(embedded_protobufs.cpp_shims[1].expression); metadata_result->header_variable_decls.emplace_back( - std::move(embedded_program_shape.cpp_variable_decl)); + std::move(embedded_protobufs.cpp_shims[0].variable_decl)); + metadata_result->header_variable_decls.emplace_back( + std::move(embedded_protobufs.cpp_shims[1].variable_decl)); metadata_result->object_file_data = - std::move(embedded_program_shape.object_file_data); + std::move(embedded_protobufs.object_file_data); return Status::OK(); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 3430b1f96cf..83f2d3ee11d 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -44,6 +44,10 @@ struct CodegenOpts { // If true, generate program shape data for the ProgramShape method. bool gen_program_shape = false; + + // If true, emit a serialized HloProfilePrinterData protobuf that can be used + // to pretty print HLO profile counters. + bool gen_hlo_profile_printer_data = false; }; // Describes a generated metadata object file. @@ -57,6 +61,12 @@ struct MetadataResult { // GenerateMetadata. string program_shape_access_shim; + // hlo_profile_printer_data_access_shim is a C++ expression that constructs + // the xla::HloProfilePrinterData instance for the CompileResult passed to + // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a + // C++ expression that evaluates to nullptr at runtime. + string hlo_profile_printer_data_access_shim; + // The contents of the object (".o") file. string object_file_data; }; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 2642536c4f6..29bc9c13b88 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -172,7 +172,7 @@ TEST(CodegenTest, Golden) { fetch->set_name("myfetch"); CompileResult compile_result; compile_result.aot.reset( - new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5)); + new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index ac3b5873318..6641d45e830 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -10,6 +10,7 @@ #define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) #include "tensorflow/compiler/xla/xla_data.pb.h" + #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" @@ -23,6 +24,7 @@ extern "C" void entry_point( extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; + namespace foo { namespace bar { @@ -54,9 +56,9 @@ namespace bar { // // Memory stats: // arg bytes total: 104 -// arg bytes aligned: 128 +// arg bytes aligned: 192 // temp bytes total: 126 -// temp bytes aligned: 224 +// temp bytes aligned: 320 class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. @@ -82,6 +84,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); + data->hlo_profile_printer_data = StaticHloProfilePrinterData(); + return data; }(); return *kStaticData; @@ -243,6 +247,13 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { }(); return kShape; } + + // Metadata that can be used to pretty-print profile counters. + static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() { + static const xla::HloProfilePrinterData* kHloProfilePrinterData = + nullptr; + return kHloProfilePrinterData; + } }; } // end namespace bar diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 7c833878818..bbc35da2ef6 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -44,7 +44,7 @@ namespace { // Compiles the XLA computation into executable code. Status CompileXla(xla::CompileOnlyClient* client, - const xla::Computation& computation, + const xla::XlaComputation& computation, const xla::cpu::CpuAotCompilationOptions& aot_opts, CompileResult* compile_result) { // Retrieves arg and result layouts from the computation. @@ -62,7 +62,7 @@ Status CompileXla(xla::CompileOnlyClient* client, for (int i = 0; i < pshape->parameters_size(); ++i) { arg_layouts.push_back(pshape->mutable_parameters(i)); } - xla::CompileOnlyClient::AotComputationInstance instance; + xla::CompileOnlyClient::AotXlaComputationInstance instance; instance.computation = &computation; instance.argument_layouts = std::move(arg_layouts); instance.result_layout = &pshape->result(); @@ -88,20 +88,19 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, // Converts the graph into an XLA computation, and compiles the // computation. // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? - namespace gpu = perftools::gputools; - gpu::Platform* cpu_platform = - gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); + se::Platform* cpu_platform = + se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); xla::CompileOnlyClient* client = xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR( ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr module, + TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); - // Serialize the SessionModule deterministically so that all the outputs of - // a tf_library genrule are deterministic. + // Serialize the HloSnapshot deterministically so that all the outputs of a + // tf_library genrule are deterministic. string proto; TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); TF_RETURN_IF_ERROR( @@ -111,6 +110,7 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, flags.target_triple, flags.target_cpu, flags.target_features, flags.entry_point, xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic); + return CompileXla(client, computation, aot_opts, compile_result); } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index 6489929a576..63d22de1ca4 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" -#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" @@ -37,9 +36,8 @@ namespace tfcompile { using xla::llvm_ir::AsStringRef; -static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( - llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine, - const ::tensorflow::protobuf::MessageLite& proto, +static void AddEmbeddedProtocolBufferToLlvmModule( + llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, StringPiece unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); @@ -47,19 +45,14 @@ static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( strings::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); - std::unique_ptr module = - MakeUnique("embedded_data_module", *llvm_context); - llvm::Constant* protobuf_array_initializer = - llvm::ConstantDataArray::getString(*llvm_context, + llvm::ConstantDataArray::getString(module->getContext(), AsStringRef(protobuf_array_contents), /*AddNull=*/false); new llvm::GlobalVariable( *module, protobuf_array_initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); - - return module; } static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, @@ -116,42 +109,44 @@ GetTargetMachineFromTriple(StringPiece target_triple) { /*Features=*/"", llvm::TargetOptions(), llvm::None)); } -StatusOr CreateEmbeddedProtocolBuffer( - StringPiece target_triple, StringPiece symbol_prefix, - StringPiece qualified_cpp_protobuf_name, - const ::tensorflow::protobuf::MessageLite* proto) { +StatusOr CreateEmbeddedProtocolBuffers( + StringPiece target_triple, + gtl::ArraySlice protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); llvm::LLVMContext llvm_context; - string object_file, cpp_shim, cpp_variable_decl; + std::unique_ptr module_with_serialized_proto = + MakeUnique("embedded_data_module", llvm_context); - if (proto) { - string protobuf_array_symbol_name; - int64 protobuf_array_size; + EmbeddedProtocolBuffers result; - std::unique_ptr module_with_serialized_proto = - CreateModuleWithEmbeddedProtocolBuffer( - &llvm_context, target_machine.get(), *proto, symbol_prefix, - &protobuf_array_symbol_name, &protobuf_array_size); - TF_ASSIGN_OR_RETURN(object_file, - CodegenModule(target_machine.get(), - std::move(module_with_serialized_proto))); - cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name, - protobuf_array_symbol_name, - protobuf_array_size); + for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) { + string cpp_shim, cpp_variable_decl; + if (protobuf_to_embed.message) { + string protobuf_array_symbol_name; + int64 protobuf_array_size; - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); - } else { - TF_ASSIGN_OR_RETURN( - object_file, - CodegenModule(target_machine.get(), - MakeUnique("empty_module", llvm_context))); - cpp_shim = "nullptr"; + AddEmbeddedProtocolBufferToLlvmModule( + module_with_serialized_proto.get(), *protobuf_to_embed.message, + protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name, + &protobuf_array_size); + cpp_shim = CreateCPPShimExpression( + protobuf_to_embed.qualified_cpp_protobuf_name, + protobuf_array_symbol_name, protobuf_array_size); + + cpp_variable_decl = strings::StrCat("extern \"C\" char ", + protobuf_array_symbol_name, "[];"); + } else { + cpp_shim = "nullptr"; + } + result.cpp_shims.push_back({cpp_shim, cpp_variable_decl}); } - return {{cpp_shim, cpp_variable_decl, object_file}}; + TF_ASSIGN_OR_RETURN(result.object_file_data, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); + return result; } } // namespace tfcompile diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 8436e0ff67f..ebfe4806c20 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -21,51 +21,70 @@ limitations under the License. #define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace tfcompile { using xla::StatusOr; -// Represents a protocol buffer embedded into an object file and describes a way -// to access it at runtime. -struct EmbeddedProtocolBuffer { - // cpp_shim_expression is a C++ expression that creates an instance of said - // protocol buffer when executed. - string cpp_shim_expression; +// Represents a set of protocol buffers embedded into an object file and +// describes how to access them at runtime. +struct EmbeddedProtocolBuffers { + // Each instance CPPShim describes how to generate C++ code to instantiate a + // protobuf instance from the corresponding static data emitted into the + // object file. + struct CPPShim { + // `expression` is a C++ expression that creates an instance of said + // protocol buffer when executed. + string expression; - // cpp_variable_decl is an "extern C" array declaration that is used in - // cpp_shim_expression. It must be visible wherever cpp_shim_expression is - // emitted. - string cpp_variable_decl; + // `variable_decl` is an "extern C" array declaration that is used in + // `expression`. It must be visible wherever `expression` is emitted. + string variable_decl; + }; - // The contents of the object (".o") file the protocol buffer is embbed in. - // This needs to be linked in to any program that wants to execute - // cpp_variable_decl . + // Each cpp_shim corresponds to one embedded protocol buffer. + std::vector cpp_shims; + + // The contents of the object (".o") file the protocol buffers are embbed in. + // This needs to be linked in to any program that wants to execute any of the + // expressions in `cpp_shims`. string object_file_data; }; -// Creates an object file that contains `proto`. -// -// `proto` is allowed to be nullptr, in which case the generated C++ shim -// expression is just `nullptr`, and the generated object file does not define -// any symbols. +// Describes a protocol buffer to embed into an object file. +struct ProtobufToEmbed { + // `symbol_prefix` is prefix that is guaranteed to be unique across the binary + // or DSO the generated object file will be linked into. + string symbol_prefix; + + // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ + // namespace qualified) protocol buffer name. This is only used in + // CPPShim::expression so relatively qualified names are fine as long as + // they're valid wherever CPPShim::expression is emitted. + string qualified_cpp_protobuf_name; + + // `message` is the protocol buffer to be embedded. It is allowed to be + // nullptr, in which case the generated C++ shim expression is just `nullptr`, + // and the generated object file does not define any symbols. + const ::tensorflow::protobuf::MessageLite* message; +}; + +// Embeds a a sequence of protocol buffers into an object file. // // `target_triple` is the target triple for the target architecture for the // generated object file. // -// `symbol_prefix` is prefix that is guaranteed to be unique across the binary -// or DSO the generated object file will be linked into. -// -// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ -// namespace qualified) protocol buffer name. This needs is only used in -// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified -// names are fine as long as they're valid wherever cpp_shim_expression -// is emitted. -StatusOr CreateEmbeddedProtocolBuffer( - StringPiece target_triple, StringPiece symbol_prefix, - StringPiece qualified_cpp_protobuf_name, - const ::tensorflow::protobuf::MessageLite* proto); +// `protobufs_to_embed` describes the protocol buffers to embed into the +// resulting object file. The C++ shim for protobufs_to_embed[i] is +// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents +// of all the protocol buffers are embedded into a single .o file whose content +// is stored in the object_file_data field in the returned +// EmbeddedProtocolBuffers instance. +StatusOr CreateEmbeddedProtocolBuffers( + StringPiece target_triple, + gtl::ArraySlice protobufs_to_embed); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h index d085864f001..d1a669ceb17 100644 --- a/tensorflow/compiler/aot/runtime.h +++ b/tensorflow/compiler/aot/runtime.h @@ -25,8 +25,8 @@ namespace tensorflow { namespace tfcompile { namespace runtime { -// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. -static constexpr size_t kAlign = 32; +// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment. +static constexpr size_t kAlign = 64; // aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1 // values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index 6d603a02eb4..06ec623eb2d 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -24,7 +24,7 @@ namespace runtime { namespace { TEST(Runtime, AlignmentValue) { - // We've chosen 32 byte alignment for the tfcompile runtime to mimic the + // We've chosen 64 byte alignment for the tfcompile runtime to mimic the // regular tensorflow allocator, which was chosen to play nicely with Eigen. // The tfcompile runtime also has a requirement that comes from the xla // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8 @@ -39,13 +39,13 @@ TEST(Runtime, AlignedBufferBytes) { EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0); static constexpr intptr_t sizesB[1] = {3}; - EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 64); static constexpr intptr_t sizesC[1] = {32}; - EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32); + EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 64); static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3}; - EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192); + EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 320); } void* add_ptr(void* base, uintptr_t delta) { @@ -101,11 +101,11 @@ TEST(Runtime, MallocFreeContiguousBuffers) { EXPECT_NE(base, nullptr); EXPECT_EQ(bufD[0], add_ptr(base, 0)); EXPECT_EQ(bufD[1], nullptr); - EXPECT_EQ(bufD[2], add_ptr(base, 32)); + EXPECT_EQ(bufD[2], add_ptr(base, 64)); EXPECT_EQ(bufD[3], nullptr); - EXPECT_EQ(bufD[4], add_ptr(base, 64)); - EXPECT_EQ(bufD[5], add_ptr(base, 128)); - EXPECT_EQ(bufD[6], add_ptr(base, 160)); + EXPECT_EQ(bufD[4], add_ptr(base, 128)); + EXPECT_EQ(bufD[5], add_ptr(base, 192)); + EXPECT_EQ(bufD[6], add_ptr(base, 256)); for (int i = 0; i < 7; ++i) { const intptr_t size = sizesD[i]; if (size != -1) { diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc index 47ef5f82cbc..6b098049cbd 100644 --- a/tensorflow/compiler/aot/test.cc +++ b/tensorflow/compiler/aot/test.cc @@ -35,6 +35,7 @@ limitations under the License. // clang-format on #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index bb73cb19c57..fd2cf2b67d4 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -15,6 +15,7 @@ test_suite( ":test_graph_tfadd_with_ckpt_saver_test", ":test_graph_tfadd_with_ckpt_test", ":test_graph_tfassert_eq_test", + ":test_graph_tfcond_test", ":test_graph_tffunction_test", ":test_graph_tfgather_test", ":test_graph_tfmatmul_test", @@ -55,6 +56,7 @@ genrule( "test_graph_tfadd_with_ckpt_saver.pb", "test_graph_tfadd_with_ckpt_saver.saver", "test_graph_tfassert_eq.pb", + "test_graph_tfcond.pb", "test_graph_tffunction.pb", "test_graph_tfgather.pb", "test_graph_tfmatmul.pb", @@ -118,6 +120,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tfcond", + testonly = 1, + config = "test_graph_tfcond.config.pbtxt", + cpp_class = "CondComp", + graph = "test_graph_tfcond.pb", + tags = [ + "manual", + ], +) + tf_library( name = "test_graph_tffunction", testonly = 1, @@ -163,6 +176,15 @@ tf_library( tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) +tf_library( + name = "test_graph_tfmatmulandadd_with_profiling", + testonly = 1, + config = "test_graph_tfmatmulandadd.config.pbtxt", + cpp_class = "MatMulAndAddCompWithProfiling", + enable_xla_hlo_profiling = True, + graph = "test_graph_tfmatmulandadd.pb", +) + tf_library( name = "test_graph_tfsplits", testonly = 1, @@ -185,13 +207,18 @@ tf_cc_test( ":test_graph_tfadd_with_ckpt", ":test_graph_tfadd_with_ckpt_saver", ":test_graph_tfassert_eq", + ":test_graph_tfcond", ":test_graph_tffunction", ":test_graph_tfgather", ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", + ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_profile_printer", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 67767f55dae..9ec7df163b1 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir): f.write(saver.as_saver_def().SerializeToString()) +def tfassert_eq(_): + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + control_flow_ops.Assert( + math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') + math_ops.add(x, math_ops.negative(y), name='x_y_diff') + + +def tfcond(_): + p = array_ops.placeholder(dtypes.bool, name='p_hold') + x = array_ops.placeholder(dtypes.int32, name='x_hold') + y = array_ops.placeholder(dtypes.int32, name='y_hold') + z = control_flow_ops.cond(p, lambda: x, lambda: y) + array_ops.identity(z, name='result') + + def tfgather(_): params = array_ops.placeholder(dtypes.float32, name='params') indices = array_ops.placeholder(dtypes.int32, name='indices') @@ -126,14 +142,6 @@ def tfsplits(_): array_ops.identity(y, name='result') -def tfassert_eq(_): - x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = array_ops.placeholder(dtypes.int32, name='y_hold') - control_flow_ops.Assert( - math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq') - math_ops.add(x, math_ops.negative(y), name='x_y_diff') - - def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -148,12 +156,13 @@ def main(_): write_graph(tfadd, FLAGS.out_dir) write_graph(tfadd_with_ckpt, FLAGS.out_dir) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) + write_graph(tfassert_eq, FLAGS.out_dir) + write_graph(tfcond, FLAGS.out_dir) + write_graph(tffunction, FLAGS.out_dir) write_graph(tfgather, FLAGS.out_dir) write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) - write_graph(tffunction, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) - write_graph(tfassert_eq, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt new file mode 100644 index 00000000000..94a01ad4abf --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt @@ -0,0 +1,20 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "p_hold" } + shape {} +} +feed { + id { node_name: "x_hold" } + shape { + dim { size: 1 } + } +} +feed { + id { node_name: "y_hold" } + shape { + dim { size: 1 } + } +} +fetch { + id { node_name: "result" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 67dbd643bfc..fee46280e9a 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -21,19 +21,27 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h" #include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h" #include "tensorflow/compiler/aot/tests/test_graph_tffunction.h" #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" +#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { namespace tfcompile { namespace { +using ::testing::HasSubstr; +using ::testing::IsSupersetOf; + TEST(TFCompileTest, Add) { AddComp add; EXPECT_EQ(add.arg0_data(), add.args()[0]); @@ -143,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); } +TEST(TFCompileTest, Cond) { + CondComp cond; + EXPECT_EQ(cond.arg0_data(), cond.args()[0]); + EXPECT_EQ(cond.arg1_data(), cond.args()[1]); + EXPECT_EQ(cond.arg2_data(), cond.args()[2]); + cond.arg1() = 10; + cond.arg2() = 20; + { + cond.arg0() = true; + const int32 expected_result = cond.arg1(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } + { + cond.arg0() = false; + const int32 expected_result = cond.arg2(); + EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.result0(), expected_result); + EXPECT_EQ(cond.result0_data()[0], expected_result); + EXPECT_EQ(cond.result0_data(), cond.results()[0]); + } +} + TEST(TFCompileTest, Gather) { GatherComp gather; EXPECT_EQ(gather.arg0_data(), gather.args()[0]); @@ -484,6 +517,56 @@ TEST(TFCompileTest, ProgramShape) { EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); } +TEST(TFCompileTest, HloProfiling) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + MatMulAndAddCompWithProfiling fn; + ASSERT_TRUE(fn.hlo_profiling_enabled()); + + fn.set_thread_pool(&device); + + // x = [[1, 2], [3, 4]] + fn.arg0(0, 0) = 1; + fn.arg0(0, 1) = 2; + fn.arg0(1, 0) = 3; + fn.arg0(1, 1) = 4; + + // y = [[10, 20], [30, 40]] + fn.arg1(0, 0) = 10; + fn.arg1(0, 1) = 20; + fn.arg1(1, 0) = 30; + fn.arg1(1, 1) = 40; + + EXPECT_TRUE(fn.Run()); + + string hlo_profile_as_string = + xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), + /*clock_rate_ghz=*/1.0); + VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string; + + std::vector hlo_profile_lines = + tensorflow::str_util::Split(hlo_profile_as_string, '\n'); + + auto header = HasSubstr("Execution profile for"); + auto total_cycles_profile_line = HasSubstr("[total]"); + auto dot_profile_line = HasSubstr( + "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); + auto add_profile_line = HasSubstr( + "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%arg1.0.1)"); + auto tuple_profile_line = HasSubstr( + "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " + "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); + auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + + EXPECT_THAT(hlo_profile_lines, + IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, + add_profile_line, tuple_profile_line})); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 3a877c5337f..5c57fee326c 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -25,7 +25,8 @@ def tf_library(name, graph, config, visibility=None, testonly=None, tfcompile_flags=None, tfcompile_tool="//tensorflow/compiler/aot:tfcompile", - include_standard_runtime_deps=True, deps=None, tags=None): + include_standard_runtime_deps=True, + enable_xla_hlo_profiling=False, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. Given an invocation of tf_library(name="foo", ...), generates the following @@ -68,6 +69,8 @@ def tf_library(name, graph, config, include_standard_runtime_deps: If True, the standard list of kernel/runtime deps is added to deps. If False, deps must contain the full set of deps needed by the generated library. + enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program, + and emit metadata that lets us pretty-print the gathered profile counters. deps: a list of deps to include on the build rules for the generated library, added to the standard deps if standard_runtime_deps is True. tags: tags to apply to subsidiary build rules. @@ -137,6 +140,10 @@ def tf_library(name, graph, config, flags = tfcompile_flags else: flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) + if enable_xla_hlo_profiling: + profiling_flag = "--xla_hlo_profile" + else: + profiling_flag = "" native.genrule( name=("gen_" + name), srcs=[ @@ -157,7 +164,7 @@ def tf_library(name, graph, config, " --out_header=$(@D)/" + header_file + " --out_metadata_object=$(@D)/" + metadata_object_file + " --out_function_object=$(@D)/" + function_object_file + - " " + flags), + " " + flags + " " + profiling_flag), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -220,6 +227,8 @@ def tf_library(name, graph, config, ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the proto. "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (enable_xla_hlo_profiling and [ + "//tensorflow/compiler/xla/service:hlo_profile_printer_data" ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 8ea014c2eed..839e1588b7b 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -100,6 +100,8 @@ Status Main(const MainFlags& flags) { if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } + codegen_opts.gen_hlo_profile_printer_data = + xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile(); TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, &codegen_opts.namespaces)); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 6edeb7047f9..980e0eec9e2 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -124,7 +124,6 @@ cc_library( srcs = ["xla_tensor.cc"], hdrs = ["xla_tensor.h"], deps = [ - ":common", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:shaped_buffer", @@ -176,6 +175,7 @@ cc_library( "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:sendrecv_ops", @@ -216,6 +216,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:gpu_runtime", @@ -256,23 +257,11 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "graph_to_functiondef", - srcs = ["graph_to_functiondef.cc"], - hdrs = ["graph_to_functiondef.h"], - visibility = [":friends"], - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "create_xla_launch_op", srcs = [ "create_xla_launch_op.cc", + "create_xla_launch_op.h", ], deps = [ ":common", @@ -282,6 +271,27 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "create_xla_launch_op_test", + srcs = [ + "create_xla_launch_op.h", + "create_xla_launch_op_test.cc", + ], + deps = [ + ":create_xla_launch_op", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -299,7 +309,6 @@ cc_library( ], deps = [ ":common", - ":graph_to_functiondef", ":shape_inference_helpers", ":union_find", "//tensorflow/compiler/jit/graphcycles", @@ -318,6 +327,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", ], ) @@ -345,28 +355,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "graph_to_functiondef_test", - size = "small", - srcs = [ - "graph_to_functiondef_test.cc", - ], - deps = [ - ":graph_to_functiondef", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:ops", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ], -) - tf_cc_test( name = "compilation_passes_test", size = "small", @@ -377,7 +365,6 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", - ":graph_to_functiondef", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -395,6 +382,31 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_launch_util_test", + size = "small", + srcs = ["xla_launch_util_test.cc"], + deps = [ + ":common", + ":xla_compilation_cache", + ":xla_launch_util", + ":xla_tensor", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core/kernels:variable_ops", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc index 9a2bb000752..b17ff589e25 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc @@ -40,7 +40,7 @@ static Status BuildLaunchNode( Graph* graph, Node** node) { NodeDef def; def.set_name(graph->NewName(nodename)); - def.set_op("_XlaLaunch"); + def.set_op("XlaLaunch"); def.set_device(device_name); AddNodeAttr("Tconstants", constant_dtypes, &def); AddNodeAttr("Targs", arg_dtypes, &def); @@ -79,7 +79,7 @@ static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { node->input_types().begin() + num_constant_args, node->input_types().begin() + num_constant_args + num_nonconst_args); - // Build a _XlaLaunch operator to execute the function body. + // Build a XlaLaunch operator to execute the function body. Node* launch_node; TF_RETURN_IF_ERROR(BuildLaunchNode( graph->NewName(node->name()), node->type_string(), node->def().attr(), diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 18d901323f1..731b8ebfdc6 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/jit/create_xla_launch_op.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/kernels/xla_launch_op.h" @@ -21,82 +22,194 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { -// Givens a NodeDef 'ndef' and the function library runtime 'flr', if -// 'ndef' is a call to a compilable function defined in 'flr', returns OK -// and fills in 'kernel' with a XlaLaunchOp kernel which computes the -// node. Otherwise, returns a non-OK. +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. // -// This routine is here so that FunctionLibraryRuntime can jit a -// specific function call as requested. -Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef, - std::unique_ptr* kernel) { +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; + } + + private: + int current_index_; + const std::vector* values_; +}; + +Status CompilationRequested(const FunctionLibraryRuntime& flr, + const NodeDef& node_def) { bool xla_compile = false; - if (!flr->GetFunctionLibraryDefinition() - ->GetAttr(ndef, kXlaCompileAttr, &xla_compile) - .ok() || - !xla_compile) { - // Not marked as _XlaCompile=true. - return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op()); - } - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - if (!IsCompilable(flr, ndef)) { - // ndef is calling a function that XLA can't compile. - return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString()); + // Check if op is marked _XlaCompile=true. + Status status = flr.GetFunctionLibraryDefinition()->GetAttr( + node_def, kXlaCompileAttr, &xla_compile); + if (!status.ok() || !xla_compile) { + if (VLOG_IS_ON(3)) { + if (!status.ok()) { + VLOG(3) << "No " << kXlaCompileAttr << " attr defined for " + << node_def.op() << ". status=" << status.ToString(); + } else { + VLOG(3) << node_def.op() << " is explicitly marked not to be compiled"; + } + } + return Status(error::INVALID_ARGUMENT, ""); } + return Status::OK(); +} + +// Given a FunctionLibraryRuntime and a NodeDef calling a function in the +// runtime, returns this function's body in `fbody` as well as the indices +// of its constant and resource arguments. +// `fbody` is owned by `flr`. +// `constant_arg_indices` and `resource_arg_indices` should be empty vector. +// They are sorted in ascending order on this function's return. +Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + const FunctionBody** fbody, + std::vector* constant_arg_indices, + std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; - // If ndef is not instantiable, e.g., the function does not exist, + // If node_def is not instantiable, e.g., the function does not exist, // simply bail out. TF_RETURN_IF_ERROR( - flr->Instantiate(ndef.op(), AttrSlice(&ndef.attr()), &handle)); - const FunctionBody* fbody = flr->GetFunctionBody(handle); - CHECK(fbody); // Can't be nullptr since we just instantiated it. - std::vector const_args(fbody->arg_types.size()); + flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle)); + *fbody = flr->GetFunctionBody(handle); + CHECK(*fbody); // Can't be nullptr since we just instantiated it. + const DataTypeVector& arg_types = (*fbody)->arg_types; + std::vector const_args(arg_types.size()); // If we can't analyze the const args. Bail out. - TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args)); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*((*fbody)->graph), &const_args)); for (int i = 0; i < const_args.size(); ++i) { if (const_args[i]) { - // There is a const arg. Bail out. - return errors::InvalidArgument("Const arg: ", i, " in ", - DebugString(fbody->fdef)); + constant_arg_indices->push_back(i); } } - NodeDef launch_def; - launch_def.set_name(ndef.name()); - launch_def.set_op("_XlaLaunch"); - launch_def.set_device(flr->device()->name()); - AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def); - AddNodeAttr("Nresources", 0, &launch_def); - AddNodeAttr("Targs", fbody->arg_types, &launch_def); - AddNodeAttr("Tresults", fbody->ret_types, &launch_def); - NameAttrList func; - func.set_name(ndef.op()); - *(func.mutable_attr()) = ndef.attr(); - AddNodeAttr("function", func, &launch_def); + // There can be hundreds of resource variables. Reserve the space for them. + // We don't reserve for constants above as they are usually few. + resource_arg_indices->reserve(arg_types.size()); + for (int i = 0; i < arg_types.size(); ++i) { + if (arg_types[i] == DT_RESOURCE) { + resource_arg_indices->push_back(i); + } + } - // TODO(b/32387911): Handles the host memory types across function - // calls properly. For now, we assume all inputs and outputs are on - // the device memory. + return Status::OK(); +} + +} // namespace + +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel) { + TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def)); + + VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + if (!IsCompilable(flr, node_def)) { + // node_def is calling a function that XLA can't compile. + return errors::InvalidArgument("Not compilable: ", + node_def.ShortDebugString()); + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (int i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instatiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + // Create the kernel. + NameAttrList function; + function.set_name(node_def.op()); + *(function.mutable_attr()) = node_def.attr(); + Device* dev = flr->device(); Status s; OpKernelConstruction construction( DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), &launch_def, + dev->GetAllocator(AllocatorAttributes()), &node_def, &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, flr->graph_def_version(), &s); - kernel->reset(new XlaLocalLaunchOp(&construction)); + + *kernel = MakeUnique(&construction, constant_arg_indices, + resource_arg_indices, function); return s; } +namespace { + bool RegisterLaunchOpCreator() { RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp); return true; diff --git a/tensorflow/compiler/jit/create_xla_launch_op.h b/tensorflow/compiler/jit/create_xla_launch_op.h new file mode 100644 index 00000000000..98a22e35153 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ +#define TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class FunctionLibraryRuntime; +class OpKernel; + +// Given a NodeDef 'node_def' and the function library runtime 'flr', if +// 'node_def' is a call to a compilable function defined in 'flr', returns OK +// and fills in 'kernel' with a XlaLaunchOp kernel which computes the +// node. Otherwise, returns a non-OK. +Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def, + std::unique_ptr* kernel); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_CREATE_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/create_xla_launch_op_test.cc b/tensorflow/compiler/jit/create_xla_launch_op_test.cc new file mode 100644 index 00000000000..b75ab486b80 --- /dev/null +++ b/tensorflow/compiler/jit/create_xla_launch_op_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/create_xla_launch_op.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +NodeDef ToNodeDef(const string& text) { + NodeDef node_def; + EXPECT_TRUE(protobuf::TextFormat::MergeFromString(text, &node_def)); + return node_def; +} + +// Create a FunctionDef that takes one resource and one regular param +FunctionDef XTimesY() { + return FunctionDefHelper::Define( + // Name + "XTimesY", + // Args + {"x: float", "y: resource"}, + // Return values + {"z: float"}, + // Attr def + {}, + // Nodes + { + {{"y0"}, "ReadVariableOp", {"y"}, {{"dtype", DT_FLOAT}}}, + {{"z"}, "Mul", {"x", "y0"}, {{"T", DT_FLOAT}}}, + }); +} + +class CreateXlaLaunchOpTest : public ::testing::Test { + protected: + void Init(const std::vector& flib) { + SessionOptions options; + auto* device_count = options.config.mutable_device_count(); + device_count->insert({"CPU", 1}); + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices_)); + + FunctionDefLibrary proto; + for (const auto& fdef : flib) { + *(proto.add_function()) = fdef; + } + lib_def_ = + MakeUnique(OpRegistry::Global(), proto); + OptimizerOptions opts; + device_mgr_ = MakeUnique(devices_); + pflr_ = MakeUnique( + device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), + opts, /*default_thread_pool=*/nullptr, /*cluster_flr=*/nullptr); + flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0"); + } + + FunctionLibraryRuntime* flr_; + std::vector devices_; + std::unique_ptr device_mgr_; + std::unique_ptr lib_def_; + std::unique_ptr pflr_; + + std::unique_ptr kernel_; +}; + +AttrValue BoolAttr(bool b) { + AttrValue v; + v.set_b(b); + return v; +} + +TEST_F(CreateXlaLaunchOpTest, OneFloatOneResourceArgument) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(true); + Init({fdef}); + + Status status = CreateXlaLaunchOp( + flr_, ToNodeDef(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), &kernel_); + ASSERT_TRUE(status.ok()) << status.ToString(); + + EXPECT_EQ("XTimesY", kernel_->name()); + EXPECT_EQ("XTimesY", kernel_->type_string()); + + EXPECT_EQ(2, kernel_->num_inputs()); + EXPECT_EQ(DT_FLOAT, kernel_->input_type(0)); + EXPECT_EQ(DT_RESOURCE, kernel_->input_type(1)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->input_memory_types()[0]); + EXPECT_EQ(HOST_MEMORY, kernel_->input_memory_types()[1]); + + EXPECT_EQ(1, kernel_->num_outputs()); + EXPECT_EQ(DT_FLOAT, kernel_->output_type(0)); + EXPECT_EQ(DEVICE_MEMORY, kernel_->output_memory_types()[0]); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrNotSet) { + FunctionDef fdef = XTimesY(); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +TEST_F(CreateXlaLaunchOpTest, FailsIfXlaCompileAttrIsSetToFalse) { + FunctionDef fdef = XTimesY(); + (*fdef.mutable_attr())["_XlaCompile"] = BoolAttr(false); + Init({fdef}); + + Status status = CreateXlaLaunchOp(flr_, ToNodeDef(R"proto( + name: 'XTimesY' + op: 'XTimesY' + input: 'a' + input: 'b' + )proto"), &kernel_); + EXPECT_TRUE(errors::IsInvalidArgument(status)) << status.ToString(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 9465385b585..6d1e3325ebd 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" @@ -160,6 +161,11 @@ class Encapsulator { std::move(outside_compilation_attribute)), graph_in_(graph_in) {} + // Find dependencies between subgraphs and outside_compilation clusters that + // only manifest via edges between outside_compilation clusters in the outer + // (non-compiled) graph. + Status FindClusterDependencies(); + // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. Status SplitIntoSubgraphs(); @@ -230,6 +236,19 @@ class Encapsulator { // the shapes of any ancestor RAH outputs. If it can be determined that the // shape of the SFH inputs will not be inferrable even once the shapes of the // RAH outputs are known, an error is returned by the rewriter. + // + // Once edges between compiled and outside_compilation clusters have been + // replaced by send/recv ops, some dependencies may no longer be apparent. + // A clustering pass finds all the dependencies between HC nodes that are only + // present as a result of edges between nodes in outside_compilation clusters. + // Suppose there is a path from outside_compilation cluster C in subgraph S + // to outside_compilation cluster D in subgraph T. If S != T then a control + // edge is added from the call node for S to the call node for T, which + // ensures that C will execute before D because S executes before T. If S==T + // then a control dependency is added between the HC nodes for C and D in S, + // and the HC node for C is added to an 'ancestors' attr in the HC node for D + // so that during compilation of the HC node for D, an XLA control dependency + // can be added to ensure C's SendToHost executes before D's RecvFromHost. class Subgraph { public: // Creates a graph to build the subgraph in, if it doesn't already exist, @@ -324,6 +343,18 @@ class Encapsulator { void RecordOutsideCompilationOutputOrControl( const string& outside_compilation_id, const Edge* edge); + // Records the fact that there is a path from a node in outside_compilation + // cluster ancestor to node in cluster successor that does not go through + // the subgraph. + void RecordOutsideCompilationDependency(const string& successor, + const string& ancestor); + + // Returns the mapping from outside_compilation cluster C to the set of + // outside_compilation clusters that have a path to C entirely outside + // compiled subgraphs. + const std::unordered_map> + OutsideCompilationAncestorMap() const; + // Adds the HostCompute nodes for each outside_compilation subgraph. Status AddHostComputes( const string& subgraph_name, @@ -406,6 +437,13 @@ class Encapsulator { Status AddHostComputeKeyPlaceholder(OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + // Get the set of outside_compilation clusters and the dependency edges + // between them. + void GetActiveClusterDependencyGraph( + std::unordered_set* clusters, + std::unordered_set* has_successor, + std::unordered_map>* ancestors_map); + // Builds a _RecvAtHost node producing all the inputs of an // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. Status AddRecvAtHostNode(const string& group_attribute, @@ -468,6 +506,14 @@ class Encapsulator { // The outside_compilation clusters in this subgraph. std::unordered_map outside_compilation_subgraphs_; + // For each outside_compilation cluster C, the outside_compilation clusters + // that have a path to C outside the compiled graph. + std::unordered_map> + outside_compilation_ancestors_; + // For each outside_compilation cluster C, the outside_compilation clusters + // that have a path from C outside the compiled graph. + std::unordered_map> + outside_compilation_successors_; // NoOp node in the output graph that is sequenced after the call node and // used to prevent host-side outside_compilation sends and recvs from being @@ -556,6 +602,10 @@ class Encapsulator { std::unordered_set, NodeSlot::PairHasher>* edges_added); + // Adds control dependencies between subgraph call nodes that have + // dependencies via outside_compilation edges. + Status AddCallNodeDependencies(Graph* graph_out); + // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, @@ -620,10 +670,65 @@ class Encapsulator { const Graph* graph_in_; std::unordered_map subgraphs_; + // For each subgraph S the subgraphs S' such that there is a path in some + // outside_compilation cluster C in S to some outside_compilation cluster C' + // in S', that goes only through the uncompiled graph. + std::unordered_map> subgraph_ancestors_; TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; +namespace { + +// Return in 'sorted' a topological sort of clusters according to the +// dependencies encoded in ancestors. clusters is the list of all clusters +// including clusters that are not present in the ancestors map. has_successors +// is the set of clusters that are ancestors of some other cluster. +void TopologicalClusterSort( + const std::unordered_set& clusters, + const std::unordered_set& has_successors, + const std::unordered_map>& ancestors, + std::vector* sorted) { + // The nodes are placed in 'sorted' in topological order. + sorted->clear(); + // We don't use the standard DFS because we are not operating on Node* + // objects. + struct Work { + string cluster; + bool leave; + }; + std::set visited; + std::vector stack; + // Seed the processing list with clusters that have no successors. + for (const auto& cluster : clusters) { + if (has_successors.find(cluster) == has_successors.end()) { + stack.push_back({cluster, false}); + } + } + while (!stack.empty()) { + const Work item = stack.back(); + stack.pop_back(); + if (item.leave) { + sorted->push_back(item.cluster); + continue; + } + + if (visited.find(item.cluster) != visited.end()) continue; + visited.insert(item.cluster); + + stack.push_back({item.cluster, true}); + const auto& iter = ancestors.find(item.cluster); + if (iter != ancestors.end()) { + for (const auto& ancestor : iter->second) { + stack.push_back({ancestor, false}); + } + } + } + CHECK(sorted->size() == clusters.size()); +} + +} // namespace + Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { return call_node_inputs_; } @@ -786,12 +891,71 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( } } +void Encapsulator::Subgraph::RecordOutsideCompilationDependency( + const string& successor, const string& ancestor) { + outside_compilation_ancestors_[successor].insert(ancestor); + outside_compilation_successors_[ancestor].insert(successor); +} + +const std::unordered_map> +Encapsulator::Subgraph::OutsideCompilationAncestorMap() const { + return outside_compilation_ancestors_; +} + +void Encapsulator::Subgraph::GetActiveClusterDependencyGraph( + std::unordered_set* clusters, + std::unordered_set* has_successor, + std::unordered_map>* ancestors_map) { + // During initial clustering the ancestor and successor datastructures may + // have been built including oc_cluster names that never turned into subgraphs + // because they had no edges into or out of the compiled cluster. Remove them + // before proceeding to simplify the logic. Get the set of clusters that was + // actually added, then remove references to the others. + for (const auto& oc_subgraph : outside_compilation_subgraphs_) { + clusters->insert(oc_subgraph.first); + } + for (const auto& cluster : outside_compilation_successors_) { + if (clusters->find(cluster.first) != clusters->end()) { + for (const auto& successor : cluster.second) { + if (clusters->find(successor) != clusters->end()) { + has_successor->insert(cluster.first); + break; + } + } + } + } + for (const auto& cluster : outside_compilation_ancestors_) { + if (clusters->find(cluster.first) != clusters->end()) { + std::unordered_set& ancestors = (*ancestors_map)[cluster.first]; + for (const auto& ancestor : cluster.second) { + if (clusters->find(ancestor) != clusters->end()) { + ancestors.insert(ancestor); + } + } + } + } +} + Status Encapsulator::Subgraph::AddHostComputes( const string& subgraph_name, const std::unordered_map& node_images) { - for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { - const string& oc_subgraph_name = oc_subgraph_iter.first; - OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + // Get the set of outside_compilation clusters and the dependency edges + // between them. + std::unordered_set clusters; + std::unordered_set has_successor; + std::unordered_map> ancestors_map; + GetActiveClusterDependencyGraph(&clusters, &has_successor, &ancestors_map); + // Topologically sort the outside_compilation clusters according to their + // dependency relation. + std::vector sorted_clusters; + TopologicalClusterSort(clusters, has_successor, ancestors_map, + &sorted_clusters); + + // The host compute nodes added for each outside_compilation_cluster; + std::unordered_map host_compute_node; + for (const string& oc_subgraph_name : sorted_clusters) { + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraphs_[oc_subgraph_name]; if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || !oc_subgraph.outputs_by_src.empty() || !oc_subgraph.control_outputs.empty()) { @@ -811,13 +975,22 @@ Status Encapsulator::Subgraph::AddHostComputes( inputs[input_index].Reset(src_image->name(), src_slot, dtype); input_dtypes[input_index] = dtype; } - for (const auto& output : oc_subgraph.outputs_by_src) { DataType dtype = output.first.dtype; int output_index = output.second; output_dtypes[output_index] = dtype; } + std::vector host_compute_ancestors; + const auto iter = ancestors_map.find(oc_subgraph_name); + if (iter != ancestors_map.end()) { + for (const string& ancestor_cluster : iter->second) { + host_compute_ancestors.push_back( + outside_compilation_subgraphs_[ancestor_cluster] + .host_compute_name); + } + } + NodeDef host_compute_def; NodeDefBuilder builder(strings::StrCat("outside_compilation_", oc_subgraph_name, "_host_compute"), @@ -825,6 +998,7 @@ Status Encapsulator::Subgraph::AddHostComputes( builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); + builder.Attr("ancestors", host_compute_ancestors); builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, "_", oc_subgraph_name)); @@ -834,6 +1008,7 @@ Status Encapsulator::Subgraph::AddHostComputes( Node* host_compute = graph_->AddNode(host_compute_def, &s); if (!s.ok()) return s; + host_compute_node[host_compute->name()] = host_compute; oc_subgraph.host_compute_name = host_compute->name(); // Connect the _HostCompute node to its producers in the subgraph. @@ -852,6 +1027,12 @@ Status Encapsulator::Subgraph::AddHostComputes( graph_->AddControlEdge(src_image, host_compute); } + // Connect the _HostCompute node to its ancestor host compute nodes. + for (const auto& ancestor_name : host_compute_ancestors) { + Node* ancestor = host_compute_node[ancestor_name]; + graph_->AddControlEdge(ancestor, host_compute); + } + // Connect the consumers in the subgraph to the _HostCompute node. for (const auto& output : oc_subgraph.outputs_by_dst) { const Node* dst_node = output.first.node; @@ -1654,6 +1835,17 @@ Status Encapsulator::CopyEdgeToOutputGraph( return Status::OK(); } +Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { + for (const auto& ancestors : subgraph_ancestors_) { + const string& subgraph = ancestors.first; + for (const string& ancestor : ancestors.second) { + graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(), + subgraphs_[subgraph].GetCallNodeForInputs()); + } + } + return Status::OK(); +} + Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, bool parallel_checking, Graph* graph_out) { @@ -1703,6 +1895,7 @@ Status Encapsulator::AddEdgesToOutputGraph( Subgraph& subgraph = subgraph_entry.second; subgraph.ConnectSequencerToCallNode(graph_out); } + TF_RETURN_IF_ERROR(AddCallNodeDependencies(graph_out)); return Status::OK(); } @@ -1960,6 +2153,182 @@ Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( return Status::OK(); } +namespace { + +// Helper struct for building cluster dependencies and also debugging cycles in +// the dependencies. While computing dependencies we construct a mapping from +// Node* to PathDetails. +struct PathDetails { + struct SubgraphAndCluster { + string subgraph; + string outside_compilation_cluster; + bool operator==(const SubgraphAndCluster& other) const { + return subgraph == other.subgraph && + outside_compilation_cluster == other.outside_compilation_cluster; + } + }; + + struct SubgraphAndClusterHash { + inline std::size_t operator()(const SubgraphAndCluster& v) const { + return hash()( + strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + } + }; + + typedef std::unordered_set + SubgraphAndClusterSet; + + // Returns the set of (subgraph, oc_cluster) pairs that should be recorded as + // ancestors for any successor of this node. If the node is in the outer + // graph, it returns the transitive union of the ancestors of the node's + // inputs. If the node is in an outside_compilation cluster, it returns just + // that cluster. If the node is compiled, it returns the empty set. + SubgraphAndClusterSet AncestorsForSuccessor() { + if (subgraph.empty()) { + return ancestor_clusters; + } else if (outside_compilation_cluster.empty()) { + return SubgraphAndClusterSet(); + } else { + SubgraphAndCluster entry; + entry.subgraph = subgraph; + entry.outside_compilation_cluster = outside_compilation_cluster; + return SubgraphAndClusterSet({entry}); + } + } + + // The transitive union of the ancestor's of this node's inputs. This is only + // saved for debugging in order to print out enough information to debug a + // discovered cycle. + SubgraphAndClusterSet ancestor_clusters; + // The subgraph attr on this node. + string subgraph; + // The outside_compilation attr on this node. + string outside_compilation_cluster; +}; + +// Adds an edge from ancestor to successor to the cycle detector, and returns an +// error if that edge causes the formation of a cycle. In the error case, logs +// the contents of the node_ancestors_map to facilitate debugging. +Status CheckClusterDependencyForCycles( + const string& ancestor, const string& successor, + const std::unordered_map>& ancestors, + const std::unordered_map& node_ancestors_map, + GraphCycles* cycle_detector, std::map* cycle_detector_map) { + if (cycle_detector_map->find(ancestor) == cycle_detector_map->end()) { + (*cycle_detector_map)[ancestor] = cycle_detector->NewNode(); + } + if (cycle_detector_map->find(successor) == cycle_detector_map->end()) { + (*cycle_detector_map)[successor] = cycle_detector->NewNode(); + } + + if (!cycle_detector->InsertEdge((*cycle_detector_map)[ancestor], + (*cycle_detector_map)[successor])) { + LOG(ERROR) << "Cycle in outside_compilation clusters"; + for (const auto& cluster : ancestors) { + LOG(ERROR) << "Cluster " << cluster.first << " depends on:"; + for (const auto& ancestor : cluster.second) { + LOG(ERROR) << " " << ancestor; + } + } + for (const auto& node_ancestors : node_ancestors_map) { + LOG(ERROR) << "Node " << node_ancestors.first->name() << " (" + << node_ancestors.second.subgraph << ";" + << node_ancestors.second.outside_compilation_cluster + << ") has ancestor clusters:"; + for (const auto& ancestor : node_ancestors.second.ancestor_clusters) { + LOG(ERROR) << " " << ancestor.subgraph << ";" + << ancestor.outside_compilation_cluster; + } + } + return errors::InvalidArgument( + "Can't compile outside_compilation clusters because there is a " + "dependency cycle: see error log for details."); + } + return Status::OK(); +} + +} // namespace + +Status Encapsulator::FindClusterDependencies() { + // Map from nodes to ancestor details. A node is entered into the map if it is + // in a compilation subgraph, and outside_compilation cluster, or appears on a + // path in the outer graph leading from an outside_compilation subgraph. + std::unordered_map node_ancestors_map; + // We check that clusters are acyclic using this cycle detector. + GraphCycles cycle_detector; + // Map from cluster name to cycle detector node id. + std::map cycle_detector_map; + // Process the nodes in topologically-sorted order. + std::vector nodes; + GetReversePostOrder(*graph_in_, &nodes); + for (Node* node : nodes) { + string subgraph_name; + string oc_cluster; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(node, &subgraph_name, &oc_cluster)); + // First create an entry in the ancestors map if the node is in a compiled + // subgraph or outside_compilation cluster, or if any incoming edge is from + // a node with an ancestor map entry; and find the union of all the + // ancestors. + if (!subgraph_name.empty()) { + node_ancestors_map[node].subgraph = subgraph_name; + node_ancestors_map[node].outside_compilation_cluster = oc_cluster; + } + for (Node* src : node->in_nodes()) { + const auto iter = node_ancestors_map.find(src); + if (iter != node_ancestors_map.end()) { + const auto& ancestors_to_follow = iter->second.AncestorsForSuccessor(); + for (const auto& ancestor : ancestors_to_follow) { + if (ancestor.subgraph != subgraph_name || + ancestor.outside_compilation_cluster != oc_cluster) { + node_ancestors_map[node].ancestor_clusters.insert(ancestor); + } + } + } + } + if (!subgraph_name.empty()) { + // The node is in a compiled subgraph or an outside_compilation cluster. + if (oc_cluster.empty()) { + // The node is not in an outside_compilation cluster. Record the + // subgraph's ancestor dependencies. + for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) { + if (cluster.subgraph != subgraph_name) { + subgraph_ancestors_[subgraph_name].insert(cluster.subgraph); + TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( + cluster.subgraph, subgraph_name, subgraph_ancestors_, + node_ancestors_map, &cycle_detector, &cycle_detector_map)); + } + } + } else { + Subgraph& subgraph = subgraphs_[subgraph_name]; + // The node is in an outside_compilation cluster. Record the cluster + // and/or subgraph ancestor dependencies. + for (const auto& cluster : node_ancestors_map[node].ancestor_clusters) { + if (cluster.subgraph == subgraph_name) { + // The ancestor is in the same subgraph. + if (cluster.outside_compilation_cluster != oc_cluster) { + // But not in the same oc_cluster, so record the dependency. + subgraph.RecordOutsideCompilationDependency( + oc_cluster, cluster.outside_compilation_cluster); + TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( + cluster.outside_compilation_cluster, oc_cluster, + subgraph.OutsideCompilationAncestorMap(), node_ancestors_map, + &cycle_detector, &cycle_detector_map)); + } + } else { + // The ancestor is in a different subgraph, so record the + // dependency. + subgraph_ancestors_[subgraph_name].insert(cluster.subgraph); + TF_RETURN_IF_ERROR(CheckClusterDependencyForCycles( + cluster.subgraph, subgraph_name, subgraph_ancestors_, + node_ancestors_map, &cycle_detector, &cycle_detector_map)); + } + } + } + } + } + return Status::OK(); +} + Status Encapsulator::MakePrunedGraphCopyAndInline( const Graph& graph, const std::vector& sink_nodes, std::unique_ptr* pruned_graph, @@ -2166,6 +2535,7 @@ Status EncapsulateSubgraphsInFunctions( Encapsulator encapsulator(std::move(group_attribute), std::move(outside_compilation_attribute), &graph_in); + TF_RETURN_IF_ERROR(encapsulator.FindClusterDependencies()); TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 34be4409a38..5fee36f022a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -80,7 +80,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate -// subgraphs pass and that should in turn be compiled via _XlaLaunch operators. +// subgraphs pass and that should in turn be compiled via XlaLaunch operators. extern const char* const kXlaCompiledKernelAttr; // Does `node` have the kXlaCompiledKernelAttr attribute? diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 8599a7038af..5ec24d39a2c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -74,7 +74,7 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), " has value '", + key_to_string(elt_a.first), "' has value '", value_to_string(elt_a.second), "' got: '", value_to_string(iter->second), "'"); } @@ -121,8 +121,22 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, } return false; } + std::unordered_set control_input_a; + std::unordered_set control_input_b; for (int i = 0; i < a.input_size(); ++i) { - if (a.input(i) != b.input(i)) { + if (str_util::StartsWith(a.input(i), "^")) { + if (!str_util::StartsWith(b.input(i), "^")) { + if (diff) { + *diff = strings::StrCat( + diff_preamble, " mismatch for node ", a.name(), " input ", i, + ", expected control input ", a.input(i), " got ", b.input(i), + " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } + control_input_a.insert(a.input(i)); + control_input_b.insert(b.input(i)); + } else if (a.input(i) != b.input(i)) { if (diff) { *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), " input ", i, ", expected ", a.input(i), @@ -132,11 +146,29 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return false; } } + if (control_input_a != control_input_b) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } return EqualProtoMap( a.attr(), b.attr(), [](const string& s) { return s; }, [](const AttrValue& v) { return v.DebugString(); }, [](const string& key, const AttrValue& av, const AttrValue& bv) { - return av.DebugString() == bv.DebugString(); + if (key == "ancestors") { + // The ancestors are added from a set so the order is unpredictable; + // just compare set equality not list equality. + std::unordered_set a_set(av.list().s().begin(), + av.list().s().end()); + std::unordered_set b_set(bv.list().s().begin(), + bv.list().s().end()); + return a_set == b_set; + } else { + return av.DebugString() == bv.DebugString(); + } }, strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); @@ -261,6 +293,7 @@ REGISTER_OP("XlaHostCompute") .Output("outputs: Toutputs") .Attr("Tinputs: list(type) >= 0") .Attr("Toutputs: list(type) >= 0") + .Attr("ancestors: list(string) >= 0") .Attr("key: string") .Attr("shape_inference_graph: string = ''") .Attr("shapes: list(shape) >= 0") @@ -899,6 +932,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"C:o:0", "c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, @@ -1044,17 +1078,20 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"D:o:0", "F:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", + gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, {"key", "host_compute_channel_F1_O2"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O2"}, {"shapes", gtl::ArraySlice({})}, {"_outside_compilation_subgraph", "O2"}}, - {"F"}}, + {"F", "outside_compilation_O1_host_compute"}}, {{"outside_compilation_O1_host_compute"}, "XlaHostCompute", {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, @@ -1193,6 +1230,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"C:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, @@ -1215,6 +1253,7 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"G:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F2_O1"}, {"shape_inference_graph", ""}, {"shapes", @@ -1279,6 +1318,179 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } +// Test with two functions to transform, each with one outside_compilation +// cluster, with the dependency between them purely from an outside_compilation +// edge. +TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = InputShaped(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = + Binary(a, b, b1.opts().WithName("G").WithAttr("_encapsulate", "F2")); + Node* h = Unary(g, b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1") + .WithControlInput(e)); + Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); + Binary(f, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, shape.opts()); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F1_O1", &library_expected)); + } + + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape.opts().WithName("KnownShape/_0")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F2", "O1", + {DT_FLOAT}, shape.opts()); + Node* h = Unary(recv, shape.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F2", "O1", {h}, shape.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape, "F2_O1", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"C:o:0", "D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}, + {"D"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F2", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + { + {{"G"}, "BinaryTest", {"a_0_arg", "b_0_arg"}}, + {{"I"}, + "UnaryTest", + {"outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"G:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F2_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F2_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, + }, + {{"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = InputShaped(b2.opts().WithName("B")); + + Node* key_constant1 = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant1, 0), "F1", "O1", + {DT_FLOAT, DT_FLOAT}, b2.opts()); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts() + .WithName("E") + .WithControlInputs({recv1, b}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send1 = SendFromHost(ops::NodeOut(key_constant1, 0), "F1", "O1", {e}, + b2.opts().WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}), + "F1"); + + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = + b2.opts().WithControlInput(s1).FinalizeBuilder(&node_builder1); + + Node* key_constant2 = + KeyPlaceholder("F2", b2.opts().WithName("F2_key_placeholder")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant2, 0), "F2", "O1", + {DT_FLOAT}, b2.opts()); + Node* h = Unary(recv2, b2.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1") + .WithControlInput(e)); + Node* send2 = SendFromHost(ops::NodeOut(key_constant2, 0), "F2", "O1", {h}, + b2.opts()); + + Node* s2 = Sequencer( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2}), + "F2"); + NodeBuilder node_builder2("F2", "F2", lib_def.get()); + node_builder2.Input(a).Input(b); + Node* call2 = b2.opts() + .WithControlInputs({s2, call1}) + .FinalizeBuilder(&node_builder2); + Binary(call1, call2, b2.opts().WithName("J")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + // Test with one outside_compilation cluster that has no inputs from the // compiled subgraph. TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { @@ -1323,6 +1535,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", @@ -1406,6 +1619,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {}, {{"Tinputs", gtl::ArraySlice({})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", @@ -1487,6 +1701,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", gtl::ArraySlice({})}, @@ -1567,6 +1782,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", ""}, {"shapes", gtl::ArraySlice({})}, @@ -1607,6 +1823,371 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } +// Test with two outside_compilation clusters that interact outside the compiled +// subgraph, where the ancestor has no HostCompute Op. +TEST(EncapsulateSubgraphsTest, + OutsideCompilationClusterDependencyNoSrcCluster) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Node* g = Unary(f, b1.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* h = Unary(g, b1.opts().WithName("H").WithAttr("_encapsulate", "F1")); + Binary(e, h, b1.opts().WithName("I")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape2.opts().WithName("KnownShape/_0")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT}, shape2.opts()); + Node* g = Unary(ops::NodeOut(recv2, 0), shape2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, shape2.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape2, "F1_O2", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"H"}, + "UnaryTest", + {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"F:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O2"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}}, + }, + {{"h_0_retval", "H:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT}, b2.opts()); + Node* g = Unary(recv, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O2", {g}, b2.opts()); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b).ControlInput(s1); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("I")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two outside_compilation clusters that interact outside the compiled +// subgraph, where the successor has no HostCompute Op. +TEST(EncapsulateSubgraphsTest, + OutsideCompilationClusterDependencyNoDstCluster) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + /*Node* g =*/Unary(a, b1.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1")); + Binary(e, h, b1.opts().WithName("I")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape1.opts()); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "UnaryTest", + {"outside_compilation_O1_host_compute:outputs:0"}}, + {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, + }, + {{"h_0_retval", "H:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + /*Node* g =*/Unary(a, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* s1 = Sequencer( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b).ControlInput(s1); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("I")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two outside_compilation clusters that interact outside the compiled +// subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Node* g = Unary(d, b1.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* h = Unary(f, b1.opts().WithName("H").WithAttr("_encapsulate", "F1")); + /*Node* i =*/Binary(d, e, + b1.opts() + .WithName("I") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O3") + .WithControlInput(g)); + Binary(e, h, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* key_constant = + KeyPlaceholderShape(shape1.opts().WithName("KnownShape/_0")); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, shape1.opts()); + Node* e = Unary(ops::NodeOut(recv2, 0), shape1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, shape1.opts()); + TF_EXPECT_OK( + AddGraphDefToFunctionLibrary(shape1, "F1_O1", &library_expected)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"h_0_retval:float"}, {}, + {{{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"outside_compilation_O1_host_compute:outputs:0"}}, + {{"H"}, "UnaryTest", {"F:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", + "_outside_compilation_shape_inference_F1_O1"}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O1"}}}, + {{"outside_compilation_O2_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", + gtl::ArraySlice({"outside_compilation_O1_host_compute"})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O2"}}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O3_host_compute"}, + "XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"ancestors", + gtl::ArraySlice({"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"})}, + {"key", "host_compute_channel_F1_O3"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}, + {"_outside_compilation_subgraph", "O3"}}, + {"outside_compilation_O1_host_compute", + "outside_compilation_O2_host_compute"}}}, + {{"h_0_retval", "H:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* key_constant = + KeyPlaceholder("F1", b2.opts().WithName("F1_key_placeholder")); + Node* recv1 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O1", + {DT_FLOAT}, b2.opts()); + Node* e = Unary(recv1, b2.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* send = + SendFromHost(ops::NodeOut(key_constant, 0), "F1", "O1", {e}, b2.opts()); + Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", + {DT_FLOAT}, b2.opts()); + Node* g = Unary(recv2, b2.opts() + .WithName("G") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2") + .WithControlInput(e)); + Node* recv3 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O3", + {DT_FLOAT}, b2.opts()); + /*Node* i =*/Binary(recv3, e, + b2.opts() + .WithName("I") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O3") + .WithControlInput(g)); + Node* s1 = Sequencer(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send, recv2, recv3}), + "F1"); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b).ControlInput(s1); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("J")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + // Test with one outside_compilation cluster that has no outputs from the // compiled subgraph. TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { @@ -1731,6 +2312,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"c:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"ancestors", gtl::ArraySlice({})}, {"key", "host_compute_channel_F1_O1"}, {"shape_inference_graph", "_outside_compilation_shape_inference_F1_O1"}, diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index bc68afb322b..805bbc62c1e 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -354,6 +354,16 @@ bool GraphCycles::IsReachableNonConst(int32 x, int32 y) { return reachable; } +bool GraphCycles::CanContractEdge(int32 a, int32 b) { + CHECK(HasEdge(a, b)) << "No edge exists from " << a << " to " << b; + RemoveEdge(a, b); + bool reachable = IsReachableNonConst(a, b); + // Restore the graph to its original state. + InsertEdge(a, b); + // If reachable, then contracting edge will cause cycle. + return !reachable; +} + bool GraphCycles::ContractEdge(int32 a, int32 b) { CHECK(HasEdge(a, b)); RemoveEdge(a, b); @@ -388,4 +398,8 @@ std::unordered_set GraphCycles::Successors(int32 node) { return rep_->nodes_[node]->out; } +std::unordered_set GraphCycles::Predecessors(int32 node) { + return rep_->nodes_[node]->in; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h index d11d6e27b1b..44448fa3d78 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.h +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h @@ -85,6 +85,9 @@ class GraphCycles { // and returns false. bool ContractEdge(int32 a, int32 b); + // Return true if can contract edge, otherwise return false. + bool CanContractEdge(int32 a, int32 b); + // Return whether dest_node is reachable from source_node // by following edges. bool IsReachable(int32 source_node, int32 dest_node) const; @@ -115,6 +118,7 @@ class GraphCycles { bool CheckInvariants() const; std::unordered_set Successors(int32 node); + std::unordered_set Predecessors(int32 node); // ---------------------------------------------------- struct Rep; diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc index e47b782207e..274f5938a12 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc @@ -494,6 +494,20 @@ TEST_F(GraphCyclesTest, ContractEdge) { EXPECT_TRUE(g_.HasEdge(1, 4)); } +TEST_F(GraphCyclesTest, CanContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + EXPECT_FALSE(g_.CanContractEdge(1, 3)); + EXPECT_FALSE(g_.CanContractEdge(2, 4)); + EXPECT_TRUE(g_.CanContractEdge(1, 2)); + EXPECT_TRUE(g_.CanContractEdge(2, 3)); + EXPECT_TRUE(g_.CanContractEdge(3, 4)); +} + static void BM_StressTest(int iters, int num_nodes) { while (iters > 0) { tensorflow::GraphCycles g; diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index f48941fce32..27287e0f963 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -37,30 +37,28 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/util/stream_executor_util.h" -namespace gpu = perftools::gputools; - namespace tensorflow { -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : OpKernel(ctx), device_type_(ctx->device_type()) { - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func)); - function_ = *func; - DataTypeVector constant_types; - OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types)); - num_constant_args_ = constant_types.size(); - OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_)); +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + device_type_(ctx->device_type()), + function_(function) { if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id_ = gpu::host::kHostPlatformId; + platform_id_ = se::host::kHostPlatformId; } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = gpu::cuda::kCudaPlatformId; + platform_id_ = se::cuda::kCudaPlatformId; } else { platform_id_ = nullptr; } } -Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { +Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache) { const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); if (s.ok()) { @@ -69,9 +67,9 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } - auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_); + auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_); if (!platform.ok()) { - return StreamExecutorUtil::ConvertStatus(platform.status()); + return platform.status(); } xla::LocalClientOptions client_options; client_options.set_platform(platform.ValueOrDie()); @@ -92,15 +90,15 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx, return Status::OK(); } -void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOp::Compute " +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " << Canonicalize(function_.name(), AttrSlice(&function_.attr())); // We store information about the JIT-compiled XLA computation // in the ResourceMgr. ResourceMgr* rm = ctx->resource_manager(); OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - gpu::Stream* stream = + se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; XlaCompilationCache* cache; @@ -114,7 +112,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata; + const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); @@ -126,7 +124,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { } std::map variables = - SnapshotResourceVariables(ctx, num_resource_args_); + SnapshotResourceVariables(ctx, resources_); xla::LocalClient* client = static_cast(cache->client()); @@ -153,27 +151,29 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.device_type = &cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); + options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - // TODO(b/77671268): We don't set variable_representation_shape_fn here. This - // is restricted to Variables, but we need something like this to apply to - // normal Tensors too. + if (metadata) { + options.shape_representation_fn = metadata->shape_representation_fn(); + } const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; std::map constant_args; - for (int i = 0; i < num_constant_args_; ++i) { + for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + OP_REQUIRES_OK( + ctx, cache->Compile(options, function_, constant_args, variables, ctx, + &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; - XlaComputationLaunchContext launch_context( - num_resource_args_, client, xla_allocator, allocate_xla_tensors); + XlaComputationLaunchContext launch_context(client, xla_allocator, + allocate_xla_tensors); launch_context.PopulateInputs(ctx, kernel, variables); // Execute the computation. @@ -196,14 +196,69 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Done"; } +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + XlaLocalLaunchOp::~XlaLocalLaunchOp() { VLOG(1) << "XlaLocalLaunchOp destroyed"; } -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU), - XlaLocalLaunchOp); +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); -REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") .Device(DEVICE_GPU) .HostMemory("constants") .HostMemory("resources"), diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h index c6cc0986af0..8dfc4b382d5 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h @@ -26,6 +26,41 @@ limitations under the License. namespace tensorflow { +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector& constants, + const std::vector& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Builds a XlaCompilationCache class suitable for the current device. + Status BuildCompilationCache(OpKernelContext* ctx, + XlaCompilationCache** cache); + + // Indexes of compile-time constant inputs + std::vector constants_; + // Indexes of resource inputs + std::vector resources_; + + DeviceType device_type_; + NameAttrList function_; + se::Platform::Id platform_id_; +}; + // XlaLocalLaunchOp is used to replace a region of the TensorFlow graph // which will be compiled and executed using XLA. The XlaLocalLaunchOp is // responsible for handling interactions with the TensorFlow executor. @@ -35,26 +70,12 @@ namespace tensorflow { // XlaLocalLaunchOp uses xla::LocalClient::Compile() and // xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device // memory. -class XlaLocalLaunchOp : public OpKernel { +class XlaLocalLaunchOp : public XlaLocalLaunchBase { public: explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); ~XlaLocalLaunchOp() override; - void Compute(OpKernelContext* ctx) override; - private: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** compiler); - - DeviceType device_type_; - NameAttrList function_; - int num_constant_args_; - // Number of resource variable arguments. - int num_resource_args_; - - perftools::gputools::Platform::Id platform_id_; - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); }; diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 07320b43dab..f2473d98ffd 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -17,7 +17,7 @@ limitations under the License. namespace tensorflow { -REGISTER_OP("_XlaLaunch") +REGISTER_OP("XlaLaunch") .Input("constants: Tconstants") .Attr("Tconstants: list(type) >= 0") .Input("args: Targs") @@ -28,7 +28,7 @@ REGISTER_OP("_XlaLaunch") .Attr("Tresults: list(type) >= 0") .Attr("function: func") // XLA random-number generation ops are stateful. - // TODO(phawkins): create stateful and non-stateful variants of _XlaLaunch. + // TODO(phawkins): create stateful and non-stateful variants of XlaLaunch. .SetIsStateful() .Doc("XLA Launch Op. For use by the XLA JIT only."); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 6430975335f..7ed609c4374 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -122,8 +122,7 @@ Status XlaCompilationCache::BuildSignature( namespace { -// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch -// op. +// Builds a XlaCompiler::Argument vector from the arguments to the XlaLaunch op. Status BuildArguments(const std::map& constant_args, const std::map& variable_args, OpKernelContext* ctx, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 682d6ea8ccc..ab644ff5a61 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -48,17 +48,16 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable) { std::map variables = GetVariables(ctx); - int64 num_resource_args = variables.size(); xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( - num_resource_args, client, client->backend().memory_allocator(), true); + client, client->backend().memory_allocator(), true); launch_context.PopulateInputs(ctx, result, variables); - perftools::gputools::Stream* stream = + se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; TF_RET_CHECK(stream); @@ -67,6 +66,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(ctx->step_id()); auto run_result = executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); @@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile( options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.shape_representation_fn = metadata.shape_representation_fn(); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, - /*compile_options=*/nullptr); + result, executable, &compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 23c6f3903f8..7cc3d0e007b 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -29,11 +29,8 @@ limitations under the License. namespace tensorflow { // An OpKernel that compiles an op to an XLA computation and runs it. Unlike -// _XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a +// XlaLaunch this doesn't rely on any rewrites of the graphdef - it will run a // vanilla TensorFlow op as long as the bridge supports it. -// -// Importantly _XlaLaunch assumes all input and output tensors are on the host, -// whereas XlacompileOnDemandOp works with tensors in device memory. class XlaCompileOnDemandOp : public OpKernel { public: explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index bc07dbd7bdf..ea9e0366043 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - registration, - /*transfer_as_literal=*/false, &device)); + TF_RETURN_IF_ERROR( + XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 12f471735f6..f13b46c532e 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -48,10 +48,9 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/ptr_util.h" #include "tensorflow/core/util/stream_executor_util.h" -namespace se = ::perftools::gputools; - namespace tensorflow { // Caches a XlaDeviceAllocator per pair. A @@ -111,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, std::unique_ptr* device) { + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -121,7 +122,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { - return StreamExecutorUtil::ConvertStatus(platform.status()); + return platform.status(); } const DeviceAttributes attrs = Device::BuildDeviceAttributes( @@ -130,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice(options, attrs, device_ordinal, - DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal)); + device->reset(new XlaDevice( + options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + shape_representation_fn_(std::move(shape_representation_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -171,19 +174,28 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice(const SessionOptions& options, - const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform, - bool transfer_as_literal) +XlaDevice::XlaDevice( + const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_(device_ordinal, platform, jit_device_name, + shape_representation_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), - transfer_as_literal_(transfer_as_literal) {} + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(shape_representation_fn) { + VLOG(1) << "Created XLA device " << jit_device_name; +} -XlaDevice::~XlaDevice() {} +XlaDevice::~XlaDevice() { + if (gpu_device_info_ != nullptr) { + gpu_device_info_->default_context->Unref(); + } +} xla::LocalClient* XlaDevice::client() const { // We lazily create the client because the platform commits to the @@ -191,9 +203,8 @@ xla::LocalClient* XlaDevice::client() const { // don't want to do it until we get a chance to hook the platform up // to a simulator. - // For now GetOrCreateLocalClient always returns success when passed - // a non-null platform. If that changes we may have to plumb in some - // way to pass Status back. + // TODO(b/78468222): This can fail, at least when the backend is GPU and + // there is no GPU on the host. return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie(); } @@ -218,15 +229,33 @@ xla::StatusOr XlaDevice::GetStream() { return stream_.get(); } +Status XlaDevice::CreateAndSetGpuDeviceInfo() { + if (gpu_device_info_ == nullptr) { + TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); + // Call GetAllocator for the side-effect of ensuring the allocator + // is created. + GetAllocator({}); + // XlaDevice owns both gpu_device_info_ and + // gpu_device_info_->default_context. + gpu_device_info_ = MakeUnique(); + gpu_device_info_->stream = stream; + gpu_device_info_->default_context = new XlaDeviceContext( + stream, client(), transfer_as_literal_, shape_representation_fn_); + set_tensorflow_gpu_device_info(gpu_device_info_.get()); + } + + return Status::OK(); +} + Status XlaDevice::FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) { VLOG(1) << "XlaDevice::FillContextMap"; device_context_map->resize(graph->num_node_ids()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - // Call GetAllocator for the side-effect of ensuring the allocator and - // XlaTensorInfoManager is created. - (void)GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); + // Call GetAllocator for the side-effect of ensuring the allocator is created. + GetAllocator({}); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, + shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -239,11 +268,10 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - // When TraceMe profiling is off (which is the default), the - // following TraceMe constructor is simply a conditional test of - // false value. Measurements show that its overhead is negligible. - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + // When Xprof profiling is off (which is the default), constructing the + // activity is simple enough that its overhead is negligible. + tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -251,8 +279,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), - op_kernel->IsExpensive()); + tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } @@ -274,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_); + XlaTransferManager manager(stream, client(), transfer_as_literal_, + shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 4fe7dd8c9fa..d5d345d43b1 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,8 +17,7 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state -// is managed by XLA. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -27,6 +26,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -49,20 +49,25 @@ class XlaDevice : public LocalDevice { // retrieved e.g., when lazily creating the XlaCompilationCache device. class Metadata { public: - Metadata(int device_ordinal, perftools::gputools::Platform* platform, - const DeviceType& device_type); + Metadata(int device_ordinal, se::Platform* platform, + const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); // The index of the device on this host. int device_ordinal() const; - perftools::gputools::Platform* platform() const; + se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } private: const int device_ordinal_; const DeviceType device_type_; - perftools::gputools::Platform* platform_; // Not owned. + se::Platform* platform_; // Not owned. + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -76,17 +81,19 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. - static Status Create(const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - std::unique_ptr* device); + static Status Create( + const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - ::perftools::gputools::Platform* platform, - bool transfer_as_literal); + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -103,7 +110,11 @@ class XlaDevice : public LocalDevice { Tensor* tensor) override; xla::LocalClient* client() const; - xla::StatusOr<::perftools::gputools::Stream*> GetStream(); + xla::StatusOr GetStream(); + + // If not already set, create and set GpuDeviceInfo. + // Not thread-safe + Status CreateAndSetGpuDeviceInfo(); private: // The metadata of this XlaDevice. @@ -113,8 +124,8 @@ class XlaDevice : public LocalDevice { // The name of the device that is used to compile Ops for this XlaDevice. DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - ::perftools::gputools::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -123,6 +134,11 @@ class XlaDevice : public LocalDevice { // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // If set, holds default device context (that we must Unref) + // and its stream. + std::unique_ptr gpu_device_info_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 43eb1640126..ff30b62bad7 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/mem.h" -namespace se = ::perftools::gputools; - namespace tensorflow { // The allocator used for Tensors assigned to the XLA device. @@ -49,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream, - xla::LocalClient* client, - bool transfer_as_literal) +XlaTransferManager::XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : stream_(stream), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal) {} + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(std::move(shape_representation_fn)) {} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { @@ -78,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + return errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return Status::OK(); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -98,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); + + TensorShape shape; + if (shape_representation_fn_) { + shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); + } else { + shape = device_tensor->shape(); + } if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), device_tensor->shape(), client_, + device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); if (!s.ok()) { done(s); @@ -108,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } } - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - status = TransferLiteralToDevice(*cpu_tensor, device_tensor); + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. Status block_status = stream_->BlockHostUntilDone(); @@ -173,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal) - : manager_(stream, client, transfer_as_literal) {} +XlaDeviceContext::XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(stream, client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index ad914a1c23b..9af96558684 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -45,16 +46,16 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(perftools::gputools::Stream* stream, - xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); - perftools::gputools::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_; } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -64,13 +65,14 @@ class XlaTransferManager { // Stream obtained from a Device, used to transfer tensors between // CPU and device. - perftools::gputools::Stream* stream_; + se::Stream* stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. - bool transfer_as_literal_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -78,8 +80,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(perftools::gputools::Stream* stream, - xla::LocalClient* client, bool transfer_as_literal); + explicit XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, @@ -87,9 +90,7 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; - perftools::gputools::Stream* stream() const override { - return manager_.stream(); - } + se::Stream* stream() const override { return manager_.stream(); } private: XlaTransferManager manager_; diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 498d25cf566..9c00a0682cc 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" +#include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/kernels/sendrecv_ops.h" @@ -32,7 +33,7 @@ namespace tensorflow { // Dummy OpKernel, used for kernels assigned to an XLA device that should be // compiled. Should never be called at runtime since such ops should be -// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an +// rewritten to a XlaLaunch op. If it is called, it means the placer placed an // operator on an XLA device but the compiler did not compile it. class XlaDeviceDummyOp : public OpKernel { public: @@ -41,7 +42,7 @@ class XlaDeviceDummyOp : public OpKernel { }; #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ - REGISTER_KERNEL_BUILDER(Name("_XlaLaunch") \ + REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ .HostMemory("constants") \ .HostMemory("resources"), \ @@ -63,6 +64,9 @@ class XlaDeviceDummyOp : public OpKernel { ConstantOp); \ REGISTER_KERNEL_BUILDER( \ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("IdentityN").Device(DEVICE).TypeConstraint("T", TYPES), \ + IdentityNOp); \ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), PlaceholderOp); \ REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \ PlaceholderOp); \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ac60423d959..26842fbe5cc 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -48,12 +48,22 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; return Status::OK(); } + + // TODO(b/78468222): Uncomment after fixing this bug + // status = device->CreateAndSetGpuDeviceInfo(); + // if (!status.ok()) { + // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT, + // " device"); + // return status; + // } + devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 50b0061d692..d0c7a936512 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -32,18 +32,19 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/stream_executor_util.h" -namespace gpu = perftools::gputools; - namespace tensorflow { +namespace { +using xla::ScopedShapedBuffer; +using xla::ShapedBuffer; +} // anonymous namespace -std::map SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables) { +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables) { std::map snapshot; - int first_variable = ctx->num_inputs() - num_variables; - for (int i = 0; i < num_variables; ++i) { + for (int i : variables) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, first_variable + i); - OptionalTensor& tensor = snapshot[first_variable + i]; + ResourceHandle handle = HandleFromInput(ctx, i); + OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); @@ -54,74 +55,78 @@ std::map SnapshotResourceVariables(OpKernelContext* ctx, return snapshot; } -XlaAllocator::XlaAllocator(const gpu::Platform* platform, Allocator* wrapped) +XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped) : xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {} XlaAllocator::~XlaAllocator() {} -xla::StatusOr XlaAllocator::Allocate( +xla::StatusOr XlaAllocator::Allocate( int device_ordinal, uint64 size, bool retry_on_failure) { - void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size); + AllocationAttributes attrs; + attrs.no_retry_on_failure = !retry_on_failure; + void* data = + wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size, attrs); if (data == nullptr) { return errors::ResourceExhausted("Out of memory while trying to allocate ", size, " bytes."); - } else { - return gpu::DeviceMemoryBase(data, size); } + return xla::OwningDeviceMemory(se::DeviceMemoryBase(data, size), + device_ordinal, this); } -Status XlaAllocator::Deallocate(int device_ordinal, - gpu::DeviceMemoryBase* mem) { - wrapped_->DeallocateRaw(mem->opaque()); +Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { + wrapped_->DeallocateRaw(mem.opaque()); return Status::OK(); } -namespace { +namespace internal { // Return the 'index''th subtree of the given ShapedBuffer as a // ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the // subtree, and sets the input's buffer pointers to nullptr for the subtree. -std::unique_ptr ExtractSubShapedBuffer( - xla::ShapedBuffer* shaped_buffer, int index, +ScopedShapedBuffer ExtractSubShapedBuffer( + ShapedBuffer* shaped_buffer, int index, xla::DeviceMemoryAllocator* allocator) { - xla::Shape on_host_shape = xla::ShapeUtil::GetTupleElementShape( + const xla::Shape& on_host_shape = xla::ShapeUtil::GetTupleElementShape( shaped_buffer->on_host_shape(), index); - xla::Shape on_device_shape = xla::ShapeUtil::GetTupleElementShape( + const xla::Shape& on_device_shape = xla::ShapeUtil::GetTupleElementShape( shaped_buffer->on_device_shape(), index); - xla::ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, - shaped_buffer->platform(), - shaped_buffer->device_ordinal()); + ShapedBuffer sub_shaped_buffer(on_host_shape, on_device_shape, + shaped_buffer->platform(), + shaped_buffer->device_ordinal()); auto& shape_tree = shaped_buffer->buffers(); auto& sub_shape_tree = sub_shaped_buffer.buffers(); sub_shape_tree.CopySubtreeFrom(shape_tree, /*source_base_index=*/{index}, /*target_base_index=*/{}); - for (auto& index_to_buffer : shape_tree) { - if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) { - index_to_buffer.second = gpu::DeviceMemoryBase(nullptr, 0); - } - } - return xla::ScopedShapedBuffer::MakeScoped(&sub_shaped_buffer, allocator) - .ValueOrDie(); + shape_tree.ForEachMutableElement( + [index](const xla::ShapeIndex& shape_index, + tensorflow::se::DeviceMemoryBase* data) { + // shape_index is empty for the root node. Ignore that. + if (!shape_index.empty() && shape_index[0] == index) { + *data = tensorflow::se::DeviceMemoryBase(nullptr, 0); + } + }); + return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator); } -} // namespace +} // namespace internal +using internal::ExtractSubShapedBuffer; XlaComputationLaunchContext::XlaComputationLaunchContext( - int64 num_resource_args, xla::LocalClient* client, - xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors) - : num_resource_args_(num_resource_args), - client_(client), + xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, + bool allocate_xla_tensors) + : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors) {} void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, const std::map& variables) { - // Build xla::ShapedBuffers that point directly to the Tensor buffers. + // Build ShapedBuffers that point directly to the Tensor buffers. arg_buffers_.reserve(kernel->xla_input_shapes.size() + 1); arg_buffers_.resize(kernel->xla_input_shapes.size()); - arg_ptrs_ = std::vector(arg_buffers_.size()); + arg_ptrs_ = std::vector(arg_buffers_.size()); // Pass remaining parameters. const Tensor* t; @@ -140,16 +145,15 @@ void XlaComputationLaunchContext::PopulateInputs( if (xla::ShapeUtil::IsTuple(on_device_shape)) { const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); - arg_ptrs_[i] = - const_cast(&xla_tensor->shaped_buffer()); + arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); } else { CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) << "On-device shape " << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) << " not the same as on-host shape " << xla::ShapeUtil::HumanStringWithLayout(shape); - gpu::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_[i] = xla::MakeUnique( + se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); + arg_buffers_[i] = xla::MakeUnique( /*on_host_shape=*/shape, /*on_device_shape=*/shape, client_->platform(), client_->default_device_ordinal()); arg_buffers_[i]->set_buffer(dmem, /*index=*/{}); @@ -160,15 +164,15 @@ void XlaComputationLaunchContext::PopulateInputs( void XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - std::unique_ptr output) { - gpu::Stream* stream = + ScopedShapedBuffer output) { + se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); VLOG(2) << "Result tuple shape (on device): " - << output->on_device_shape().DebugString(); + << output.on_device_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); @@ -191,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES_OK( ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( - const_tensor.dtype(), const_tensor.shape(), - client_, stream->parent()->device_ordinal())); - } Device* device = dynamic_cast(ctx->device()); OP_REQUIRES(ctx, device != nullptr, @@ -226,18 +225,18 @@ void XlaComputationLaunchContext::PopulateOutputs( const TensorShape& shape = kernel->outputs[i].shape; VLOG(2) << "Retval " << i << " shape " << shape.DebugString(); - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); + se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { Tensor* output_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); CHECK(xla_tensor); - xla_tensor->set_shaped_buffer( - ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_)); + xla_tensor->set_shaped_buffer(ScopedShapedBuffer( + ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( ctx->expected_output_dtype(i), shape, buffer, allocator); - output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } ++output_num; @@ -257,7 +256,7 @@ void XlaComputationLaunchContext::PopulateOutputs( write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); - gpu::DeviceMemoryBase buffer = output->buffer({output_num}); + se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, @@ -282,12 +281,12 @@ void XlaComputationLaunchContext::PopulateOutputs( XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( - ExtractSubShapedBuffer(output.get(), output_num, xla_allocator_)); + ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); *variable->tensor() = output_tensor; } else { Tensor output_tensor = XlaTensorBuffer::MakeTensor( write.type, write.shape, buffer, allocator); - output->set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num}); + output.set_buffer(xla::OwningDeviceMemory(), {output_num}); *variable->tensor() = output_tensor; } ++output_num; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 14f70fe3589..4390701ccbd 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -22,6 +22,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" @@ -31,28 +33,28 @@ limitations under the License. namespace tensorflow { class XlaAllocator; -// Takes a snapshot of the values of resource variable arguments, which are -// the last `num_variables` arguments. We snapshot tensors that back +// Takes a snapshot of the values of resource variable arguments, whose +// indices are specified in `variables` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is // important that the shapes used for compilation match the true shapes of the // buffers. // -// Returns a map of TensorFlow argument index to resource variable. -std::map SnapshotResourceVariables(OpKernelContext* ctx, - int num_variables); +// Returns a map of TensorFlow argument index to resource variable. If a +// resource variable is not initialized, the corresponding OptionalTensor +// will have its `present` field set to false. +std::map SnapshotResourceVariables( + OpKernelContext* ctx, const std::vector& variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: // see comment on `AllowsAsynchronousDeallocation()`. class XlaAllocator : public xla::DeviceMemoryAllocator { public: - XlaAllocator(const perftools::gputools::Platform* platform, - Allocator* wrapped); + XlaAllocator(const se::Platform* platform, Allocator* wrapped); ~XlaAllocator() override; - xla::StatusOr Allocate( + xla::StatusOr Allocate( int device_ordinal, uint64 size, bool retry_on_failure) override; - Status Deallocate(int device_ordinal, - perftools::gputools::DeviceMemoryBase* mem) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // The Tensorflow BFC allocator used on GPU allows host-side deallocation // before GPU execution takes place. Tensorflow uses the ordering of the main @@ -74,7 +76,7 @@ class XlaComputationLaunchContext { // Create a new launch context. 'allocate_xla_tensors' is true if allocated // output tensors and variables are always XlaTensors. If false they are // assumed to be "normal" device pointers. - XlaComputationLaunchContext(int64 num_resource_args, xla::LocalClient* client, + XlaComputationLaunchContext(xla::LocalClient* client, xla::DeviceMemoryAllocator* xla_allocator, bool allocate_xla_tensors); @@ -87,14 +89,13 @@ class XlaComputationLaunchContext { // Given the XLA output in `output`, populate all outputs of `ctx`. void PopulateOutputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - std::unique_ptr output); + xla::ScopedShapedBuffer output); // Return the argument list. Only valid after PopulateInputs() has been // called. const std::vector& arguments() const { return arg_ptrs_; } private: - int64 num_resource_args_; xla::LocalClient* client_; xla::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; @@ -126,8 +127,7 @@ class XlaTensorBuffer : public TensorBuffer { } static Tensor MakeTensor(DataType dtype, const TensorShape& shape, - perftools::gputools::DeviceMemoryBase buffer, - Allocator* allocator) { + se::DeviceMemoryBase buffer, Allocator* allocator) { size_t expected_size = shape.num_elements() * DataTypeSize(dtype); auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size, buffer.size(), allocator); @@ -143,6 +143,17 @@ class XlaTensorBuffer : public TensorBuffer { Allocator* allocator_; }; +// Exposed in this header file for microbenchmarking purposes, but this is an +// internal implementation detail. +namespace internal { +// Return the 'index''th subtree of the given ShapedBuffer as a +// ScopedShapedBuffer. The returned ScopedShapedBuffer takes ownership of the +// subtree, and sets the input's buffer pointers to nullptr for the subtree. +xla::ScopedShapedBuffer ExtractSubShapedBuffer( + xla::ShapedBuffer* shaped_buffer, int index, + xla::DeviceMemoryAllocator* allocator); +} // namespace internal + } // namespace tensorflow #endif diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc new file mode 100644 index 00000000000..a45932403ec --- /dev/null +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains microbenchmarks for performance critical functions in +// xla_launch_util.cc. + +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +// Test ExtractSubBuffer with different depths (depth of ShapeTree) and fan-outs +// (cardinality of each non-leaf node's children). +void BM_ExtractSubBuffer(int iters, int depth, int fan_out) { + tensorflow::testing::StopTiming(); + xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128}); + for (int i = 0; i < depth; ++i) { + std::vector shapes(fan_out, shape); + shape = xla::ShapeUtil::MakeTupleShape(shapes); + } + xla::ShapedBuffer shaped_buffer(shape, shape, /*platform=*/nullptr, + /*device_ordinal=*/0); + tensorflow::testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + // Extract a buffer from approximately the middle of the first level of the + // tree. + (void)tensorflow::internal::ExtractSubShapedBuffer(&shaped_buffer, + /*index=*/fan_out / 2, + /*allocator=*/nullptr) + .release(); + } +} + +BENCHMARK(BM_ExtractSubBuffer) + ->ArgPair(1, 4) + ->ArgPair(1, 8) + ->ArgPair(1, 32) + ->ArgPair(1, 64) + ->ArgPair(1, 128) + ->ArgPair(1, 256) + ->ArgPair(1, 512) + ->ArgPair(2, 4) + ->ArgPair(2, 8) + ->ArgPair(2, 32) + ->ArgPair(2, 64) + ->ArgPair(2, 128); + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + tensorflow::testing::RunBenchmarks(); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 956328e6757..a7211c9c7e2 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -31,16 +31,15 @@ namespace tensorflow { return FromTensor(const_cast(tensor)); } -/*static*/ perftools::gputools::DeviceMemoryBase -XlaTensor::DeviceMemoryFromTensor(const Tensor& tensor) { +/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( + const Tensor& tensor) { const XlaTensor* xla_tensor = FromTensor(&tensor); if (xla_tensor) { CHECK(xla_tensor->has_shaped_buffer()); return xla_tensor->shaped_buffer().root_buffer(); } else { - return perftools::gputools::DeviceMemoryBase( - const_cast(tensor.tensor_data().data()), - tensor.tensor_data().size()); + return se::DeviceMemoryBase(const_cast(tensor.tensor_data().data()), + tensor.tensor_data().size()); } } @@ -53,22 +52,22 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, client->backend().transfer_manager()->HostShapeToDeviceShape( on_host_shape); - xla::ShapedBuffer buffer(on_host_shape, on_device_shape, client->platform(), - device_ordinal); - for (auto& index_to_buffer : buffer.buffers()) { + xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, + client->backend().memory_allocator(), + device_ordinal); + for (auto& index_to_buffer : shaped_buffer.buffers()) { xla::Shape subshape = xla::ShapeUtil::GetSubshape(on_device_shape, index_to_buffer.first); uint64 size = client->backend().transfer_manager()->GetByteSizeRequirement(subshape); - TF_ASSIGN_OR_RETURN(index_to_buffer.second, + TF_ASSIGN_OR_RETURN(xla::OwningDeviceMemory buffer, client->backend().memory_allocator()->Allocate( device_ordinal, size, /*retry_on_failure=*/false)); + // Move our buffer into shaped_buffer, which takes ownership of it. + index_to_buffer.second = buffer.Forget(); } - TF_ASSIGN_OR_RETURN(auto scoped_buffer, - xla::ScopedShapedBuffer::MakeScoped( - &buffer, client->backend().memory_allocator())); - set_shaped_buffer(std::move(scoped_buffer)); + set_shaped_buffer(std::move(shaped_buffer)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 5ff2fb08f03..6b29c82ec11 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -43,8 +43,7 @@ class XlaTensor { // which case the returned value is shaped_buffer()->root_buffer(), or a // normal Tensor in which case the returned value is // {tensor.tensor_data().data(), tensor.tensor_data().size}. - static perftools::gputools::DeviceMemoryBase DeviceMemoryFromTensor( - const Tensor& tensor); + static se::DeviceMemoryBase DeviceMemoryFromTensor(const Tensor& tensor); // Assign the internal ShapedBuffer to new memory for the given dtype and // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it @@ -55,7 +54,7 @@ class XlaTensor { // Some Tensors can have complex on-device shapes, including tuple shapes. To // manage the memory for these tensors a ShapedBuffer may be required. - // Return true if this TensorInfo contains a ShapedBuffer. + // Return true if this XlaTensor contains a ShapedBuffer. bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; } // Return the contained ShapedBuffer. // REQUIRES: has_shaped_buffer() @@ -63,17 +62,17 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } - // Mutates the TensorInfo to set the ShapedBuffer. - void set_shaped_buffer( - std::unique_ptr shaped_buffer) { - shaped_buffer_ = std::move(shaped_buffer); + // Mutates the XlaTensor to set the ShapedBuffer. + void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { + shaped_buffer_ = + xla::MakeUnique(std::move(shaped_buffer)); } // Some tensors on the device may have known values on the host. We use these // in on-demand mode to avoid re-copying values from the device if we know the // host value already. - // Return true if this TensorInfo contains a host tensor. + // Return true if this XlaTensor contains a host tensor. bool has_host_tensor() const { return host_tensor_ != nullptr; } // Return the contained host tensor. // REQUIRES: has_host_tensor() diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b9e42ca677c..2a88743c805 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -93,7 +93,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -111,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -127,7 +127,7 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -141,7 +141,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -156,7 +156,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -170,7 +170,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -184,7 +184,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -209,7 +209,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -225,7 +225,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -241,7 +241,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -263,7 +263,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -291,7 +291,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -300,14 +300,41 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + tags = [ + "manual", + "notap", + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) +tf_xla_py_test( + name = "eager_test", + size = "small", + srcs = ["eager_test.py"], + disabled_backends = [ + # TODO(b/78199195) Support XLA CPU devices in eager runtime + "cpu", + "cpu_ondemand", + # TODO(b/78468222) Enable GPU backend + "gpu", + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", + ], +) + tf_xla_py_test( name = "fft_test", size = "medium", @@ -319,7 +346,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -333,19 +360,19 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) tf_xla_py_test( name = "ftrl_test", - size = "small", + size = "medium", srcs = ["ftrl_test.py"], deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -361,7 +388,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -376,12 +403,27 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], ) +tf_xla_py_test( + name = "listdiff_op_test", + size = "small", + srcs = ["listdiff_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform_test", + "@six_archive//:six", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -389,7 +431,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -404,7 +446,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -416,7 +458,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -430,7 +472,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -443,7 +485,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -456,7 +498,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -471,7 +513,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -488,7 +530,7 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -503,7 +545,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -519,7 +561,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -532,7 +574,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", ], ) @@ -544,7 +586,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -556,7 +598,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -571,7 +613,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -584,7 +626,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -599,7 +641,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -615,7 +657,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -628,7 +670,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -642,7 +684,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -661,7 +703,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -674,7 +716,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -688,7 +730,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -707,7 +749,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -726,7 +768,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -736,11 +778,12 @@ tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], + tags = ["noasan"], # times out, http://b/78599043 deps = [ ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -752,7 +795,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -765,21 +808,34 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "xla_device_test", + size = "small", + srcs = ["xla_device_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) cuda_py_test( - name = "xla_device_test", + name = "xla_device_gpu_test", size = "small", - srcs = ["xla_device_test.py"], + srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], ) @@ -796,15 +852,23 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", ], - # TODO(b/62961789): Test fails with SIGABRT - tags = [ - "manual", - "notap", +) + +cuda_py_test( + name = "dense_layer_test", + size = "small", + srcs = ["dense_layer_test.py"], + additional_deps = [ + "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", + "//tensorflow/python:variables", ], ) @@ -847,7 +911,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -862,7 +926,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -900,7 +964,19 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "placeholder_test", + size = "small", + srcs = ["placeholder_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index ec547e16cd9..9d3a889b1f5 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -29,51 +29,70 @@ from tensorflow.python.platform import test class ArgMinMaxTest(xla_test.XLATestCase): - def _assertOpOutputMatchesExpected(self, op, inp, expected): - """Verifies that 'op' produces 'expected' when fed input 'inp' . + def _assertOpOutputMatchesExpected(self, op, axis, output_type, op_input, + expected): + """Verifies that 'op' produces 'expected' when fed input 'op_input' . Args: - op: operator to test - inp: numpy input array to use as input to 'op'. + op: argmin or argmax operator to test. + axis: integer axis to reduce across. + output_type: numpy datatype of the output to produce. + op_input: numpy input array to use as input to 'op'. expected: numpy array representing the expected output of 'op'. """ with self.test_session() as session: with self.test_scope(): pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name="a") - output = op(pinp) - result = session.run(output, {pinp: inp}) + dtypes.as_dtype(op_input.dtype), op_input.shape, name="a") + output = op(pinp, axis=axis, output_type=output_type) + result = session.run(output, {pinp: op_input}) self.assertAllEqual(result, expected) def testArgMinMax(self): # Complex numbers do not support argmin/argmax. minmax_types = set(self.numeric_types) - set(self.complex_types) for dtype in minmax_types: - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([1, 10, 27, 3, 3, 4], dtype=dtype), - expected=np.int32(2)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([0, 1, 0], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([0, 0], dtype=np.int32)) + # output_type is a numpy data type that is used to specify the desired + # output type of the op as well as to convert the Python number to the + # array scalar of the type. + for output_type in self.int_types: + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([1, 10, 27, 3, 3, 4], dtype=dtype), + expected=output_type(2)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([0, 1, 0], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmax, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([0, 0], dtype=output_type)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([3, 10, 27, 3, 2, 4], dtype=dtype), - expected=np.int32(4)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), - np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), - expected=np.array([1, 0, 1], dtype=np.int32)) - self._assertOpOutputMatchesExpected( - lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32), - np.array([[4, 1], [3, 2]], dtype=dtype), - expected=np.array([1, 1], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([3, 10, 27, 3, 2, 4], dtype=dtype), + expected=output_type(4)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=0, + output_type=output_type, + op_input=np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([1, 0, 1], dtype=output_type)) + self._assertOpOutputMatchesExpected( + math_ops.argmin, + axis=1, + output_type=output_type, + op_input=np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([1, 1], dtype=output_type)) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py new file mode 100644 index 00000000000..865f60ccab4 --- /dev/null +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DenseLayer JIT compilation on the CPU and GPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.compiler import jit +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.layers import layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + +jit_scope = jit.experimental_jit_scope + + +def GetRunMetadataLabels(run_metadata): + """Returns all labels in run_metadata.""" + labels = [] + for dev_stats in run_metadata.step_stats.dev_stats: + for node_stats in dev_stats.node_stats: + labels.append(node_stats.timeline_label) + return labels + + +def InLabels(labels, substr): + """Returns true iff one of the labels contains substr.""" + return any([substr in x for x in labels]) + + +def XlaLaunchOpCount(labels): + """Count how many XlaLaunch labels are present.""" + return sum("XlaLaunch(" in x for x in labels) + + +class DenseLayerTest(test.TestCase): + + def testDenseLayerAutoJit(self): + """Tests dense layer compilation in auto-jit mode. + + Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + """ + + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + config = config_pb2.ConfigProto() + config.graph_options.optimizer_options.global_jit_level = ( + config_pb2.OptimizerOptions.ON_1) + + with self.test_session(config=config) as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + def testDenseLayerJitScopeDefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer with static shape input tensor should be compiled into a single + XlaLaunch op by XLA. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[2, 2, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(1, XlaLaunchOpCount(labels)) + # No need to check whether ListDiff is compiled or not because ListDiff op + # is not used when input tensor shape is fully defined. + + def testDenseLayerJitScopeUndefinedShape(self): + """Tests that the dense layer node is properly compiled in jit scope. + + Dense layer uses shape op to get shape of input tensor if its shape is not + fully defined. XLA does not cluster shape op with other operators. But in + experimental_jit_scope, XLA is forced to compile shape op into its own + cluster, causing dense layer to be split into TWO XlaLaunch ops. + """ + + with self.test_session() as sess: + x = array_ops.placeholder(shape=[None, None, 3], dtype=np.float32) + with jit_scope(): + y = layers.dense(x, 3) + + sess.run(variables.initialize_all_variables()) + run_metadata = config_pb2.RunMetadata() + sess.run( + y, {x: np.array([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]])}, + run_metadata=run_metadata, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + + labels = GetRunMetadataLabels(run_metadata) + self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertFalse(InLabels(labels, "ListDiff")) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py new file mode 100644 index 00000000000..52d8d6d295c --- /dev/null +++ b/tensorflow/compiler/tests/eager_test.py @@ -0,0 +1,309 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for eager execution using XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import pooling +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import googletest + + +class EagerTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = three * five + self.assertAllEqual(15, product) + + def testExecuteListOutputLen0(self): + with self.test_scope(): + empty = constant_op.constant([], dtype=dtypes.float32) + result = array_ops.unstack(empty, 0) + self.assertTrue(isinstance(result, list)) + self.assertEqual(0, len(result)) + + def testExecuteListOutputLen1(self): + with self.test_scope(): + split_dim = constant_op.constant(1) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) + result = array_ops.split(value, 1, axis=split_dim) + self.assertTrue(isinstance(result, list)) + self.assertEqual(1, len(result)) + self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0]) + + def testExecuteListOutputLen3(self): + with self.test_scope(): + split_dim = constant_op.constant(1) + value = constant_op.constant([[0., 1., 2.], [3., 4., 5.]]) + result = array_ops.split(value, 3, axis=split_dim) + self.assertTrue(isinstance(result, list)) + self.assertEqual(3, len(result)) + self.assertAllEqual([[0], [3]], result[0]) + self.assertAllEqual([[1], [4]], result[1]) + self.assertAllEqual([[2], [5]], result[2]) + + def testBasicGraph(self): + # Run some ops eagerly + with self.test_scope(): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = three * five + self.assertAllEqual(15, product) + + # Run some ops graphly + with context.graph_mode(), self.test_session() as sess: + with self.test_scope(): + three = constant_op.constant(3) + five = constant_op.constant(5) + product = three * five + self.assertAllEqual(15, sess.run(product)) + + def testDegenerateSlices(self): + with self.test_scope(): + npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3) + t = constant_op.constant(npt) + # degenerate by offering a forward interval with a negative stride + self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :]) + # degenerate with a reverse interval with a positive stride + self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :]) + # empty interval in every dimension + self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1]) + + def testIdentity(self): + with self.test_scope(): + self.assertAllEqual(2, array_ops.identity(2)) + + def testIdentityOnVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(True) + i = array_ops.identity(v) + self.assertAllEqual(True, i.numpy()) + + def testAssignAddVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + v.assign_add(2.0) + self.assertEqual(3.0, v.numpy()) + + def testGradient(self): + def f(x): + return x + + with self.test_scope(): + grad_fn = backprop.gradients_function(f) + self.assertAllEqual(2., grad_fn(1., dy=2.)[0]) + + def testVariableGradient(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(1.0) + + def f(): + x = v0 * v0 + return x + + grads = backprop.implicit_grad(f)() + self.assertEqual(2., grads[0][0].numpy()) + + +class EagerFunctionTest(XLATestCase): + + def testBasic(self): + with self.test_scope(): + matmul = function.defun(math_ops.matmul, compiled=True) + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + sq = matmul(t, t, transpose_a=True) + self.assertAllEqual(sq.numpy().reshape(-1), [10, 14, 14, 20]) + + def testConv(self): + if 'GPU' in self.device: + # TODO(b/32333178) + self.skipTest('Current implementation of RandomStandardNormal kernel ' + 'is very slow on GPU, and has been blacklisted.') + with self.test_scope(): + data_format = 'channels_last' + conv = convolutional.Conv2D( + filters=1, kernel_size=2, padding='VALID', + data_format=data_format, activation=nn_ops.relu, + kernel_initializer=init_ops.ones_initializer(), + bias_initializer=init_ops.zeros_initializer()) + pool = pooling.MaxPooling2D(2, 2, data_format=data_format) + + def model(x): + x = conv(x) + return pool(x) + model = function.defun(model, compiled=True) + + x = array_ops.ones([1, 4, 4, 1]) + y = model(x) + self.assertAllEqual(y.numpy(), [[[[4.]]]]) + + def testReadVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun(compiled=True) + def f(): + return v.read_value() + + var = f() + self.assertEqual(1.0, var.numpy()) + + def testUpdateVariable(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + + def f(v): + v.assign_add(1.0) + return v + + f = function.defun(f, compiled=True) + + var = f(v) + self.assertEqual(2.0, var.numpy()) + + def testAllArgumentKinds(self): + """Test a complex function that takes different argument kinds. + + tf2xla machinery that translates, compiles, and runs defuns + classifies arguments into: compile-time constants, regular tensors, + and resources. This test creates a function with a mix of all these + kinds. Moreover, the order of function arguments is intentionally mixed up. + + This also tests the case when the same argument is a compile-time constant + as well as used in an operation that normally expects its inputs to be + in device memory - addition in this case. + """ + with self.test_scope(): + def foo(c1, r1, v1, c2, v2, r2): + # c1 and c2 are compile-time constants + # r1 and r2 are regular tensors + # v1 and v2 are resource variables + a = c1 + r1 + b = math_ops.cast(c2, dtypes.float32) + v2 + c = array_ops.slice(v1, c1, c2) + d = r2 * v2 + return a, b, c, d + + foo = function.defun(foo, compiled=True) + + c1 = [0, 0] + c2 = array_ops.ones([2], dtype=dtypes.int32) + + r1 = array_ops.ones([2]) + r2 = [[2., 2.], [3., 3.]] + + v1 = resource_variable_ops.ResourceVariable([[1., 2.], [3., 4.]]) + v2 = resource_variable_ops.ResourceVariable([[10., 20.], [30., 40.]]) + + a, b, c, d = foo(c1, r1, v1, c2, v2, r2) + + self.assertAllEqual([1, 1], a.numpy()) + self.assertAllEqual([[11., 21.], [31., 41.]], b.numpy()) + self.assertAllEqual([[1.]], c.numpy()) + self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) + + def testDefunInGradientTape(self): + with self.test_scope(): + v0 = resource_variable_ops.ResourceVariable(5.0) + + @function.defun(compiled=True) + def f(x): + x = v0 * v0 * x + return x + + x = constant_op.constant(3.0) + with backprop.GradientTape() as tape: + y = f(x) + dy = tape.gradient(y, v0) + + self.assertEqual(75, y.numpy()) + self.assertEqual(30, dy.numpy()) + + +class ExcessivePaddingTest(XLATestCase): + """Test that eager execution works with TPU flattened tensors. + + Tensors that would normally be excessively padded when written + to TPU memory are reshaped to 1-D flat tensors. + + This test case verifies that such tensors work with eager execution. + + The flattening currently only happens on TPU, but tests should work + fine with all backends as flattening is transparent. + """ + + def testFromConstant(self): + with self.test_scope(): + # Create constant of shape [100, 2, 1]. This tensor would be + # excessively padded on TPU. + tensor = constant_op.constant(100 * [[[10.0], [2.0]]]) + # Use reduce_sum since it requires correctly working with + # a particular dimension. + reduced = math_ops.reduce_sum(tensor, axis=1) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testFromOperation(self): + with self.test_scope(): + tensor = array_ops.ones([3, 100, 2, 2]) + reduced = math_ops.reduce_sum(tensor, axis=[0, 2, 3]) + self.assertAllEqual(100 * [12.0], reduced) + + def testAsFunctionInput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return math_ops.reduce_sum(x, axis=2) + + tensor = constant_op.constant(100 * [[[10.0, 2.0]]]) + reduced = f(tensor) + self.assertAllEqual(100 * [[12.0]], reduced) + + def testAsFunctionOutput(self): + with self.test_scope(): + + @function.defun(compiled=True) + def f(x): + return x * constant_op.constant(100 * [[[10.0, 2.0]]]) + + y = f(3) + reduced = math_ops.reduce_sum(y, axis=2) + self.assertAllEqual(100 * [[36.0]], reduced) + + +if __name__ == '__main__': + ops.enable_eager_execution( + config=config_pb2.ConfigProto(log_device_placement=True)) + googletest.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index fbc3c994d16..8a3f4b0bdc7 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,12 +24,10 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest -@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 12791ef8ac1..42e637734c5 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -34,6 +34,13 @@ from tensorflow.python.ops import image_ops from tensorflow.python.platform import test +def GenerateNumpyRandomRGB(shape): + # Only generate floating points that are fractions like n / 256, since they + # are RGB pixels. Some low-precision floating point types in this test can't + # handle arbitrary precision floating points well. + return np.random.randint(0, 256, shape) / 256. + + class RGBToHSVTest(XLATestCase): def testBatch(self): @@ -43,7 +50,7 @@ class RGBToHSVTest(XLATestCase): shape = (batch_size, 2, 7, 3) for nptype in self.float_types: - inp = np.random.rand(*shape).astype(nptype) + inp = GenerateNumpyRandomRGB(shape).astype(nptype) # Convert to HSV and back, as a batch and individually with self.test_session() as sess: @@ -83,7 +90,7 @@ class RGBToHSVTest(XLATestCase): def testRGBToHSVNumpy(self): """Tests the RGB to HSV conversion matches a reference implementation.""" for nptype in self.float_types: - rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype) + rgb_flat = GenerateNumpyRandomRGB((64, 3)).astype(nptype) rgb_np = rgb_flat.reshape(4, 4, 4, 3) hsv_np = np.array([ colorsys.rgb_to_hsv( diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 1f7da659e55..4b0043b6b4c 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -78,10 +78,10 @@ def InLabels(labels, substr): def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline.""" + """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") class JitLaunchTest(test.TestCase): @@ -90,8 +90,8 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node - # actually ran. However, it is sometimes possible for _XlaLaunch ops to be + # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node + # actually ran. However, it is sometimes possible for XlaLaunch ops to be # constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: @@ -441,14 +441,14 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "_XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaLaunch")) - # Compile the backprop. One _XlaLaunch. + # Compile the backprop. One XlaLaunch. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "_XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaLaunch")) class ElementWiseFusionTest(test.TestCase): @@ -482,14 +482,15 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("_XlaLaunch(" in x for x in labels) + count = sum("XlaLaunch(" in x for x in labels) return output, count def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = "--tf_xla_fusion_only=true" + os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit") tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py new file mode 100644 index 00000000000..45a04f0cf56 --- /dev/null +++ b/tensorflow/compiler/tests/listdiff_op_test.py @@ -0,0 +1,101 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for XLA listdiff operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ListDiffTest(xla_test.XLATestCase): + + def _testListDiff(self, x, y, out, idx): + for dtype in [dtypes.int32, dtypes.int64]: + for index_dtype in [dtypes.int32, dtypes.int64]: + with self.test_session() as sess: + x_tensor = ops.convert_to_tensor(x, dtype=dtype) + y_tensor = ops.convert_to_tensor(y, dtype=dtype) + with self.test_scope(): + out_tensor, idx_tensor = array_ops.listdiff( + x_tensor, y_tensor, out_idx=index_dtype) + tf_out, tf_idx = sess.run([out_tensor, idx_tensor]) + self.assertAllEqual(out, tf_out) + self.assertAllEqual(idx, tf_idx) + self.assertEqual(1, out_tensor.get_shape().ndims) + self.assertEqual(1, idx_tensor.get_shape().ndims) + + def testBasic1(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3]) + + def testBasic2(self): + self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3]) + + def testBasic3(self): + self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2]) + + def testDuplicates(self): + self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1], + y=[4, 2], + out=[1, 3, 3, 3, 1], + idx=[0, 3, 5, 6, 7]) + + def testRandom(self): + num_random_tests = 10 + int_low = -7 + int_high = 8 + max_size = 50 + for _ in xrange(num_random_tests): + x_size = np.random.randint(max_size + 1) + x = np.random.randint(int_low, int_high, size=x_size) + y_size = np.random.randint(max_size + 1) + y = np.random.randint(int_low, int_high, size=y_size) + out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y] + if out_idx: + out, idx = map(list, zip(*out_idx)) + else: + out = [] + idx = [] + self._testListDiff(list(x), list(y), out, idx) + + def testFullyOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[]) + + def testNonOverlapping(self): + self._testListDiff(x=[1, 2, 3, 4], + y=[5, 6], + out=[1, 2, 3, 4], + idx=[0, 1, 2, 3]) + + def testEmptyX(self): + self._testListDiff(x=[], y=[1, 2], out=[], idx=[]) + + def testEmptyY(self): + self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3]) + + def testEmptyXY(self): + self._testListDiff(x=[], y=[], out=[], idx=[]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/oom_test.py b/tensorflow/compiler/tests/oom_test.py index 1434e965e3d..d68d32057a3 100644 --- a/tensorflow/compiler/tests/oom_test.py +++ b/tensorflow/compiler/tests/oom_test.py @@ -22,6 +22,8 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest @@ -42,20 +44,33 @@ class OutOfMemoryTest(xla_test.XLATestCase): """ def test_loop(): - size = 2e8 + size = int(2e8) while True: with self.test_session(): - # Force the compiled code to not be constant by feeding in an addend. - p = array_ops.placeholder(dtypes.float32, shape=[]) + # Force the compiled code to not be constant by feeding in a + # parameter. + p = array_ops.placeholder(dtypes.float32, shape=[2, 1, 1]) with self.test_scope(): - # Create a large R1 tensor. - c = array_ops.zeros([size, 1]) + p + # Create a computation that produces a large R1 tensor as an + # intermediate result. Reduce it down so that if this file was + # compiled without --config=cuda, we don't force a D2H copy of a + # large tensor and potentially OOM the host. + # + # This is a bit tricky because XLA:GPU doesn't currently support RNG + # ops. Here we rely on the fact that XLA doesn't do algebraic + # simplifications on conv(, ). + c = math_ops.reduce_sum( + nn_ops.convolution( + array_ops.ones([1, size, 1]), + p, + padding='SAME', + data_format='NWC')) - c.eval(feed_dict={p: 1.0}) + c.eval(feed_dict={p: [[[1.0]], [[2.0]]]}) size *= 2 self.assertRaises(errors.ResourceExhaustedError, test_loop) -if __name__ == "__main__": +if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py new file mode 100644 index 00000000000..5e6d1313bd0 --- /dev/null +++ b/tensorflow/compiler/tests/placeholder_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for xla handling of placeholder_with_default.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class PlaceholderTest(XLATestCase): + + def test_placeholder_with_default_default(self): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(4.0) + ph = array_ops.placeholder_with_default(v, shape=[]) + out = ph * 2 + sess.run(variables.variables_initializer([v])) + self.assertEqual(8.0, sess.run(out)) + + def test_placeholder_with_default_fed(self): + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(4.0) + ph = array_ops.placeholder_with_default(v, shape=[]) + out = ph * 2 + sess.run(variables.variables_initializer([v])) + self.assertEqual(2.0, sess.run(out, {ph: 1.0})) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index 2c084b04fa2..7420724bdbe 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import itertools import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -155,5 +156,68 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA) +class ReduceOpPrecisionTest(XLATestCase): + + def _testReduceSum(self, + expected_result, + dtype, + test_inputs, + rtol=1e-3, + atol=1e-4): + """Tests reduce sum on a list of input arrays. + + For each array in test_inputs, check that performing reduce sum on the array + produces a value that is close to the expected result. + + Args: + expected_result: the expected result. + dtype: the data type of the reduce sum operation. + test_inputs: a list of input arrays for the reduce sum operation. + rtol: the relative error. + atol: the absolute error. + """ + + for test_input in test_inputs: + with self.test_session() as sess: + with self.test_scope(): + a = array_ops.placeholder(dtype) + index = array_ops.placeholder(dtypes.int32) + out = math_ops.reduce_sum(a, index) + result = sess.run(out, { + a: np.array(test_input, dtype=dtype), + index: [0] + }) + # Compare the results using float32 type. + self.assertAllClose( + np.float32(result), + np.float32(expected_result), + rtol=rtol, + atol=atol) + + def testReduceSumF16(self): + """Tests the reduce sum of float16 doesn't lose too much precision.""" + + if np.float16 not in self.all_types: + return + + f16_max = np.finfo(np.float16).max + self._testReduceSum( + f16_max, np.float16, + itertools.permutations([f16_max, f16_max, f16_max * (-1.0)], 3)) + + def testReduceSumBF16(self): + """Tests the reduce sum of bfloat16 doesn't lose too much precision.""" + + if dtypes.bfloat16.as_numpy_dtype not in self.all_types: + return + + bf16_max = np.float32(dtypes.bfloat16.max) + f32_max = dtypes.float32.max + value = min(bf16_max, f32_max - bf16_max) + self._testReduceSum( + dtypes.bfloat16.as_numpy_dtype(value), dtypes.bfloat16.as_numpy_dtype, + itertools.permutations([bf16_max, value, bf16_max * (-1.0)], 3)) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 4336ebdbd18..b6f8390a45d 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase): # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) + def testRandomNormalIsFinite(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[10000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(np.isfinite(y))) + def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 7624d6e4b2e..f332aa2e9b9 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -472,7 +472,9 @@ class TensorArrayTest(xla_test.XLATestCase): self.assertAllEqual(c([[-2.0, -10.0]]), grad_vals[1]) def testTensorArrayGradientWriteRead(self): - for dtype in self.numeric_types: + for dtype in self.float_types: + self._testTensorArrayGradientWriteReadType(dtype) + for dtype in self.complex_types: self._testTensorArrayGradientWriteReadType(dtype) def _testTensorArrayGradientWritePackConcatAndRead(self): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index ba5f829936f..ef047005b60 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -68,40 +69,41 @@ class TernaryOpsTest(XLATestCase): expected=np.array([1, 3, 5], dtype=np.int32)) def testSelect(self): - self._testTernary( - array_ops.where, - np.array(0, dtype=np.bool), - np.array(2, dtype=np.float32), - np.array(7, dtype=np.float32), - expected=np.array(7, dtype=np.float32)) + for dtype in self.numeric_types: + self._testTernary( + array_ops.where, + np.array(0, dtype=np.bool), + np.array(2, dtype=dtype), + np.array(7, dtype=dtype), + expected=np.array(7, dtype=dtype)) - self._testTernary( - array_ops.where, - np.array(1, dtype=np.bool), - np.array([1, 2, 3, 4], dtype=np.float32), - np.array([5, 6, 7, 8], dtype=np.float32), - expected=np.array([1, 2, 3, 4], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array(1, dtype=np.bool), + np.array([1, 2, 3, 4], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([1, 2, 3, 4], dtype=dtype)) - self._testTernary( - array_ops.where, - np.array(0, dtype=np.bool), - np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), - np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), - expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array(0, dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype)) - self._testTernary( - array_ops.where, - np.array([0, 1, 1, 0], dtype=np.bool), - np.array([1, 2, 3, 4], dtype=np.float32), - np.array([5, 6, 7, 8], dtype=np.float32), - expected=np.array([5, 2, 3, 8], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array([0, 1, 1, 0], dtype=np.bool), + np.array([1, 2, 3, 4], dtype=dtype), + np.array([5, 6, 7, 8], dtype=dtype), + expected=np.array([5, 2, 3, 8], dtype=dtype)) - self._testTernary( - array_ops.where, - np.array([0, 1, 0], dtype=np.bool), - np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), - np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32), - expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32)) + self._testTernary( + array_ops.where, + np.array([0, 1, 0], dtype=np.bool), + np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype), + np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype), + expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype)) def testSlice(self): for dtype in self.numeric_types: @@ -119,6 +121,23 @@ class TernaryOpsTest(XLATestCase): np.array([2, 1], dtype=np.int32), expected=np.array([[2], [5]], dtype=dtype)) + def testClipByValue(self): + # TODO(b/78258593): enable integer types here too. + for dtype in self.float_types: + test_cases = [ + (np.array([2, 4, 5], dtype=dtype), dtype(7)), # + (dtype(1), np.array([2, 4, 5], dtype=dtype)), # + (np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype)) + ] + x = np.array([-2, 10, 6], dtype=dtype) + for lower, upper in test_cases: + self._testTernary( + gen_math_ops._clip_by_value, + x, + lower, + upper, + expected=np.minimum(np.maximum(x, lower), upper)) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ba79f393a8f..689a4a1f4e0 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -209,7 +209,8 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype)) + expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -251,12 +252,12 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2]], dtype=dtype), expected=np.array([[0.540297, -0.41614]], dtype=dtype)) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15, 0.6]], dtype=dtype), - expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype))) + expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.rint, @@ -333,13 +334,19 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, @@ -419,7 +426,9 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( math_ops.expm1, np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), - expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype)), + rtol=1e-6, + atol=1e-6) self._assertOpOutputMatchesExpected( math_ops.reciprocal, @@ -441,13 +450,13 @@ class UnaryOpsTest(XLATestCase): np.array([[5j, 3 - 2j]], dtype=dtype), expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) - # TODO(b/34703906): improve log1p implementation and make tolerance - # tighter. self._assertOpOutputMatchesExpected( math_ops.log1p, np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)), + rtol=1e-4, + atol=1e-6) val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) self._assertOpOutputMatchesExpected( @@ -789,7 +798,9 @@ class UnaryOpsTest(XLATestCase): zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected) + nn_ops.softplus, features, expected=expected, + rtol=1e-6, + atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py new file mode 100644 index 00000000000..1e30ebd55d0 --- /dev/null +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for XLA devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class XlaDeviceGpuTest(test.TestCase): + + def testCopiesToAndFromGpuWork(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f5c228f8305..b707bd0963d 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,30 +18,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class XlaDeviceTest(test.TestCase): +class XlaDeviceTest(XLATestCase): def testCopies(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return + """Tests that copies onto and off XLA devices work.""" + shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], + [16384, 1], [1, 16384], [1, 20000, 1, 1]] + for dtype in self.numeric_types: + for shape in shapes: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(dtype, shape) + with self.test_scope(): + y = x + x + with ops.device("CPU"): + z = array_ops.identity(y) - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) + inputs = np.random.randint(-100, 100, shape).astype(dtype) + result = sess.run(z, {x: inputs}) + self.assertAllCloseAccordingToType(result, inputs + inputs) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ba5c3a14849..cd57452302f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -81,7 +81,7 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -168,9 +168,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -215,7 +215,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -326,6 +325,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", @@ -412,7 +412,6 @@ cc_library( hdrs = ["functionalize_control_flow.h"], deps = [ ":tf2xla_util", - "//tensorflow/compiler/jit:graph_to_functiondef", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:xla_ops", diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 16b9142cbf7..42585ad4d8a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -21,13 +21,13 @@ limitations under the License. #include #include -#include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" @@ -282,7 +282,58 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, return Status::OK(); } -Status FunctionalizeLoop(Graph* graph, Frame* frame, +// Copy the FunctionDef of given function from lookup_library to library, if +// it can be found in lookup_library but is missing from library. +Status AddMissingFunctionByName(const string& function_name, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + if (!library->Find(function_name) && lookup_library->Find(function_name)) { + return library->AddFunctionDef(*lookup_library->Find(function_name)); + } + return Status::OK(); +} + +// Iterate over all functions that the given fdef refers to. Copy the missing +// FunctionDefs from lookup_library to library. +Status AddMissingFunctionDef(const FunctionDef& fdef, + const FunctionLibraryDefinition* lookup_library, + FunctionLibraryDefinition* library) { + TF_RET_CHECK(lookup_library); + for (const NodeDef& node : fdef.node_def()) { + if (library->Find(node.op())) { + continue; + } + // The function refered by 'SymbolicGradient' node is specified in its + // attribute 'f'. + if (node.op() == FunctionLibraryDefinition::kGradientOp) { + const AttrValue* attr = + AttrSlice(&node.attr()).Find(FunctionLibraryDefinition::kFuncAttr); + if (!attr) { + return errors::InvalidArgument("SymbolicGradient is missing attr: f"); + } + const string& func_name = attr->func().name(); + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(func_name, lookup_library, library)); + // Copy the user-defined gradient function if it exists. + const string grad_name = lookup_library->FindGradient(func_name); + if (!grad_name.empty() && library->FindGradient(func_name).empty()) { + TF_RETURN_IF_ERROR( + AddMissingFunctionByName(grad_name, lookup_library, library)); + GradientDef grad_def; + grad_def.set_function_name(func_name); + grad_def.set_gradient_func(grad_name); + TF_RETURN_IF_ERROR(library->AddGradientDef(grad_def)); + } + } else if (lookup_library->Find(node.op())) { + TF_RETURN_IF_ERROR( + library->AddFunctionDef(*lookup_library->Find(node.op()))); + } + } + return Status::OK(); +} + +Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, + Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " << dump_graph::DumpGraphToFile("functionalize_before", *graph, @@ -489,6 +540,14 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + if (lookup_library) { + // Copy missing FunctionDefs from lookup_library to library to make library + // self-contained. + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(cond_fdef, lookup_library, library)); + TF_RETURN_IF_ERROR( + AddMissingFunctionDef(body_fdef, lookup_library, library)); + } // Builds a While operator. NodeDef while_def; @@ -870,6 +929,9 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { // Merge the inputs of the switch node with one another. This results in // predicates and control input residing in the same cluster. for (const Edge* e : n->in_edges()) { + // Only consider the data inputs to the Switch node. + if (e->IsControlEdge()) continue; + Node* src = e->src(); UnionFind* src_cluster = find_output_cluster(src); int src_cluster_depth = switch_depth[src_cluster->Get().representative]; @@ -1362,6 +1424,12 @@ Status FunctionalizeCond::Functionalize(Graph* graph, // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { + return FunctionalizeControlFlow(/*lookup_library=*/nullptr, graph, library); +} + +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph, library); @@ -1431,7 +1499,8 @@ Status FunctionalizeControlFlow(Graph* graph, continue; } - TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 4d4ee3054c2..d941041d155 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -22,9 +22,13 @@ limitations under the License. namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While -// operators, suitable for XLA compilation. +// operators, suitable for XLA compilation. If lookup_library is provided, use +// it to make the library for control flow self-contained. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); +Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, + Graph* graph, + FunctionLibraryDefinition* library); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index e494f42e8ed..14977a908ae 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -299,6 +299,131 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// @function.Defun(noinline=True) +// def increment_fn(x): +// return [x + 1] +// Define the above function, and add it to the given graph. It's used as the +// while loop body in NoinlineLoopBody test. +Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { + FunctionDef fdef = FunctionDefHelper::Create( + "increment_fn", {"x:int32"}, {"add:int32"}, {}, + { + {{"add/y"}, "Const", {}, {{"dtype", DT_INT32}}}, + {{"add_0"}, "Add", {"x", "add/y:output:0"}, {{"T", DT_INT32}}}, + }, + {{"add", "add_0:z:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); + NodeDef increment_fn; + increment_fn.set_name(node_name); + increment_fn.set_op("increment_fn"); + *increment_fn.add_input() = "while/Identity"; + *increment_fn.add_input() = "^while/Identity"; + Status status; + graph->AddNode(increment_fn, &status); + return status; +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, increment_fn, [x]) +TEST(FunctionalizeControlFlow, NoinlineLoopBody) { + const string& noinline_node_name = "while/increment_fn"; + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = ops::internal::Enter(scope.WithOpName("while/Enter"), source, + "while/while_context"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + + NodeDef next_iter; + next_iter.set_name("while/NextIteration"); + next_iter.set_op("NextIteration"); + *next_iter.add_input() = noinline_node_name; + (*next_iter.mutable_attr())["T"].set_type(DT_INT32); + + Status status; + Node* n = scope.graph()->AddNode(next_iter, &status); + TF_ASSERT_OK(status); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(n, 0, merge.output.node(), 1); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition lookup_lib(graph.flib_def()); + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + // Function increment_fn will be copied from lookup_lib to library. + TF_ASSERT_OK(FunctionalizeControlFlow(&lookup_lib, &graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + TF_ASSERT_OK(AddNoinlineFunctionToGraph(noinline_node_name, scope.graph())); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + NodeDef retval; + retval.set_name("_retval0_RetVal"); + retval.set_op(FunctionLibraryDefinition::kRetOp); + *retval.add_input() = noinline_node_name; + (*retval.mutable_attr())["T"].set_type(DT_INT32); + (*retval.mutable_attr())["index"].set_i(0); + Status status; + scope.graph()->AddNode(retval, &status); + TF_ASSERT_OK(status); + + GraphDef expected; + TF_ASSERT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + // Verify that increment_fn has been copied to library. + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + // Ignore the function library when comparing the graphs. + expected.clear_library(); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Tests functionalizing OneLoopVar where the loop value is not used post the // loop. // Graph: diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index b20c1ffc7d8..212f6f39661 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -51,6 +51,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { auto builder = ctx->builder(); + auto client = ctx->compiler()->client(); std::vector compile_time_constant_flags(expressions.size()); TF_RETURN_IF_ERROR( @@ -72,8 +73,10 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.kind = XlaCompiler::Argument::kConstant; TF_RET_CHECK(expressions[i]->resource() == nullptr) << "Input with resource is not yet implemented."; + TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( + expressions[i]->handle())); TF_ASSIGN_OR_RETURN(auto literal, - builder->ComputeConstant(expressions[i]->handle())); + client->ComputeConstant(constant_graph)); TF_RETURN_IF_ERROR( LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); } else { @@ -205,14 +208,15 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; XlaCompiler::CompilationResult result; - - TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), - func, arguments, &result)); + TF_RETURN_IF_ERROR( + compiler->CompileFunction(compile_options, func, arguments, &result)); TF_RET_CHECK(arguments.size() == expressions.size()); - std::vector handles; + std::vector handles; for (int64 i = 0; i < expressions.size(); ++i) { if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; @@ -226,11 +230,14 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, auto output_handle = b->Call(*result.computation, handles); // The output handle of `Call` computation is a tuple type. Unzip it so // that it can fit into future computations. + int computation_output = 0; for (int64 i = 0; i < n->num_outputs(); ++i) { if (result.outputs[i].is_constant) { xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); } else { - xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + xla_op_context.SetOutput( + i, b->GetTupleElement(output_handle, computation_output)); + ++computation_output; } } return b->first_error(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 579b6696999..e6da157c111 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -21,6 +21,7 @@ tf_kernel_library( "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", + "clip_by_value_op.cc", "concat_op.cc", "const_op.cc", "conv_ops.cc", @@ -44,6 +45,7 @@ tf_kernel_library( "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", + "listdiff_op.cc", "lrn_ops.cc", "matmul_op.cc", "matrix_band_part_op.cc", @@ -113,8 +115,8 @@ tf_kernel_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", @@ -150,7 +152,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -166,7 +168,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -202,8 +204,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/kernels:argmax_op", diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 5c9f66df101..1e598686214 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -29,7 +29,7 @@ class AddNOp : public XlaOpKernel { OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("AddN requires at least one argument")); - xla::ComputationDataHandle sum = ctx->Input(0); + xla::XlaOp sum = ctx->Input(0); for (int i = 1; i < ctx->num_inputs(); ++i) { sum = ctx->builder()->Add(sum, ctx->Input(i)); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 931175be111..15e1815a4cf 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -48,9 +48,9 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); TensorShape input_shape = ctx->InputShape(0); int feature_index = @@ -62,7 +62,7 @@ class FusedBatchNormOp : public XlaOpKernel { input = builder->ConvertElementType(input, scale_type); if (is_training_) { - xla::ComputationDataHandle output = builder->BatchNormTraining( + xla::XlaOp output = builder->BatchNormTraining( input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the @@ -79,7 +79,7 @@ class FusedBatchNormOp : public XlaOpKernel { ctx->SetOutput(3, builder->GetTupleElement(output, 1)); ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = builder->BatchNormInference( + xla::XlaOp output = builder->BatchNormInference( input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), epsilon_, feature_index); ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); @@ -118,7 +118,7 @@ class FusedBatchNormGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); DataType input_dtype = ctx->input_type(0); DataType scale_dtype = ctx->input_type(2); @@ -137,11 +137,11 @@ class FusedBatchNormGradOp : public XlaOpKernel { const int feature_index = GetTensorFeatureDimIndex(input_dims, data_format_); - xla::ComputationDataHandle x_backprop; - xla::ComputationDataHandle scale_backprop; - xla::ComputationDataHandle offset_backprop; + xla::XlaOp x_backprop; + xla::XlaOp scale_backprop; + xla::XlaOp offset_backprop; if (is_training_) { - xla::ComputationDataHandle output = + xla::XlaOp output = b->BatchNormGrad(activations, scale, mean, var, grad_backprop, epsilon_, feature_index); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 569950c2dfa..642278ab994 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -20,9 +20,8 @@ limitations under the License. namespace tensorflow { namespace { -void BatchToSpace(XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, DataType input_dtype, - const TensorShape& input_tensor_shape, +void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, + DataType input_dtype, const TensorShape& input_tensor_shape, gtl::ArraySlice block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); @@ -46,7 +45,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, ", 2] instead of ", xla::ShapeUtil::HumanString(crops.shape()))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const int64 batch_size = input_shape[0]; // Compute the product of the block_shape values. @@ -73,7 +72,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, reshaped_shape[block_rank] = batch_size / block_num_elems; std::copy(input_shape.begin() + 1, input_shape.end(), reshaped_shape.begin() + block_rank + 1); - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce `permuted` of shape // [batch / prod(block_shape), @@ -91,7 +90,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, } std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::ComputationDataHandle permuted = b->Transpose(reshaped, permutation); + xla::XlaOp permuted = b->Transpose(reshaped, permutation); // 3. Reshape `permuted` to produce `reshaped_permuted` of shape // [batch / prod(block_shape), @@ -111,8 +110,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_permuted_shape.begin() + 1 + block_rank); - xla::ComputationDataHandle reshaped_permuted = - b->Reshape(permuted, reshaped_permuted_shape); + xla::XlaOp reshaped_permuted = b->Reshape(permuted, reshaped_permuted_shape); // 4. Crop the start and end of dimensions `[1, ..., M]` of // `reshaped_permuted` according to `crops` to produce the output of shape: @@ -139,7 +137,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, "Cropped size must be non-negative: start: ", crop_start, " end: ", crop_end, " size ", reshaped_permuted_shape[1 + i])); } - xla::ComputationDataHandle output = + xla::XlaOp output = b->Slice(reshaped_permuted, start_indices, end_indices, strides); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index ed33b8ed2e8..9d677f42665 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -60,7 +60,7 @@ class BiasOp : public XlaOpKernel { "of the input tensor: ", bias_shape.DebugString(), " vs. ", input_shape.DebugString())); - xla::ComputationDataHandle result = + xla::XlaOp result = ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim}); ctx->SetOutput(0, result); } @@ -103,7 +103,7 @@ class BiasAddGradOp : public XlaOpKernel { std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0); std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(), feature_dim + 1); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); const DataType accumulation_type = XlaHelpers::SumAccumulationType(input_type(0)); auto converted = diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 2436a6074a1..f04cde878e9 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -34,14 +34,13 @@ namespace { class NAME##Op : public XlaBinaryOp { \ public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \ - xla::ComputationDataHandle Computation( \ - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \ - const gtl::ArraySlice& lhs_shape, \ - const xla::ComputationDataHandle& rhs, \ + xla::XlaOp Computation( \ + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, \ + const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, \ const gtl::ArraySlice& rhs_shape, \ const BCast& broadcast_helper, \ const std::vector& extend_dimensions) override { \ - xla::ComputationBuilder* b = ctx->builder(); \ + xla::XlaBuilder* b = ctx->builder(); \ return HLO; \ } \ }; \ @@ -63,11 +62,8 @@ XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); // } else { // return x / y; // } -static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, - DataType dtype, - xla::ComputationDataHandle x, - xla::ComputationDataHandle y, - const BCast& broadcast_helper) { +static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); @@ -87,11 +83,8 @@ XLA_MAKE_BINARY(FloorDiv, // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); -static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, - DataType dtype, - xla::ComputationDataHandle x, - xla::ComputationDataHandle y, - const BCast& broadcast_helper) { +static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero)); @@ -127,8 +120,7 @@ XLA_MAKE_BINARY(SqrtGrad, XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), lhs, extend_dimensions)); -static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Square(xla::XlaBuilder* builder, const xla::XlaOp& x) { return builder->Mul(x, x); } @@ -175,11 +167,11 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); auto abs_shape = b->GetShape(abs); OP_REQUIRES_OK(ctx, abs_shape.status()); - auto abs_type = abs_shape.ValueOrDie()->element_type(); + auto abs_type = abs_shape.ValueOrDie().element_type(); auto result = b->Lt( abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); ctx->SetOutput(0, result); diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index c52b2dcb7e9..e9d98c76857 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -33,9 +33,9 @@ class CastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output; if (src_dtype_ == dst_dtype_) { output = input; @@ -72,9 +72,9 @@ class BitcastOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + xla::XlaOp output; if (src_dtype_ == dst_dtype_) { output = input; diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 545aa364f93..835a7f56894 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -34,7 +34,7 @@ class CategoricalOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // Get the logits - const xla::ComputationDataHandle& logits = ctx->Input(0); + const xla::XlaOp& logits = ctx->Input(0); TensorShape logits_shape = ctx->InputShape(0); int64 num_samples; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples)); @@ -56,7 +56,7 @@ class CategoricalOp : public XlaOpKernel { const int64 batch_size = logits_shape.dim_size(0); const int64 num_classes = logits_shape.dim_size(1); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); std::array uniform_shape_array = { {batch_size, num_samples, num_classes}}; @@ -78,7 +78,7 @@ class CategoricalOp : public XlaOpKernel { /*broadcast_dimensions=*/{0, 2}); TensorShape softmax_shape(uniform_shape_array); - xla::ComputationDataHandle argmax; + xla::XlaOp argmax; OP_REQUIRES_OK( ctx, XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc new file mode 100644 index 00000000000..a00bc912f9f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class ClipByValueOp : public XlaOpKernel { + public: + explicit ClipByValueOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape shape = ctx->InputShape(0); + const TensorShape min_shape = ctx->InputShape(1); + const TensorShape max_shape = ctx->InputShape(2); + + xla::XlaBuilder* builder = ctx->builder(); + auto input = ctx->Input(0); + auto min = ctx->Input(1); + auto max = ctx->Input(2); + + auto shape_error = [&]() -> tensorflow::Status { + return errors::InvalidArgument( + "clip_value_min and clip_value_max must be either of " + "the same shape as input, or a scalar. ", + "Input shape: ", shape.DebugString(), + " clip_value_min shape: ", min_shape.DebugString(), + " clip_value_max shape: ", max_shape.DebugString()); + }; + + if (shape != min_shape) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(min_shape), shape_error()); + min = builder->Broadcast(min, shape.dim_sizes()); + } + if (shape != max_shape) { + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(max_shape), shape_error()); + max = builder->Broadcast(max, shape.dim_sizes()); + } + ctx->SetOutput(0, builder->Clamp(min, input, max)); + } +}; + +REGISTER_XLA_OP(Name("ClipByValue"), ClipByValueOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 1a246e8df9b..78285affa1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -54,7 +54,7 @@ class ConcatBaseOp : public XlaOpKernel { // TODO(annarev): add a helper to support int64 input. const int32 concat_dim = literal.Get({}); - std::vector values; + std::vector values; std::vector shapes; OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); const int N = values.size(); @@ -70,13 +70,13 @@ class ConcatBaseOp : public XlaOpKernel { "[", -input_dims, ", ", input_dims, "), but got ", concat_dim)); - // Make a vector holding the ComputationDataHandles for each of - // the inputs that has non-zero elements. - std::vector input_data; + // Make a vector holding the XlaOp for each of the inputs that has non-zero + // elements. + std::vector input_data; int output_concat_dim = 0; const bool input_is_scalar = IsLegacyScalar(input_shape); for (int i = 0; i < N; ++i) { - xla::ComputationDataHandle handle = values[i]; + xla::XlaOp handle = values[i]; const TensorShape& in_shape = shapes[i]; const bool in_is_scalar = IsLegacyScalar(in_shape); OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 8f78b4c8f90..59d06c654de 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -45,7 +45,7 @@ class ConstOp : public XlaOpKernel { ctx->SetInvalidOutput(0); return; } - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, // recognize that case and emit a scalar broadcast instead. diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index c0ee0c9c2ea..627bad12f33 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -47,9 +47,8 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( } // Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::ComputationDataHandle CreateExpandedZero( - const TensorShape& filter_shape, DataType dtype, - xla::ComputationBuilder* builder) { +xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); return builder->Broadcast(XlaHelpers::Zero(builder, dtype), @@ -87,8 +86,8 @@ xla::ComputationDataHandle CreateExpandedZero( // // Finally compare A and broadcasted B in dimension 2 amd return the result at // the beginning of the comment. -xla::ComputationDataHandle CreateExpandedFilterMask( - const TensorShape& filter_shape, xla::ComputationBuilder* builder) { +xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); @@ -96,11 +95,11 @@ xla::ComputationDataHandle CreateExpandedFilterMask( // Create a M sized linspace and an M*N sized linspace that will be // broadcasted into perpendicular dimensions and compared. - xla::ComputationDataHandle input_feature_iota; + xla::XlaOp input_feature_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, &input_feature_iota)); - xla::ComputationDataHandle expanded_feature_iota; + xla::XlaOp expanded_feature_iota; TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature * depthwise_multiplier, &expanded_feature_iota)); @@ -126,10 +125,10 @@ xla::ComputationDataHandle CreateExpandedFilterMask( // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. -xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( - const TensorShape& filter_shape, DataType dtype, - const xla::ComputationDataHandle& filter, - xla::ComputationBuilder* builder) { +xla::XlaOp ExpandFilterForDepthwiseConvolution(const TensorShape& filter_shape, + DataType dtype, + const xla::XlaOp& filter, + xla::XlaBuilder* builder) { int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = @@ -156,10 +155,11 @@ xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( } // Inverse of ExpandFilterForDepthwiseConvolution. -xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, - const xla::ComputationDataHandle& filter_backprop, - xla::ComputationBuilder* builder) { +xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, + const TensorShape& filter_shape, + DataType dtype, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); auto masked_expanded_filter = builder->Select( @@ -248,9 +248,9 @@ class ConvOp : public XlaOpKernel { "input and filter must have the same depth: ", in_depth, " vs ", input_shape.dim_size(feature_dim))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); - xla::ComputationDataHandle filter = ctx->Input(1); + xla::XlaOp filter = ctx->Input(1); TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { filter = ExpandFilterForDepthwiseConvolution( @@ -288,7 +288,7 @@ class ConvOp : public XlaOpKernel { &unused_output_size, &padding[i].first, &padding[i].second)); } - xla::ComputationDataHandle conv = + xla::XlaOp conv = b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); @@ -391,7 +391,7 @@ class ConvBackpropInputOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto filter = ctx->Input(1); auto out_backprop = ctx->Input(2); @@ -435,12 +435,11 @@ class ConvBackpropInputOp : public XlaOpKernel { } // Mirror the filter in the spatial dimensions. - xla::ComputationDataHandle mirrored_weights = - b->Rev(filter, kernel_spatial_dims); + xla::XlaOp mirrored_weights = b->Rev(filter, kernel_spatial_dims); // activation gradients // = gradients (with padding and dilation) mirrored_weights - xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated( + xla::XlaOp in_backprop = b->ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, lhs_dilation, rhs_dilation, dnums); @@ -546,9 +545,9 @@ class ConvBackpropFilterOp : public XlaOpKernel { expanded_filter_shape, out_backprop_shape, dilations_, strides_, padding_, data_format_, &dims)); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle activations = ctx->Input(0); - xla::ComputationDataHandle gradients = ctx->Input(2); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp activations = ctx->Input(0); + xla::XlaOp gradients = ctx->Input(2); // The filter gradients are computed by a convolution of the input // activations and the output gradients, with some appropriate padding. diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 3df8c00f1b8..7fcd4170fb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -53,7 +53,7 @@ class CrossOp : public XlaOpKernel { } std::vector strides(in0_shape.dims(), 1); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto in0 = ctx->Input(0); auto in1 = ctx->Input(1); starts.back() = 0; diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 0cf03ceb948..01aa1a83e79 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/util/bcast.h" @@ -75,7 +75,7 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } // Call virtual method to emit the computation. - xla::ComputationDataHandle output = + xla::XlaOp output = Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle, rhs_shape.dim_sizes(), bcast, extend_dimension); @@ -85,11 +85,9 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(0, output); } -/* static */ std::pair -XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& lhs, - const xla::ComputationDataHandle& rhs, - const BCast& broadcast_helper) { +/* static */ std::pair XlaBinaryOp::Broadcast( + xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, + const BCast& broadcast_helper) { // Manually construct the broadcasting since MapN does not do // automatic broadcasting. The bcast helper ensures that // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 5bc1d5fb1f0..4f92dbc8740 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -30,7 +30,7 @@ namespace tensorflow { // inputs that can be broadcast to the same shape. The base class // contains pure virtual methods to override: description is a textual // description of the operation; and Computation adds the -// implementation of the operation to a xla::ComputationBuilder. For most +// implementation of the operation to a xla::XlaBuilder. For most // arithmetic Ops XLA handles the broadcasting automatically given the input // tensors. class XlaBinaryOp : public XlaOpKernel { @@ -55,10 +55,9 @@ class XlaBinaryOp : public XlaOpKernel { // higher-rank input should be matched when broadcasting the // lower-rank input. See comment below and the documentation on broadcasting // in the XLA documentation. - virtual xla::ComputationDataHandle Computation( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, - const gtl::ArraySlice& lhs_shape, - const xla::ComputationDataHandle& rhs, + virtual xla::XlaOp Computation( + XlaOpKernelContext* ctx, const xla::XlaOp& lhs, + const gtl::ArraySlice& lhs_shape, const xla::XlaOp& rhs, const gtl::ArraySlice& rhs_shape, const BCast& broadcast_helper, const std::vector& extend_dimensions) = 0; @@ -67,11 +66,9 @@ class XlaBinaryOp : public XlaOpKernel { // Helper function that performs the broadcasting described by // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. - static std::pair - Broadcast(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& lhs, - const xla::ComputationDataHandle& rhs, - const BCast& broadcast_helper); + static std::pair Broadcast( + xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, + const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 96d7809f799..23243f62462 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -50,8 +50,8 @@ class DepthToSpaceOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); @@ -130,7 +130,7 @@ class DepthToSpaceOp : public XlaOpKernel { ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -141,8 +141,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2], // block_size_, // depth / (block_size_ * block_size_)] - xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -152,8 +151,7 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index 765ea922a53..931705ba837 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -25,10 +25,10 @@ namespace tensorflow { namespace { // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. -xla::StatusOr CreateDiagonal( - const xla::ComputationDataHandle& input, int64 last_dim_size, +xla::StatusOr CreateDiagonal( + const xla::XlaOp& input, int64 last_dim_size, tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, - xla::ComputationBuilder* builder) { + xla::XlaBuilder* builder) { // Create two matrices that have the following forms, and compare them: // // [[0, 0, 0, 0] [[0, 1, 2, 3] @@ -38,12 +38,11 @@ xla::StatusOr CreateDiagonal( // // This produces a predicate matrix of the right size, with "true" on the // diagonal. - xla::ComputationDataHandle iota; + xla::XlaOp iota; TF_RETURN_IF_ERROR( XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); - xla::ComputationDataHandle iota_broadcast = - builder->Broadcast(iota, {last_dim_size}); - xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + xla::XlaOp iota_broadcast = builder->Broadcast(iota, {last_dim_size}); + xla::XlaOp mask = builder->Eq(iota_broadcast, iota, {0}); // If this is a batched diagonal, broadcast the mask across the other // dimensions. @@ -65,8 +64,7 @@ xla::StatusOr CreateDiagonal( std::vector broadcast_dims(other_dims.begin(), other_dims.end()); broadcast_dims.push_back(1LL); broadcast_dims.push_back(last_dim_size); - xla::ComputationDataHandle input_broadcast = - builder->Reshape(input, broadcast_dims); + xla::XlaOp input_broadcast = builder->Reshape(input, broadcast_dims); broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; xla::PrimitiveType element_type; @@ -74,7 +72,7 @@ xla::StatusOr CreateDiagonal( DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); auto broadcast_shape = xla::ShapeUtil::MakeShape(element_type, broadcast_dims); - xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + xla::XlaOp zeros = Zeros(builder, broadcast_shape); input_broadcast = builder->Add(input_broadcast, zeros); return builder->Select(mask, input_broadcast, zeros); @@ -85,7 +83,7 @@ class DiagOp : public XlaOpKernel { explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("Diag op must have at an input")); @@ -96,7 +94,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -112,7 +110,7 @@ class DiagOp : public XlaOpKernel { auto diag_or_status = CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); OP_REQUIRES_OK(ctx, diag_or_status.status()); - xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); + xla::XlaOp diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -131,7 +129,7 @@ class DiagPartOp : public XlaOpKernel { explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -158,7 +156,7 @@ class DiagPartOp : public XlaOpKernel { new_dims.push_back(dims[i]); } - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); // TODO(b/30878775): use Slice with strides when supported, in place of // the Pad -> Reshape -> Slice. @@ -199,7 +197,7 @@ class MatrixDiagOp : public XlaOpKernel { explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); OP_REQUIRES(ctx, ctx->num_inputs() >= 1, errors::InvalidArgument("MatrixDiag op must have at an input")); @@ -210,7 +208,7 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); @@ -232,7 +230,7 @@ class MatrixDiagPartOp : public XlaOpKernel { explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -241,7 +239,7 @@ class MatrixDiagPartOp : public XlaOpKernel { errors::InvalidArgument("Expected 2 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::XlaOp diag = ctx->Input(0); int last_dim = dims.size() - 1; int64 last_dim_size = dims[last_dim]; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 800ef5ab98d..0419de78b2e 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -57,7 +57,7 @@ class DynamicUpdateSliceOp : public XlaOpKernel { input_shape.DebugString(), "; update shape is ", update_shape.DebugString())); - xla::ComputationDataHandle result = ctx->builder()->DynamicUpdateSlice( + xla::XlaOp result = ctx->builder()->DynamicUpdateSlice( ctx->Input(0), ctx->Input(1), ctx->Input(2)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index f2cd21ffb9c..dd4a1690877 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -56,7 +56,7 @@ class DynamicStitchOp : public XlaOpKernel { std::vector indices_input; OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input)); - std::vector data; + std::vector data; std::vector data_shapes; OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes)); @@ -136,7 +136,7 @@ class DynamicStitchOp : public XlaOpKernel { // Look up all the children expressions that represent the data // inputs. - std::vector input(indices.size()); + std::vector input(indices.size()); for (int input_num = 0; input_num < indices.size(); input_num++) { TensorShape new_shape; // first reshaped dimension is the number of indices for this input. @@ -166,7 +166,7 @@ class DynamicStitchOp : public XlaOpKernel { for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } - std::vector to_concat(number_of_indices); + std::vector to_concat(number_of_indices); for (int index_num = 0; index_num < number_of_indices; index_num++) { const auto& expression = input[src_input_vector[index_num]]; // Take the appropriate slice of data. diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index 2fd27c5ca7e..493781a1e68 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -32,11 +32,10 @@ class EluOp : public XlaOpKernel { explicit EluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); } }; @@ -47,7 +46,7 @@ class EluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return lhs * (1 + rhs). void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); const auto one = XlaHelpers::One(b, input_type(0)); const auto grad = ctx->Input(0); @@ -66,15 +65,14 @@ class SeluOp : public XlaOpKernel { explicit SeluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), b->Mul(scale_alpha, expm1))); } @@ -86,9 +84,8 @@ class SeluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return lhs * (1 + rhs). void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index b2970eae20a..6df01cabbf1 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -93,7 +93,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { input_shape.DebugString())); const int64 depth = input_shape.dim_size(feature_dim); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); // The following code is equivalent to: // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) @@ -110,7 +110,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { // Builds an identity matrix as a broadcast equality of iotas. // iota = np.arange(np.prod(ksize), depth) // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) - xla::ComputationDataHandle iota; + xla::XlaOp iota; TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, kernel_size * depth, &iota)); @@ -147,7 +147,7 @@ class ExtractImagePatchesOp : public XlaOpKernel { &padding[i].first, &padding[i].second)); } - xla::ComputationDataHandle conv = + xla::XlaOp conv = builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 99470d70e70..8f0de0a524c 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -44,23 +44,20 @@ void CpuNudge(const float min, const float max, const float quant_min, } // An XLA version of CpuNudge(). -void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, - const xla::ComputationDataHandle& min, - const xla::ComputationDataHandle& max, +void XlaNudge(xla::XlaBuilder* b, const DataType data_type, + const xla::XlaOp& min, const xla::XlaOp& max, const float quant_min_value, const float quant_max_value, - xla::ComputationDataHandle* nudged_min, - xla::ComputationDataHandle* nudged_max, - xla::ComputationDataHandle* scale) { + xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, + xla::XlaOp* scale) { *scale = b->Div(b->Sub(max, min), XlaHelpers::FloatLiteral(b, data_type, quant_max_value - quant_min_value)); - xla::ComputationDataHandle quant_min = + xla::XlaOp quant_min = XlaHelpers::FloatLiteral(b, data_type, quant_min_value); - xla::ComputationDataHandle zero_point_from_min = - b->Sub(quant_min, b->Div(min, *scale)); - xla::ComputationDataHandle quant_max = + xla::XlaOp zero_point_from_min = b->Sub(quant_min, b->Div(min, *scale)); + xla::XlaOp quant_max = XlaHelpers::FloatLiteral(b, data_type, quant_max_value); - xla::ComputationDataHandle nudged_zero_point = + xla::XlaOp nudged_zero_point = b->Select(b->Le(zero_point_from_min, quant_min), quant_min, b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, b->Round(zero_point_from_min))); @@ -68,22 +65,18 @@ void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); } -xla::ComputationDataHandle Quantize( - xla::ComputationBuilder* b, const xla::ComputationDataHandle& input, - const DataType data_type, - const xla::ComputationDataHandle& nudged_input_min, - const xla::ComputationDataHandle& nudged_input_max, - const xla::ComputationDataHandle& input_scale) { - xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); - xla::ComputationDataHandle inv_scale = b->Div(one, input_scale); - xla::ComputationDataHandle half = - XlaHelpers::FloatLiteral(b, data_type, 0.5f); +xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, + const DataType data_type, + const xla::XlaOp& nudged_input_min, + const xla::XlaOp& nudged_input_max, + const xla::XlaOp& input_scale) { + xla::XlaOp one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); + xla::XlaOp inv_scale = b->Div(one, input_scale); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5f); - xla::ComputationDataHandle clamped = - b->Clamp(nudged_input_min, input, nudged_input_max); - xla::ComputationDataHandle clamped_shifted = - b->Sub(clamped, nudged_input_min); - xla::ComputationDataHandle rounded = + xla::XlaOp clamped = b->Clamp(nudged_input_min, input, nudged_input_max); + xla::XlaOp clamped_shifted = b->Sub(clamped, nudged_input_min); + xla::XlaOp rounded = b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); return b->Add(b->Mul(rounded, input_scale), nudged_input_min); } @@ -111,18 +104,18 @@ class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min = + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); - xla::ComputationDataHandle nudged_input_max = + xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::ComputationDataHandle input_scale = + xla::XlaOp input_scale = XlaHelpers::FloatLiteral(b, data_type, input_scale_); - xla::ComputationDataHandle output = Quantize( - b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min, + nudged_input_max, input_scale); ctx->SetOutput(0, output); } @@ -159,23 +152,22 @@ class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle gradient = ctx->Input(0); + xla::XlaOp gradient = ctx->Input(0); const TensorShape gradient_shape = ctx->InputShape(0); - xla::ComputationDataHandle input = ctx->Input(1); + xla::XlaOp input = ctx->Input(1); const DataType data_type = ctx->input_type(1); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min = + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min = XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); - xla::ComputationDataHandle nudged_input_max = + xla::XlaOp nudged_input_max = XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); - xla::ComputationDataHandle between_nudged_min_max = + xla::XlaOp between_nudged_min_max = b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::ComputationDataHandle zeroes = b->Broadcast( - XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes()); - xla::ComputationDataHandle output = - b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zeroes = b->Broadcast(XlaHelpers::Zero(b, data_type), + gradient_shape.dim_sizes()); + xla::XlaOp output = b->Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output); } @@ -204,18 +196,18 @@ class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); - xla::ComputationDataHandle input_min = ctx->Input(1); - xla::ComputationDataHandle input_max = ctx->Input(2); + xla::XlaOp input_min = ctx->Input(1); + xla::XlaOp input_max = ctx->Input(2); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::ComputationDataHandle output = Quantize( - b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + xla::XlaOp output = Quantize(b, input, data_type, nudged_input_min, + nudged_input_max, input_scale); ctx->SetOutput(0, output); } @@ -243,47 +235,43 @@ class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle gradient = ctx->Input(0); + xla::XlaOp gradient = ctx->Input(0); const TensorShape gradient_shape = ctx->InputShape(0); - xla::ComputationDataHandle input = ctx->Input(1); + xla::XlaOp input = ctx->Input(1); const DataType data_type = ctx->input_type(1); const DataType accumulation_type = XlaHelpers::SumAccumulationType(data_type); - xla::ComputationDataHandle input_min = ctx->Input(2); - xla::ComputationDataHandle input_max = ctx->Input(3); + xla::XlaOp input_min = ctx->Input(2); + xla::XlaOp input_max = ctx->Input(3); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp nudged_input_min, nudged_input_max, input_scale; XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, &nudged_input_min, &nudged_input_max, &input_scale); - xla::ComputationDataHandle between_nudged_min_max = + xla::XlaOp between_nudged_min_max = b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type); - xla::ComputationDataHandle zeroes = - b->Broadcast(zero, gradient_shape.dim_sizes()); - xla::ComputationDataHandle output0 = - b->Select(between_nudged_min_max, gradient, zeroes); + xla::XlaOp zero = XlaHelpers::Zero(b, data_type); + xla::XlaOp zeroes = b->Broadcast(zero, gradient_shape.dim_sizes()); + xla::XlaOp output0 = b->Select(between_nudged_min_max, gradient, zeroes); ctx->SetOutput(0, output0); - xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); - xla::ComputationDataHandle select1 = b->Select(below_min, gradient, zeroes); - xla::ComputationDataHandle reduce1 = b->ReduceAll( + xla::XlaOp below_min = b->Lt(input, nudged_input_min); + xla::XlaOp select1 = b->Select(below_min, gradient, zeroes); + xla::XlaOp reduce1 = b->ReduceAll( XlaHelpers::ConvertElementType(b, select1, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::ComputationDataHandle output1 = - XlaHelpers::ConvertElementType(b, reduce1, data_type); + xla::XlaOp output1 = XlaHelpers::ConvertElementType(b, reduce1, data_type); ctx->SetOutput(1, output1); - xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); - xla::ComputationDataHandle select2 = b->Select(above_max, gradient, zeroes); - xla::ComputationDataHandle reduce2 = b->ReduceAll( + xla::XlaOp above_max = b->Gt(input, nudged_input_max); + xla::XlaOp select2 = b->Select(above_max, gradient, zeroes); + xla::XlaOp reduce2 = b->ReduceAll( XlaHelpers::ConvertElementType(b, select2, accumulation_type), XlaHelpers::Zero(b, accumulation_type), *ctx->GetOrCreateAdd(accumulation_type)); - xla::ComputationDataHandle output2 = - XlaHelpers::ConvertElementType(b, reduce2, data_type); + xla::XlaOp output2 = XlaHelpers::ConvertElementType(b, reduce2, data_type); ctx->SetOutput(2, output2); } diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index a4f3c1c3ad9..933924cad1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -62,9 +62,8 @@ class GenericFftOp : public XlaOpKernel { } } - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle fft = - b->Fft(ctx->Input(0), fft_type_, fft_length); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp fft = b->Fft(ctx->Input(0), fft_type_, fft_length); ctx->SetOutput(0, fft); } @@ -82,9 +81,11 @@ class FFTOp : public GenericFftOp { explicit FFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); -REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); -REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); +REGISTER_XLA_OP(Name("FFT").TypeConstraint("Tcomplex", DT_COMPLEX64), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + FFTOp<3>); template class IFFTOp : public GenericFftOp { @@ -92,9 +93,12 @@ class IFFTOp : public GenericFftOp { explicit IFFTOp(OpKernelConstruction* ctx) : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} }; -REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); -REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); -REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); +REGISTER_XLA_OP(Name("IFFT").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D").TypeConstraint("Tcomplex", DT_COMPLEX64), + IFFTOp<3>); template class RFFTOp : public GenericFftOp { diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index eaa13b8dfac..e4467a0fb13 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -48,7 +48,7 @@ class FillOp : public XlaOpKernel { 0, {dims_shape.num_elements()}, &dims_literal)); // Convert the dims literal into a vector that we can pass to - // ComputationBuilder. + // XlaBuilder. std::vector broadcast; broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { @@ -56,7 +56,7 @@ class FillOp : public XlaOpKernel { } // Look up the value input, reshaping to a scalar if it was a // 'legacy' scalar (secretly a vector). - xla::ComputationDataHandle data = ctx->Input(1); + xla::XlaOp data = ctx->Input(1); if (value_shape.dims() > 0) { CHECK_EQ(value_shape.dims(), 1); data = ctx->builder()->Reshape(data, {}); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 7945c05af40..d13e25bcdda 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -26,55 +26,55 @@ limitations under the License. namespace tensorflow { -Status XlaGather(const xla::ComputationDataHandle& input, - const TensorShape& input_shape, - const xla::ComputationDataHandle& indices, - TensorShape indices_shape, int64 axis, bool indices_are_nd, - DataType dtype, DataType index_type, - xla::ComputationBuilder* builder, - xla::ComputationDataHandle* gather_output) { +Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, const TensorShape& indices_shape, + int64 axis, bool indices_are_nd, DataType dtype, + DataType index_type, xla::XlaBuilder* builder, + xla::XlaOp* gather_output) { + // There is no deep reason why we need this precondition, but this is the only + // combination that is used and tested today. + CHECK(!indices_are_nd || axis == 0); + + // num_index_dims is the number of components in each index in the indices + // tensor. + // + // num_indices is the total number of (n dimensional or scalar) indices in the + // indices tensor. + // // If the indices are N-dimensional, then the minor dimension of indices // should be of size N and correspond to the N indices. - int64 num_index_dims = 1; + int64 num_index_dims; + int64 num_indices = 1; if (indices_are_nd) { CHECK_GE(indices_shape.dims(), 1); num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); - indices_shape.RemoveLastDims(1); + for (int64 i = 0, e = indices_shape.dims() - 1; i < e; i++) { + num_indices *= indices_shape.dim_size(i); + } + } else { + num_index_dims = 1; + for (int64 i = 0, e = indices_shape.dims(); i < e; i++) { + num_indices *= indices_shape.dim_size(i); + } } - // Although the indices Tensor is flattened into rank 1 during the lookup, - // and each scalar entry is used as an index into the first dimension of the - // input, the output is returned with shape: - // input.shape[:axis] + indices.shape + input.shape[axis+1:] - - const int64 num_indices = indices_shape.num_elements(); - TensorShape input_shape_pre_axis(input_shape); - input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); - TensorShape input_shape_post_axis(input_shape); - input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); - // Each slice of the input tensor has shape: - // [, 1, ..., 1, ] - TensorShape slice_shape(input_shape); - for (int64 i = 0; i < num_index_dims; ++i) { - slice_shape.set_dim(axis + i, 1); - } - - TensorShape loop_out_shape; - loop_out_shape.AppendShape(input_shape_pre_axis); - loop_out_shape.AddDim(num_indices); - loop_out_shape.AppendShape(input_shape_post_axis); - TensorShape loop_out_slice_shape; - loop_out_slice_shape.AppendShape(input_shape_pre_axis); - loop_out_slice_shape.AddDim(1); - loop_out_slice_shape.AppendShape(input_shape_post_axis); - - TensorShape out_shape; - out_shape.AppendShape(input_shape_pre_axis); - out_shape.AppendShape(indices_shape); - out_shape.AppendShape(input_shape_post_axis); - // Degenerate case: empty indices. if (num_indices == 0) { + TensorShape input_shape_pre_axis{input_shape}; + input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); + TensorShape input_shape_post_axis{input_shape}; + input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); + + TensorShape indices_shape_no_index_vectors{indices_shape}; + if (indices_are_nd) { + indices_shape_no_index_vectors.RemoveLastDims(1); + } + + TensorShape out_shape; + out_shape.AppendShape(input_shape_pre_axis); + out_shape.AppendShape(indices_shape_no_index_vectors); + out_shape.AppendShape(input_shape_post_axis); + *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), out_shape.dim_sizes()); return Status::OK(); @@ -88,76 +88,61 @@ Status XlaGather(const xla::ComputationDataHandle& input, } } - // Flatten the major dimensions of indices into a single dimension for ease of - // iteration. If there is an axis dimension, we must leave it alone. - std::vector flat_indices_shape = {num_indices}; - if (indices_are_nd) { - flat_indices_shape.push_back(num_index_dims); - } + // Example of a 1-D gather with axis=1, pulling two [3,1] tensors out of a + // tensor of shape [3,3]. + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // gather = s32[3,2] gather(operand, indices), + // output_window_dims={0}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3, 1} + // + // + // Example of an N-D gather pulling out slices of shape [1,1,2] out of a + // tensor of shape [3,3,2]. + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // gather = s32[2,2] gather(operand, indices), + // output_window_dims={1}, + // elided_window_dims={0,1}, + // gather_dims_to_operand_dims={0,1}, + // index_vector_dim=0, + // window_bounds={1,1,2} - // Specify the shape of the loop-carried Tensor tuple. - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - loop_out_shape.dim_sizes()); - auto init = {input, flat_indices, init_out}; - - // Construct the while loop body's function. The implementation of gather is: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // xi = dynamic-slice(input, index) - // output = dynamic-update-slice(output, xi, i) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* bodyb) { - auto input = loop_vars[0]; - auto indices = loop_vars[1]; - auto output = loop_vars[2]; - - auto zero_index = XlaHelpers::Zero(bodyb, index_type); - - // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; - auto indices_offset = bodyb->Reshape(i, {1}); - if (indices_are_nd) { - // Slice out the entire nd index, if applicable. - indices_offset = bodyb->Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = bodyb->Collapse(index, {0, 1}); + xla::GatherDimensionNumbers dim_numbers; + std::vector window_bounds; + window_bounds.reserve(input_shape.dims()); + for (int64 i = 0; i < input_shape.dims(); i++) { + int64 window_bound; + if (axis <= i && i < (axis + num_index_dims)) { + dim_numbers.add_elided_window_dims(i); + window_bound = 1; } else { - index = bodyb->DynamicSlice(indices, indices_offset, {1}); + window_bound = input_shape.dim_size(i); } - // Slice the corresponding data from the input array. - auto start_indices = bodyb->Pad( - index, zero_index, - xla::MakeEdgePaddingConfig( - {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice_i = bodyb->Reshape( - bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()), - loop_out_slice_shape.dim_sizes()); + window_bounds.push_back(window_bound); - // Construct the index into the output Tensor 0, ..., , 0, ... - std::vector out_index_vals( - loop_out_shape.dims(), bodyb->Reshape(zero_index, {1})); - out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1}); - auto out_index = bodyb->ConcatInDim(out_index_vals, 0); + if (i < axis) { + dim_numbers.add_output_window_dims(i); + } else if (i >= (axis + num_index_dims)) { + int64 indices_rank = + indices_are_nd ? (indices_shape.dims() - 1) : indices_shape.dims(); + dim_numbers.add_output_window_dims(i + indices_rank - num_index_dims); + } + } - // Update the output Tensor - auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); + dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1) + : indices_shape.dims()); + for (int64 i = axis; i < axis + num_index_dims; i++) { + dim_numbers.add_gather_dims_to_operand_dims(i); + } - return std::vector{input, indices, - updated_output}; - }; - - // Construct the While loop, extract and reshape the output. - xla::PrimitiveType ptype; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype)); - TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn, - init, "gather", builder)); - *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes()); + *gather_output = builder->Gather(input, indices, dim_numbers, window_bounds); return Status::OK(); } @@ -166,7 +151,7 @@ class GatherOp : public XlaOpKernel { explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto input = context->Input(0); auto input_shape = context->InputShape(0); auto indices = context->Input(1); @@ -195,7 +180,7 @@ class GatherOp : public XlaOpKernel { OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, errors::InvalidArgument("indices must be int32 or int64")); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( context, XlaGather(input, input_shape, indices, indices_shape, axis, /*indices_are_nd=*/false, input_type(0), index_type, @@ -233,10 +218,10 @@ class GatherNdOp : public XlaOpKernel { indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", params_shape.dims())); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto params = context->Input(0); auto indices = context->Input(1); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/true, params_type, diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index bd8b92c22d7..d898e43b858 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -33,13 +33,11 @@ namespace tensorflow { // If `indices_are_nd` is true, the last dimension of `indices` are treated as // a multidimensional index values. Otherwise, `indices` is treated as a tensor // of scalar indices. -Status XlaGather(const xla::ComputationDataHandle& input, - const TensorShape& input_shape, - const xla::ComputationDataHandle& indices, - TensorShape indices_shape, int64 axis, bool indices_are_nd, - DataType dtype, DataType index_type, - xla::ComputationBuilder* builder, - xla::ComputationDataHandle* gather_output); +Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, const TensorShape& indices_shape, + int64 axis, bool indices_are_nd, DataType dtype, + DataType index_type, xla::XlaBuilder* builder, + xla::XlaOp* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 39af662b638..e72200bfbcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -38,6 +38,7 @@ class IdentityOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); REGISTER_XLA_OP(Name("Snapshot"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index eefbe55c815..8b9b026643c 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -37,7 +37,7 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { // TODO(b/35949885): There is duplication here with the handling of the // while_op. Refactor the common code out/rework. void XlaIfOp::Compile(XlaOpKernelContext* ctx) { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); OP_REQUIRES(ctx, cond_type_ == DT_BOOL, errors::InvalidArgument( @@ -48,7 +48,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); + std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; @@ -175,19 +175,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } - xla::ComputationDataHandle outputs = + xla::XlaOp outputs = b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, b->Tuple(inputs), *else_result.computation); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { if (ctx->input_type(i) != DT_RESOURCE) { - xla::ComputationDataHandle output_handle = b->GetTupleElement(outputs, i); + xla::XlaOp output_handle = b->GetTupleElement(outputs, i); if (VLOG_IS_ON(2)) { LOG(INFO) << "Setting output " << i; auto shape_or = b->GetShape(output_handle); if (shape_or.ok()) { LOG(INFO) << "Shape for output " << i << ": " - << xla::ShapeUtil::HumanString(*shape_or.ValueOrDie()); + << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); } else { LOG(INFO) << "Shape unknown for output " << i; } diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 5eeda79a935..1568b336799 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -23,10 +23,9 @@ namespace { // Converts 'input' from RGB format to HSV format. // 'shape' is the shape of the red/green/blue tensors. -std::array RGBToHSV( - XlaOpKernelContext* ctx, xla::ComputationBuilder* b, - const std::array& rgb, DataType dtype, - const TensorShape& shape) { +std::array RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, + const std::array& rgb, + DataType dtype, const TensorShape& shape) { auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::One(b, dtype); @@ -54,12 +53,12 @@ std::array RGBToHSV( } // Converts 'input' from HSV format to RGB format. -std::array HSVToRGB( - xla::ComputationBuilder* b, - const std::array& hsv, DataType dtype) { - xla::ComputationDataHandle hue = hsv[0]; - xla::ComputationDataHandle saturation = hsv[1]; - xla::ComputationDataHandle value = hsv[2]; +std::array HSVToRGB(xla::XlaBuilder* b, + const std::array& hsv, + DataType dtype) { + xla::XlaOp hue = hsv[0]; + xla::XlaOp saturation = hsv[1]; + xla::XlaOp value = hsv[2]; auto zero = XlaHelpers::Zero(b, dtype); auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); @@ -95,16 +94,16 @@ class RGBToHSVOp : public XlaOpKernel { errors::FailedPrecondition("input must have 3 channels but input has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; @@ -133,15 +132,15 @@ class HSVToRGBOp : public XlaOpKernel { errors::FailedPrecondition("input must have 3 channels but input has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle hue = + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp hue = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle saturation = + xla::XlaOp saturation = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle value = + xla::XlaOp value = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); @@ -174,9 +173,9 @@ class AdjustContrastOpV2 : public XlaOpKernel { errors::InvalidArgument("contrast_factor must be scalar: ", factor_shape.DebugString())); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle factor = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp factor = context->Input(1); DataType type = context->input_type(0); @@ -221,19 +220,19 @@ class AdjustSaturationOp : public XlaOpKernel { errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle scale = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp scale = context->Input(1); DataType type = context->input_type(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; @@ -271,19 +270,19 @@ class AdjustHueOp : public XlaOpKernel { errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); - xla::ComputationBuilder* b = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle delta = context->Input(1); + xla::XlaBuilder* b = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp delta = context->Input(1); DataType type = context->input_type(0); - xla::ComputationDataHandle red = + xla::XlaOp red = b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle green = + xla::XlaOp green = b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, /*dimno=*/channel_dim); - xla::ComputationDataHandle blue = + xla::XlaOp blue = b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, /*dimno=*/channel_dim); TensorShape channel_shape = input_shape; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index f36b3f59482..9058cbc7476 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -99,9 +99,9 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } -xla::ComputationDataHandle MakeBilinearResizeKernel( - xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, - int64 channels) { +xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels) { // Form a 2D convolution kernel like: // 1 2 3 2 1 // 2 4 6 4 2 @@ -120,7 +120,7 @@ xla::ComputationDataHandle MakeBilinearResizeKernel( return kernel; }; - xla::ComputationDataHandle channels_iota; + xla::XlaOp channels_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK( XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); @@ -139,10 +139,12 @@ xla::ComputationDataHandle MakeBilinearResizeKernel( /*broadcast_dimensions=*/{0}); } -xla::ComputationDataHandle ResizeUsingDilationAndConvolution( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input, - const int num_spatial_dims, std::vector in_size, - std::vector out_size, const int64 channels) { +xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, + const xla::XlaOp& input, + const int num_spatial_dims, + std::vector in_size, + std::vector out_size, + const int64 channels) { // Picture for a 1x3 to 1x4 resize: // stride = 2, kernel size = 3 // Input: @@ -168,9 +170,9 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution( ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::ComputationDataHandle kernel = + xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::ComputationDataHandle output = builder->ConvGeneralDilated( + xla::XlaOp output = builder->ConvGeneralDilated( input, kernel, dims.stride, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -189,10 +191,12 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolution( return output; } -xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad, - const int num_spatial_dims, std::vector in_size, - std::vector grad_size, const int64 channels) { +xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, + const xla::XlaOp& grad, + const int num_spatial_dims, + std::vector in_size, + std::vector grad_size, + const int64 channels) { ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, grad_size); @@ -210,7 +214,7 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::ComputationDataHandle kernel = + xla::XlaOp kernel = MakeBilinearResizeKernel(builder, dims.kernel_size, channels); // Broadcast the input kernel where the forward op expanded from a size == 1 @@ -223,7 +227,7 @@ xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( } } - xla::ComputationDataHandle output = builder->ConvGeneralDilated( + xla::XlaOp output = builder->ConvGeneralDilated( grad, kernel, /*window_strides=*/dims.kernel_size, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, @@ -258,7 +262,7 @@ class ResizeBilinearOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape input_shape = ctx->InputShape(0); OP_REQUIRES(ctx, input_shape.dims() == 4, @@ -283,7 +287,7 @@ class ResizeBilinearOp : public XlaOpKernel { const int num_spatial_dims = 2; - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in // dimension i. @@ -318,7 +322,7 @@ class ResizeBilinearOp : public XlaOpKernel { // from image of size axb -> cxd is same as resizing axb -> exf -> cxd. // // This makes the convolutions kernels smaller and the operation faster. - xla::ComputationDataHandle output = input; + xla::XlaOp output = input; while (in_size != out_size) { if (in_size[0] != 1 && in_size[1] != 1) { std::vector k = { @@ -369,7 +373,7 @@ class ResizeBilinearGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape input_shape = ctx->InputShape(1); OP_REQUIRES(ctx, input_shape.dims() == 4, @@ -406,9 +410,9 @@ class ResizeBilinearGradOp : public XlaOpKernel { const int num_spatial_dims = 2; - xla::ComputationDataHandle grad = ctx->Input(0); + xla::XlaOp grad = ctx->Input(0); - xla::ComputationDataHandle output = grad; + xla::XlaOp output = grad; while (in_size != grad_size) { if (in_size[0] != 1 && in_size[1] != 1) { std::vector k = { diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 7bf4b435f52..36eb4c75454 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -61,10 +61,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { DataType index_type = output_type(0); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); - xla::ComputationDataHandle output; + xla::XlaOp output; if (is_min_) { OP_REQUIRES_OK(ctx, XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index b1f3c3c298c..2c2d88486fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -71,10 +71,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), errors::InvalidArgument( "ArgMax implementation requires a CustomCall on CPU")); - xla::ComputationBuilder& b = *ctx->builder(); + xla::XlaBuilder& b = *ctx->builder(); // XLA passes to the function, so it is not included here. - std::vector args; + std::vector args; args.push_back(ctx->Input(0)); args.push_back(b.ConstantLiteral( *xla::Literal::CreateR1(input_shape.dim_sizes()))); @@ -91,7 +91,7 @@ class ArgMaxCustomCallOp : public XlaOpKernel { // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. - xla::ComputationDataHandle output; + xla::XlaOp output; switch (input_shape.dims()) { case 1: output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index c177f08d9c4..1decf7d72d7 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/no_op.h" @@ -33,7 +33,7 @@ class L2LossOp : public XlaOpKernel { std::iota(dims.begin(), dims.end(), 0); DataType dtype = ctx->input_type(0); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); // output = sum(t ** 2) / 2 const DataType accumulation_type = XlaHelpers::SumAccumulationType(dtype); diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc new file mode 100644 index 00000000000..0388b4c8307 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64 +// input. + +#include + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { + +constexpr std::array kListDiffTypes = {DT_INT32, DT_INT64}; + +// ListDiffOp is an XLA kernel that supports constant-only x and y input. +class ListDiffOp : public XlaOpKernel { + public: + explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)), + errors::InvalidArgument("ListDiff expects x as a vector, not ", + context->InputShape(0).DebugString())); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)), + errors::InvalidArgument("ListDiff expects y as a vector, not ", + context->InputShape(1).DebugString())); + + DataType val_type = context->expected_output_dtype(0); + DataType idx_type = context->expected_output_dtype(1); + + Status status; + switch (val_type) { + case DT_INT32: + status = ListDiffWithIndexType(context, idx_type); + break; + case DT_INT64: + status = ListDiffWithIndexType(context, idx_type); + break; + default: + // This should never happen since we restrict this kernel to only match + // inputs with supported Tensor datatype. + status = errors::InvalidArgument("ListDiff expects x and y as either ", + "int32 or int64, not ", + DataTypeString(val_type)); + } + OP_REQUIRES_OK(context, status); + } + + private: + template + Status ListDiff(XlaOpKernelContext* context) { + std::vector x_input, y_input; + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); + TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); + + std::unordered_set y_input_set; + y_input_set.reserve(y_input.size()); + for (auto y : y_input) { + y_input_set.insert(y); + } + + std::vector val_output; + std::vector idx_output; + auto x_size = x_input.size(); + for (Tidx i = 0; i < x_size; ++i) { + if (y_input_set.count(x_input[i]) > 0) { + continue; + } + val_output.push_back(x_input[i]); + idx_output.push_back(i); + } + + context->SetOutput(0, context->builder()->ConstantR1(val_output)); + context->SetOutput(1, context->builder()->ConstantR1(idx_output)); + return Status::OK(); + } + + template + Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) { + switch (idx_type) { + case DT_INT32: + return ListDiff(context); + case DT_INT64: + return ListDiff(context); + default: + return errors::InvalidArgument( + "ListDiff expects idx_out as either int32 or int64, not ", + DataTypeString(idx_type)); + } + } +}; + +REGISTER_XLA_OP(Name("ListDiff") + .TypeConstraint("T", kListDiffTypes) + .CompileTimeConstInput("x") + .CompileTimeConstInput("y"), + ListDiffOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 1cfee3070f3..39fbf98a627 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -38,8 +38,8 @@ class LRNOp : public XlaOpKernel { OP_REQUIRES(ctx, in_shape.dims() == 4, errors::InvalidArgument("in must be 4-dimensional")); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); // sqr_sum[a, b, c, d] = // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2) @@ -111,10 +111,10 @@ class LRNGradOp : public XlaOpKernel { "input_grads, input_image, and out_image should have the same " "shape")); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle in_grads = ctx->Input(0); - xla::ComputationDataHandle in_image = ctx->Input(1); - xla::ComputationDataHandle out_image = ctx->Input(2); + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp in_grads = ctx->Input(0); + xla::XlaOp in_image = ctx->Input(1); + xla::XlaOp out_image = ctx->Input(2); // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python // pseudo-code, the Eigen code does this for each spatial position: @@ -166,7 +166,7 @@ class LRNGradOp : public XlaOpKernel { auto dy_reduced = XlaHelpers::ConvertElementType(builder, dy_reduce, input_type(0)); - xla::ComputationDataHandle gradients = builder->Add( + xla::XlaOp gradients = builder->Add( builder->Mul(in_image, dy_reduced), builder->Mul(in_grads, builder->Pow(norm, builder->ConstantR0(-beta_)))); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 886baf81152..6949b296f4b 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -66,8 +66,8 @@ class MatMulOp : public XlaOpKernel { a_shape.DebugString(), ", In[1]: ", b_shape.DebugString())); - xla::ComputationDataHandle a = ctx->Input(0); - xla::ComputationDataHandle b = ctx->Input(1); + xla::XlaOp a = ctx->Input(0); + xla::XlaOp b = ctx->Input(1); if (is_sparse_) { if (a_type_ == DT_BFLOAT16) { a = ctx->builder()->ConvertElementType(a, xla::F32); diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc index faa415a97b0..fbd5dc0fdad 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -44,10 +44,10 @@ class MatrixBandPartOp : public XlaOpKernel { errors::InvalidArgument("num_upper must be scalar, got shape ", num_upper_in_shape.DebugString())); - xla::ComputationBuilder* builder = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle num_lower = context->Input(1); - xla::ComputationDataHandle num_upper = context->Input(2); + xla::XlaBuilder* builder = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp num_lower = context->Input(1); + xla::XlaOp num_upper = context->Input(2); DataType input_type = context->input_type(0); DataType index_type = context->input_type(1); @@ -58,10 +58,10 @@ class MatrixBandPartOp : public XlaOpKernel { // Compute 'offset', which is how many diagonals we are above/below the // diagonal. - xla::ComputationDataHandle iota_m; + xla::XlaOp iota_m; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); - xla::ComputationDataHandle iota_n; + xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc index b2940bdcff7..db53f6fef8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -54,16 +54,16 @@ class MatrixSetDiagOp : public XlaOpKernel { input_shape.DebugString(), " and diagonal shape: ", diag_shape.DebugString())); - xla::ComputationBuilder* builder = context->builder(); - xla::ComputationDataHandle input = context->Input(0); - xla::ComputationDataHandle diag = context->Input(1); + xla::XlaBuilder* builder = context->builder(); + xla::XlaOp input = context->Input(0); + xla::XlaOp diag = context->Input(1); auto zero = XlaHelpers::Zero(builder, context->input_type(0)); // Create an indicator tensor that is true only on the diagonal. - xla::ComputationDataHandle iota_m; + xla::XlaOp iota_m; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); - xla::ComputationDataHandle iota_n; + xla::XlaOp iota_n; OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); auto indicator = builder->Eq(iota_m, builder->Broadcast(iota_n, {m}), diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 05a36a031ad..7e9de3ef9b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -25,10 +25,11 @@ class MirrorPadOp : public XlaOpKernel { public: explicit MirrorPadOp(OpKernelConstruction* context) : XlaOpKernel(context) {} - xla::StatusOr DoMirrorPad( - const xla::ComputationDataHandle& t, const xla::Shape& original_shape, - const xla::Literal& pad_literal, xla::ComputationBuilder* b) { - xla::ComputationDataHandle accum = t; + xla::StatusOr DoMirrorPad(const xla::XlaOp& t, + const xla::Shape& original_shape, + const xla::Literal& pad_literal, + xla::XlaBuilder* b) { + xla::XlaOp accum = t; for (int64 dimno = xla::ShapeUtil::Rank(original_shape) - 1; dimno >= 0; --dimno) { auto t_rev = b->Rev(accum, {dimno}); @@ -76,12 +77,12 @@ class MirrorPadOp : public XlaOpKernel { OP_REQUIRES_OK( ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); auto in0 = ctx->Input(0); - xla::StatusOr> in0_shape = b->GetShape(in0); + xla::StatusOr in0_shape = b->GetShape(in0); OP_REQUIRES(ctx, in0_shape.ok(), in0_shape.status()); - xla::StatusOr accum_status = - DoMirrorPad(in0, *in0_shape.ValueOrDie(), pad_literal, b); + xla::StatusOr accum_status = + DoMirrorPad(in0, in0_shape.ValueOrDie(), pad_literal, b); OP_REQUIRES_OK(ctx, accum_status.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 9f7c9913802..cac2eea96ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -62,7 +62,7 @@ class OneHotOp : public XlaOpKernel { ctx, depth >= 0, errors::InvalidArgument("depth must be non-negative, got: ", depth)); - xla::ComputationDataHandle one_hot; + xla::XlaOp one_hot; OP_REQUIRES_OK( ctx, XlaHelpers::OneHot(ctx->builder(), depth, axis, input_type(0), indices_shape, ctx->Input(0), ctx->Input(2), diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index a4318e29d25..aecaabb6dcf 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -43,7 +43,7 @@ class PackOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - std::vector values; + std::vector values; std::vector shapes; OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); const int num = values.size(); @@ -69,7 +69,7 @@ class PackOp : public XlaOpKernel { -expanded_num_dims, ", ", expanded_num_dims, ")")); - std::vector reshaped_inputs(num); + std::vector reshaped_inputs(num); TensorShape child_shape(shapes[0]); child_shape.InsertDim(axis, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 791351637ae..7c95475e7b1 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -70,7 +70,7 @@ class PadOp : public XlaOpKernel { } // PadV2 added a "constant_values" input that indicates the pad value. - xla::ComputationDataHandle constant_values; + xla::XlaOp constant_values; if (ctx->num_inputs() == 3) { OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(2)), errors::InvalidArgument("constant_values must be a scalar.")); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 5f635dd1bc6..f8e7b48a0fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -66,15 +66,15 @@ class PoolingOp : public XlaOpKernel { int num_dims() const { return num_spatial_dims_ + 2; } // Method that builds an initial value to use in reductions. - virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) = 0; + virtual xla::XlaOp InitValue(xla::XlaBuilder* b) = 0; // The reduction operation to apply to each window. - virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx) = 0; + virtual const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) = 0; // A post-processing operation to apply on the outputs of the ReduceWindow. - virtual xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) = 0; + virtual xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) = 0; void Compile(XlaOpKernelContext* ctx) override { std::vector ksize = ksize_; @@ -110,7 +110,7 @@ class PoolingOp : public XlaOpKernel { " operator must have ", num_dims(), " dimensions")); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); auto input = XlaHelpers::ConvertElementType(b, ctx->Input(0), reduction_type_); auto reduce = ctx->builder()->ReduceWindow( @@ -135,17 +135,17 @@ class MaxPoolOp : public PoolingOp { : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims, /*reduction_type=*/ctx->input_type(0)) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + xla::XlaOp InitValue(xla::XlaBuilder* b) override { return XlaHelpers::MinValue(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { return ctx->GetOrCreateMax(reduction_type_); } - xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) override { + xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) override { return output; } }; @@ -176,9 +176,9 @@ REGISTER_XLA_OP(Name("MaxPool3D"), MaxPool3DOp); // Common computation shared between AvgPool and AvgPoolGrad. Divide each // element of an image by the count of elements that contributed to that // element during pooling. -static xla::ComputationDataHandle AvgPoolDivideByCount( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape, xla::Padding padding, +static xla::XlaOp AvgPoolDivideByCount( + XlaOpKernelContext* ctx, const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape, xla::Padding padding, const std::vector& ksize, const std::vector& stride, int num_spatial_dims, TensorFormat data_format) { if (padding == xla::Padding::kValid) { @@ -234,17 +234,17 @@ class AvgPoolOp : public PoolingOp { /*reduction_type=*/ XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b) override { + xla::XlaOp InitValue(xla::XlaBuilder* b) override { return XlaHelpers::Zero(b, reduction_type_); } - const xla::Computation* Reduction(XlaOpKernelContext* ctx) override { + const xla::XlaComputation* Reduction(XlaOpKernelContext* ctx) override { return ctx->GetOrCreateAdd(reduction_type_); } - xla::ComputationDataHandle PostProcessOutput( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output, - DataType dtype, const TensorShape& input_shape) override { + xla::XlaOp PostProcessOutput(XlaOpKernelContext* ctx, + const xla::XlaOp& output, DataType dtype, + const TensorShape& input_shape) override { return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_, ksize_, stride_, num_spatial_dims_, data_format_); @@ -344,11 +344,10 @@ class MaxPoolGradOp : public XlaOpKernel { xla::PrimitiveType element_type; OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type)); - xla::ComputationDataHandle init_value = - XlaHelpers::Zero(ctx->builder(), input_type(2)); + xla::XlaOp init_value = XlaHelpers::Zero(ctx->builder(), input_type(2)); auto select = CreateScalarGeComputation(element_type, ctx->builder()); auto scatter = CreateScalarAddComputation(element_type, ctx->builder()); - xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter( + xla::XlaOp gradients = ctx->builder()->SelectAndScatter( input, select, ksize_, stride_, xla_padding, out_backprop, init_value, scatter); @@ -462,7 +461,7 @@ class AvgPoolGradOp : public XlaOpKernel { // The input gradients are computed by a convolution of the output gradients // and the filter, with some appropriate padding. See the comment at the top // of conv_grad_ops.h for details. - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); auto out_backprop = ctx->Input(1); auto dtype = input_type(1); xla::Padding xla_padding = diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 4171e076ff6..661cd5923e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -35,7 +35,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const DataType data_type = ctx->input_type(0); // Comments taken from semantics description at @@ -46,8 +46,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // m = max(abs(input_min), abs(input_max)) if range_given is true, // m = max(abs(min_elem(input)), // abs(max_elem(input))) otherwise. - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input_min, input_max; + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input_min, input_max; if (range_given_) { double input_min_value, input_max_value; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value)); @@ -55,14 +55,14 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value); input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value); } else { - const xla::Computation* fmax = ctx->GetOrCreateMax(data_type); - const xla::Computation* fmin = ctx->GetOrCreateMin(data_type); + const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); + const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); input_min = b->ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); input_max = b->ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); } - xla::ComputationDataHandle m = b->Max(b->Abs(input_min), b->Abs(input_max)); + xla::XlaOp m = b->Max(b->Abs(input_min), b->Abs(input_max)); // Next, we choose our fixed-point quantization buckets, [min_fixed, // max_fixed]. If signed_input is true, this is @@ -85,7 +85,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // From this we compute our scaling factor, s: // // s = (max_fixed - min_fixed) / (2 * m). - xla::ComputationDataHandle s = + xla::XlaOp s = b->Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed), b->Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m)); @@ -93,7 +93,7 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // e is transformed into e': // // e' = (e * s).round_to_nearest() / s. - xla::ComputationDataHandle result = b->Div(b->Round(b->Mul(input, s)), s); + xla::XlaOp result = b->Div(b->Round(b->Mul(input, s)), s); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index c0994c434bc..5f5bd586376 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -41,9 +41,9 @@ class RandomUniformOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle result = b->RngUniform( - XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp result = b->RngUniform(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -100,11 +100,11 @@ class RandomStandardNormalOp : public XlaOpKernel { xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Normal distribution with a mean of 0 and a standard deviation of 1: - xla::ComputationDataHandle result = b->RngNormal( - XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape); + xla::XlaOp result = b->RngNormal(XlaHelpers::Zero(b, dtype), + XlaHelpers::One(b, dtype), xla_shape); ctx->SetOutput(0, result); } @@ -130,19 +130,18 @@ class TruncatedNormalOp : public XlaOpKernel { xla::Shape xla_element_shape = xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); - xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); - xla::ComputationDataHandle candidate = - b->RngNormal(mean, stddev, xla_shape); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp mean = XlaHelpers::Zero(b, dtype); + xla::XlaOp stddev = XlaHelpers::One(b, dtype); + xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape); - auto two_sd = [dtype](bool negate, xla::ComputationBuilder* b) { + auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); }; - auto out_of_range_mask = [two_sd](xla::ComputationDataHandle candidate, - xla::ComputationBuilder* b) { - xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); - xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); + auto out_of_range_mask = [two_sd](xla::XlaOp candidate, + xla::XlaBuilder* b) { + xla::XlaOp too_large = b->Gt(candidate, two_sd(false, b)); + xla::XlaOp too_small = b->Lt(candidate, two_sd(true, b)); return b->Or(too_large, too_small); }; @@ -152,35 +151,32 @@ class TruncatedNormalOp : public XlaOpKernel { // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd // candidate = select(out_of_range_mask, rng_normal(), candidate) // } - std::unique_ptr test_builder = + std::unique_ptr test_builder = b->CreateSubBuilder("truncated_normal_test"); { auto* b = test_builder.get(); - xla::ComputationDataHandle candidate = - b->Parameter(0, xla_shape, "candidate"); - xla::ComputationDataHandle oor_mask = out_of_range_mask(candidate, b); + xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); + out_of_range_mask(candidate, b); OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); } - std::unique_ptr body_builder = + std::unique_ptr body_builder = b->CreateSubBuilder("truncated_normal_body"); { auto* b = body_builder.get(); - xla::ComputationDataHandle candidate = - b->Parameter(0, xla_shape, "candidate"); - xla::ComputationDataHandle to_resample = out_of_range_mask(candidate, b); - xla::ComputationDataHandle mean = XlaHelpers::Zero(b, dtype); - xla::ComputationDataHandle stddev = XlaHelpers::One(b, dtype); + xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); + xla::XlaOp to_resample = out_of_range_mask(candidate, b); + xla::XlaOp mean = XlaHelpers::Zero(b, dtype); + xla::XlaOp stddev = XlaHelpers::One(b, dtype); b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); } - xla::StatusOr test_computation = test_builder->Build(); + xla::StatusOr test_computation = test_builder->Build(); OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); + xla::StatusOr body_computation = body_builder->Build(); OP_REQUIRES_OK(ctx, body_computation.status()); - xla::ComputationDataHandle result = - b->While(test_computation.ValueOrDie(), body_computation.ValueOrDie(), - candidate); + xla::XlaOp result = b->While(test_computation.ValueOrDie(), + body_computation.ValueOrDie(), candidate); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index cb144bea9e4..08894489ac7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -65,7 +64,7 @@ class ReduceWindowOp : public XlaOpKernel { "rank (", padding_high_.size(), " vs. ", rank, ")")); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -95,15 +94,15 @@ class ReduceWindowOp : public XlaOpKernel { xla::ShapeUtil::HumanString(reducer.xla_output_shape))); // Wraps the reducer in a computation that unpacks the output tuple. - xla::Computation wrapper; + xla::XlaComputation wrapper; { - std::unique_ptr cb = + std::unique_ptr cb = builder->CreateSubBuilder("wrapper"); auto x = cb->Parameter(0, scalar_shape, "x"); auto y = cb->Parameter(1, scalar_shape, "y"); auto outputs = cb->Call(*reducer.computation, {x, y}); cb->GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); + xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(context, result.status()); wrapper = std::move(result.ValueOrDie()); } @@ -113,7 +112,7 @@ class ReduceWindowOp : public XlaOpKernel { padding[i] = {padding_low_[i], padding_high_[i]}; } - xla::ComputationDataHandle output = builder->ReduceWindowWithGeneralPadding( + xla::XlaOp output = builder->ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), wrapper, window_dimensions_, window_strides_, padding); context->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 812d258cd16..0f425637795 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -30,13 +30,11 @@ class SumOp : public XlaReductionOp { explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::Zero(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Add(scalar_lhs, scalar_rhs); } }; @@ -49,14 +47,12 @@ class ProdOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::One(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Mul(scalar_lhs, scalar_rhs); } }; @@ -69,14 +65,12 @@ class MinOp : public XlaReductionOp { explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::MaxValue(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Min(scalar_lhs, scalar_rhs); } }; @@ -88,14 +82,12 @@ class MaxOp : public XlaReductionOp { explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::MinValue(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Max(scalar_lhs, scalar_rhs); } }; @@ -108,20 +100,17 @@ class MeanOp : public XlaReductionOp { : XlaReductionOp(ctx, XlaHelpers::SumAccumulationType(ctx->input_type(0))) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return XlaHelpers::Zero(builder, reduction_type_); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Add(scalar_lhs, scalar_rhs); } - xla::ComputationDataHandle BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced) override { + xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced) override { auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0), num_elements_reduced); return builder->Div(reduce_output, divisor); @@ -136,14 +125,12 @@ class AllOp : public XlaReductionOp { explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return builder->ConstantR0(true); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->And(scalar_lhs, scalar_rhs); } }; @@ -155,14 +142,12 @@ class AnyOp : public XlaReductionOp { explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx, ctx->input_type(0)) {} - xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) override { + xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { return builder->ConstantR0(false); } - void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) override { + void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) override { builder->Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index f3181f0dadc..2ecfb854a1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_ #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { @@ -28,35 +28,33 @@ namespace tensorflow { // to override: description is a textual description of the mapped // function; InitialValue constructs the base case for the reduction; // BuildReducer adds the implementation of the reduction lambda to a -// xla::ComputationBuilder and BuildFinalizer adds the +// xla::XlaBuilder and BuildFinalizer adds the // implementation of the finalizer lambda (if there is one) to a -// xla::ComputationBuilder. +// xla::XlaBuilder. class XlaReductionOp : public XlaOpKernel { public: XlaReductionOp(OpKernelConstruction* ctx, DataType reduction_type); ~XlaReductionOp() override {} // Return the base case for the reduction. - virtual xla::ComputationDataHandle InitialValue( - xla::ComputationBuilder* builder) = 0; + virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; // Implement the (scalar,scalar)->scalar lambda that should be // applied to each pair of elements to be reduced. The desired // computation should be added to 'builder' and // '(scalar_lhs,scalar_rhs)' are the function's inputs. - virtual void BuildReducer(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& scalar_lhs, - const xla::ComputationDataHandle& scalar_rhs) = 0; + virtual void BuildReducer(xla::XlaBuilder* builder, + const xla::XlaOp& scalar_lhs, + const xla::XlaOp& scalar_rhs) = 0; // Applies a transformation to the output of the reduction. The desired // computation should be added to 'builder'. Argument 'reduce_output' is the // output of the reduction. 'num_elements_reduced' is the number of elements // that contributed to the reduction. Returns the transformed reduction // output, Defaults to returning 'reduce_output' unchanged. - virtual xla::ComputationDataHandle BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced); + virtual xla::XlaOp BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced); void Compile(XlaOpKernelContext* ctx) override; diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 64fe765ae9a..4fd5bfd0399 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -35,10 +35,9 @@ XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx, // Unless BuildFinalizer is overridden the reduction has no // finalizer. -xla::ComputationDataHandle XlaReductionOp::BuildFinalizer( - xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& reduce_output, - int64 num_elements_reduced) { +xla::XlaOp XlaReductionOp::BuildFinalizer(xla::XlaBuilder* builder, + const xla::XlaOp& reduce_output, + int64 num_elements_reduced) { return reduce_output; } @@ -96,9 +95,9 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { string desc = ctx->op_kernel().name(); - xla::ComputationBuilder* const b = ctx->builder(); + xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::ComputationBuilder r(b->client(), strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); @@ -110,7 +109,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { auto ry = r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y"); // Call virtual method to build the reduction lambda. BuildReducer(&r, rx, ry); - xla::Computation reduction_computation = r.Build().ConsumeValueOrDie(); + xla::XlaComputation reduction_computation = r.Build().ConsumeValueOrDie(); auto reduce = b->Reduce(data, initial, reduction_computation, xla_axes); auto deconverted = XlaHelpers::ConvertElementType(b, reduce, input_type(0)); diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index 12a35529992..ba7d484d53d 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -32,7 +32,7 @@ class ReluOp : public XlaOpKernel { explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); ctx->SetOutput(0, builder->Max(zero, ctx->Input(0))); } @@ -43,7 +43,7 @@ class Relu6Op : public XlaOpKernel { explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} // Clamp the scalar input between 0 and 6. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto zero = XlaHelpers::Zero(builder, input_type(0)); auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6); ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six)); @@ -56,7 +56,7 @@ class ReluGradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); @@ -71,7 +71,7 @@ class Relu6GradOp : public XlaOpKernel { // Return the lhs (incoming gradient) if the rhs (input feature) > 0, // otherwise return 0. void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape shape = ctx->InputShape(0); const auto zero = b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index c283e3b02c2..a7112786384 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -45,7 +45,7 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); auto is_constant = ctx->builder()->IsConstant(input); @@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + if (tc.resolve_compile_time_constants() && + (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - // The core from which a return value is returned depends on the core - // assignment of the input to the retval .Since we can't change the core - // assignment of as this point, create a tuple/get-tuple-element - // combination so that the core will be set on them. - auto tuple_elem = - ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); - tc.AddRetval(index_, dtype_, tuple_elem); + TensorShape shape = ctx->InputShape(0); + TensorShape representation_shape = + tc.is_entry_computation() + ? tc.RepresentationShape(shape, ctx->input_type(0)) + : shape; + + xla::XlaOp output = input; + if (tc.is_entry_computation()) { + output = + ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + } else { + // The core from which a return value is returned depends on the + // device assignment of the input to the retval. Since we can't change + // the device assignment of "input" at this point, we must always + // introduce an operator here, even if the shape does not change. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + output = ctx->builder()->GetTupleElement( + ctx->builder()->Tuple({output}), 0); + } + tc.AddRetval(index_, dtype_, shape, output); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index e51d3869267..2872a3c4d49 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -48,7 +48,7 @@ class ReverseOp : public XlaOpKernel { ctx->SetOutput(0, ctx->Input(0)); return; } - // ComputationBuilder::Rev() requires concrete values for dimensions arg. + // XlaBuilder::Rev() requires concrete values for dimensions arg. xla::Literal lax; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); std::vector revdims(x_shape.dims()); @@ -90,7 +90,7 @@ class ReverseV2Op : public XlaOpKernel { ctx->SetOutput(0, ctx->Input(0)); return; } - // ComputationBuilder::Rev() requires concrete values for dimensions arg. + // XlaBuilder::Rev() requires concrete values for dimensions arg. std::vector axes; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &axes)); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 6bc5d3adb09..5d1c0526849 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -54,7 +54,7 @@ class ReverseSequenceOp : public XlaOpKernel { "), ", "(", seq_lens_shape.num_elements(), " vs. ", input_shape.dim_size(batch_dim_))); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); const auto input = context->Input(0); const auto seq_lens = context->Input(1); @@ -106,20 +106,40 @@ class ReverseSequenceOp : public XlaOpKernel { seq_lens, body_builder->Reshape(i, {1}), {1}); // Indices is the offset of the batch element in the input. - auto indices = body_builder->Broadcast( + auto batch_element_indices = body_builder->Broadcast( XlaHelpers::Zero(body_builder.get(), seq_lens_type), {input_shape.dims()}); - indices = body_builder->DynamicUpdateSlice( - indices, body_builder->Reshape(i, {1}), + batch_element_indices = body_builder->DynamicUpdateSlice( + batch_element_indices, body_builder->Reshape(i, {1}), body_builder->Reshape( XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, batch_dim_), {1})); - // slice_indices is the offset of the start of the reversed sequence in - // the input. - auto slice_indices = body_builder->DynamicUpdateSlice( - indices, + // Slice out the current batch element and pad it out in the sequence + // dimension. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + slice_shape.set_dim(seq_dim_, max_seq_len); + auto slice = body_builder->DynamicSlice(output, batch_element_indices, + slice_shape.dim_sizes()); + auto padding_config = xla::MakeNoPaddingConfig(slice_shape.dims()); + padding_config.mutable_dimensions(seq_dim_)->set_edge_padding_high( + slice_shape.dim_size(seq_dim_)); + slice = body_builder->Pad( + slice, XlaHelpers::Zero(body_builder.get(), input_type), + padding_config); + + // Now slice out the reversed sequence from its actual start. + // sequence_start_indices is the offset of the start of the reversed + // sequence in the input. The slice will go into the padding, however, we + // will mask off these elements and replace them with elements from the + // original input so their values do not matter. + auto sequence_start_indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {slice_shape.dims()}); + sequence_start_indices = body_builder->DynamicUpdateSlice( + sequence_start_indices, body_builder->Sub(XlaHelpers::IntegerLiteral( body_builder.get(), seq_lens_type, max_seq_len), seq_len), @@ -127,18 +147,12 @@ class ReverseSequenceOp : public XlaOpKernel { XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, seq_dim_), {1})); - - // Slice out the reversed sequence. The slice will overflow the end of the - // sequence, and the contents of the overflow are implementation-defined. - // However, we will mask off these elements and replace them with elements - // from the original input so their values do not matter. - TensorShape slice_shape = input_shape; - slice_shape.set_dim(batch_dim_, 1); - auto slice = body_builder->DynamicSlice(output, slice_indices, - slice_shape.dim_sizes()); + slice = body_builder->DynamicSlice(slice, sequence_start_indices, + slice_shape.dim_sizes()); // Shift the reversed sequence to the left. - output = body_builder->DynamicUpdateSlice(output, slice, indices); + output = body_builder->DynamicUpdateSlice(output, slice, + batch_element_indices); body_builder->Tuple( {body_builder->Add( @@ -155,7 +169,7 @@ class ReverseSequenceOp : public XlaOpKernel { auto output = builder->GetTupleElement(loop_output, 2); // Mask out elements after the sequence length. - xla::ComputationDataHandle iota; + xla::XlaOp iota; OP_REQUIRES_OK( context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); std::vector dims(input_shape.dims(), 1); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index 4cfa28a0ce3..1819fb54331 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -74,7 +74,7 @@ class ScanOp : public XlaOpKernel { return; } - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); std::vector window_strides(input_shape.dims(), 1); std::vector window_dims(input_shape.dims(), 1); @@ -91,8 +91,8 @@ class ScanOp : public XlaOpKernel { std::swap(padding[axis].first, padding[axis].second); } - xla::ComputationDataHandle init; - const xla::Computation* reducer; + xla::XlaOp init; + const xla::XlaComputation* reducer; if (sum_) { init = XlaHelpers::Zero(builder, dtype); reducer = ctx->GetOrCreateAdd(dtype); diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index 8433a29c4e2..f2c63b4f908 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -102,7 +102,7 @@ class ScatterNdOp : public XlaOpKernel { OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape)); - xla::ComputationBuilder* builder = context->builder(); + xla::XlaBuilder* builder = context->builder(); auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), buffer_shape.dim_sizes()); auto indices = context->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 498342a9888..664078ca16c 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" namespace tensorflow { namespace { @@ -62,16 +62,16 @@ class UnsortedSegmentSum : public XlaOpKernel { d, " differs ", data_shape.dim_size(d), " vs. ", indices_shape.dim_size(d))); } - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), buffer_shape.dim_sizes()); - auto combiner = - [](xla::ComputationDataHandle a, xla::ComputationDataHandle b, - xla::ComputationBuilder* builder) { return builder->Add(a, b); }; + auto combiner = [](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { + return builder->Add(a, b); + }; auto result = XlaScatter(buffer, /*updates=*/data, indices, /*indices_are_vectors=*/false, combiner, builder); diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 8081d3c41c4..f9f48164d63 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -40,7 +40,7 @@ class SelectOp : public XlaOpKernel { "'then' and 'else' must have the same size. but received: ", then_shape.DebugString(), " vs. ", else_shape.DebugString())); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); auto cond_handle = ctx->Input(0); auto then_handle = ctx->Input(1); diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index d079b898618..9ce01d0d445 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index 463788b8b46..bbf5ee8b121 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -43,8 +43,8 @@ class SoftmaxOp : public XlaOpKernel { const DataType type = input_type(0); auto logits = ctx->Input(0); - xla::ComputationBuilder* const b = ctx->builder(); - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); + xla::XlaBuilder* const b = ctx->builder(); + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = @@ -76,16 +76,15 @@ class SoftmaxOp : public XlaOpKernel { REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); -std::pair -CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, - const xla::ComputationDataHandle& logits, - const xla::ComputationDataHandle& labels) { - const xla::Computation& max_func = *ctx->GetOrCreateMax(type); +std::pair CrossEntropyWithLogits( + XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, + const xla::XlaOp& labels) { + const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; const int kClassDim = 1; - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); @@ -123,7 +122,7 @@ CrossEntropyWithLogits(XlaOpKernelContext* ctx, DataType type, // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) // (where the division broadcasts along the batch dimension) - xla::ComputationDataHandle backprop = + xla::XlaOp backprop = b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels); return {loss, backprop}; } @@ -150,7 +149,7 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { auto logits = ctx->Input(0); auto labels = ctx->Input(1); - xla::ComputationDataHandle loss, backprop; + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, type, logits, labels); ctx->SetOutput(0, loss); @@ -191,10 +190,10 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { DataType logits_type = input_type(0); DataType indices_type = input_type(1); - xla::ComputationDataHandle indices = ctx->Input(1); + xla::XlaOp indices = ctx->Input(1); - xla::ComputationBuilder* builder = ctx->builder(); - xla::ComputationDataHandle labels; + xla::XlaBuilder* builder = ctx->builder(); + xla::XlaOp labels; OP_REQUIRES_OK(ctx, XlaHelpers::OneHot( builder, depth, /*axis=*/1, input_type(1), labels_shape, @@ -207,7 +206,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // Builds a vector of {batch_size} that is 0 if the index is in range, or // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. - xla::ComputationDataHandle nan_or_zero = builder->Select( + xla::XlaOp nan_or_zero = builder->Select( builder->And( builder->Le(XlaHelpers::Zero(builder, indices_type), indices), builder->Lt(indices, XlaHelpers::IntegerLiteral( @@ -218,7 +217,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { {batch_size})); labels = builder->Add(labels, nan_or_zero, {0}); - xla::ComputationDataHandle loss, backprop; + xla::XlaOp loss, backprop; std::tie(loss, backprop) = CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 01b46e160d1..ec077924b5b 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -20,9 +20,8 @@ limitations under the License. namespace tensorflow { namespace { -void SpaceToBatch(XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, DataType input_dtype, - const TensorShape& input_tensor_shape, +void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, + DataType input_dtype, const TensorShape& input_tensor_shape, gtl::ArraySlice block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); @@ -46,7 +45,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, ", 2] instead of ", xla::ShapeUtil::HumanString(paddings.shape()))); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the // input according to `paddings` to produce `padded` of shape `padded_shape`. @@ -73,7 +72,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, errors::InvalidArgument( "The product of the block dimensions must be positive")); - xla::ComputationDataHandle padded = + xla::XlaOp padded = b->Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config); // 2. Reshape `padded` to `reshaped_padded` of shape: @@ -101,8 +100,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), reshaped_padded_shape.begin() + 1 + 2 * block_rank); - xla::ComputationDataHandle reshaped_padded = - b->Reshape(padded, reshaped_padded_shape); + xla::XlaOp reshaped_padded = b->Reshape(padded, reshaped_padded_shape); // 3. Permute dimensions of `reshaped_padded` to produce // `permuted_reshaped_padded` of shape: @@ -121,7 +119,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, permutation[block_rank] = 0; std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(), 1 + block_rank * 2); - xla::ComputationDataHandle permuted_reshaped_padded = + xla::XlaOp permuted_reshaped_padded = b->Transpose(reshaped_padded, permutation); // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the @@ -142,8 +140,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, std::copy(remainder_shape.begin(), remainder_shape.end(), output_shape.begin() + 1 + block_rank); - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped_padded, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped_padded, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 806fda632cd..4c5886ee2a0 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -50,8 +50,8 @@ class SpaceToDepthOp : public XlaOpKernel { const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp input = ctx->Input(0); int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); @@ -135,7 +135,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); + xla::XlaOp reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -145,8 +145,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_, block_size_, // depth] - xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, transpose_order); + xla::XlaOp permuted_reshaped = b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -156,8 +155,7 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = - b->Reshape(permuted_reshaped, output_shape); + xla::XlaOp output = b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 43c15e75380..8958b2e7701 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -124,7 +124,7 @@ class SplitVOp : public XlaOpKernel { input_shape.dims(), "), but got ", split_dim_orig)); - xla::ComputationDataHandle input = ctx->Input(0); + xla::XlaOp input = ctx->Input(0); OP_REQUIRES(ctx, input_shape.dims() > 0, errors::InvalidArgument("Can't split a 0 dimensional input")); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index 1a78c7ab9be..0fb05a2be7b 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -38,13 +38,13 @@ limitations under the License. namespace tensorflow { namespace { -Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, +Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, TensorShape* stack_shape) { auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } - xla::Shape shape = *shape_or_status.ValueOrDie(); + xla::Shape shape = shape_or_status.ValueOrDie(); TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), stack_shape); @@ -60,9 +60,8 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, // // TODO(phawkins): consider changing the API of the stack operators to // allow an optional element shape at stack construction time. -Status MaybeInitializeStack(xla::ComputationBuilder* builder, - XlaResource* resource, DataType dtype, - const TensorShape& elem_shape) { +Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, + DataType dtype, const TensorShape& elem_shape) { if (resource->type() != dtype) { return errors::InvalidArgument( "Stack dtype is ", DataTypeString(resource->type()), @@ -75,8 +74,6 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder, if (!resource->initialized()) { // Stack has not been initialized. - xla::ComputationDataHandle zero = - XlaHelpers::Zero(builder, resource->type()); TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { @@ -111,7 +108,7 @@ class StackOp : public XlaOpKernel { // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. - xla::ComputationDataHandle value; + xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; string name = strings::StrCat("Stack: ", stack_name_); @@ -138,7 +135,7 @@ class StackPushOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape elem_shape = ctx->InputShape(1); XlaResource* resource; @@ -147,9 +144,9 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0); - xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1); - xla::ComputationDataHandle value = ctx->Input(1); + xla::XlaOp ta = b->GetTupleElement(resource->value(), 0); + xla::XlaOp index = b->GetTupleElement(resource->value(), 1); + xla::XlaOp value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -184,7 +181,7 @@ class StackPopOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -199,9 +196,9 @@ class StackPopOp : public XlaOpKernel { TensorShape stack_shape; OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); - xla::ComputationDataHandle state = resource->value(); - xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); - xla::ComputationDataHandle index = b->GetTupleElement(state, 1); + xla::XlaOp state = resource->value(); + xla::XlaOp ta = b->GetTupleElement(state, 0); + xla::XlaOp index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); @@ -216,8 +213,7 @@ class StackPopOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - xla::ComputationDataHandle read = - b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 5bb773d97fc..a99d4ddc7c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -30,9 +30,8 @@ namespace tensorflow { namespace { // Rotates a 32-bit integer 'v' left by 'distance' bits. -xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& v, - int distance) { +xla::XlaOp RotateLeftS32(xla::XlaBuilder* builder, const xla::XlaOp& v, + int distance) { return builder->Or( builder->ShiftLeft(v, builder->ConstantR0(distance)), builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); @@ -40,25 +39,24 @@ xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, // TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than // building XOR out of other bitwise operators. -xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& y) { +xla::XlaOp BitwiseXor(xla::XlaBuilder* builder, const xla::XlaOp& x, + const xla::XlaOp& y) { return builder->Or(builder->And(x, builder->Not(y)), builder->And(builder->Not(x), y)); } -using ThreeFry2x32State = std::array; +using ThreeFry2x32State = std::array; // Implements the ThreeFry counter-based PRNG algorithm. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, +ThreeFry2x32State ThreeFry2x32(xla::XlaBuilder* builder, ThreeFry2x32State input, ThreeFry2x32State key) { // Rotation distances specified by the Threefry2x32 algorithm. constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; ThreeFry2x32State x; - std::array ks; + std::array ks; // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. ks[2] = builder->ConstantR0(0x1BD11BDA); for (int i = 0; i < 2; ++i) { @@ -121,10 +119,9 @@ ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, // Returns a tensor of 'shape' random values uniformly distributed in the range // [minval, maxval) -xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, - const xla::ComputationDataHandle& seed, - const TensorShape& shape, - double minval, double maxval) { +xla::XlaOp RandomUniform(xla::XlaBuilder* builder, const xla::XlaOp& seed, + const TensorShape& shape, double minval, + double maxval) { // Split the seed into two 32-bit scalars to form a key. auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); @@ -178,9 +175,8 @@ xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, // p = sum_{i=1}^n gq[i]*w^i // } // return p*x -xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, - const xla::ComputationDataHandle& x, - const TensorShape& shape) { +xla::XlaOp ErfInvF32(xla::XlaBuilder* b, const xla::XlaOp& x, + const TensorShape& shape) { constexpr int kDegree = 9; constexpr std::array w_less_than_5_constants = { 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, @@ -220,7 +216,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); TensorShape shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); @@ -229,7 +225,7 @@ class StatelessRandomUniformOp : public XlaOpKernel { OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); - xla::ComputationDataHandle seed = ctx->Input(1); + xla::XlaOp seed = ctx->Input(1); ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); } @@ -257,9 +253,10 @@ class StatelessRandomNormalOp : public XlaOpKernel { OP_REQUIRES(ctx, seed_shape == TensorShape({2}), errors::InvalidArgument("seed must have shape [2], not ", seed_shape.DebugString())); - xla::ComputationDataHandle seed = ctx->Input(1); - xla::ComputationBuilder* builder = ctx->builder(); - auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + xla::XlaOp seed = ctx->Input(1); + xla::XlaBuilder* builder = ctx->builder(); + auto uniform = + RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 6204aa4e270..55254c746e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -90,7 +90,7 @@ class StridedSliceOp : public XlaOpKernel { } } - xla::ComputationDataHandle slice = ctx->Input(0); + xla::XlaOp slice = ctx->Input(0); if (!dimensions_to_reverse.empty()) { slice = ctx->builder()->Rev(slice, dimensions_to_reverse); } @@ -168,7 +168,7 @@ class StridedSliceGradOp : public XlaOpKernel { auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0)); - xla::ComputationDataHandle grad = ctx->Input(4); + xla::XlaOp grad = ctx->Input(4); // Undo any new/shrink axes. grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes()); @@ -255,7 +255,7 @@ class StridedSliceAssignOp : public XlaOpKernel { &strides_tensor)); TensorShape lhs_shape; - xla::ComputationDataHandle lhs; + xla::XlaOp lhs; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -284,7 +284,7 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle rhs = ctx->Input(4); + xla::XlaOp rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; gtl::InlinedVector slice_begin, slice_dims; diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 000b50af6bd..9adee78a1fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -47,7 +47,7 @@ namespace { // the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. -Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, +Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { if (resource->kind() != XlaResource::kTensorArray) { @@ -64,9 +64,6 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, << resource->name() << " size " << resource->tensor_array_size(); if (!resource->initialized()) { - xla::ComputationDataHandle zero = - XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { @@ -77,7 +74,7 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, } TensorShape shape; TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TensorShape ta_shape; ta_shape.AddDim(resource->tensor_array_size()); @@ -114,23 +111,21 @@ Status CheckTensorArrayIsInitialized(const string& op_name, } Status GetTensorArrayShape(const XlaResource* resource, - xla::ComputationBuilder* builder, - TensorShape* shape) { + xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } -// Like ComputationBuilder::DynamicUpdateSlice, but adds 'update' to the +// Like XlaBuilder::DynamicUpdateSlice, but adds 'update' to the // relevant slice of 'operand'. -xla::ComputationDataHandle DynamicAddSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& operand, - const xla::ComputationDataHandle& update, - const gtl::ArraySlice& update_dims, - const xla::ComputationDataHandle& start_indices) { - xla::ComputationDataHandle current = +xla::XlaOp DynamicAddSlice(xla::XlaBuilder* builder, const xla::XlaOp& operand, + const xla::XlaOp& update, + const gtl::ArraySlice& update_dims, + const xla::XlaOp& start_indices) { + xla::XlaOp current = builder->DynamicSlice(operand, start_indices, update_dims); - xla::ComputationDataHandle sum = builder->Add(current, update); + xla::XlaOp sum = builder->Add(current, update); return builder->DynamicUpdateSlice(operand, sum, start_indices); } @@ -155,18 +150,18 @@ class TensorArrayOp : public XlaOpKernel { OP_REQUIRES(ctx, size >= 0, errors::InvalidArgument("TensorArray size must be >= 0")); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. - xla::ComputationDataHandle value; + xla::XlaOp value; TensorShape shape; if (element_shape_.IsFullyDefined()) { CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); ta_shape.AppendShape(shape); - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, dtype_); + xla::XlaOp zero = XlaHelpers::Zero(b, dtype_); value = b->Broadcast(zero, ta_shape.dim_sizes()); } @@ -202,7 +197,7 @@ class TensorArrayWriteOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape elem_shape = ctx->InputShape(2); @@ -213,10 +208,10 @@ class TensorArrayWriteOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value(); - xla::ComputationDataHandle index = ctx->Input(1); - xla::ComputationDataHandle value = ctx->Input(2); - xla::ComputationDataHandle flow = ctx->Input(3); + xla::XlaOp ta = resource->value(); + xla::XlaOp index = ctx->Input(1); + xla::XlaOp value = ctx->Input(2); + xla::XlaOp flow = ctx->Input(3); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -227,7 +222,7 @@ class TensorArrayWriteOp : public XlaOpKernel { slice_shape.InsertDim(0, 1LL); auto update = b->Reshape(value, slice_shape.dim_sizes()); - xla::ComputationDataHandle written = + xla::XlaOp written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); OP_REQUIRES_OK(ctx, resource->SetValue(written)); @@ -249,7 +244,7 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -259,8 +254,8 @@ class TensorArrayReadOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value(); - xla::ComputationDataHandle index = ctx->Input(1); + xla::XlaOp ta = resource->value(); + xla::XlaOp index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = @@ -270,8 +265,7 @@ class TensorArrayReadOp : public XlaOpKernel { auto slice_shape = ta_shape.dim_sizes(); slice_shape[0] = 1LL; - xla::ComputationDataHandle read = - b->DynamicSlice(ta, start_indices, slice_shape); + xla::XlaOp read = b->DynamicSlice(ta, start_indices, slice_shape); // Remove the leading '1' dimension. std::vector value_shape(slice_shape.begin() + 1, slice_shape.end()); @@ -293,7 +287,7 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -309,7 +303,7 @@ class TensorArrayGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); // Look for the case where the gather takes a simple slice from the // tensor array (0, 1, 2, 3, 4, ..., N) @@ -337,7 +331,7 @@ class TensorArrayGatherOp : public XlaOpKernel { } } - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, @@ -360,7 +354,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); const TensorShape value_shape = ctx->InputShape(2); @@ -375,11 +369,11 @@ class TensorArrayScatterOp : public XlaOpKernel { OP_REQUIRES(ctx, indices_shape.dims() >= 1, errors::InvalidArgument("indices must be rank 1")); const int num_indices = indices_shape.dim_size(0); - const xla::ComputationDataHandle indices = ctx->Input(1); + const xla::XlaOp indices = ctx->Input(1); - xla::ComputationDataHandle ta = resource->value(); - const xla::ComputationDataHandle value = ctx->Input(2); - const xla::ComputationDataHandle flow = ctx->Input(3); + xla::XlaOp ta = resource->value(); + const xla::XlaOp value = ctx->Input(2); + const xla::XlaOp flow = ctx->Input(3); // Look for the case where the scatter is for each sub-tensor in order. The // tensor array implementation allows for this to be a straight addition. @@ -443,7 +437,7 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); @@ -453,7 +447,7 @@ class TensorArrayConcatOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -503,12 +497,12 @@ class TensorArraySplitOp : public XlaOpKernel { TensorShape elem_shape = value_shape; elem_shape.set_dim(0, length); - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value(); + xla::XlaOp ta = resource->value(); TensorShape ta_shape; ta_shape.AddDim(resource->tensor_array_size()); @@ -520,8 +514,8 @@ class TensorArraySplitOp : public XlaOpKernel { "TensorArray's size is not equal to the size of lengths (", lengths.size(), " vs. ", resource->tensor_array_size(), ")")); - const xla::ComputationDataHandle value = ctx->Input(1); - const xla::ComputationDataHandle flow = ctx->Input(3); + const xla::XlaOp value = ctx->Input(1); + const xla::XlaOp flow = ctx->Input(3); OP_REQUIRES(ctx, value_shape.num_elements() == ta_shape.num_elements(), errors::InvalidArgument("mismatched element count ", @@ -569,7 +563,7 @@ class TensorArrayGradOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9aefcd4fc7f..e91075196bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -112,7 +112,7 @@ class TileOp : public XlaOpKernel { flattened.push_back(i); flattened.push_back(i + output_shape.size()); } - xla::ComputationDataHandle output = + xla::XlaOp output = ctx->builder()->Reshape(broadcasted, flattened, output_shape); ctx->SetOutput(0, output); diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index f750f7003be..34caefa050c 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -30,8 +30,8 @@ class ResourceApplyGradientDescent : public XlaOpKernel { explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaOp handle; + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(1); TensorShape var_shape; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); @@ -63,12 +63,12 @@ class ResourceApplyMomentum : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; - xla::ComputationDataHandle var, accum; + xla::XlaOp var, accum; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); @@ -93,9 +93,9 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(2); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle momentum = ctx->Input(4); + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp momentum = ctx->Input(4); accum = b->Add(b->Mul(accum, momentum), grad); if (use_nesterov_) { @@ -121,12 +121,12 @@ class ResourceApplyAdagrad : public XlaOpKernel { explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(2); TensorShape var_shape, accum_shape; - xla::ComputationDataHandle var, accum; + xla::XlaOp var, accum; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); @@ -146,8 +146,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(2); - xla::ComputationDataHandle grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(2); + xla::XlaOp grad = ctx->Input(3); accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0))); var = b->Sub( @@ -168,7 +168,7 @@ class ResourceApplyAdam : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape var_shape, m_shape, v_shape; - xla::ComputationDataHandle var, m, v; + xla::XlaOp var, m, v; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); @@ -213,25 +213,25 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle beta1_power = ctx->Input(3); - xla::ComputationDataHandle beta2_power = ctx->Input(4); - xla::ComputationDataHandle lr = ctx->Input(5); - xla::ComputationDataHandle beta1 = ctx->Input(6); - xla::ComputationDataHandle beta2 = ctx->Input(7); - xla::ComputationDataHandle epsilon = ctx->Input(8); - xla::ComputationDataHandle grad = ctx->Input(9); + xla::XlaOp beta1_power = ctx->Input(3); + xla::XlaOp beta2_power = ctx->Input(4); + xla::XlaOp lr = ctx->Input(5); + xla::XlaOp beta1 = ctx->Input(6); + xla::XlaOp beta2 = ctx->Input(7); + xla::XlaOp epsilon = ctx->Input(8); + xla::XlaOp grad = ctx->Input(9); // alpha <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) // m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t // v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t // variable <- variable - alpha * m_t / (sqrt(v_t) + epsilon) - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); - xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp half = XlaHelpers::FloatLiteral(b, dtype_, 0.5); + xla::XlaOp one = XlaHelpers::FloatLiteral(b, dtype_, 1.0); + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype_, 2.0); - xla::ComputationDataHandle alpha = + xla::XlaOp alpha = b->Div(b->Mul(lr, b->Pow(b->Sub(one, beta2_power), half)), b->Sub(one, beta1_power)); m = b->Add(m, b->Mul(b->Sub(grad, m), b->Sub(one, beta1))); @@ -255,12 +255,12 @@ class ResourceApplyRMSProp : public XlaOpKernel { explicit ResourceApplyRMSProp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); DataType type = ctx->input_type(3); TensorShape var_shape, ms_shape, mom_shape; - xla::ComputationDataHandle var, ms, mom; + xla::XlaOp var, ms, mom; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); @@ -297,11 +297,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle lr = ctx->Input(3); - xla::ComputationDataHandle rho = ctx->Input(4); - xla::ComputationDataHandle momentum = ctx->Input(5); - xla::ComputationDataHandle epsilon = ctx->Input(6); - xla::ComputationDataHandle grad = ctx->Input(7); + xla::XlaOp lr = ctx->Input(3); + xla::XlaOp rho = ctx->Input(4); + xla::XlaOp momentum = ctx->Input(5); + xla::XlaOp epsilon = ctx->Input(6); + xla::XlaOp grad = ctx->Input(7); // ms <- rho * ms_{t-1} + (1-rho) * grad * grad // mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon) @@ -320,16 +320,16 @@ class ResourceApplyRMSProp : public XlaOpKernel { // ms <- grad**2 (1 - rho) + ms * rho // // Which is the equation listed above. - xla::ComputationDataHandle new_ms = b->Add( + xla::XlaOp new_ms = b->Add( ms, b->Mul(b->Sub(b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)), ms), b->Sub(XlaHelpers::FloatLiteral(b, type, 1.0), rho))); - xla::ComputationDataHandle new_mom = + xla::XlaOp new_mom = b->Add(b->Mul(mom, momentum), b->Mul(b->Mul(grad, lr), b->Pow(b->Add(new_ms, epsilon), XlaHelpers::FloatLiteral(b, type, -0.5)))); - xla::ComputationDataHandle new_var = b->Sub(var, new_mom); + xla::XlaOp new_var = b->Sub(var, new_mom); OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, new_var)); OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, new_ms)); @@ -341,10 +341,10 @@ REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { - xla::ComputationBuilder* b = ctx->builder(); + xla::XlaBuilder* b = ctx->builder(); TensorShape var_shape, accum_shape, linear_shape; - xla::ComputationDataHandle var, accum, linear; + xla::XlaOp var, accum, linear; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); @@ -399,12 +399,12 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle grad = ctx->Input(3); - xla::ComputationDataHandle lr = ctx->Input(4); - xla::ComputationDataHandle l1 = ctx->Input(5); - xla::ComputationDataHandle l2 = ctx->Input(6); - xla::ComputationDataHandle l2_shrinkage; - xla::ComputationDataHandle lr_power; + xla::XlaOp grad = ctx->Input(3); + xla::XlaOp lr = ctx->Input(4); + xla::XlaOp l1 = ctx->Input(5); + xla::XlaOp l2 = ctx->Input(6); + xla::XlaOp l2_shrinkage; + xla::XlaOp lr_power; if (has_l2_shrinkage) { l2_shrinkage = ctx->Input(7); lr_power = ctx->Input(8); @@ -421,26 +421,23 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, // var = (linear_clipped - linear) / quadratic // accum = new_accum - xla::ComputationDataHandle two = XlaHelpers::FloatLiteral(b, dtype, 2.0); - xla::ComputationDataHandle grad_to_use; + xla::XlaOp two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + xla::XlaOp grad_to_use; if (has_l2_shrinkage) { grad_to_use = b->Add(grad, b->Mul(two, b->Mul(l2_shrinkage, var))); } else { grad_to_use = grad; } - xla::ComputationDataHandle new_accum = - b->Add(accum, b->Pow(grad_to_use, two)); - xla::ComputationDataHandle new_accum_lr_pow = - b->Pow(new_accum, b->Neg(lr_power)); - xla::ComputationDataHandle accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); + xla::XlaOp new_accum = b->Add(accum, b->Pow(grad_to_use, two)); + xla::XlaOp new_accum_lr_pow = b->Pow(new_accum, b->Neg(lr_power)); + xla::XlaOp accum_lr_pow = b->Pow(accum, b->Neg(lr_power)); linear = b->Add( linear, b->Sub(grad_to_use, b->Mul(b->Div(b->Sub(new_accum_lr_pow, accum_lr_pow), lr), var))); - xla::ComputationDataHandle linear_clipped = b->Clamp(b->Neg(l1), linear, l1); - xla::ComputationDataHandle quadratic = - b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); + xla::XlaOp linear_clipped = b->Clamp(b->Neg(l1), linear, l1); + xla::XlaOp quadratic = b->Add(b->Div(new_accum_lr_pow, lr), b->Mul(two, l2)); var = b->Div(b->Sub(linear_clipped, linear), quadratic); accum = new_accum; diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 7cb47f908d4..71a9fd051bf 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -33,9 +33,9 @@ namespace { public: \ explicit NAME##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \ void Compile(XlaOpKernelContext* ctx) { \ - xla::ComputationBuilder* b = ctx->builder(); \ - xla::ComputationDataHandle x = ctx->Input(0); \ - xla::ComputationDataHandle y = COMPUTATION; \ + xla::XlaBuilder* b = ctx->builder(); \ + xla::XlaOp x = ctx->Input(0); \ + xla::XlaOp y = COMPUTATION; \ ctx->SetOutput(0, y); \ } \ }; \ @@ -100,8 +100,7 @@ XLAJIT_MAKE_UNARY(Cosh, XLAJIT_MAKE_UNARY(Sin, b->Sin(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); -// TODO(b/34703906): use a more accurate implementation of expm1. -XLAJIT_MAKE_UNARY(Expm1, b->Sub(b->Exp(x), XlaHelpers::One(b, input_type(0)))); +XLAJIT_MAKE_UNARY(Expm1, b->Expm1(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); XLAJIT_MAKE_UNARY(IsFinite, b->IsFinite(x)); @@ -115,8 +114,7 @@ XLAJIT_MAKE_UNARY(Inv, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Reciprocal, b->Div(XlaHelpers::One(b, input_type(0)), x)); XLAJIT_MAKE_UNARY(Log, b->Log(x)); -// TODO(b/34703906): use a more accurate implementation of log1p. -XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); +XLAJIT_MAKE_UNARY(Log1p, b->Log1p(x)); XLAJIT_MAKE_UNARY(Invert, b->Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); @@ -124,9 +122,8 @@ XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two // integers are rounded towards even. -static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, - DataType dtype, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Round(xla::XlaBuilder* b, DataType dtype, + const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); @@ -148,9 +145,8 @@ XLAJIT_MAKE_UNARY(Rsqrt, b->Pow(x, XlaHelpers::FloatLiteral(b, input_type(0), -0.5))); // Expresses sigmoid as a rescaled tanh: sigmoid(x) == (tanh(x/2) + 1) / 2. -static xla::ComputationDataHandle Sigmoid(xla::ComputationBuilder* b, - DataType dtype, - const xla::ComputationDataHandle& x) { +static xla::XlaOp Sigmoid(xla::XlaBuilder* b, DataType dtype, + const xla::XlaOp& x) { auto half = XlaHelpers::FloatLiteral(b, dtype, 0.5); return b->Add(half, b->Mul(half, b->Tanh(b->Mul(half, x)))); } @@ -162,26 +158,17 @@ XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -static xla::ComputationDataHandle Softplus( - xla::ComputationBuilder* b, DataType dtype, - const xla::ComputationDataHandle& features) { - xla::ComputationDataHandle threshold = - b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), - XlaHelpers::FloatLiteral(b, dtype, 2.0)); - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold)); - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - xla::ComputationDataHandle too_small = b->Lt(features, threshold); - xla::ComputationDataHandle features_exp = b->Exp(features); - xla::ComputationDataHandle output = b->Select( - too_large, features, - b->Select(too_small, features_exp, - b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); - return output; -} -XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); +// softplus(x) = log(1 + exp(x)) +// +// This is not numerically stable when x is large, it can easily overflow. +// However, we can compute it as LogSumExp(x, 0): +// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) +// +// This is equivalent to: +// max(x, 0) + log1p(exp(-abs(x))) +XLAJIT_MAKE_UNARY(Softplus, + b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 71173f5aead..6109db8e89e 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" @@ -48,7 +48,7 @@ class ReadVariableOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK( ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); @@ -74,7 +74,7 @@ class AssignAddVariableOp : public XlaOpKernel { explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); @@ -90,7 +90,7 @@ class AssignSubVariableOp : public XlaOpKernel { explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { DataType type = ctx->input_type(1); - xla::ComputationDataHandle handle; + xla::XlaOp handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); @@ -105,19 +105,19 @@ class ResourceGatherOp : public XlaOpKernel { public: explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); DataType type = ctx->expected_output_dtype(0); TensorShape resource_shape; - xla::ComputationDataHandle resource_handle; + xla::XlaOp resource_handle; OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle gather; + xla::XlaOp gather; OP_REQUIRES_OK( ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, /*axis=*/0, /*indices_are_nd=*/false, type, index_type, diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 0ff1b65ae91..5467c5d9946 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" @@ -101,7 +101,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { ctx, MakeXlaCompilerArgumentsFromInputs( ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); - xla::ComputationBuilder* builder = ctx->builder(); + xla::XlaBuilder* builder = ctx->builder(); XlaCompiler* compiler = ctx->compiler(); VLOG(1) << "Compiling body"; @@ -234,7 +234,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { xla::ShapeUtil::HumanString(cond.xla_output_shape))); int num_inputs = body.input_mapping.size(); - std::vector inputs(num_inputs); + std::vector inputs(num_inputs); for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; if (ctx->input_type(input_num) == DT_RESOURCE) { @@ -246,24 +246,24 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } } - xla::ComputationDataHandle init = builder->Tuple(inputs); + xla::XlaOp init = builder->Tuple(inputs); VLOG(1) << "Building while loop"; // Wraps the condition in a computation that unpacks the output tuple. - xla::Computation cond_wrapper; + xla::XlaComputation cond_wrapper; { - std::unique_ptr cb = + std::unique_ptr cb = builder->CreateSubBuilder("cond_wrapper"); auto inputs = cb->Parameter(0, cond_input_shape, "inputs"); auto outputs = cb->Call(*cond.computation, {inputs}); cb->GetTupleElement(outputs, 0); - xla::StatusOr result = cb->Build(); + xla::StatusOr result = cb->Build(); OP_REQUIRES_OK(ctx, result.status()); cond_wrapper = std::move(result.ValueOrDie()); } - xla::ComputationDataHandle while_result = + xla::XlaOp while_result = builder->While(cond_wrapper, *body.computation, init); // Sets non-variable outputs. diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 344773c8c5f..ee7f5d510ab 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -25,8 +25,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -39,12 +39,13 @@ cc_library( ":batch_dot", ":triangular_solve", ":util", + ":while_loop", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -61,9 +62,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -79,10 +80,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -90,6 +90,7 @@ cc_library( xla_test( name = "triangular_solve_test", srcs = ["triangular_solve_test.cc"], + tags = ["noasan"], # sometimes times out, http://b/78650012 deps = [ ":triangular_solve", "//tensorflow/compiler/xla:array2d", @@ -99,9 +100,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -120,12 +121,35 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) +xla_test( + name = "util_test", + srcs = ["util_test.cc"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "while_loop", srcs = ["while_loop.cc"], @@ -135,8 +159,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 798f0fa7805..526694d5a0c 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,24 +25,22 @@ limitations under the License. namespace tensorflow { -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x, bool conjugate_y) { - TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, - builder->GetShape(x)); - TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, - builder->GetShape(y)); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x, + bool conjugate_y) { + TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape y_shape, builder->GetShape(y)); // Check that both tensors have the same number of dimensions. There must be // at least two (the batch dimensions can be empty). - if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + if (xla::ShapeUtil::Rank(x_shape) != xla::ShapeUtil::Rank(y_shape)) { return errors::InvalidArgument( "Arguments to BatchedDot have different ranks: ", - xla::ShapeUtil::HumanString(*x_shape), " vs. ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs. ", + xla::ShapeUtil::HumanString(y_shape)); } - const int ndims = xla::ShapeUtil::Rank(*x_shape); + const int ndims = xla::ShapeUtil::Rank(x_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to BatchedDot must have rank >= 2: ", ndims); @@ -52,46 +50,46 @@ xla::StatusOr BatchDot( // valid. std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - if (x_shape->dimensions(i) != y_shape->dimensions(i)) { + if (x_shape.dimensions(i) != y_shape.dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " vs ", - xla::ShapeUtil::HumanString(*y_shape)); + xla::ShapeUtil::HumanString(x_shape), " vs ", + xla::ShapeUtil::HumanString(y_shape)); } batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { + if (x_shape.dimensions(x_inner_dim) != y_shape.dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", - xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, - " vs. ", xla::ShapeUtil::HumanString(*y_shape), + xla::ShapeUtil::HumanString(x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(y_shape), " transpose: ", transpose_y); } // Check for zero lhs/rhs dim size. - if (xla::ShapeUtil::HasZeroElements(*x_shape) || - xla::ShapeUtil::HasZeroElements(*y_shape)) { + if (xla::ShapeUtil::HasZeroElements(x_shape) || + xla::ShapeUtil::HasZeroElements(y_shape)) { std::vector dimensions(batch_dimension_numbers.size()); for (int i = 0; i < batch_dimension_numbers.size(); ++i) { - dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]); } int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); + dimensions.push_back(x_shape.dimensions(x_outer_dim)); + dimensions.push_back(y_shape.dimensions(y_outer_dim)); return builder->Broadcast( - builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + builder->ConstantLiteral(xla::Literal::Zero(x_shape.element_type())), dimensions); } - if (x_shape->element_type() == xla::C64 && conjugate_x) { + if (x_shape.element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && conjugate_y) { + if (y_shape.element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b230e885f10..1acc72033b0 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -43,10 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::StatusOr BatchDot( - xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, - bool conjugate_x = false, bool conjugate_y = false); +xla::StatusOr BatchDot(xla::XlaBuilder* builder, xla::XlaOp x, + xla::XlaOp y, bool transpose_x, + bool transpose_y, bool conjugate_x = false, + bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index e795701181d..3f1384bc864 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" #include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -31,88 +32,138 @@ namespace tensorflow { namespace { +// The Cholesky–Banachiewicz algorithm. See +// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms +// for a description. +// // def cholesky_unblocked(a): // assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] // n = a.shape[-2] // l = np.zeros_like(a) // for j in xrange(n): -// r = l[..., j, :j] -// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r)) -// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], -// np.transpose(r))) / l[..., j, j] +// row = l[..., j, :j] +// row_t = np.swapaxes(row, -1, -2) +// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t)) +// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) / +// l[..., j, j] // return l -xla::StatusOr CholeskyUnblocked( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(a)); - xla::ComputationDataHandle l = Zeros(builder, *shape); - const int64 n = xla::ShapeUtil::GetDimension(*shape, -2); - for (int j = 0; j < n; ++j) { - // Picture of block structure: - // ... \ - // \ - // -- r -- d - // |\ - // B c \ - // | \ - // | ... - // - // ^ - // column j - TF_ASSIGN_OR_RETURN(auto d, - SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1})); - TF_ASSIGN_OR_RETURN(auto c, - SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1})); - xla::ComputationDataHandle new_d_squared = d; - xla::ComputationDataHandle br; - if (j > 0) { - TF_ASSIGN_OR_RETURN(auto r, - SliceInMinorDims(builder, l, {j, 0}, {j + 1, j})); - TF_ASSIGN_OR_RETURN(auto b, - SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); - TF_ASSIGN_OR_RETURN(auto r_squared, - BatchDot(builder, r, r, /*transpose_x=*/false, - /*transpose_y=*/true, /*conjugate_x=*/false, - /*conjugate_y=*/false)); - new_d_squared = builder->Sub(new_d_squared, r_squared); +xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, + const xla::XlaOp& a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int n_dims = xla::ShapeUtil::Rank(a_shape); + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + gtl::ArraySlice major_dims(xla::AsInt64Slice(a_shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - 2); - TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, - /*transpose_y=*/true, - /*conjugate_x=*/false, - /*conjugate_y=*/false)); - } - auto new_d_inv = builder->Pow( - new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); - auto new_d = builder->Mul(new_d_inv, new_d_squared); - TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j})); + xla::XlaOp l = Zeros(builder, a_shape); - if (j > 0) { - c = builder->Sub(c, br); + // Construct the for loop body to iterate over rows. + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::Shape col_shape; + xla::Shape row_shape; + for (int64 d : major_dims) { + row_shape.add_dimensions(d); + col_shape.add_dimensions(d); } - auto new_c = builder->Mul(c, new_d_inv); - TF_ASSIGN_OR_RETURN(l, - UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j})); - } - return l; + row_shape.add_dimensions(1); + row_shape.add_dimensions(n); + row_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_row = Zeros(body_builder, row_shape); + + col_shape.add_dimensions(n); + col_shape.add_dimensions(1); + col_shape.set_element_type(a_shape.element_type()); + auto mask_zeros_col = Zeros(body_builder, col_shape); + + std::vector mask_vector(n); + std::iota(mask_vector.begin(), mask_vector.end(), 0); + auto mask_range = body_builder->ConstantR1(mask_vector); + auto mask_range_row = body_builder->Broadcast( + body_builder->Reshape(mask_range, {0}, {1, n}), major_dims); + auto mask_range_col = body_builder->Broadcast( + body_builder->Reshape(mask_range, {0}, {n, 1}), major_dims); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + + // row = l[..., i, :i] + // select the whole i-th row, then mask out all columns past i-1 + auto zero = body_builder->ConstantR0(0); + TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l, + {i, zero}, {1, n})); + auto row = body_builder->Select(body_builder->Ge(mask_range_row, i), + mask_zeros_row, l_i); + // a[..., i, i] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a, + {i, i}, {1, 1})); + // np.dot(row, np.swapaxes(row, -1, -2)) + xla::XlaOp diag_dot; + TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row, + /*transpose_x=*/false, + /*transpose_y=*/true)); + // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row, + // np.swapaxes(row, -1, -2))) + auto l_ii = body_builder->Pow( + body_builder->Sub(a_ii, diag_dot), + FloatLiteral(body_builder, a_shape.element_type(), 0.5)); + + // a[..., i+1:, i] + auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); + // select the whole i-th column, then mask out all rows above i+1 + TF_ASSIGN_OR_RETURN( + auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); + auto a_ip1i = body_builder->Select(body_builder->Le(mask_range_col, i), + mask_zeros_col, a_0i); + + // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) / + // l[..., i, i] + // The columns in [i, n] are zeroed out in `row`, so we just have to + // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i], + // r.T) + TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row, + /*transpose_x=*/false, + /*transpose_y=*/true)); + // np.dot(l[..., i+1:, :i], r.T) + auto dot_ip1 = body_builder->Select(body_builder->Le(mask_range_col, i), + mask_zeros_col, dot); + + auto col_update = + body_builder->Div(body_builder->Sub(a_ip1i, dot_ip1), l_ii); + TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( + body_builder, body_l, col_update, {i})); + // Assign the diagonal after the rest of the column because otherwise the + // column assign will wrap around and overwrite the diagonal assign. + TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims( + body_builder, body_l, l_ii, {i, i})); + + return std::vector{body_a, body_l}; + }; + + TF_ASSIGN_OR_RETURN( + auto cholesky_while, + XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder)); + + return cholesky_while[1]; } } // namespace -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - const int ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to Cholesky must have rank >= 2: ", ndims); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "Arguments to Cholesky must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } if (block_size < 1) { @@ -124,7 +175,7 @@ xla::StatusOr Cholesky( // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only // execution." Proceedings of General Purpose GPUs. ACM, 2017. - xla::ComputationDataHandle l = Zeros(builder, *a_shape); + xla::XlaOp l = Zeros(builder, a_shape); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); if (i > 0) { @@ -163,7 +214,7 @@ xla::StatusOr Cholesky( /*lower=*/true, /*transpose_a=*/true, /*conjugate_a=*/false, - /*block_size=*/8)); + /*block_size=*/block_size)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); } diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index e083a383be4..20fca7969ec 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -29,10 +29,9 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. -// TODO(mattjj): handle the complex Hermitian case -xla::StatusOr Cholesky( - xla::ComputationBuilder* builder, xla::ComputationDataHandle a, - int64 block_size = 256); +// TODO(znado): handle the complex Hermitian case +xla::StatusOr Cholesky(xla::XlaBuilder* builder, xla::XlaOp a, + int64 block_size = 256); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 45699233ea8..d5a27abb258 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -30,24 +30,19 @@ limitations under the License. namespace tensorflow { -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, - builder->GetShape(buffer)); - TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, - builder->GetShape(updates)); - TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, - builder->GetShape(indices)); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder) { + TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); + TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); gtl::ArraySlice indices_dims = - xla::AsInt64Slice(indices_shape->dimensions()); + xla::AsInt64Slice(indices_shape.dimensions()); gtl::ArraySlice buffer_dims = - xla::AsInt64Slice(buffer_shape->dimensions()); + xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -55,12 +50,12 @@ xla::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + if (num_index_dims > xla::ShapeUtil::Rank(buffer_shape)) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", - xla::ShapeUtil::HumanString(*indices_shape), + xla::ShapeUtil::HumanString(indices_shape), ") must be <= the rank of the buffer (shape: ", - xla::ShapeUtil::HumanString(*buffer_shape), ")"); + xla::ShapeUtil::HumanString(buffer_shape), ")"); } indices_dims.pop_back(); } @@ -78,10 +73,10 @@ xla::StatusOr XlaScatter( // If any of the indexed dimensions are zero in the buffer, the update cannot // succeed since it updates a slice of size 1. for (int64 i = 0; i < num_index_dims; ++i) { - if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { - return errors::InvalidArgument( - "Scatter dimension ", i, " is of size zero in tensor with shape ", - xla::ShapeUtil::HumanString(*buffer_shape)); + if (xla::ShapeUtil::GetDimension(buffer_shape, i) == 0) { + return errors::InvalidArgument("Scatter dimension ", i, + " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(buffer_shape)); } } @@ -111,18 +106,17 @@ xla::StatusOr XlaScatter( // index = dynamic-slice(indices, i) // update = dynamic-slice(updates, i) // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::ComputationDataHandle i, - gtl::ArraySlice loop_vars, - xla::ComputationBuilder* body_builder) { + auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice loop_vars, + xla::XlaBuilder* body_builder) { auto indices = loop_vars[0]; auto updates = loop_vars[1]; auto buffer = loop_vars[2]; auto zero_index = body_builder->ConstantLiteral( - xla::Literal::Zero(indices_shape->element_type())); + xla::Literal::Zero(indices_shape.element_type())); // Slice the i-th index from the indices array. - xla::ComputationDataHandle index; + xla::XlaOp index; auto indices_offset = body_builder->Reshape(i, {1}); if (indices_are_vectors) { indices_offset = body_builder->Pad(indices_offset, zero_index, @@ -180,12 +174,12 @@ xla::StatusOr XlaScatter( // Apply the update. buffer = body_builder->DynamicUpdateSlice(buffer, update, index); - return std::vector{indices, updates, buffer}; + return std::vector{indices, updates, buffer}; }; - TF_ASSIGN_OR_RETURN( - auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), - body_fn, init, "scatter", builder)); + TF_ASSIGN_OR_RETURN(auto outputs, + XlaForEachIndex(num_indices, indices_shape.element_type(), + body_fn, init, "scatter", builder)); return outputs[2]; } diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 41e6d3b195e..87309e10ede 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" namespace tensorflow { @@ -39,14 +39,12 @@ namespace tensorflow { // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. -xla::StatusOr XlaScatter( - const xla::ComputationDataHandle& buffer, - const xla::ComputationDataHandle& updates, - const xla::ComputationDataHandle& indices, bool indices_are_vectors, - const std::function& combiner, - xla::ComputationBuilder* builder); +xla::StatusOr XlaScatter( + const xla::XlaOp& buffer, const xla::XlaOp& updates, + const xla::XlaOp& indices, bool indices_are_vectors, + const std::function& + combiner, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 7f72a6073df..b4503601f94 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -29,21 +29,20 @@ limitations under the License. namespace tensorflow { -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + if (xla::ShapeUtil::Rank(a_shape) != xla::ShapeUtil::Rank(b_shape)) { return errors::InvalidArgument( "Arguments to TriangularSolve have different ranks: ", - xla::ShapeUtil::HumanString(*a_shape), " vs. ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs. ", + xla::ShapeUtil::HumanString(b_shape)); } - const int ndims = xla::ShapeUtil::Rank(*a_shape); + const int ndims = xla::ShapeUtil::Rank(a_shape); if (ndims < 2) { return errors::InvalidArgument( "Arguments to TriangularSolve must have rank >= 2: ", ndims); @@ -51,30 +50,30 @@ xla::StatusOr TriangularSolve( // The batch dimensions must be equal. std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); - int64 b_size = b_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); + int64 b_size = b_shape.dimensions(i); if (a_size != b_size) { return errors::InvalidArgument( "Batch dimensions of arguments to TriangularSolve must be equal: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } batch_dimensions.push_back(a_size); } - if (xla::ShapeUtil::GetDimension(*a_shape, -1) != - xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(a_shape, -1) != + xla::ShapeUtil::GetDimension(a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", - xla::ShapeUtil::HumanString(*a_shape)); + xla::ShapeUtil::HumanString(a_shape)); } - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", - xla::ShapeUtil::HumanString(*a_shape), " vs ", - xla::ShapeUtil::HumanString(*b_shape)); + xla::ShapeUtil::HumanString(a_shape), " vs ", + xla::ShapeUtil::HumanString(b_shape)); } if (block_size < 1) { @@ -83,36 +82,20 @@ xla::StatusOr TriangularSolve( block_size); } - // Returns [b1, b2, ... , bn, indices[0], indices[1]]. - auto prepend_batch_dims = [&](std::array indices) { - std::vector output(ndims); - std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); - std::copy(indices.begin(), indices.end(), - output.begin() + batch_dimensions.size()); - return output; - }; - - // Applies a complex conjugation operation if `a` is complex and `conjugate_a` - // is true, otherwise returns its argument. - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - - std::map base_computations; + std::map base_computations; auto get_base_triangular_solve = - [&](int k) -> xla::StatusOr { - xla::Computation& computation = base_computations[k]; + [&](int k) -> xla::StatusOr { + xla::XlaComputation& computation = base_computations[k]; if (computation.IsNull()) { - std::unique_ptr sub = builder->CreateSubBuilder( + std::unique_ptr sub = builder->CreateSubBuilder( tensorflow::strings::StrCat("trsm_base_", k)); - auto a_param = - sub->Parameter(0, - xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims({k, k})), - "a"); + auto a_param = sub->Parameter( + 0, + xla::ShapeUtil::MakeShape( + b_shape.element_type(), + PrependMajorDims(sub.get(), batch_dimensions, {k, k})), + "a"); std::array b_lastd; if (left_side) { @@ -120,22 +103,28 @@ xla::StatusOr TriangularSolve( } else { b_lastd = {m, k}; } - auto b_param = - sub->Parameter(1, - xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims(b_lastd)), - "b"); + auto b_param = sub->Parameter( + 1, + xla::ShapeUtil::MakeShape( + b_shape.element_type(), + PrependMajorDims(sub.get(), batch_dimensions, b_lastd)), + "b"); - // We use a left-looking subroutine on the block diagonal in some common - // cases, while falling back to a recursive call in unsupported cases. The - // left-looking subroutine is written with a While loop and so yields much - // faster compile times. Moreover, the left-looking variant can give - // higher performance on smaller (sub)problems. + // We use a left-looking or right-looking subroutine on the block diagonal + // in the lower=true cases, while falling back to a recursive call in + // others. The left-looking and right-looking subroutines are written with + // a While loop and so yields much faster compile times. Moreover, they + // can give higher performance on smaller (sub)problems. if (left_side && lower) { TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, b_param, transpose_a, conjugate_a) .status()); + } else if (!left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveRightLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); } else { TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, left_side, lower, transpose_a, @@ -149,7 +138,7 @@ xla::StatusOr TriangularSolve( return &computation; }; - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); // Right-looking blocked triangular solve. // For an explanation of the algorithm, see the TRSM discussion in: @@ -172,13 +161,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -188,7 +179,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) if (i + k < n) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); @@ -222,13 +213,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -238,7 +231,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) if (i + k < m) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN( a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); @@ -271,13 +264,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {0, i})); @@ -287,7 +282,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -321,13 +316,15 @@ xla::StatusOr TriangularSolve( SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); - xla::ComputationDataHandle update; + xla::XlaOp update; if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, + TF_ASSIGN_OR_RETURN(xla::XlaComputation * solve, get_base_triangular_solve(k)); update = builder->Call(*solve, {a_slice, b_slice}); } else { - update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + update = builder->Div(b_slice, a_slice_conj); } TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); @@ -337,7 +334,7 @@ xla::StatusOr TriangularSolve( // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) if (i - k >= 0) { - xla::ComputationDataHandle a_slice_2; + xla::XlaOp a_slice_2; if (lower) { TF_ASSIGN_OR_RETURN(a_slice_2, SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); @@ -363,37 +360,23 @@ xla::StatusOr TriangularSolve( return output; } -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { - TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, - builder->GetShape(a)); - TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, - builder->GetShape(b)); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); - const int64 ndims = xla::ShapeUtil::Rank(*a_shape); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); std::vector batch_dimensions; for (int i = 0; i < ndims - 2; ++i) { - int64 a_size = a_shape->dimensions(i); + int64 a_size = a_shape.dimensions(i); batch_dimensions.push_back(a_size); } - auto prepend_batch_dims = [&](std::array indices) { - std::vector output(ndims); - std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); - std::copy(indices.begin(), indices.end(), - output.begin() + batch_dimensions.size()); - return output; - }; - - auto maybe_conj = [&](xla::ComputationBuilder* builder, - xla::ComputationDataHandle x) { - auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; - return perform_conj ? builder->Conj(x) : x; - }; - // The main computation is performed in a While loop. // Allocate the output and set its first or last row, @@ -402,14 +385,16 @@ xla::StatusOr TriangularSolveLeftLooking( // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] // else: // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] - xla::ComputationDataHandle output = Zeros(builder, *b_shape); + xla::XlaOp output = Zeros(builder, b_shape); { auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); - auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN(auto a_slice_conj, + MaybeConjugate(builder, a_slice, conjugate_a)); + auto update = builder->Div(b_slice, a_slice_conj); TF_ASSIGN_OR_RETURN( output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); } @@ -423,11 +408,11 @@ xla::StatusOr TriangularSolveLeftLooking( // The loop iteration counter is a scalar, incremented each iteration. xla::ShapeUtil::MakeShape(xla::S32, {}), // The output has the shape of b, with one row updated each iteration. - *b_shape, + b_shape, // The coefficient matrix a is a loop invariant. - *a_shape, + a_shape, // The right-hand-side matrix b is a loop invariant. - *b_shape}; + b_shape}; xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); auto init = builder->Tuple({init_i, output, a, b}); @@ -436,7 +421,7 @@ xla::StatusOr TriangularSolveLeftLooking( // def cond_fun(loop_carry): // i, output, a, b = loop_carry // return i >= 0 if transpose_a else i < m - std::unique_ptr condb = + std::unique_ptr condb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); { auto i = condb->GetTupleElement( @@ -466,7 +451,7 @@ xla::StatusOr TriangularSolveLeftLooking( // return (i + 1, output, a, b) // We have to do some extra FLOPs propagating zeros in the matrix multiply // because we can't have the size of its arguments depend on the loop counter. - std::unique_ptr bodyb = + std::unique_ptr bodyb = builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); { auto input_tuple = bodyb->Parameter(0, tuple_shape, @@ -479,30 +464,6 @@ xla::StatusOr TriangularSolveLeftLooking( auto body_b = bodyb->GetTupleElement(input_tuple, 3); auto zero = bodyb->ConstantR0(0); - // Set up some helper functions. - auto prepend_zeros = [&](std::array starts) { - auto zero = bodyb->Reshape(bodyb->ConstantR0(0), {1}); - std::vector padded_starts(ndims, zero); - padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); - padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); - return bodyb->ConcatInDim(padded_starts, 0); - }; - - auto dynamic_slice = [&](xla::ComputationDataHandle x, - std::array starts, - std::array sizes) { - auto padded_starts = prepend_zeros(starts); - auto padded_sizes = prepend_batch_dims(sizes); - return bodyb->DynamicSlice(x, padded_starts, padded_sizes); - }; - - auto update = [&](xla::ComputationDataHandle x, - xla::ComputationDataHandle update, - std::array starts) { - auto padded_starts = prepend_zeros(starts); - return bodyb->DynamicUpdateSlice(x, update, padded_starts); - }; - // We'd like to implement this: // if transpose_a: // a_row = T(a[..., i+1:, i:i+1]) @@ -514,24 +475,33 @@ xla::StatusOr TriangularSolveLeftLooking( // But since we can't have intermediate array sizes depend on the loop // counter, we instead exploit the fact that we initialized the output to // all zeros and use that as zero-padding (doing unnecessary FLOPs). - xla::ComputationDataHandle a_row; + xla::XlaOp a_row; if (transpose_a) { - a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); + TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, + {zero, i}, {m, 1})); } else { - a_row = dynamic_slice(body_a, {i, zero}, {1, m}); + TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, zero}, {1, m})); } TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, /*transpose_x=*/transpose_a, /*transpose_y=*/false, /*conjugate_x=*/conjugate_a, /*conjugate_y=*/false)); - auto result_row = - bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); + TF_ASSIGN_OR_RETURN( + auto result_row_slice, + DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n})); + auto result_row = bodyb->Sub(result_row_slice, b_update); // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] - auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); - auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); - body_out = update(body_out, div_result, {i, zero}); + TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); + TF_ASSIGN_OR_RETURN(auto a_elt_conj, + MaybeConjugate(bodyb.get(), a_elt, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_elt_conj); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {i, zero})); // if transpose_a: // return (i - 1, body_out, a, b) @@ -548,4 +518,130 @@ xla::StatusOr TriangularSolveLeftLooking( return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a) { + TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(xla::Shape b_shape, builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape.dimensions(i); + batch_dimensions.push_back(a_size); + } + + // The main computation is performed in a While loop. + xla::XlaOp output = Zeros(builder, b_shape); + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (0, output, a, b) + // else: + // init = (n-1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + b_shape, + // The coefficient matrix a is a loop invariant. + a_shape, + // The right-hand-side matrix b is a loop invariant. + b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? 0 : n - 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i < n if transpose_a else i >= 0 + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveRightLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Lt(i, condb->ConstantR0(n)); + } else { + condb->Ge(i, condb->ConstantR0(0)); + } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); + + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., :, i:i+1], -1 -2) + // else: + // a_row = a[..., :, i:i+1] + // result_row = b[..., :, i:i+1] - np.matmul(output, a_row) + // output[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveRightLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter( + 0, tuple_shape, "TriangularSolveRightLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // We'd like to implement b[..., :, i:i+1] - np.matmul(output, a[..., :, + // i:i+1]) But since we can't have intermediate array sizes depend on the + // loop counter, we instead exploit the fact that we initialized the output + // to all zeros and use that as zero-padding (doing unnecessary FLOPs). + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), body_out, body_a, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + // result = b - np.matmul(output, a) + auto result = bodyb->Sub(body_b, b_update); + // result_row = result[..., :, i:i+1] + TF_ASSIGN_OR_RETURN( + auto result_row, + DynamicSliceInMinorDims(bodyb.get(), result, {zero, i}, {m, 1})); + + // body_out[..., :, i:i+1] = result_row / a[..., i:i+1, i:i+1] + TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(bodyb.get(), body_a, + {i, i}, {1, 1})); + TF_ASSIGN_OR_RETURN(auto a_ii_conj, + MaybeConjugate(bodyb.get(), a_ii, conjugate_a)); + auto div_result = bodyb->Div(result_row, a_ii_conj); + TF_ASSIGN_OR_RETURN(body_out, + DynamicUpdateSliceInMinorDims(bodyb.get(), body_out, + div_result, {zero, i})); + + // if transpose_a: + // return (i + 1, body_out, a, b) + // else: + // return (i - 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? 1 : -1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); + } + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index e32223bfddd..540c26b2473 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" namespace tensorflow { @@ -57,14 +57,23 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::StatusOr TriangularSolve( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, - bool conjugate_a, int64 block_size = 256); +xla::StatusOr TriangularSolve(xla::XlaBuilder* builder, + const xla::XlaOp& a, xla::XlaOp b, + bool left_side, bool lower, + bool transpose_a, bool conjugate_a, + int64 block_size = 256); -xla::StatusOr TriangularSolveLeftLooking( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); +xla::StatusOr TriangularSolveLeftLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); + +xla::StatusOr TriangularSolveRightLooking(xla::XlaBuilder* builder, + const xla::XlaOp& a, + const xla::XlaOp& b, + bool transpose_a, + bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 66170706291..87ea4763f7c 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -80,9 +80,9 @@ xla::Array2D AValsFull() { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -102,9 +102,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -124,9 +124,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -146,9 +146,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -168,9 +168,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -191,9 +191,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -214,9 +214,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -237,9 +237,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolve(&builder, a, b, @@ -260,9 +260,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { } XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); auto b_data = @@ -288,9 +288,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { } XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); auto b_data = @@ -318,9 +318,9 @@ XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { } XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, @@ -340,9 +340,9 @@ XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { } XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { - xla::ComputationBuilder builder(client_, TestName()); + xla::XlaBuilder builder(TestName()); - xla::ComputationDataHandle a, b; + xla::XlaOp a, b; auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); auto result = TriangularSolveLeftLooking(&builder, a, b, diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index f579669bbd8..d9ff7e6259f 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,15 +27,14 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape) { +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); } -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value) { +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value) { switch (type) { case xla::F16: return builder->ConstantR0(static_cast(value)); @@ -57,9 +56,8 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, - int64 value) { +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value) { xla::Literal literal; switch (type) { case xla::U8: @@ -112,17 +110,18 @@ xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, return builder->ConstantLiteral(literal); } -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end) { +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end) { TF_RET_CHECK(start.size() == end.size()); int64 n_minor_dims = start.size(); - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_minor_dims <= n_dims); - gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), /*pos=*/0, /*len=*/n_dims - n_minor_dims); @@ -140,20 +139,56 @@ xla::StatusOr SliceInMinorDims( return builder->Slice(x, padded_start, padded_end, strides); } -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { - // TODO(phawkins): make int64 work on all backends, remove the int32 cast. - std::vector start_as_int32(start.begin(), start.end()); - return builder->DynamicUpdateSlice( - x, update, builder->ConstantR1(start_as_int32)); +std::vector PrependMajorDims(xla::XlaBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices) { + std::vector output(indices.size() + major_dims.size()); + std::copy(major_dims.begin(), major_dims.end(), output.begin()); + std::copy(indices.begin(), indices.end(), output.begin() + major_dims.size()); + return output; } -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, + const gtl::ArraySlice& sizes) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + int64 n_minor_dims = starts.size(); + TF_RET_CHECK(n_minor_dims == sizes.size()); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape.dimensions()), + /*pos=*/0, + /*len=*/n_dims - sizes.size()); + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + auto padded_sizes = PrependMajorDims(builder, major_dims, sizes); + return builder->DynamicSlice(x, padded_starts, padded_sizes); +} + +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + auto start_constant = builder->ConstantR1(start_as_int32); + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + TF_ASSIGN_OR_RETURN(xla::Shape start_constant_shape, + builder->GetShape(start_constant)); + const int64 start_length = + xla::ShapeUtil::GetDimension(start_constant_shape, -1); + TF_RET_CHECK(start_length == n_dims); + return builder->DynamicUpdateSlice(x, update, start_constant); +} + +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); const int64 n_minor_dims = start.size(); TF_RET_CHECK(n_minor_dims <= n_dims); std::vector padded_start(n_dims, 0); @@ -162,10 +197,32 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { - TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); - const int64 n_dims = xla::ShapeUtil::Rank(*shape); +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(auto padded_starts, + PrependZerosInMajorDims(builder, x, starts)); + return builder->DynamicUpdateSlice(x, update, padded_starts); +} + +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); + auto zero = builder->Reshape(builder->ConstantR0(0), {1}); + std::vector padded_starts(n_dims, zero); + for (int i = 0; i < starts.size(); ++i) { + padded_starts[n_dims - starts.size() + i] = + builder->Reshape(starts[i], {1}); + } + return builder->ConcatInDim(padded_starts, 0); +} + +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(shape); TF_RET_CHECK(n_dims >= 2); std::vector permutation(n_dims); std::iota(permutation.begin(), permutation.end(), 0); @@ -173,4 +230,11 @@ xla::StatusOr TransposeInMinorDims( return builder->Transpose(x, permutation); } +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate) { + TF_ASSIGN_OR_RETURN(xla::Shape shape, builder->GetShape(x)); + auto perform_conj = shape.element_type() == xla::C64 && conjugate; + return perform_conj ? builder->Conj(x) : x; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 51f8baaf00b..3c120a25485 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,47 +16,79 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. -xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - const xla::Shape& shape); +xla::XlaOp Zeros(xla::XlaBuilder* builder, const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, double value); +xla::XlaOp FloatLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + double value); + +// Makes a 1D tensor [0, ..., x, y] from two tensors x and y with zeros +// prepended until the array is length n_dims. +xla::XlaOp PrependZerosInMajorDims(xla::XlaBuilder* builder, + gtl::ArraySlice starts); // Returns a integer scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. -xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, - xla::PrimitiveType type, int64 value); +xla::XlaOp IntegerLiteral(xla::XlaBuilder* builder, xla::PrimitiveType type, + int64 value); + +// Builds a vector of zeros of length rank(x) with the last two values being +// those in `starts`. +xla::StatusOr PrependZerosInMajorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts); // Performs a slice in the minor dimensions of a Tensor. -xla::StatusOr SliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - gtl::ArraySlice start, gtl::ArraySlice end); +xla::StatusOr SliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + gtl::ArraySlice start, + gtl::ArraySlice end); + +// Builds a 1-d vector out of a concatenation of `major_dims` and `starts`. +std::vector PrependMajorDims(xla::XlaBuilder* builder, + const gtl::ArraySlice& major_dims, + const gtl::ArraySlice& indices); + +// Performs a dynamic slice in the minor dimensions of a Tensor. +xla::StatusOr DynamicSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, + const std::vector& starts, const gtl::ArraySlice& sizes); // Updates a slice of 'x', i.e., // x[start[0], ..., start[n]] = update -xla::StatusOr UpdateSlice( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSlice(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); // Updates a slice of 'x', where 'start' contains a list of minor dimensions: // x[..., start[0], ..., start[n]] = update -xla::StatusOr UpdateSliceInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, - const xla::ComputationDataHandle& update, gtl::ArraySlice start); +xla::StatusOr UpdateSliceInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x, + const xla::XlaOp& update, + gtl::ArraySlice start); + +xla::StatusOr DynamicUpdateSliceInMinorDims( + xla::XlaBuilder* builder, const xla::XlaOp& x, const xla::XlaOp& update, + const std::vector& starts); // Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::StatusOr TransposeInMinorDims( - xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); +xla::StatusOr TransposeInMinorDims(xla::XlaBuilder* builder, + const xla::XlaOp& x); + +// Applies a complex conjugation operation if `a` is complex and `conjugate_a` +// is true, otherwise returns its argument. +xla::StatusOr MaybeConjugate(xla::XlaBuilder* builder, + const xla::XlaOp& x, bool conjugate); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util_test.cc b/tensorflow/compiler/tf2xla/lib/util_test.cc new file mode 100644 index 00000000000..265b39402c8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using UtilTest = xla::ClientLibraryTestBase; +using UtilLeftLookingTest = xla::ClientLibraryTestBase; + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array3D BatchedAValsFull() { + return {{ + {2, 0, 1, 2}, + {3, 6, 0, 1}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }}; +} + +XLA_TEST_F(UtilTest, Simple2dLookup) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, x, y; + auto a_data = CreateR2Parameter(BValsRight(), 0, "a", &builder, &a); + auto x_data = CreateR0Parameter(2, 1, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 2, "y", &builder, &y); + auto result = DynamicSliceInMinorDims(&builder, a, {x, y}, {1, 1}); + TF_ASSERT_OK(result.status()); + + ComputeAndCompareR2(&builder, {{10}}, + {a_data.get(), x_data.get(), y_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(UtilTest, Simple3dLookup) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto index_data = CreateR0Parameter(1, 1, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, 4})); + + ComputeAndCompareR3(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}}, + {a_data.get(), index_data.get()}); +} + +XLA_TEST_F(UtilTest, SimpleSliceUpdate) { + xla::XlaBuilder builder(TestName()); + + xla::XlaOp a, b, x, y; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter({{9, 1, -10}}, 1, "b", &builder, &b); + auto x_data = CreateR0Parameter(2, 2, "x", &builder, &x); + auto y_data = CreateR0Parameter(1, 3, "y", &builder, &y); + + auto result = DynamicUpdateSliceInMinorDims(&builder, a, b, {x, y}); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected( + {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}}); + + ComputeAndCompareR2( + &builder, expected, + {a_data.get(), b_data.get(), x_data.get(), y_data.get()}); +} + +XLA_TEST_F(UtilTest, RowBatchDot) { + xla::XlaBuilder builder(TestName()); + + int n = 4; + + xla::XlaOp a, row, index; + auto a_data = + CreateR3Parameter(BatchedAValsFull(), 0, "a", &builder, &a); + auto row_data = CreateR3Parameter({{{9, 1, 0, 0}}, {{2, 4, 0, 0}}}, 1, + "row", &builder, &row); + // Select {{3, 6, 0, 1}, {24, 61, 82, 48}} out of BatchedAValsFull(). + auto index_data = CreateR0Parameter(1, 2, "index", &builder, &index); + + TF_ASSERT_OK_AND_ASSIGN( + auto l_index, + DynamicSliceInMinorDims(&builder, a, + {index, builder.ConstantR0(0)}, {1, n})); + TF_ASSERT_OK_AND_ASSIGN( + auto dot, BatchDot(&builder, l_index, row, + /*transpose_x=*/false, /*transpose_y=*/true)); + + ComputeAndCompareR3(&builder, {{{33}}, {{292}}}, + {a_data.get(), row_data.get(), index_data.get()}); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 495d9c60780..09ce594930e 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -20,24 +20,24 @@ limitations under the License. namespace tensorflow { -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; var_shapes.reserve(arity); - for (const xla::ComputationDataHandle& input : initial_values) { + for (const xla::XlaOp& input : initial_values) { TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); - var_shapes.push_back(std::move(*shape)); + var_shapes.push_back(std::move(shape)); } xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); // Unpacks a tuple into its component parts. - auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, - xla::ComputationBuilder* builder) { - std::vector elements(arity); + auto unpack_tuple = [](xla::XlaOp tuple, int arity, + xla::XlaBuilder* builder) { + std::vector elements(arity); for (int i = 0; i < arity; ++i) { elements[i] = builder->GetTupleElement(tuple, i); } @@ -45,20 +45,20 @@ xla::StatusOr> XlaWhileLoop( }; // Build the condition. - std::unique_ptr cond_builder = + std::unique_ptr cond_builder = builder->CreateSubBuilder(strings::StrCat(name, "_condition")); { auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); - TF_ASSIGN_OR_RETURN( - auto result, + TF_RETURN_IF_ERROR( condition_function(unpack_tuple(parameter, arity, cond_builder.get()), - cond_builder.get())); + cond_builder.get()) + .status()); } TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); // Build the body. - std::unique_ptr body_builder = + std::unique_ptr body_builder = builder->CreateSubBuilder(strings::StrCat(name, "_body")); { auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); @@ -78,38 +78,38 @@ xla::StatusOr> XlaWhileLoop( return unpack_tuple(outputs, arity, builder); } -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder) { - auto while_cond_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* cond_builder) - -> xla::StatusOr { + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder) { + auto while_cond_fn = + [&](gtl::ArraySlice values, + xla::XlaBuilder* cond_builder) -> xla::StatusOr { return cond_builder->Lt( values[0], IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); }; - auto while_body_fn = [&](gtl::ArraySlice values, - xla::ComputationBuilder* body_builder) - -> xla::StatusOr> { - xla::ComputationDataHandle iteration = values[0]; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::XlaBuilder* body_builder) + -> xla::StatusOr> { + xla::XlaOp iteration = values[0]; - std::vector updated_values; + std::vector updated_values; updated_values.reserve(values.size()); updated_values.push_back(body_builder->Add( iteration, body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); values.remove_prefix(1); - TF_ASSIGN_OR_RETURN(std::vector body_outputs, + TF_ASSIGN_OR_RETURN(std::vector body_outputs, body_function(iteration, values, body_builder)); updated_values.insert(updated_values.end(), body_outputs.begin(), body_outputs.end()); return updated_values; }; - std::vector values; + std::vector values; values.reserve(initial_values.size() + 1); values.push_back( builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 2e67a0c99b6..5b6684c9958 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,8 +19,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -29,14 +29,14 @@ namespace tensorflow { // Function that builds a loop condition. Takes as input a sequence of input // values, and returns a boolean value representing if the condition succeeds. -typedef std::function( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function(gtl::ArraySlice, + xla::XlaBuilder*)> LoopConditionFunction; // Function that builds a loop body. Takes as input a sequence of input values // and returns a sequence of output values. -typedef std::function>( - gtl::ArraySlice, xla::ComputationBuilder*)> +typedef std::function>( + gtl::ArraySlice, xla::XlaBuilder*)> LoopBodyFunction; // Helper function for building an XLA while loop, where the values carried by @@ -47,27 +47,26 @@ typedef std::function>( // init: (a, b, c) // ) // 'name' is a descriptive name for the loop. -xla::StatusOr> XlaWhileLoop( +xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. // // The body function (ForEachIndexBodyFunction) takes as input a pair of // (current iteration number, loop-carried values), and returns an updated // vector of the loop-carried values. -typedef std::function>( - xla::ComputationDataHandle, gtl::ArraySlice, - xla::ComputationBuilder*)> +typedef std::function>( + xla::XlaOp, gtl::ArraySlice, xla::XlaBuilder*)> ForEachIndexBodyFunction; -xla::StatusOr> XlaForEachIndex( +xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - gtl::ArraySlice initial_values, - StringPiece name, xla::ComputationBuilder* builder); + gtl::ArraySlice initial_values, StringPiece name, + xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 2c3cd658e04..43e1c1e9fec 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,7 +40,7 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && xla::ShapeUtil::ElementsIn(literal.shape()) == @@ -63,8 +63,8 @@ Status CopyLiteralToHostTensor(const xla::Literal& literal, return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index f283b023681..220bec15538 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -36,13 +36,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor); +Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::Literal& literal, +Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 6051d7dffd7..3a08aa8cf4f 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -251,7 +251,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation) { + xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -303,7 +303,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, } // InitGraph creates a graph based on the graph_def, that may then be converted -// to an xla::Computation via ConvertGraphToXla. +// to an xla::XlaComputation via ConvertGraphToXla. // // The graph is rewritten with _Arg and _Retval nodes, representing the inputs // and outputs of the function that will be compiled. Each feed id causes a new @@ -348,7 +348,7 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation) { + xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index 473c431b12d..d02fc56c5b8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -18,21 +18,21 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" namespace tensorflow { -// Converts a tensorflow::GraphDef into an xla::Computation. The given `config` -// specifies the portion of the graph to convert, via feeds and fetches. Each -// feed is a positional input argument for the generated computation, while each -// fetch is a positional output argument. +// Converts a tensorflow::GraphDef into an xla::XlaComputation. The given +// `config` specifies the portion of the graph to convert, via feeds and +// fetches. Each feed is a positional input argument for the generated +// computation, while each fetch is a positional output argument. // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation); + xla::XlaComputation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index b813668a9ed..84c133ffabe 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -69,7 +69,7 @@ TEST(ConvertGraphDefToXla, Sum) { tf2xla::Config config = SumConfig(); xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); - xla::Computation computation; + xla::XlaComputation computation; TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 7ec85aa3cde..9203e8d9e60 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -232,7 +232,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Push input nodes of the currently visited node to name_queue. for (const string& in_edge : map_entry.second->input()) { auto id = ParseTensorName(in_edge); - const string node_name = id.first.ToString(); + const string node_name = std::string(id.first); if (feed_tensors.find(std::make_pair(node_name, id.second)) == feed_tensors.end()) { name_queue.push(node_name); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fcb0a4e6381..fe7ec633eca 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" @@ -108,7 +109,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, // If no sharding metadata is found, XLA is free to use whatever device it // wants. In practice this usually has the effect of placing things on device // 0. - xla::ScopedShardingAssignment assign_sharding(b, op_sharding); + xla::XlaScopedShardingAssignment assign_sharding(b, op_sharding); op_kernel->Compute(context); b->ClearOpMetadata(); @@ -126,9 +127,7 @@ Status XlaCompilationDevice::MakeTensorFromProto( XlaExpression::XlaExpression() = default; -void XlaExpression::set_handle(const xla::ComputationDataHandle& h) { - handle_ = h; -} +void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } void XlaExpression::set_constant_value(Tensor value) { has_constant_value_ = true; diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 0243ee332fb..d0b9e34e162 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" @@ -69,7 +69,7 @@ class XlaCompilationDevice : public LocalDevice { // A XlaExpression wraps an XLA computation. Each Tensor on an // XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the ComputationDataHandle. Each +// matches the shape of the subcomputation in the XlaOp. Each // expression is either a constant, or a function of previously-compiled // expressions. class XlaExpression { @@ -78,8 +78,8 @@ class XlaExpression { // handle() stores the XLA handle of the computation that the // expression represents. - void set_handle(const xla::ComputationDataHandle& h); - const xla::ComputationDataHandle& handle() const { return handle_; } + void set_handle(const xla::XlaOp& h); + const xla::XlaOp& handle() const { return handle_; } void set_constant_value(Tensor value); bool has_constant_value() const { return has_constant_value_; } @@ -90,7 +90,7 @@ class XlaExpression { private: // The XLA handle of the expression's computation. - xla::ComputationDataHandle handle_; + xla::XlaOp handle_; // If this expression is a constant with a known value, 'constant_value' is a // host-memory Tensor containing the value. Used to avoid invoking XLA for diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 86263d847ae..f7098917b19 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include #include +#include -#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -40,7 +38,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); - // The default variable representation shape is the identity function. - if (!options_.variable_representation_shape_fn) { - options_.variable_representation_shape_fn = - [](const TensorShape& shape, DataType type) { return shape; }; + // The default shape representation function is the identity. + if (!options_.shape_representation_fn) { + options_.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return shape; }; } } @@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, + bool is_entry_computation, xla::Shape* xla_shape) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), - xla_shape); - case XlaCompiler::Argument::kParameter: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + LOG(FATAL) << "Unreachable case"; + case XlaCompiler::Argument::kParameter: { + TensorShape shape = + is_entry_computation + ? options_.shape_representation_fn(arg.shape, arg.type) + : arg.shape; + return TensorShapeToXLAShape(arg.type, shape, xla_shape); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.variable_representation_shape_fn(arg.shape, arg.type); + options_.shape_representation_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& arg_cores, - const std::vector& retvals, + const std::vector& retvals, const std::vector>& resources, - bool return_updated_values_for_all_resources, - xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, + bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, + xla::XlaComputation* computation, int* num_computation_outputs, + int* num_nonconst_outputs, + std::vector* outputs, std::vector* resource_updates) { - std::vector elems; + std::vector elems; elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + output.type = retvals[i].type; + output.shape = retvals[i].shape; + const XlaExpression& retval = retvals[i].expression; + if (retval.has_constant_value()) { + output.is_constant = true; + output.constant_value = retval.constant_value(); + } else { + output.is_constant = false; elems.push_back(retval.handle()); } } @@ -376,14 +387,12 @@ Status BuildComputation( const XlaCompiler::Argument& arg = args[resource->arg_num()]; const int core = arg_cores[resource->arg_num()]; DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = - resource->value().handle() != resource->initial_value().handle(); + bool modified = resource->value() != resource->initial_value(); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. for (const auto& grad : resource->tensor_array_gradients()) { modified = modified || - grad.second->value().handle() != - grad.second->initial_value().handle() || + grad.second->value() != grad.second->initial_value() || arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { @@ -398,11 +407,11 @@ Status BuildComputation( } // Request that the value be returned on a specific core. - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); - xla::ComputationDataHandle handle; + xla::XlaOp handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); // Since we can't change the sharding metadata of as this point, @@ -421,7 +430,7 @@ Status BuildComputation( builder->Tuple(elems); builder->ClearOpMetadata(); - xla::StatusOr computation_status = builder->Build(); + xla::StatusOr computation_status = builder->Build(); if (!computation_status.ok()) { return computation_status.status(); } @@ -435,7 +444,7 @@ Status BuildComputation( // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { @@ -461,8 +470,7 @@ Status XlaCompiler::BuildArguments( // alias. XlaResource* resource; TF_RETURN_IF_ERROR(context->CreateResource( - arg.resource_kind, i, arg.name, arg.type, arg.shape, - xla::ComputationDataHandle(), + arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); @@ -493,8 +501,8 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR( - XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR(XLAShapeForArgument( + args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -531,9 +539,9 @@ Status XlaCompiler::BuildArguments( builder->SetOpMetadata(arg_metadata); // Build parameter handles for non-constant arguments. - std::vector arg_handles(input_mapping->size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { - xla::ComputationDataHandle tuple; + xla::XlaOp tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); @@ -544,15 +552,15 @@ Status XlaCompiler::BuildArguments( core == -1 ? xla::sharding_builder::AssignDevice(root_device) : xla::sharding_builder::AssignDevice(core); } - xla::ScopedShardingAssignment assign_tuple_sharding(builder, - tuple_sharding); + xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, + tuple_sharding); tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } else { tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); @@ -560,7 +568,7 @@ Status XlaCompiler::BuildArguments( } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const int core = (*arg_cores)[input_mapping->at(i)]; - xla::ScopedShardingAssignment assign_sharding( + xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = @@ -570,7 +578,8 @@ Status XlaCompiler::BuildArguments( builder->ClearOpMetadata(); - // Fill in the handles in non-constant arguments. + // Fill in the handles in non-constant arguments, and reshape parameters + // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -589,7 +598,15 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - arg_expression.set_handle(arg_handles[i]); + // Reshape parameters back to their correct shapes. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + if (is_entry_computation) { + arg_expression.set_handle( + builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + } else { + arg_expression.set_handle(arg_handles[i]); + } break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -647,7 +664,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, std::unique_ptr graph, const std::vector& args, CompilationResult* result) { - VLOG(1) << "Executing graph symbolically to populate ComputationBuilder."; + VLOG(1) << "Executing graph symbolically to populate XlaBuilder."; if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " @@ -661,13 +678,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Converts Tensorflow's graph control-flow constructs into functional // control-flow that can be compiled into XLA code. TF_RETURN_IF_ERROR( - FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); + FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), + graph.get(), local_flib_def_.get())); - xla::ComputationBuilder builder(client(), name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, - &options_.variable_representation_shape_fn); + xla::XlaBuilder builder(name); + XlaContext* context = new XlaContext( + this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, options.is_entry_computation, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -683,36 +701,23 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; - result->computation = std::make_shared(); + result->computation = std::make_shared(); + result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(context->retvals().size()); - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (retval.has_constant_value()) { - OutputDescription& output = result->outputs[i]; - output.shape = retval.constant_value().shape(); - output.is_constant = true; - output.constant_value = retval.constant_value(); - } - } - // Compute the output shapes, if there is a computation with non-constant + // Compute the XLA output shape, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(*result->computation); - if (!computation_shape.ok()) { - return computation_shape.status(); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, + client()->GetComputationShape(*result->computation)); - result->xla_output_shape.Swap( - computation_shape.ValueOrDie()->mutable_result()); + result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); @@ -727,23 +732,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - // Converts the output shapes to TensorShapes. - int computation_output = 0; - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (!retval.has_constant_value()) { - TF_RET_CHECK(computation_output < num_computation_outputs) - << "Computation has more outputs than expected"; - OutputDescription& output = result->outputs[i]; - output.is_constant = false; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, - computation_output), - &output.shape)); - ++computation_output; - } - } return Status::OK(); } @@ -813,4 +801,29 @@ Status XlaCompiler::SetHostToDeviceMetadata( return Status::OK(); } +Status XlaCompiler::GetHostComputeControlDependency( + const string& host_compute_name, xla::XlaOp* handle) { + const auto iter = host_compute_control_output_.find(host_compute_name); + if (iter == host_compute_control_output_.end()) { + return errors::InvalidArgument( + "No registered control handle for host compute Op '", host_compute_name, + "'"); + } else { + *handle = iter->second; + } + return Status::OK(); +} + +Status XlaCompiler::SetHostComputeControlDependency( + const string& host_compute_name, const xla::XlaOp& handle) { + if (host_compute_control_output_.find(host_compute_name) != + host_compute_control_output_.end()) { + return errors::InvalidArgument( + "Duplicate control handles registered for for host compute Op ", + host_compute_name); + } + host_compute_control_output_[host_compute_name] = handle; + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index a6747bbe72e..bf496bd8bc8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -38,7 +38,7 @@ class XlaContext; // It does a symbolic execution of the graph starting from specific input // shapes, using a JIT device to convert operators into XLA computations. // -// XlaCompiler is typically invoked from an `_XlaLaunch` operator once the +// XlaCompiler is typically invoked from an `XlaLaunch` operator once the // shapes of all input parameters to the computation are known. This is // because the symbolic execution requires known shapes for all operations. // @@ -67,6 +67,15 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -171,7 +180,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. + // Type and shape of the output. The shape is the unflattened shape. DataType type; TensorShape shape; @@ -206,10 +215,12 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Input shapes of the computation. + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. std::vector xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -227,9 +238,11 @@ class XlaCompiler { std::vector resource_updates; // The XLA computation built from the tensorflow subgraph. - std::shared_ptr computation; + std::shared_ptr computation; }; + typedef std::function + ShapeRepresentationFn; struct Options { // Name of the compilation device to use. Needs to be live only during // XlaCompiler's constructor. @@ -250,8 +263,7 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - std::function - variable_representation_shape_fn; + ShapeRepresentationFn shape_representation_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -281,7 +293,7 @@ class XlaCompiler { const NameAttrList& fn_name_attrs, std::vector args, CompilationResult* result); - // Compiles a tensorflow::Graph into an xla::Computation. + // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. Status CompileGraph(const CompileOptions& options, string const& name, @@ -290,7 +302,7 @@ class XlaCompiler { CompilationResult* result); // Compiles a single Op, given by an OpKernelContext, into an - // xla::Computation. Similar to CompileFunction but takes a single Op as + // xla::XlaComputation. Similar to CompileFunction but takes a single Op as // input. Status CompileSingleOp(const CompileOptions& options, string const& name, OpKernelContext* ctx, @@ -300,7 +312,8 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, + xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -325,6 +338,22 @@ class XlaCompiler { gtl::ArraySlice types, gtl::ArraySlice shapes); + // In order to avoid deadlocks from dependencies in host computations, it can + // be necessary to enforce a partial order on the execution of HostCompute + // Ops. In particular it may be necessary to constrain the SendToHost for one + // HostCompute to run before blocking on the RecvAtHost for another + // HostCompute. The compiler maintains a mapping from 'host_compute_name' to + // handle, where the handle is an 'output' of the HostCompute Op corresponding + // to 'host_compute_name'. Another HostCompute Op that needs to be sequenced + // later can add the handle as an 'input' to enforce the constraints. + // 'host_compute_name' can be any string the client wishes to use to identify + // a given HostCompute Op as long as the names are unique within the + // compilation. + Status GetHostComputeControlDependency(const string& host_compute_name, + xla::XlaOp* handle); + Status SetHostComputeControlDependency(const string& host_compute_name, + const xla::XlaOp& handle); + const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } @@ -341,7 +370,7 @@ class XlaCompiler { // `args` are the arguments to the computation. Status BuildArguments(const Graph& graph, const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, @@ -391,6 +420,8 @@ class XlaCompiler { std::unordered_map host_compute_sends_; std::unordered_map host_compute_recvs_; + std::unordered_map host_compute_control_output_; + TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 096dc7160bf..55772ca3248 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,12 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -164,7 +166,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); - // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { XlaCompiler compiler(DefaultOptions()); @@ -226,7 +227,7 @@ TEST_F(XlaCompilerTest, Simple) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { @@ -321,7 +322,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE( + xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } { @@ -356,10 +358,80 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal)); } } +TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { + // Define a function with one compile-time constant output and one + // data-dependent output. + // @function.Defun(noinline=True) + // foo(a) {b=7; return b, a; } + const Tensor seven = test::AsScalar(7); + FunctionDef fdef = FunctionDefHelper::Create( + "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {}, + { + {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}}, + }, + {{"a", "a_0"}, {"const", "Const:output:0"}}); + (*fdef.mutable_attr())["_noinline"].set_b(true); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = fdef; + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0); + NodeDef foo; + foo.set_name("foo"); + foo.set_op("foo"); + *foo.add_input() = "input_arg"; + Status status; + scope.graph()->AddNode(foo, &status); + TF_ASSERT_OK(status); + NodeDef retval_1; + retval_1.set_name("retval_0"); + retval_1.set_op(FunctionLibraryDefinition::kRetOp); + *retval_1.add_input() = "foo"; + (*retval_1.mutable_attr())["T"].set_type(DT_INT32); + (*retval_1.mutable_attr())["index"].set_i(0); + scope.graph()->AddNode(retval_1, &status); + TF_ASSERT_OK(status); + NodeDef retval_2; + retval_2.set_name("retval_1"); + retval_2.set_op(FunctionLibraryDefinition::kRetOp); + *retval_2.add_input() = "foo:1"; + (*retval_2.mutable_attr())["T"].set_type(DT_INT32); + (*retval_2.mutable_attr())["index"].set_i(1); + scope.graph()->AddNode(retval_2, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + } + + // Builds a description of the arguments. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({1}); + + XlaCompiler::Options options = DefaultOptions(); + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + options.flib_def = &flib_def; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.resolve_compile_time_constants = true; + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants", + std::move(graph), args, &result)); + + ASSERT_EQ(2, result.outputs.size()); + EXPECT_TRUE(result.outputs[0].is_constant); + test::ExpectTensorEqual(result.outputs[0].constant_value, + test::AsScalar(7)); + EXPECT_FALSE(result.outputs[1].is_constant); +} + // Tests compilation and execution of a graph that adds two tensors. TEST_F(XlaCompilerTest, ResourceManager) { // Builds a graph that calls the dummy resource Op. @@ -433,21 +505,26 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { } for (int64 i = 1; i < test_count; ++i) { - auto m1 = - results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); - auto m2 = - results[i].computation->Snapshot().ValueOrDie()->entry().requests(); - // Check if every entry is the same. - for (auto& entry1 : m1) { - int64 key = entry1.first; - auto value1 = entry1.second; - auto entry2 = m2.find(key); - auto value2 = entry2->second; - EXPECT_TRUE(entry2 != m2.end()); - string str1, str2; - value1.AppendToString(&str1); - value2.AppendToString(&str2); - EXPECT_EQ(str1, str2); + const auto& m1 = results[i - 1].computation->proto(); + const auto& m2 = results[i].computation->proto(); + ASSERT_EQ(m1.computations_size(), m2.computations_size()); + // Check if every hlo computation is the same. + for (int k = 0; k < m1.computations_size(); k++) { + const auto& c1 = m1.computations(k); + const auto& c2 = m2.computations(k); + ASSERT_EQ(c1.instructions_size(), c2.instructions_size()); + for (int j = 0; j < c1.instructions_size(); j++) { + auto instr1 = c1.instructions(j); + auto instr2 = c2.instructions(j); + instr1.clear_name(); + instr2.clear_name(); + // The names of instructions were uniquified by the XlaBuilder, the rest + // of the fields should be identical. + string str1, str2; + instr1.AppendPartialToString(&str1); + instr2.AppendPartialToString(&str2); + EXPECT_EQ(str1, str2); + } } } } @@ -519,7 +596,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { {output_base.get(), output_grad1.get(), output_grad2.get()}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } // Tests compilation and execution of a graph that adds two tensors. @@ -742,13 +819,10 @@ TEST_F(XlaCompilerTest, Variables) { xla::Literal::CreateR1({4, 143}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a simple graph that reads and writes a variable, with a -// variable_representation_shape_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { +xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -759,7 +833,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_ASSERT_OK(scope.ToGraph(graph.get())); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + return std::move(graph); +} + +// Tests a simple graph that reads and writes a variable, with a +// shape_representation_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); // Builds a description of the arguments. std::vector args(2); @@ -774,15 +856,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.variable_representation_shape_fn = [](const TensorShape& shape, - DataType type) { + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; // Only reshape variables. + XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::Compatible(program_shape->parameters(0), + xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -807,7 +907,76 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { xla::Literal::CreateR1({26, 66, 34, 401}); std::unique_ptr expected_literal = xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + +TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; // Reshape args and retvals. + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {4}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({4, 55, 1, -3}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({27, 67, 35, 402}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } // namespace diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 8423921086f..098072d33cd 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,28 +63,32 @@ void XlaContext::set_args(std::vector args) { } XlaContext::XlaContext( - XlaCompiler* compiler, xla::ComputationBuilder* builder, + XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn) + shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - variable_representation_shape_fn_(variable_representation_shape_fn) {} + is_entry_computation_(is_entry_computation), + shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const xla::ComputationDataHandle& handle) { + const TensorShape& shape, const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - retvals_[retval_index].set_handle(handle); + XlaExpression e; + e.set_handle(handle); + retvals_[retval_index] = Retval{type, shape, e}; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -94,23 +98,20 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - if (resolve_compile_time_constants_) { - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - retvals_[retval_index].set_constant_value(std::move(value)); - } else { - retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); - } + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + XlaExpression e; + e.set_constant_value(value); + retvals_[retval_index] = Retval{dtype, value.shape(), e}; return Status::OK(); } -xla::ComputationBuilder* XlaContext::builder() { return builder_; } +xla::XlaBuilder* XlaContext::builder() { return builder_; } Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, - TensorShape shape, const xla::ComputationDataHandle& handle, - int64 tensor_array_size, const std::set& tensor_array_gradients, - XlaResource** resource) { + TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource) { resources_.emplace_back( new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), handle, tensor_array_size, tensor_array_gradients)); @@ -118,16 +119,16 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, - DataType type) const { - return (*variable_representation_shape_fn_)(shape, type); +TensorShape XlaContext::RepresentationShape(const TensorShape& shape, + DataType type) const { + return (*shape_representation_fn_)(shape, type); } -const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Max() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "max<" + type_string + ">"); + xla::XlaBuilder b("max<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -137,11 +138,11 @@ const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMin(const DataType type) { return LookupOrCreate(type, &min_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Min() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "min<" + type_string + ">"); + xla::XlaBuilder b("min<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -151,11 +152,11 @@ const xla::Computation* XlaContext::GetOrCreateMin(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateAdd(const DataType type) { return LookupOrCreate(type, &add_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Add() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "add<" + type_string + ">"); + xla::XlaBuilder b("add<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -165,11 +166,11 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } -const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { +const xla::XlaComputation* XlaContext::GetOrCreateMul(const DataType type) { return LookupOrCreate(type, &mul_func_, [this, type] { const string type_string = DataTypeString(type); VLOG(1) << "Building Mul() for " << type_string; - xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::XlaBuilder b("mul<" + type_string + ">"); xla::PrimitiveType xla_type; TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); @@ -179,9 +180,9 @@ const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { }); } -const xla::Computation* XlaContext::LookupOrCreate( +const xla::XlaComputation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create) { + const std::function& create) { { const auto& entry = (*out)[type]; if (!entry.IsNull()) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 00fbaba37c5..341bf6ff1f3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -42,32 +42,44 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. - XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. + XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn); + shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; XlaCompiler* compiler() const { return compiler_; } - // Returns the ComputationBuilder that Ops use for compiling new - // expressions. - xla::ComputationBuilder* builder(); + // Returns the XlaBuilder that Ops use for compiling new expressions. + xla::XlaBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool resolve_compile_time_constants() const { + return resolve_compile_time_constants_; + } + bool is_entry_computation() const { return is_entry_computation_; } + const std::vector& args() const { return args_; } void set_args(std::vector args); - const std::vector& retvals() { return retvals_; } + struct Retval { + DataType type; + TensorShape shape; + // An XlaExpression representing the Retval's value. + XlaExpression expression; + }; + const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, - const xla::ComputationDataHandle& handle); + void AddRetval(int retval_index, DataType type, const TensorShape& shape, + const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -79,8 +91,7 @@ class XlaContext : public ResourceBase { // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, - const xla::ComputationDataHandle& handle, - int64 tensor_array_size, + const xla::XlaOp& handle, int64 tensor_array_size, const std::set& tensor_array_gradients, XlaResource** resource); @@ -89,29 +100,29 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`. - TensorShape VariableRepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`, or of an argument or return value of a top-level computation. + TensorShape RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(const DataType type); // Get an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(const DataType type); // Get an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(const DataType type); // Get an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(const DataType type); // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -119,9 +130,8 @@ class XlaContext : public ResourceBase { private: XlaCompiler* const compiler_; - // The ComputationBuilder used to construct the subgraph's compiled - // representation. - xla::ComputationBuilder* builder_; + // The XlaBuilder used to construct the subgraph's compiled representation. + xla::XlaBuilder* builder_; // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; @@ -135,25 +145,33 @@ class XlaContext : public ResourceBase { std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // A function that describes how variable shapes should be represented - // in XLA. Variable values will be reshaped to this shape. Must be non-null. + // Is this a top-level computation, or an inner computation (e.g., a while + // body)? + const bool is_entry_computation_; + + // A function that describes how the shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared shapes + // for computations. Must be non-null. const std::function* - variable_representation_shape_fn_; + shape_representation_fn_; // Cache of prebuilt computations indexed by their type. - using ComputationMap = std::map; + using ComputationMap = std::map; // Finds the value for the given type in out map if it already // exists or makes a new value with create function and keeps it the // map. The returned value != nullptr and is owned by the map. - const xla::Computation* LookupOrCreate( + const xla::XlaComputation* LookupOrCreate( DataType type, ComputationMap* out, - const std::function& create); + const std::function& create); // Cached computation to compute Max of two elements, specialized by type. ComputationMap max_func_; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 62a5114837e..f1594193af0 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" @@ -32,13 +32,12 @@ namespace tensorflow { namespace { -Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, bool is_min, - xla::ComputationDataHandle* argminmax) { - xla::ComputationDataHandle init_value; - const xla::Computation* reducer; +Status ArgMinMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + bool is_min, xla::XlaOp* argminmax) { + xla::XlaOp init_value; + const xla::XlaComputation* reducer; if (is_min) { init_value = XlaHelpers::MaxValue(builder, input_type); reducer = ctx->GetOrCreateMin(input_type); @@ -50,13 +49,13 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, xla::PrimitiveType xla_output_type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); - xla::ComputationDataHandle input_max = builder->Reduce( - input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); + xla::XlaOp input_max = builder->Reduce(input, init_value, *reducer, + /*dimensions_to_reduce=*/{axis}); std::vector broadcast_dims(input_shape.dims() - 1); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); // Compute a mask that has 1s for elements equal to the maximum. - xla::ComputationDataHandle partial_mask = builder->ConvertElementType( + xla::XlaOp partial_mask = builder->ConvertElementType( builder->Eq(input, input_max, broadcast_dims), xla_output_type); // In order to make identity elements for a bitwise And, we: @@ -65,23 +64,23 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, // 0xFF...F int32 bits_in_type = xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; - xla::ComputationDataHandle shift_amount = + xla::XlaOp shift_amount = XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); - xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic( + xla::XlaOp full_mask = builder->ShiftRightArithmetic( builder->ShiftLeft(partial_mask, shift_amount), shift_amount); // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its // index. - xla::ComputationDataHandle iota; + xla::XlaOp iota; const int64 axis_size = input_shape.dim_size(axis); TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); - xla::ComputationDataHandle product = + xla::XlaOp product = builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); // If there are multiple maximum elements, choose the one with the highest // index. - xla::ComputationDataHandle output = + xla::XlaOp output = builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), *ctx->GetOrCreateMax(output_type), /*dimensions_to_reduce=*/{axis}); @@ -91,36 +90,31 @@ Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, } // namespace -xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::MinValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MinValue(type)); } -xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::MaxValue(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::MaxValue(type)); } -xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::Zero(type)); } -xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return b->ConstantLiteral(xla::Literal::One(type)); } -xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, - DataType data_type) { +xla::XlaOp XlaHelpers::Epsilon(xla::XlaBuilder* b, DataType data_type) { switch (data_type) { case DT_HALF: return b->ConstantR0( @@ -137,16 +131,15 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, } } -xla::ComputationDataHandle XlaHelpers::IntegerLiteral( - xla::ComputationBuilder* b, DataType data_type, int64 value) { +xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64 value) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return ::tensorflow::IntegerLiteral(b, type, value); } -xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, - DataType data_type, - double value) { +xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); return ::tensorflow::FloatLiteral(b, type, value); @@ -183,28 +176,24 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } -Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, +Status XlaHelpers::ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmax) { + DataType output_type, int axis, xla::XlaOp* argmax) { return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, axis, /*is_min=*/false, argmax); } -Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, +Status XlaHelpers::ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmin) { + DataType output_type, int axis, xla::XlaOp* argmin) { return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, axis, /*is_min=*/true, argmin); } -Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, - int64 size, xla::ComputationDataHandle* iota) { +Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, + xla::XlaOp* iota) { TensorShape linspace_shape({size}); Tensor linspace; switch (dtype) { @@ -227,13 +216,10 @@ Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, return Status::OK(); } -Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, - int axis, DataType index_type, - const TensorShape& indices_shape, - const xla::ComputationDataHandle& indices, - const xla::ComputationDataHandle& on_value, - const xla::ComputationDataHandle& off_value, - xla::ComputationDataHandle* one_hot) { +Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, + DataType index_type, const TensorShape& indices_shape, + const xla::XlaOp& indices, const xla::XlaOp& on_value, + const xla::XlaOp& off_value, xla::XlaOp* one_hot) { const int indices_dims = indices_shape.dims(); const int output_dims = indices_dims + 1; @@ -267,7 +253,7 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, std::vector broadcast_dims(indices_shape.dims()); std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - xla::ComputationDataHandle one_hot_bool = builder->Eq( + xla::XlaOp one_hot_bool = builder->Eq( indices, builder->ConstantLiteral(linspace_literal), broadcast_dims); // Selects the user-provided off_value and on_value values. @@ -278,16 +264,15 @@ Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, } DataType XlaHelpers::SumAccumulationType(const DataType& dtype) { - if (dtype == DT_BFLOAT16) { + if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { return DT_FLOAT; } return dtype; } -xla::ComputationDataHandle XlaHelpers::ConvertElementType( - xla::ComputationBuilder* const builder, - const xla::ComputationDataHandle& operand, - const DataType new_element_type) { +xla::XlaOp XlaHelpers::ConvertElementType(xla::XlaBuilder* const builder, + const xla::XlaOp& operand, + const DataType new_element_type) { xla::PrimitiveType convert_to; TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to)); return builder->ConvertElementType(operand, convert_to); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 68ab93b64a5..c3fdc5252e7 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,7 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -30,41 +30,34 @@ class XlaHelpers { public: // Returns a handle representing the minimum value of a scalar // element of data_type. - static xla::ComputationDataHandle MinValue(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp MinValue(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the maximum value of a scalar // element of data_type. - static xla::ComputationDataHandle MaxValue(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp MaxValue(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the zero value of a scalar // element of data_type. - static xla::ComputationDataHandle Zero(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp Zero(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the one value of a scalar // element of data_type. - static xla::ComputationDataHandle One(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp One(xla::XlaBuilder* b, DataType data_type); // Returns the machine epsilon for floating-point type `data_type`, i.e., // the difference between 1.0 and the next representable value. - static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b, - DataType data_type); + static xla::XlaOp Epsilon(xla::XlaBuilder* b, DataType data_type); // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. - static xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* b, - DataType data_type, - int64 value); + static xla::XlaOp IntegerLiteral(xla::XlaBuilder* b, DataType data_type, + int64 value); // Returns a handle representing the given value of a floating-point scalar // element of data_type. - static xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* b, - DataType data_type, - double value); + static xla::XlaOp FloatLiteral(xla::XlaBuilder* b, DataType data_type, + double value); // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. @@ -75,38 +68,32 @@ class XlaHelpers { // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and // `input_dtype` are the shape and dtype of `input` respectively, and // `output_type` is the dtype to use for `argmax`. - static Status ArgMax(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmax); + static Status ArgMax(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + xla::XlaOp* argmax); // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and // `input_dtype` are the shape and dtype of `input` respectively, and // `output_type` is the dtype to use for `argmin`. - static Status ArgMin(xla::ComputationBuilder* builder, - XlaOpKernelContext* ctx, - const xla::ComputationDataHandle& input, - const TensorShape& input_shape, DataType input_type, - DataType output_type, int axis, - xla::ComputationDataHandle* argmin); + static Status ArgMin(xla::XlaBuilder* builder, XlaOpKernelContext* ctx, + const xla::XlaOp& input, const TensorShape& input_shape, + DataType input_type, DataType output_type, int axis, + xla::XlaOp* argmin); // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. - static Status Iota(xla::ComputationBuilder* builder, DataType dtype, - int64 size, xla::ComputationDataHandle* iota); + static Status Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, + xla::XlaOp* iota); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and // `off_value` represent the values to use for the on and off positions, // respectively. - static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, + static Status OneHot(xla::XlaBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, - const xla::ComputationDataHandle& indices, - const xla::ComputationDataHandle& on_value, - const xla::ComputationDataHandle& off_value, - xla::ComputationDataHandle* one_hot); + const xla::XlaOp& indices, const xla::XlaOp& on_value, + const xla::XlaOp& off_value, xla::XlaOp* one_hot); // Certain DataTypes should use increased precision DataTypes when performing // reductions. This function remaps a given DataType to a higher precision @@ -115,10 +102,9 @@ class XlaHelpers { // A helper for creating a ConvertElementType xla op given a DataType rather // than the xla::PrimitiveType. - static xla::ComputationDataHandle ConvertElementType( - xla::ComputationBuilder* const builder, - const xla::ComputationDataHandle& operand, - const DataType new_element_type); + static xla::XlaOp ConvertElementType(xla::XlaBuilder* const builder, + const xla::XlaOp& operand, + const DataType new_element_type); }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1fe6e69ff2d..9e17756b277 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -112,10 +112,10 @@ void CollectNames(const T& entries, std::vector* nonempty_names, XlaJitCompiledCpuFunction::Compile( const GraphDef& graph_def, const tf2xla::Config& config, const xla::ExecutableBuildOptions& build_options) { - // Convert the graph_def into an xla::Computation. + // Convert the graph_def into an xla::XlaComputation. TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); - xla::Computation computation; + xla::XlaComputation computation; TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, &computation)); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c4bb90d5875..76c68d81af4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -30,7 +30,7 @@ bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { return context_->ValidateInputsAreSameShape(op); } -xla::ComputationBuilder* XlaOpKernelContext::builder() const { +xla::XlaBuilder* XlaOpKernelContext::builder() const { return XlaContext::Get(this).builder(); } @@ -38,9 +38,9 @@ xla::ComputationBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().handle() != 0 || + CHECK(expression->handle().builder() != nullptr || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle().handle(); + VLOG(1) << "Fetched T" << expression->handle(); return expression; } @@ -48,20 +48,18 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK_EQ(expression->handle().handle(), 0); + CHECK_EQ(expression->handle().builder(), nullptr); return const_cast(expression); } -// Retrieves the ComputationDataHandle from an input Tensor to an Op. This -// computation was constructed by an Op that executed previously and -// created the output Tensor using CreateOutputTensorFromComputation -// or CreateConstantOutputTensor. -static const xla::ComputationDataHandle& GetComputationFromTensor( - const Tensor& tensor) { +// Retrieves the XlaOp from an input Tensor to an Op. This computation was +// constructed by an Op that executed previously and created the output Tensor +// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. +static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { return CastExpressionFromTensor(tensor)->handle(); } -const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) { +const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } @@ -106,7 +104,7 @@ Status XlaOpKernelContext::ConstantInputReshaped( return HostTensorToLiteral(temp, constant_literal); } - xla::ComputationDataHandle handle = expression->handle(); + xla::XlaOp handle = expression->handle(); if (new_shape != tensor.shape()) { // Reshape the handle to the desired shape. handle = builder()->Reshape(handle, new_shape.dim_sizes()); @@ -141,8 +139,17 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Ask the XLA compiler to evaluate the data handle to a literal. + xla::StatusOr constant_graph = + builder()->BuildConstantSubGraph(handle); + if (!constant_graph.ok()) { + return errors::Internal( + "Error getting a compile-time constant graph for ", + context_->op_kernel().name(), " input ", index, + ".\nError: ", constant_graph.status().error_message()); + } xla::StatusOr> computed = - builder()->ComputeConstant(handle, &layout); + compiler()->client()->ComputeConstant(constant_graph.ValueOrDie(), + &layout); if (!computed.ok()) { return errors::Internal("Error evaluating ", context_->op_kernel().name(), " input ", index, @@ -260,9 +267,9 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList( - StringPiece name, std::vector* handles, - std::vector* shapes) { +Status XlaOpKernelContext::InputList(StringPiece name, + std::vector* handles, + std::vector* shapes) { OpInputList inputs; TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); @@ -285,9 +292,9 @@ Status XlaOpKernelContext::ConstantInputList( return Status::OK(); } -Status XlaOpKernelContext::ReadVariableInput( - int index, DataType type, TensorShape* shape, - xla::ComputationDataHandle* value) { +Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, + TensorShape* shape, + xla::XlaOp* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -307,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput( } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = xla_context.VariableRepresentationShape( - variable->shape(), variable->type()); + TensorShape representation_shape = + xla_context.RepresentationShape(variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -334,8 +341,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -void XlaOpKernelContext::SetOutput(int index, - const xla::ComputationDataHandle& handle) { +void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; auto shape = builder()->GetShape(handle); @@ -349,7 +355,7 @@ void XlaOpKernelContext::SetOutput(int index, // corresponds. TensorShape tensor_shape; OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(*shape.ValueOrDie(), &tensor_shape)); + XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, context_->allocate_output(index, tensor_shape, &output)); @@ -364,8 +370,8 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { xla::Literal literal; OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal)); - xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal); - CHECK_NE(handle.handle(), 0); + xla::XlaOp handle = builder()->ConstantLiteral(literal); + CHECK_NE(handle.builder(), nullptr); // Make the Tensor that will refer to the expression. Tensor* output = nullptr; @@ -386,8 +392,7 @@ void XlaOpKernelContext::SetInvalidOutput(int index) { OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape({}), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::ComputationDataHandle handle; - handle.set_handle(0); + xla::XlaOp handle; expression->set_handle(handle); } @@ -410,8 +415,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { } Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, - xla::ComputationDataHandle handle) { - TF_RET_CHECK(handle.handle() != 0); + xla::XlaOp handle) { + TF_RET_CHECK(handle.builder() != nullptr); const XlaExpression* expression = CastExpressionFromTensor(context_->input(input_index)); @@ -425,13 +430,13 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, } TensorShape shape; TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape)); TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.VariableRepresentationShape(shape, type); + xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); } @@ -457,22 +462,22 @@ void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, context_->CtxFailureWithWarning(file, line, s); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMax( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { return XlaContext::Get(context_).GetOrCreateMax(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMin( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin( const DataType type) { return XlaContext::Get(context_).GetOrCreateMin(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd( const DataType type) { return XlaContext::Get(context_).GetOrCreateAdd(type); } -const xla::Computation* XlaOpKernelContext::GetOrCreateMul( +const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( const DataType type) { return XlaContext::Get(context_).GetOrCreateMul(type); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 4e4b97e0cec..667dc262ca0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/macros.h" @@ -58,8 +58,8 @@ class XlaOpKernelContext { public: explicit XlaOpKernelContext(OpKernelContext* context); - // Returns the XLA ComputationBuilder containing the output of compilation. - xla::ComputationBuilder* builder() const; + // Returns the XLA XlaBuilder containing the output of compilation. + xla::XlaBuilder* builder() const; // Inputs @@ -72,10 +72,10 @@ class XlaOpKernelContext { // Returns the shape of input 'index'. TensorShape InputShape(int index); - // Returns input 'index' as a ComputationDataHandle. Unlike + // Returns input 'index' as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::ComputationDataHandle& Input(int index); + const xla::XlaOp& Input(int index); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -85,8 +85,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, - std::vector* handles, + Status InputList(StringPiece name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -132,10 +131,10 @@ class XlaOpKernelContext { return context_->expected_output_dtype(index); } - // Sets output 'index' to the ComputationDataHandle 'handle'. + // Sets output 'index' to the XlaOp 'handle'. // All outputs should be set using SetOutput and SetConstantOutput, not // via the underlying OpKernelContext. - void SetOutput(int index, const xla::ComputationDataHandle& handle); + void SetOutput(int index, const xla::XlaOp& handle); // Sets output 'index' to compile-time constant 'host_tensor', where // 'host_tensor' is a tensor in host memory. It is preferable to use @@ -168,14 +167,13 @@ class XlaOpKernelContext { // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. Status ReadVariableInput(int index, DataType type, TensorShape* shape, - xla::ComputationDataHandle* value); + xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. - Status AssignVariable(int input_index, DataType type, - xla::ComputationDataHandle handle); + Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -205,22 +203,22 @@ class XlaOpKernelContext { // Gets an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMax(const DataType type); + const xla::XlaComputation* GetOrCreateMax(const DataType type); // Gets an XLA lambda to compute Min. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMin(const DataType type); + const xla::XlaComputation* GetOrCreateMin(const DataType type); // Gets an XLA lambda to compute Add. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateAdd(const DataType type); + const xla::XlaComputation* GetOrCreateAdd(const DataType type); // Gets an XLA lambda to compute Mul. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. - const xla::Computation* GetOrCreateMul(const DataType type); + const xla::XlaComputation* GetOrCreateMul(const DataType type); private: OpKernelContext* const context_; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index bbe808595d9..4692038b61f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -39,10 +39,10 @@ const char* const DEVICE_XLA_GPU = "XLA_GPU"; static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; node_def.set_name("_XlaLaunch-op"); - node_def.set_op("_XlaLaunch"); + node_def.set_op("XlaLaunch"); string kernel_class_name; TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, &kernel_class_name)); @@ -311,7 +311,7 @@ XlaOpRegistry& XlaOpRegistry::Instance() { XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { registration_.reset(new XlaOpRegistry::OpRegistration); - registration_->name = name.ToString(); + registration_->name = std::string(name); } XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { @@ -323,14 +323,14 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( gtl::ArraySlice devices) { registration_->has_device_whitelist = true; for (StringPiece device : devices) { - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { registration_->has_device_whitelist = true; - registration_->device_whitelist.insert(device.ToString()); + registration_->device_whitelist.insert(std::string(device)); return *this; } @@ -347,7 +347,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, DataType allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; types.insert(allowed); return *this; } @@ -355,7 +355,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( StringPiece attr_name, gtl::ArraySlice allowed) { std::set& types = - registration_->type_constraints[attr_name.ToString()]; + registration_->type_constraints[std::string(attr_name)]; for (DataType t : allowed) { types.insert(t); } @@ -364,7 +364,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( StringPiece input_name) { - registration_->compile_time_constant_inputs.insert(input_name.ToString()); + registration_->compile_time_constant_inputs.insert(std::string(input_name)); return *this; } @@ -394,7 +394,7 @@ XlaBackendRegistrar::XlaBackendRegistrar( StringPiece name, gtl::ArraySlice types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); - registry.RegisterBackend(name.ToString(), types, op_filter); + registry.RegisterBackend(std::string(name), types, op_filter); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index c2075b44b82..540c65c597f 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -26,8 +26,7 @@ limitations under the License. namespace tensorflow { XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, - TensorShape shape, - const xla::ComputationDataHandle& initial_value, + TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, const std::set& tensor_array_gradients) : kind_(kind), @@ -41,11 +40,10 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, CHECK(kind_ != kInvalid); for (const string& gradient : tensor_array_gradients) { - tensor_array_gradients_[gradient].reset( - new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), - type_, shape_, xla::ComputationDataHandle(), - tensor_array_size_, /*tensor_array_gradients=*/{})); + tensor_array_gradients_[gradient].reset(new XlaResource( + /*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, + xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); } } @@ -73,7 +71,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { return Status::OK(); } -Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { +Status XlaResource::SetValue(const xla::XlaOp& value) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -83,7 +81,7 @@ Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { return Status::OK(); } -Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { +Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -121,9 +119,9 @@ Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { return Status::OK(); } -Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::ComputationBuilder* builder, - XlaResource** gradient_out) { +Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, + xla::XlaBuilder* builder, + XlaResource** gradient_out) { VLOG(2) << "Gradient lookup for resource: " << name_ << " gradient: " << source; TF_RET_CHECK(kind_ == kTensorArray); @@ -132,7 +130,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient( TensorShape ta_shape; ta_shape.AddDim(tensor_array_size_); ta_shape.AppendShape(shape_); - xla::ComputationDataHandle gradient_value = builder->Broadcast( + xla::XlaOp gradient_value = builder->Broadcast( XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, @@ -144,13 +142,12 @@ Status XlaResource::GetOrCreateTensorArrayGradient( return Status::OK(); } -Status XlaResource::Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const { +Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { if (tensor_array_gradients_.empty()) { *pack = value_; } else { TF_RET_CHECK(kind_ == kTensorArray); - std::vector elems; + std::vector elems; elems.push_back(value_); for (const auto& gradient : tensor_array_gradients_) { elems.push_back(gradient.second->value_); @@ -161,8 +158,8 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack, } Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder) { + const xla::XlaOp& pack, + xla::XlaBuilder* builder) { if (gradient_sources.empty()) { if (!initialized()) { initial_value_ = pack; diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 1bb2c7274ec..9ce36d1aa76 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -37,8 +37,7 @@ class XlaResource { }; XlaResource(Kind kind, int arg_num, string name, DataType type, - TensorShape shape, - const xla::ComputationDataHandle& initial_value, + TensorShape shape, const xla::XlaOp& initial_value, int64 tensor_array_size, const std::set& tensor_array_gradients); @@ -69,16 +68,14 @@ class XlaResource { // this is the shape of each entry in the TensorArray/Stack. const TensorShape& shape() const { return shape_; } - const xla::ComputationDataHandle& value() const { return value_; } + const xla::XlaOp& value() const { return value_; } // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. - const xla::ComputationDataHandle& initial_value() const { - return initial_value_; - } + const xla::XlaOp& initial_value() const { return initial_value_; } // A variable is initialized if it has a value. - bool initialized() const { return value_.handle() > 0; } + bool initialized() const { return value_.builder() != nullptr; } // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. @@ -86,17 +83,17 @@ class XlaResource { // Sets the current value of the resource. Returns an error if the type is not // set to a valid value. - Status SetValue(const xla::ComputationDataHandle& value); + Status SetValue(const xla::XlaOp& value); // Sets the current value of the resource to an all-zero value. - Status SetZeroValue(xla::ComputationBuilder* builder); + Status SetZeroValue(xla::XlaBuilder* builder); // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator // documentation for TensorArrayGradV3 for details. Status GetOrCreateTensorArrayGradient(const string& source, - xla::ComputationBuilder* builder, + xla::XlaBuilder* builder, XlaResource** gradient_out); // Packs a resource into a single XLA value `pack`, suitable for use as @@ -104,8 +101,7 @@ class XlaResource { // gradients, sets `*pack` to `value`. // For TensorArrays with gradients, packs the value and its gradient values in // a tuple; the gradients values are packed in order by source name. - Status Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const; + Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and @@ -114,8 +110,7 @@ class XlaResource { // values. // Opposite of Pack(). Status SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder); + const xla::XlaOp& pack, xla::XlaBuilder* builder); // TensorArray and Stack specific fields @@ -144,8 +139,8 @@ class XlaResource { DataType type_; TensorShape shape_; - xla::ComputationDataHandle value_; - xla::ComputationDataHandle initial_value_; + xla::XlaOp value_; + xla::XlaOp initial_value_; int64 tensor_array_size_ = -1; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 751777222fc..fb1991e9ec2 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -99,8 +99,9 @@ cc_library( hdrs = ["service_interface.h"], visibility = [":friends"], deps = [ + ":status", + ":xla_data_proto", ":xla_proto", - "//tensorflow/core:lib", ], ) @@ -244,6 +245,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protobuf_util", + ":status", ":status_macros", ":statusor", ":types", @@ -302,13 +304,13 @@ cc_library( ":array2d", ":array3d", ":array4d", - ":shape_tree", ":shape_util", ":sparse_index_array", ":status_macros", ":types", ":util", ":xla_data_proto", + "//tensorflow/core:framework", "//tensorflow/core:lib", ], ) @@ -323,12 +325,30 @@ tf_cc_test( ":shape_util", ":test", ":types", + "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], ) +cc_library( + name = "error_spec", + hdrs = ["error_spec.h"], +) + +cc_library( + name = "literal_comparison", + srcs = ["literal_comparison.cc"], + hdrs = ["literal_comparison.h"], + deps = [ + ":error_spec", + ":literal_util", + ":util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "metric_table_report", srcs = ["metric_table_report.cc"], @@ -443,6 +463,9 @@ cc_library( srcs = ["executable_run_options.cc"], hdrs = ["executable_run_options.h"], visibility = ["//visibility:public"], + deps = [ + ":types", + ], ) cc_library( @@ -602,8 +625,8 @@ cc_library( ":util", ":window_util", ":xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md index c93c39e1806..39f8caaa961 100644 --- a/tensorflow/compiler/xla/README.md +++ b/tensorflow/compiler/xla/README.md @@ -1 +1,7 @@ -This is the home of XLA. +

+ +

+ +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear +algebra that optimizes TensorFlow computations. See the +[documentation](https://www.tensorflow.org/performance/xla/) for more details. diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index a299c2afd45..aacb394ae5f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -63,7 +63,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":computation", ":global_data", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", @@ -76,7 +75,7 @@ cc_library( "//tensorflow/compiler/xla:xla_proto", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -99,13 +98,13 @@ cc_library( hdrs = ["local_client.h"], deps = [ ":client", - ":computation", ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", @@ -125,11 +124,11 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":computation", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/core:stream_executor_no_cuda", @@ -160,47 +159,6 @@ cc_library( ], ) -cc_library( - name = "computation", - srcs = ["computation.cc"], - hdrs = ["computation.h"], - deps = [ - "//tensorflow/compiler/xla:service_interface", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/service:session_proto", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "computation_builder", - srcs = ["computation_builder.cc"], - hdrs = ["computation_builder.h"], - deps = [ - ":client", - ":computation", - ":global_data", - ":padding", - "//tensorflow/compiler/xla:array", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "sharding_builder", srcs = ["sharding_builder.cc"], diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index f0f94298a05..c9d275a77b5 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -161,22 +161,6 @@ Status Client::ResetDevice() { return Status::OK(); } -StatusOr> Client::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr data, - Execute(computation, arguments, execution_options, execution_profile)); - - const Shape* shape_with_output_layout = nullptr; - if (execution_options && execution_options->has_shape_with_output_layout()) { - shape_with_output_layout = &execution_options->shape_with_output_layout(); - } - return Transfer(*data, shape_with_output_layout); -} - StatusOr> Client::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -221,58 +205,9 @@ StatusOr> Client::ComputeConstant( return Literal::CreateFromProto(response.literal()); } -StatusOr Client::LoadSnapshot(const SessionModule& module) { - LoadComputationSnapshotRequest request; - *request.mutable_module() = module; - LoadComputationSnapshotResponse response; - - Status s = stub_->LoadComputationSnapshot(&request, &response); - if (!s.ok()) { - return s; - } - - VLOG(1) << "load snapshot response: " << response.ShortDebugString(); - return Computation(stub_, response.computation()); -} - -StatusOr> Client::Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options, - ExecutionProfile* execution_profile) { - ExecuteRequest request; - *request.mutable_computation() = computation.handle(); - - if (execution_options == nullptr) { - *request.mutable_execution_options() = CreateDefaultExecutionOptions(); - } else { - *request.mutable_execution_options() = *execution_options; - } - for (GlobalData* argument : arguments) { - CHECK(argument != nullptr) << "Argument pointers must not be null."; - *request.add_arguments() = argument->handle(); - } - - ExecuteResponse response; - VLOG(1) << "making execute request: " << request.ShortDebugString(); - Status s = stub_->Execute(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - if (execution_profile != nullptr) { - *execution_profile = response.profile(); - if (VLOG_IS_ON(1)) { - TF_ASSIGN_OR_RETURN( - auto execution_stats, - ExecutionStatsAsString(computation, response.profile())); - VLOG(1) << execution_stats; - } - } - - return MakeUnique(stub_, response.output()); +StatusOr Client::LoadSnapshot(const HloSnapshot& module) { + TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module()); + return XlaComputation(module.hlo().hlo_module()); } StatusOr> Client::Execute( @@ -315,41 +250,6 @@ StatusOr> Client::Execute( return MakeUnique(stub_, response.output()); } -StatusOr>> Client::ExecuteParallel( - tensorflow::gtl::ArraySlice computations) { - ExecuteParallelRequest request; - - for (const ComputationInstance& computation : computations) { - ExecuteRequest single_request; - *single_request.mutable_computation() = computation.computation.handle(); - for (GlobalData* argument : computation.arguments) { - *single_request.add_arguments() = argument->handle(); - } - *single_request.mutable_execution_options() = computation.execution_options; - *request.add_requests() = single_request; - } - - ExecuteParallelResponse response; - VLOG(1) << "making execute-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteParallel(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - std::vector> outputs; - for (size_t i = 0; i < computations.size(); ++i) { - outputs.push_back( - MakeUnique(stub_, response.responses(i).output())); - if (computations[i].execution_profile != nullptr) { - *computations[i].execution_profile = response.responses(i).profile(); - } - } - - return std::move(outputs); -} - StatusOr>> Client::ExecuteParallel( tensorflow::gtl::ArraySlice computations) { ExecuteGraphParallelRequest request; @@ -367,7 +267,7 @@ StatusOr>> Client::ExecuteParallel( ExecuteParallelResponse response; VLOG(1) << "making execute-graph-parallel request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->ExecuteGraphParallel(&request, &response); + Status s = stub_->ExecuteGraphParallel(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -396,7 +296,7 @@ StatusOr> Client::GetDeviceHandles( GetDeviceHandlesResponse response; VLOG(1) << "making get device request: " << request.ShortDebugString(); - tensorflow::Status s = stub_->GetDeviceHandles(&request, &response); + Status s = stub_->GetDeviceHandles(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { @@ -444,24 +344,6 @@ StatusOr>> Client::DeconstructTuple( return std::move(handles); } -StatusOr Client::GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const { - ComputationStatsRequest request; - *request.mutable_computation() = computation.handle(); - *request.mutable_debug_options() = debug_options; - ComputationStatsResponse response; - - VLOG(1) << "making computation stats request"; - Status s = stub_->GetComputationStats(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - CHECK(response.has_stats()); - return response.stats(); -} - StatusOr Client::GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const { @@ -483,23 +365,6 @@ StatusOr Client::GetComputationStats( return response.stats(); } -StatusOr> Client::GetComputationShape( - const Computation& computation) { - GetComputationShapeRequest request; - *request.mutable_computation() = computation.handle(); - GetComputationShapeResponse response; - - VLOG(1) << "making get-computation-shape request"; - Status s = stub_->GetComputationShape(&request, &response); - VLOG(1) << "done with request"; - - if (!s.ok()) { - return s; - } - - return WrapUnique(response.release_program_shape()); -} - StatusOr> Client::GetComputationShape( const XlaComputation& computation) { TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape()); @@ -522,28 +387,6 @@ StatusOr Client::GetShape(const GlobalData& data) { return response.shape(); } -StatusOr Client::ExecutionStatsAsString( - const Computation& computation, const ExecutionProfile& profile) { - TF_ASSIGN_OR_RETURN( - auto computation_stats, - GetComputationStats(computation, - legacy_flags::GetDebugOptionsFromFlags())); - int64 total_flops = - computation_stats.flop_count() + computation_stats.transcendental_count(); - if (profile.compute_time_ns() > 0) { - int64 nanoseconds = profile.compute_time_ns(); - int64 cycle_count = profile.compute_cycle_count(); - double gflops = total_flops / nanoseconds; - return tensorflow::strings::StrCat( - "[Execution Statistics] flop count: ", computation_stats.flop_count(), - ", transcendental count: ", computation_stats.transcendental_count(), - ", compute execution time: ", nanoseconds, " nsec", - ", compute cycles: ", cycle_count, ", performance: ", gflops, - "gflop/s"); - } - return string("[Execution Statistics] not available."); -} - StatusOr Client::ExecutionStatsAsString( const XlaComputation& computation, const ExecutionProfile& profile) { TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index 14c685d94ea..d57e2536d0b 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -19,11 +19,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -52,21 +51,6 @@ class Client { // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. - StatusOr> Execute( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and returns the global - // data that was produced from the execution. - // * If execution_options is not nullptr, these options are passed to the - // service to affect how it compiles our computation. (The pointer does not - // need to live beyond this call.) - // * If execution_profile is not nullptr then the pointed-to ExecutionProfile - // will be filled with profile data from the execution. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Execute( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -78,34 +62,6 @@ class Client { // executed on the devices associated with the handles by partitioning the // computation based on the attached sharding attributes. Otherwise, a // device is chosen by the service. - struct ComputationInstance { - const Computation& computation; - std::vector arguments; - ExecutionOptions execution_options; - ExecutionProfile* execution_profile; - - ComputationInstance(const Computation& computation, - std::vector arguments, - ExecutionOptions execution_options, - ExecutionProfile* execution_profile) - : computation(computation), - arguments(std::move(arguments)), - execution_options(execution_options), - execution_profile(execution_profile) {} - }; - - // Executes a list ComputationInstances and returns global data produced from - // each computation. - StatusOr>> ExecuteParallel( - tensorflow::gtl::ArraySlice computations); - - // A struct to represent a computation instance to be executed. - // * If execution_options.device_handles is not empty, the computation is - // executed on the devices associated with the handles by partitioning the - // computation based on the attached sharding attributes. Otherwise, a - // device is chosen by the service. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct XlaComputationInstance { const XlaComputation& computation; std::vector arguments; @@ -125,7 +81,6 @@ class Client { // Executes a list XlaComputationInstances and returns global data produced // from each computation. // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> ExecuteParallel( tensorflow::gtl::ArraySlice computations); @@ -177,17 +132,6 @@ class Client { // Executes the computation with the given arguments and transfers the result // to the client as a literal. Parameters are defined the same as for // Execute() and Transfer(). - StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions* execution_options = nullptr, - ExecutionProfile* execution_profile = nullptr); - - // Executes the computation with the given arguments and transfers the result - // to the client as a literal. Parameters are defined the same as for - // Execute() and Transfer(). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -223,12 +167,6 @@ class Client { const GlobalData& data); // Retrieves the statistics of the given computation. - StatusOr GetComputationStats( - const Computation& computation, const DebugOptions& debug_options) const; - - // Retrieves the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr GetComputationStats( const XlaComputation& computation, const DebugOptions& debug_options) const; @@ -239,13 +177,6 @@ class Client { // As above, but returns the shape of the provided computation (parameter // types/names and return type). - StatusOr> GetComputationShape( - const Computation& computation); - - // As above, but returns the shape of the provided computation (parameter - // types/names and return type). - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> GetComputationShape( const XlaComputation& computation); @@ -253,15 +184,13 @@ class Client { // two computations via a pair of Send and Recv instructions. StatusOr CreateChannelHandle(); - StatusOr LoadSnapshot(const SessionModule& module); + StatusOr LoadSnapshot(const HloSnapshot& module); ServiceInterface* stub() { return stub_; } private: // Returns the execution statistics (e.g., gflop/s) as a string from the // ExecutionProfile returned from an execution of the computation. - StatusOr ExecutionStatsAsString(const Computation& computation, - const ExecutionProfile& profile); StatusOr ExecutionStatsAsString(const XlaComputation& computation, const ExecutionProfile& profile); diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc index b1663bc8157..803a9e40094 100644 --- a/tensorflow/compiler/xla/client/client_library.cc +++ b/tensorflow/compiler/xla/client/client_library.cc @@ -23,22 +23,19 @@ limitations under the License. namespace xla { -LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform, +LocalClientOptions::LocalClientOptions(se::Platform* platform, int number_of_replicas, int intra_op_parallelism_threads) : platform_(platform), number_of_replicas_(number_of_replicas), intra_op_parallelism_threads_(intra_op_parallelism_threads) {} -LocalClientOptions& LocalClientOptions::set_platform( - perftools::gputools::Platform* platform) { +LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) { platform_ = platform; return *this; } -perftools::gputools::Platform* LocalClientOptions::platform() const { - return platform_; -} +se::Platform* LocalClientOptions::platform() const { return platform_; } LocalClientOptions& LocalClientOptions::set_number_of_replicas( int number_of_replicas) { @@ -69,7 +66,7 @@ ClientLibrary::ClientLibrary() = default; ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( - perftools::gputools::Platform* platform) { + se::Platform* platform) { LocalClientOptions default_options; default_options.set_platform(platform); return GetOrCreateLocalClient(default_options); @@ -77,7 +74,7 @@ ClientLibrary::~ClientLibrary() = default; /* static */ StatusOr ClientLibrary::GetOrCreateLocalClient( const LocalClientOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); int replica_count = options.number_of_replicas(); ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); @@ -115,7 +112,7 @@ ClientLibrary::~ClientLibrary() = default; } /* static */ LocalService* ClientLibrary::GetXlaService( - perftools::gputools::Platform* platform) { + se::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); auto it = client_library.local_instances_.find(platform->id()); @@ -124,8 +121,7 @@ ClientLibrary::~ClientLibrary() = default; } /* static */ StatusOr -ClientLibrary::GetOrCreateCompileOnlyClient( - perftools::gputools::Platform* platform) { +ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) { ClientLibrary& client_library = Singleton(); tensorflow::mutex_lock lock(client_library.service_mutex_); diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h index a6f30d82e43..3ad558fa532 100644 --- a/tensorflow/compiler/xla/client/client_library.h +++ b/tensorflow/compiler/xla/client/client_library.h @@ -43,13 +43,13 @@ namespace xla { // Options to configure the local client when it is created. class LocalClientOptions { public: - LocalClientOptions(perftools::gputools::Platform* platform = nullptr, + LocalClientOptions(se::Platform* platform = nullptr, int number_of_replicas = 1, int intra_op_parallelism_threads = -1); // Set the platform backing the service, or nullptr for the default platform. - LocalClientOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; + LocalClientOptions& set_platform(se::Platform* platform); + se::Platform* platform() const; // Set the number of replicas to use when compiling replicated // programs. @@ -61,7 +61,7 @@ class LocalClientOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_; + se::Platform* platform_; int number_of_replicas_; int intra_op_parallelism_threads_; }; @@ -74,7 +74,7 @@ class ClientLibrary { // platform : The platform the underlying XLA service should target. If // null then default platform is used. static StatusOr GetOrCreateLocalClient( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); static StatusOr GetOrCreateLocalClient( const LocalClientOptions& options); @@ -84,14 +84,14 @@ class ClientLibrary { // Returns the service from the service thread. Only used in unit tests to // access user computations from client. - static LocalService* GetXlaService(perftools::gputools::Platform* platform); + static LocalService* GetXlaService(se::Platform* platform); // Singleton constructor-or-accessor for compile-only clients. Arguments: // // platform : The platform the underlying XLA service should target. If // null then default platform is used. static StatusOr GetOrCreateCompileOnlyClient( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); // Clears the local instance and compile only instance caches. The client // pointers returned by the previous GetOrCreateLocalClient() or @@ -120,12 +120,10 @@ class ClientLibrary { }; tensorflow::mutex service_mutex_; // Guards the singleton creation state. - std::unordered_map> + std::unordered_map> local_instances_ GUARDED_BY(service_mutex_); - std::unordered_map> + std::unordered_map> compile_only_instances_ GUARDED_BY(service_mutex_); TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary); diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index 59662c95ac1..dc69d2097eb 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -23,16 +23,16 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options) { - std::vector service_instances; + std::vector service_instances; service_instances.reserve(computations.size()); - for (const AotComputationInstance& instance : computations) { - service_instances.push_back({}); - CompileOnlyService::AotComputationInstance& service_instance = + for (const AotXlaComputationInstance& instance : computations) { + service_instances.emplace_back(); + CompileOnlyService::AotXlaComputationInstance& service_instance = service_instances.back(); TF_RET_CHECK(instance.computation != nullptr); - service_instance.computation = instance.computation->handle(); + service_instance.computation = instance.computation->proto(); service_instance.argument_layouts = instance.argument_layouts; service_instance.result_layout = instance.result_layout; } diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index 59000487113..f9a7c31270c 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_COMPILE_ONLY_CLIENT_H_ #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/compile_only_service.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/statusor.h" @@ -37,21 +37,21 @@ class CompileOnlyClient : public Client { CompileOnlyClient(const CompileOnlyClient&) = delete; void operator=(const CompileOnlyClient&) = delete; - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - const Computation* computation; + // A description of an xla computation to compile using CompileAheadOfTime. + struct AotXlaComputationInstance { + const XlaComputation* computation; // Inform the compiler of the expected layout for arguments. std::vector argument_layouts; // Specifies the expected result layout. const Shape* result_layout; }; - // Compiles a list of computations for ahead-of-time execution. This is + // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. The |options| parameter describes // the target for which the compiler should emit code. StatusOr>> CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, + const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); // Returns the size of a pointer in bytes for a given triple. diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc deleted file mode 100644 index e6c57bda0f0..00000000000 --- a/tensorflow/compiler/xla/client/computation.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation.h" - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { - -Computation::Computation() : parent_(nullptr) {} - -Computation::Computation(ServiceInterface* parent, - const ComputationHandle& handle) - : handle_(handle), parent_(parent) {} - -Computation::Computation(Computation&& computation) - : handle_(std::move(computation.handle_)), parent_(computation.parent_) { - computation.ResetWithoutFreeing(); -} - -void Computation::Reset() { - // TODO(b/34469253) deallocate any owned computation. - ResetWithoutFreeing(); -} - -StatusOr> Computation::Snapshot() const { - SnapshotComputationRequest request; - *request.mutable_computation() = handle_; - SnapshotComputationResponse response; - - TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response)); - - return WrapUnique(response.release_module()); -} - -Computation::~Computation() { Reset(); } - -Computation& Computation::operator=(Computation&& computation) { - if (&computation != this) { - Reset(); - handle_ = computation.handle_; - parent_ = computation.parent_; - computation.ResetWithoutFreeing(); - } - return *this; -} - -void Computation::ResetWithoutFreeing() { - handle_.Clear(); - parent_ = nullptr; -} - -StatusOr Computation::GetProgramShape() const { - GetComputationShapeRequest request; - *request.mutable_computation() = handle_; - GetComputationShapeResponse response; - - TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); - - return std::move(*response.mutable_program_shape()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h deleted file mode 100644 index a53fc9e9cf3..00000000000 --- a/tensorflow/compiler/xla/client/computation.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ - -#include - -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service_interface.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" - -namespace xla { - -// Wraps a ComputationHandle protobuf with a lifetime. Computation is -// movable and not copyable to capture the same kind of unique -// ownership that std::unique_ptr represents. -class Computation { - public: - // Creates a null Computation. - Computation(); - - // parent: stub for the service on which we will deallocate the computation - // when it is no longer needed. - // handle: the computation handle protobuf from the service. - Computation(ServiceInterface* parent, const ComputationHandle& handle); - - Computation(Computation&& computation); - - // Deallocates the computation. - ~Computation(); - - Computation& operator=(Computation&& computation); - - // Returns the underlying handle. - const ComputationHandle& handle() const { return handle_; } - - // Sets handle to a null state and clears any owned computation. - void Reset(); - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return parent_ == nullptr; } - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - StatusOr GetProgramShape() const; - - private: - void ResetWithoutFreeing(); - - ComputationHandle handle_; // Handle that is wrapped by this class. - - // Stub that the handle is deallocated on when this object's lifetime ends. - ServiceInterface* parent_; - - TF_DISALLOW_COPY_AND_ASSIGN(Computation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc deleted file mode 100644 index 4d3b0ee0d6e..00000000000 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ /dev/null @@ -1,1569 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/client/computation_builder.h" - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { - -ComputationBuilder::ComputationBuilder(Client* client, - const string& computation_name) - : name_(computation_name), client_(client) {} - -ComputationBuilder::~ComputationBuilder() {} - -void ComputationBuilder::NoteError(const Status& error) { - if (die_immediately_on_error_) { - LOG(FATAL) << "error building computation: " << error; - } - - if (first_error_.ok()) { - first_error_ = error; - first_error_backtrace_.CreateCurrent(/*skip_count=*/1); - } -} - -std::unique_ptr ComputationBuilder::CreateSubBuilder( - const string& computation_name) { - auto sub_builder = MakeUnique(client_, computation_name); - sub_builder->parent_builder_ = this; - sub_builder->die_immediately_on_error_ = die_immediately_on_error_; - return sub_builder; -} - -Status ComputationBuilder::PrepareComputation() { - TF_RETURN_IF_ERROR(first_error_); - - if (!computation_.IsNull()) { - return Status::OK(); - } - - ComputationRequest request; - request.set_name(name_); - ComputationResponse response; - - VLOG(2) << "making computation request"; - Status s = client_->stub()->Computation(&request, &response); - VLOG(2) << "done with computation request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - computation_ = Computation(client_->stub(), response.computation()); - return Status::OK(); -} - -Status ComputationBuilder::RunOp(OpRequest* op_request, - OpResponse* op_response) { - TF_RETURN_IF_ERROR(first_error_); - TF_RETURN_IF_ERROR(PrepareComputation()); - - // Fill in fields that are set on every OpRequest. - *op_request->mutable_computation() = computation_.handle(); - *op_request->mutable_metadata() = metadata_; - if (sharding_) { - *op_request->mutable_sharding() = *sharding_; - } - - const string& op_name = - OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); - VLOG(2) << "running op request: " << op_name; - Status status = client_->stub()->Op(op_request, op_response); - VLOG(2) << "done with op request: " << op_name; - return status; -} - -void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - } -} - -ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( - OpRequest* op_request) { - OpResponse op_response; - Status status = RunOp(op_request, &op_response); - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - if (op_response.output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return op_response.output(); -} - -bool ComputationBuilder::MakeWindow( - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, Window* window) { - const auto verify_size = [&](const size_t x, const char* x_name) { - if (x == 0 || x == window_dimensions.size()) { - return true; - } else { - NoteError(InvalidArgument( - "%s", tensorflow::strings::StrCat( - "Window has different number of window dimensions than of ", - x_name, "\nNumber of window dimensions: ", - window_dimensions.size(), "\nNumber of ", x_name, ": ", x, - "\n") - .c_str())); // - return false; - } - }; - if (!verify_size(window_strides.size(), "window strides") || - !verify_size(padding.size(), "padding entries") || - !verify_size(lhs_dilation.size(), "lhs dilation factors") || - !verify_size(rhs_dilation.size(), "rhs dilation factors")) { - return false; - } - - window->Clear(); - for (size_t i = 0; i < window_dimensions.size(); i++) { - auto dim = window->add_dimensions(); - dim->set_size(window_dimensions[i]); - if (!window_strides.empty()) { - dim->set_stride(window_strides[i]); - } else { - dim->set_stride(1); - } - if (!padding.empty()) { - dim->set_padding_low(padding[i].first); - dim->set_padding_high(padding[i].second); - } else { - dim->set_padding_low(0); - dim->set_padding_high(0); - } - if (!lhs_dilation.empty()) { - dim->set_base_dilation(lhs_dilation[i]); - } else { - dim->set_base_dilation(1); - } - if (!rhs_dilation.empty()) { - dim->set_window_dilation(rhs_dilation[i]); - } else { - dim->set_window_dilation(1); - } - dim->set_window_reversal(false); - } - return true; -} - -ComputationDataHandle ComputationBuilder::ConstantLiteral( - const Literal& literal) { - OpRequest op_request; - ConstantRequest* request = op_request.mutable_constant_request(); - *request->mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request->literal().ShortDebugString(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { - OpRequest op_request; - ParameterRequest* request = op_request.mutable_parameter_request(); - *request->mutable_shape() = shape; - request->set_parameter(parameter_number); - request->set_name(name); - return RunOpAndParseResponse(&op_request); -} - -StatusOr> ComputationBuilder::GetShapeWithoutNoteError( - const ComputationDataHandle& operand) { - GetLocalShapeRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - GetLocalShapeResponse response; - - VLOG(2) << "making get-shape request"; - TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); - VLOG(2) << "done with request"; - - TF_RET_CHECK(response.has_shape()); - std::unique_ptr shape = WrapUnique(response.release_shape()); - TF_RET_CHECK(shape != nullptr); - return std::move(shape); -} - -StatusOr> ComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - auto status_or_shape = GetShapeWithoutNoteError(operand); - if (!status_or_shape.ok()) { - NoteError(status_or_shape.status()); - return first_error_; - } - return status_or_shape; -} - -StatusOr ComputationBuilder::GetProgramShape() { - TF_RETURN_IF_ERROR(first_error_); - - GetComputationShapeRequest request; - *request.mutable_computation() = computation_.handle(); - GetComputationShapeResponse response; - - VLOG(2) << "making get-program-shape-request"; - Status status = client_->stub()->GetComputationShape(&request, &response); - VLOG(2) << "done with get-program-shape-request"; - - if (!status.ok()) { - first_error_ = status; - return status; - } - - TF_RET_CHECK(response.has_program_shape()); - return std::move(*response.mutable_program_shape()); -} - -ComputationDataHandle ComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides) { - OpRequest op_request; - SliceRequest* request = op_request.mutable_slice_request(); - *request->mutable_operand() = operand; - for (int64 index : start_indices) { - request->add_start_indices(index); - } - for (int64 index : limit_indices) { - request->add_limit_indices(index); - } - for (int64 index : strides) { - request->add_strides(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - NoteError(shape_status.status()); - return ComputationDataHandle{}; - } - const Shape& shape = *shape_status.ValueOrDie(); - std::vector starts(ShapeUtil::Rank(shape), 0); - std::vector limits(shape.dimensions().begin(), - shape.dimensions().end()); - std::vector strides(ShapeUtil::Rank(shape), 1); - starts[dimno] = start_index; - limits[dimno] = limit_index; - strides[dimno] = stride; - return Slice(operand, starts, limits, strides); -} - -ComputationDataHandle ComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes) { - OpRequest op_request; - DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_start_indices() = start_indices; - for (int64 index : slice_sizes) { - request->add_slice_sizes(index); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - OpRequest op_request; - DynamicUpdateSliceRequest* request = - op_request.mutable_dynamic_update_slice_request(); - *request->mutable_operand() = operand; - *request->mutable_update() = update; - *request->mutable_start_indices() = start_indices; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - OpRequest op_request; - ConcatenateRequest* request = op_request.mutable_concatenate_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - request->set_dimension(dimension); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes) { - OpRequest op_request; - BroadcastRequest* request = op_request.mutable_broadcast_request(); - *request->mutable_operand() = operand; - for (int64 size : broadcast_sizes) { - request->add_broadcast_sizes(size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - OpRequest op_request; - PadRequest* request = op_request.mutable_pad_request(); - *request->mutable_operand() = operand; - *request->mutable_padding_value() = padding_value; - *request->mutable_padding_config() = padding_config; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes) { - OpRequest op_request; - ReshapeRequest* request = op_request.mutable_reshape_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (int64 new_size : new_sizes) { - request->add_new_sizes(new_size); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - std::vector dimensions(shape.ValueOrDie()->dimensions().size()); - std::iota(dimensions.begin(), dimensions.end(), 0); - return Reshape(operand, dimensions, new_sizes); -} - -ComputationDataHandle ComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - // Don't support out-of-order collapse here. - // Checks that the collapsed dimensions are in order and consecutive. - for (tensorflow::gtl::ArraySlice::size_type i = 1; - i < dimensions.size(); ++i) { - if (dimensions[i] - 1 != dimensions[i - 1]) { - NoteError(InvalidArgument( - "Collapsed dimensions are not in order and consecutive.")); - return ComputationDataHandle(); - } - } - - // Create a new sizes vector from the old shape, replacing the collapsed - // dimensions by the product of their sizes. - StatusOr> shape_or_status = GetShape(operand); - if (!shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); - - VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); - VLOG(3) << "dims to collapse: " - << tensorflow::str_util::Join(dimensions, ","); - - if (dimensions.size() <= 1) { - // Not collapsing anything, trivially we can return the operand versus - // enqueueing a trivial reshape. - return operand; - } - - std::vector new_sizes; - for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { - if (i <= dimensions.front() || i > dimensions.back()) { - new_sizes.push_back(original_shape->dimensions(i)); - } else { - new_sizes.back() *= original_shape->dimensions(i); - } - } - - VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") - << "]"; - - return Reshape(operand, new_sizes); -} - -void ComputationBuilder::Trace(const string& tag, - const ComputationDataHandle& operand) { - OpRequest op_request; - TraceRequest* request = op_request.mutable_trace_request(); - request->set_tag(tag); - *request->mutable_operand() = operand; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Select( - const ComputationDataHandle& pred, const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false) { - return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); -} - -ComputationDataHandle ComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - OpRequest op_request; - VariadicOpRequest* request = op_request.mutable_variadic_op_request(); - request->set_varop(VAROP_TUPLE); - for (const ComputationDataHandle& operand : elements) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - OpRequest op_request; - GetTupleElementRequest* request = - op_request.mutable_get_tuple_element_request(); - *request->mutable_operand() = tuple_data; - request->set_index(index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - - DotDimensionNumbers dimension_numbers; - dimension_numbers.add_lhs_contracting_dimensions( - lhs_shape->dimensions_size() == 1 ? 0 : 1); - dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers) { - OpRequest op_request; - DotRequest* request = op_request.mutable_dot_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conv( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - return ConvWithGeneralDimensions( - lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - return ConvGeneral(lhs, rhs, window_strides, padding, - CreateDefaultConvDimensionNumbers(window_strides.size())); -} - -bool ComputationBuilder::VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { - NoteError( - InvalidArgument("Convolution arguments must have same number of " - "dimensions. Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 2) { - NoteError(InvalidArgument( - "Convolution expects argument arrays with >= 3 dimensions. " - "Got: %s and %s", - ShapeUtil::HumanString(lhs_shape).c_str(), - ShapeUtil::HumanString(rhs_shape).c_str())); - return false; - } - int num_spatial_dims = num_dims - 2; - - const auto check_spatial_dimensions = - [&](const char* const field_name, - const tensorflow::protobuf::RepeatedField& - numbers) { - if (numbers.size() != num_spatial_dims) { - NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", - num_spatial_dims, field_name, - numbers.size())); - return false; - } - for (int i = 0; i < numbers.size(); ++i) { - if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { - NoteError( - InvalidArgument("Convolution %s[%d] is out of bounds: %lld", - field_name, i, numbers.Get(i))); - return false; - } - } - return true; - }; - return check_spatial_dimensions( - "input_spatial_dimensions", - dimension_numbers.input_spatial_dimensions()) && - check_spatial_dimensions( - "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()) && - check_spatial_dimensions( - "output_spatial_dimensions", - dimension_numbers.output_spatial_dimensions()); -} - -ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - NoteError(InternalError("failed to verify convolution")); - return ComputationDataHandle(); - } - - std::vector base_area_dimensions( - dimension_numbers.input_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < base_area_dimensions.size(); - ++i) { - base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - return ConvGeneral(lhs, rhs, window_strides, - MakePadding(base_area_dimensions, window_dimensions, - window_strides, padding), - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers) { - return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, - dimension_numbers); -} - -ComputationDataHandle ComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> lhs_shape_or_status = GetShape(lhs); - if (!lhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - StatusOr> rhs_shape_or_status = GetShape(rhs); - if (!rhs_shape_or_status.ok()) { - return ComputationDataHandle(); - } - - std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); - std::unique_ptr rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); - if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { - // Error is recorded in VerifyConvolution. - return ComputationDataHandle(); - } - - std::vector window_dimensions( - dimension_numbers.kernel_spatial_dimensions_size()); - for (std::vector::size_type i = 0; i < window_dimensions.size(); ++i) { - window_dimensions[i] = - rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); - } - - OpRequest op_request; - ConvolveRequest* request = op_request.mutable_convolve_request(); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_dimension_numbers() = dimension_numbers; - - if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request->mutable_window())) { - // Error is recorded in MakeWindow. - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Fft( - const ComputationDataHandle& operand, const FftType fft_type, - const tensorflow::gtl::ArraySlice fft_length) { - OpRequest op_request; - FftRequest* request = op_request.mutable_fft_request(); - *request->mutable_operand() = operand; - request->set_fft_type(fft_type); - for (int64 dim_len : fft_length) { - request->add_fft_length(dim_len); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - OpRequest op_request; - InfeedRequest* request = op_request.mutable_infeed_request(); - *request->mutable_shape() = shape; - *request->mutable_config() = config; - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, - const string& outfeed_config) { - OpRequest op_request; - OutfeedRequest* request = op_request.mutable_outfeed_request(); - request->set_outfeed_config(outfeed_config); - *request->mutable_operand() = operand; - *request->mutable_shape() = shape_with_layout; - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands) { - OpRequest op_request; - CallRequest* request = op_request.mutable_call_request(); - *request->mutable_to_apply() = computation.handle(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape) { - OpRequest op_request; - CustomCallRequest* request = op_request.mutable_custom_call_request(); - request->set_call_target_name(call_target_name); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - OpRequest op_request; - HostComputeRequest* request = op_request.mutable_host_compute_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_shape() = shape; - request->set_channel_name(channel_name); - request->set_cost_estimate_ns(cost_estimate_ns); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Conj( - const ComputationDataHandle& operand) { - return Complex(Real(operand), Neg(Imag(operand))); -} - -ComputationDataHandle ComputationBuilder::Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); -} - -// TODO(b/65209188): Create a dedicated lowering for Xor -ComputationDataHandle ComputationBuilder::Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return Or(And(Not(lhs), rhs, broadcast_dimensions), - And(lhs, Not(rhs), broadcast_dimensions)); -} - -ComputationDataHandle ComputationBuilder::Not( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NOT, operand); -} - -ComputationDataHandle ComputationBuilder::ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Abs( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ABS, operand); -} - -ComputationDataHandle ComputationBuilder::Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::Exp( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_EXP, operand); -} - -ComputationDataHandle ComputationBuilder::Floor( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_FLOOR, operand); -} - -ComputationDataHandle ComputationBuilder::Ceil( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_CEIL, operand); -} - -ComputationDataHandle ComputationBuilder::Round( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand); -} - -ComputationDataHandle ComputationBuilder::Log( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOG, operand); -} - -ComputationDataHandle ComputationBuilder::Sign( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIGN, operand); -} - -ComputationDataHandle ComputationBuilder::Cos( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_COS, operand); -} - -ComputationDataHandle ComputationBuilder::Sin( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SIN, operand); -} - -ComputationDataHandle ComputationBuilder::Tanh( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_TANH, operand); -} - -ComputationDataHandle ComputationBuilder::Real( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_REAL, operand); -} - -ComputationDataHandle ComputationBuilder::Imag( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IMAG, operand); -} - -ComputationDataHandle ComputationBuilder::IsFinite( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_IS_FINITE, operand); -} - -ComputationDataHandle ComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - OpRequest op_request; - TransposeRequest* request = op_request.mutable_transpose_request(); - *request->mutable_operand() = operand; - for (int64 dimension : permutation) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - OpRequest op_request; - ReverseRequest* request = op_request.mutable_reverse_request(); - *request->mutable_operand() = operand; - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Sort( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_SORT, operand); -} - -ComputationDataHandle ComputationBuilder::SqrtF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(0.5), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); -} - -ComputationDataHandle ComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BitcastConvertType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape_status = GetShape(operand); - if (!shape_status.ok()) { - return ComputationDataHandle(); - } - std::unique_ptr original = shape_status.ConsumeValueOrDie(); - - OpRequest op_request; - ConvertRequest* request = op_request.mutable_bitcast_convert_request(); - *request->mutable_operand() = operand; - request->set_new_element_type(new_element_type); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SquareF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::ReciprocalF32( - const ComputationDataHandle& operand) { - return BinaryOp(BINOP_POW, operand, ConstantR0(-1.0), - /*broadcast_dimensions=*/{}); -} - -ComputationDataHandle ComputationBuilder::Neg( - const ComputationDataHandle& operand) { - return UnaryOp(UNOP_NEGATE, operand); -} - -ComputationDataHandle ComputationBuilder::Clamp( - const ComputationDataHandle& min, const ComputationDataHandle& operand, - const ComputationDataHandle& max) { - return TernaryOp(TRIOP_CLAMP, min, operand, max); -} - -ComputationDataHandle ComputationBuilder::UnaryOp( - UnaryOperation unop, const ComputationDataHandle& operand) { - OpRequest op_request; - UnaryOpRequest* request = op_request.mutable_unary_op_request(); - request->set_unop(unop); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { - OpRequest op_request; - BinaryOpRequest* request = op_request.mutable_binary_op_request(); - request->set_binop(binop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - for (int64 dimension : broadcast_dimensions) { - request->add_broadcast_dimensions(dimension); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape) { - OpRequest op_request; - RngRequest* request = op_request.mutable_rng_request(); - request->set_distribution(distribution); - for (const ComputationDataHandle& param : parameters) { - *request->add_parameter() = param; - } - *request->mutable_shape() = shape; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::TernaryOp( - TernaryOperation triop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - OpRequest op_request; - TernaryOpRequest* request = op_request.mutable_ternary_op_request(); - request->set_triop(triop); - *request->mutable_lhs() = lhs; - *request->mutable_rhs() = rhs; - *request->mutable_ehs() = ehs; - return RunOpAndParseResponse(&op_request); -} - -Status ComputationBuilder::SetReturnValue( - const ComputationDataHandle& operand) { - TF_RETURN_IF_ERROR(first_error_); - - SetReturnValueRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - - SetReturnValueResponse response; - - VLOG(2) << "making set-handle-to-execute request"; - Status s = client_->stub()->SetReturnValue(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - return first_error_; - } - - return Status::OK(); -} - -StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - TF_RETURN_IF_ERROR(first_error_); - - IsConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - request.set_num_parameters(num_parameters); - IsConstantResponse response; - - VLOG(2) << "making IsConstant request"; - Status s = client_->stub()->IsConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - return response.is_constant(); -} - -StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - TF_RETURN_IF_ERROR(first_error_); - - ComputeConstantRequest request; - *request.mutable_computation() = computation_.handle(); - *request.mutable_operand() = operand; - if (output_layout != nullptr) { - *request.mutable_output_layout() = *output_layout; - } - for (const auto& param : parameters) { - *request.add_parameters() = param.ToProto(); - } - - ComputeConstantResponse response; - - VLOG(2) << "making compute-constant request"; - Status s = client_->stub()->ComputeConstant(&request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - return s; - } - - VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; - - if (!response.has_literal()) { - return InternalError( - "no computed literal in the provided response in ComputeConstant " - "request"); - } - return Literal::CreateFromProto(response.literal()); -} - -ComputationDataHandle ComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - OpRequest op_request; - MapRequest* request = op_request.mutable_map_request(); - for (const ComputationDataHandle& operand : operands) { - *request->add_operands() = operand; - } - *request->mutable_to_apply() = computation.handle(); - for (int64 dimension : dimensions) { - request->add_dimensions(dimension); - } - for (const ComputationDataHandle& sop : static_operands) { - *request->add_static_operands() = sop; - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); -} - -ComputationDataHandle ComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); -} - -ComputationDataHandle ComputationBuilder::While( - const Computation& condition, const Computation& body, - const ComputationDataHandle& init) { - OpRequest op_request; - WhileRequest* request = op_request.mutable_while_request(); - *request->mutable_condition() = condition.handle(); - *request->mutable_body() = body.handle(); - *request->mutable_init() = init; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds) { - OpRequest op_request; - GatherRequest* gather_request = op_request.mutable_gather_request(); - *gather_request->mutable_input() = input; - *gather_request->mutable_gather_indices() = gather_indices; - *gather_request->mutable_dimension_numbers() = dimension_numbers; - for (int64 window_bound : window_bounds) { - gather_request->add_window_bounds(window_bound); - } - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation) { - OpRequest op_request; - ConditionalRequest* request = op_request.mutable_conditional_request(); - *request->mutable_predicate() = predicate; - *request->mutable_true_operand() = true_operand; - *request->mutable_true_computation() = true_computation.handle(); - *request->mutable_false_operand() = false_operand; - *request->mutable_false_computation() = false_computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce) { - OpRequest op_request; - ReduceRequest* request = op_request.mutable_reduce_request(); - *request->mutable_operand() = operand; - *request->mutable_init_value() = init_value; - for (int64 dimension : dimensions_to_reduce) { - request->add_dimensions(dimension); - } - *request->mutable_to_apply() = computation.handle(); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReduceAll( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - std::vector all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie())); - std::iota(all_dimnos.begin(), all_dimnos.end(), 0); - return Reduce(operand, init_value, computation, all_dimnos); -} - -ComputationDataHandle ComputationBuilder::ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - - Status padding_valid = - ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides); - if (!padding_valid.ok()) { - first_error_ = padding_valid; - return ComputationDataHandle(); - } - - std::vector> padding_values = - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); -} - -ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding) { - OpRequest op_request; - ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); - *request->mutable_operand() = operand; - *request->mutable_to_apply() = computation.handle(); - *request->mutable_init_value() = init_value; - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormTraining( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormTrainingRequest* request = - op_request.mutable_batch_norm_training_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - OpRequest op_request; - BatchNormInferenceRequest* request = - op_request.mutable_batch_norm_inference_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_offset() = offset; - *request->mutable_mean() = mean; - *request->mutable_variance() = variance; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::BatchNormGrad( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, float epsilon, - int64 feature_index) { - OpRequest op_request; - BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); - *request->mutable_operand() = operand; - *request->mutable_scale() = scale; - *request->mutable_mean() = batch_mean; - *request->mutable_variance() = batch_var; - *request->mutable_grad_output() = grad_output; - request->set_epsilon(epsilon); - request->set_feature_index(feature_index); - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - OpRequest op_request; - CrossReplicaSumRequest* request = - op_request.mutable_cross_replica_sum_request(); - *request->mutable_operand() = operand; - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok()) { - return ComputationDataHandle(); - } - - StatusOr> shape = GetShape(operand); - if (!shape.ok()) { - return ComputationDataHandle(); - } - return SelectAndScatterWithGeneralPadding( - operand, select, window_dimensions, window_strides, - MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding), - source, init_value, scatter); -} - -ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter) { - OpRequest op_request; - SelectAndScatterRequest* request = - op_request.mutable_select_and_scatter_request(); - *request->mutable_operand() = operand; - *request->mutable_select() = select.handle(); - *request->mutable_source() = source; - *request->mutable_init_value() = init_value; - *request->mutable_scatter() = scatter.handle(); - - if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request->mutable_window())) { - NoteError(InternalError("failed to make window")); - return ComputationDataHandle(); - } - - return RunOpAndParseResponse(&op_request); -} - -ComputationDataHandle ComputationBuilder::ReducePrecision( - const ComputationDataHandle& operand, const int exponent_bits, - const int mantissa_bits) { - OpRequest op_request; - ReducePrecisionRequest* request = - op_request.mutable_reduce_precision_request(); - *request->mutable_operand() = operand; - request->set_exponent_bits(exponent_bits); - request->set_mantissa_bits(mantissa_bits); - return RunOpAndParseResponse(&op_request); -} - -void ComputationBuilder::Send(const ComputationDataHandle& operand, - const ChannelHandle& handle) { - OpRequest op_request; - SendRequest* request = op_request.mutable_send_request(); - *request->mutable_operand() = operand; - *request->mutable_channel_handle() = handle; - *op_request.mutable_computation() = computation_.handle(); - RunOpAndNoteError(&op_request); -} - -ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, - const ChannelHandle& handle) { - OpRequest op_request; - RecvRequest* request = op_request.mutable_recv_request(); - *request->mutable_shape() = shape; - *request->mutable_channel_handle() = handle; - return RunOpAndParseResponse(&op_request); -} - -Computation ComputationBuilder::BuildAndNoteError() { - DCHECK(parent_builder_ != nullptr); - auto build_status = Build(); - if (!build_status.ok()) { - parent_builder_->NoteError( - AddStatus(build_status.status(), - tensorflow::strings::StrCat("error from: ", name_))); - return Computation(); - } - return build_status.ConsumeValueOrDie(); -} - -StatusOr ComputationBuilder::Build() { - if (!first_error_.ok()) { - string backtrace; - first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); - return AppendStatus(first_error_, backtrace); - } - - if (computation_.IsNull()) { - return FailedPrecondition("no computation was built"); - } - - return {std::move(computation_)}; -} - -/* static */ ConvolutionDimensionNumbers -ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(kConvBatchDimension); - dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_output_batch_dimension(kConvBatchDimension); - dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); - dimension_numbers.set_kernel_output_feature_dimension( - kConvKernelOutputDimension); - dimension_numbers.set_kernel_input_feature_dimension( - kConvKernelInputDimension); - for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(i + 2); - dimension_numbers.add_kernel_spatial_dimensions(i + 2); - dimension_numbers.add_output_spatial_dimensions(i + 2); - } - return dimension_numbers; -} - -/* static */ StatusOr -ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({input_batch, input_feature, input_first_spatial, - input_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the input are not unique: (%lld, %lld, %lld, " - "%lld)", - input_batch, input_feature, input_first_spatial, input_second_spatial); - } - if (std::set({kernel_output_feature, kernel_input_feature, - kernel_first_spatial, kernel_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " - "%lld)", - kernel_output_feature, kernel_input_feature, kernel_first_spatial, - kernel_second_spatial); - } - if (std::set({output_batch, output_feature, output_first_spatial, - output_second_spatial}) - .size() != 4) { - return FailedPrecondition( - "dimension numbers for the output are not unique: (%lld, %lld, %lld, " - "%lld)", - output_batch, output_feature, output_first_spatial, - output_second_spatial); - } - ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_input_batch_dimension(input_batch); - dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.add_input_spatial_dimensions(input_first_spatial); - dimension_numbers.add_input_spatial_dimensions(input_second_spatial); - dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); - dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); - dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); - dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_output_spatial_dimensions(output_first_spatial); - dimension_numbers.add_output_spatial_dimensions(output_second_spatial); - return dimension_numbers; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h deleted file mode 100644 index 019c6f3afb5..00000000000 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ /dev/null @@ -1,1062 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/array.h" -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/global_data.h" -#include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/bitmap.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stacktrace.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Wraps an XLA client with a convenient interface for building up -// computations. Any errors encountered in building up the computation are -// deferred from being handled until Build() is called. -// -// Thread-compatible. -class ComputationBuilder { - public: - // client: client in which to build the computation. - // computation_name: name to use for the built computation. - ComputationBuilder(Client* client, const string& computation_name); - - ~ComputationBuilder(); - - // Returns the client the builder was initialized with. - Client* client() const { return client_; } - - // Returns the computation name. - const string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the Computation Builder. All subsequent - // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const tensorflow::gtl::optional& sharding() const { - return sharding_; - } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Enqueues a "retrieve parameter value" instruction for a parameter that was - // passed to the computation. - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); - - // Retrieves the (inferred) shape of the operand in the computation. - StatusOr> GetShape( - const ComputationDataHandle& operand); - - // Retrieves the (inferred) result for the current computation's shape. - StatusOr GetProgramShape(); - - // Enqueues a constant with the value of the given literal onto the - // computation. - ComputationDataHandle ConstantLiteral(const Literal& literal); - - // Enqueues a constant onto the computation. Methods are templated on the - // native host type (NativeT) which corresponds to a specific XLA - // PrimitiveType as given in the following table: - // - // Native Type PrimitiveType - // ----------------------------- - // bool PRED - // int32 S32 - // int64 S64 - // uint32 U32 - // uint64 U64 - // float F32 - // double F64 - // - // Note: not all primitive types defined in xla_data.proto have a - // corresponding native type yet. - template - ComputationDataHandle ConstantR0(NativeT value); - template - ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice values); - ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values); - template - ComputationDataHandle ConstantR2( - std::initializer_list> values); - template - ComputationDataHandle ConstantFromArrayWithLayout( - const Array& values, const Layout& layout); - template - ComputationDataHandle ConstantFromArray(const Array& values); - template - ComputationDataHandle ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout); - template - ComputationDataHandle ConstantR2FromArray2D(const Array2D& values); - template - ComputationDataHandle ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout); - template - ComputationDataHandle ConstantR3FromArray3D(const Array3D& values); - template - ComputationDataHandle ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout); - template - ComputationDataHandle ConstantR4FromArray4D(const Array4D& values); - - // Enqueues a rank one constant (vector) onto the computation. The vector has - // size 'length' and every element has the value 'value'. - template - ComputationDataHandle ConstantR1(int64 length, NativeT value); - - // Adds dimensions to an array by duplicating the data in the array. - // - // The new dimensions are inserted on the left, i.e. if - // broadcast_sizes has values {a0, ..., aN} and the operand shape - // has dimensions {b0, ..., bM} then the shape of the output has - // dimensions {a0, ..., aN, b0, ..., bM}. - // - // The new dimensions index into copies of the operand, i.e. - // - // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); - - // Enqueues a pad operation onto the computation that pads the given value on - // the edges as well as between the elements of the input. padding_config - // specifies the padding amount for each dimension. - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); - - // Enqueues an operation onto the computation that flattens the operand based - // on the dimension order (major/slowest-varying to minor/fastest-varying) - // given, followed by reshaping it into the shape with the given dimension - // sizes (also major to minor). Conceptually, this is a limited form of - // "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); - - // Enqueues an operation onto the computation that collapses the operand, from - // first to last dimension (C order), then reshapes it to the given dimension - // sizes. Conceptually, this is a limited form of "shape casting". - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice new_sizes); - - // Wrapper for Reshape. - // Enqueues an operation to collapse the provided dimensions; e.g. an - // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to - // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must - // be a consecutive, in-order subsequence of the operand dimensions. - // - // Note that collapsing a single dimension does nothing: - // - // {256} collapsing {0} => {256} - // {1} collapsing {0} => {1} - // - // Collapsing multiple dimensions produces a single result dimension: - // - // {256, 2} collapsing {0,1} => {512} - // {256, 2, 3} collapsing {0,1} => {512, 3} - // - // This could potentially cause data to be moved -- it provides a more - // structured form of reshaping than an arbitrary Reshape operation. - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a slice operation onto the computation that slices the operand - // from the start indices to the limit indices; e.g. - // - // x - // [ 0 1 2 3 ] - // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] - // [ 8 9 a b ] - // - // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D - // range notation. - // The strides parameter determines the stride over the slice - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); - - // Enqueues a slice operation in a given dimension, taking all other - // dimensions as they are; e.g. if dimno is 1 from start_index 2 to - // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand - // for: - // - // array[:, 2:4:1, :] - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); - - // Enqueues a slice operation onto the computation that slices the 'operand' - // from dynamic start indices which are passed in 'start_indices'. - // The size of the slice in each dimension is passed in 'slice_sizes', - // which specify the end point of exclusive slice intervals in each - // dimension [start, start + size). - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo input dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); - - // Enqueues a dynamic update slice operation onto the computation, which - // updates a slice of 'operand' with 'update' at dynamic 'start_indices'. - // The shape of 'update' determines the shape of the slice of 'operand' - // which is updated. - // The indices specified in 'start_indices' specify the offset of the slice - // of 'operand' which is updated. - // - // update = {10, 11} // calculated at runtime. - // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] - // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] - // [7 8 9] [7 8 9 ] - // - // The shape of 'start_indices' must be rank == 1, with dimension size - // equal to the rank of the 'operand'. - // Slice index calculations are computed modulo update dimension sizes to - // prevent dynamic start indices from generating out-of-bound array accesses. - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); - - // Enqueues a concatenate instruction onto the computation. 'operands' must - // have >= 1 entry. - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); - - // Enqueue a tracing operation onto the computation; the computation will emit - // a logging message with the operand. - void Trace(const string& tag, const ComputationDataHandle& operand); - - // Enqueues a conditional-move-like select operation onto the computation; - // predicated on pred, selects between on_true and on_false. - ComputationDataHandle Select(const ComputationDataHandle& pred, - const ComputationDataHandle& on_true, - const ComputationDataHandle& on_false); - - // Enqueues a tuple-creation instruction onto the computation. - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); - - // Enqueues a tuple-element-get instruction onto the computation. - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); - - // Enqueues an equal-to comparison instruction onto the computation. - ComputationDataHandle Eq( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a not-equal comparison instruction onto the computation. - ComputationDataHandle Ne( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-or-equal comparison instruction onto the computation. - ComputationDataHandle Ge( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a greater-than comparison instruction onto the computation. - ComputationDataHandle Gt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-than comparison instruction onto the computation. - ComputationDataHandle Lt( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a less-or-equal comparison instruction onto the computation. - ComputationDataHandle Le( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a dot instruction onto the computation. - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); - - // Enqueues a general dot instruction onto the computation. - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); - - // Default dimension numbers used for a 2D convolution. - static constexpr int64 kConvBatchDimension = 0; - static constexpr int64 kConvFeatureDimension = 1; - static constexpr int64 kConvFirstSpatialDimension = 2; - static constexpr int64 kConvSecondSpatialDimension = 3; - static constexpr int64 kConvKernelOutputDimension = 0; - static constexpr int64 kConvKernelInputDimension = 1; - static constexpr int64 kConvKernelFirstSpatialDimension = 2; - static constexpr int64 kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an - // error if either the input or the weight dimension numbers have conflicts. - static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 input_first_spatial, - int64 input_second_spatial, int64 output_batch, int64 output_feature, - int64 output_first_spatial, int64 output_second_spatial, - int64 kernel_output_feature, int64 kernel_input_feature, - int64 kernel_first_spatial, int64 kernel_second_spatial); - - // Enqueues a convolution instruction onto the computation, which uses the - // default convolution dimension numbers. - ComputationDataHandle Conv(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - Padding padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration in the format returned by MakePadding(). - ComputationDataHandle ConvWithGeneralPadding( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided dimension numbers configuration. - ComputationDataHandle ConvWithGeneralDimensions( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration as well as the dimension numbers. - ComputationDataHandle ConvGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues a convolution instruction onto the computation, with the caller - // provided padding configuration, dilation factors and dimension numbers. - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); - - // Enqueues an FFT instruction onto the computation, of the given type and - // with the given FFT length. - ComputationDataHandle Fft(const ComputationDataHandle& operand, - FftType fft_type, - tensorflow::gtl::ArraySlice fft_length); - - // Enqueues an infeed instruction onto the computation, which writes data of - // the given shape to the infeed buffer of the device. - ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); - - // Enqueues an outfeed instruction onto the computation. This instruction - // generates outgoing data transfers for the given data. - // - // shape_with_layout communicates the laid out shape that we want to outfeed - // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error - // will occur. - void Outfeed(const ComputationDataHandle& operand, - const Shape& shape_with_layout, const string& outfeed_config); - - // Enqueues a call instruction onto the computation. - ComputationDataHandle Call( - const Computation& computation, - tensorflow::gtl::ArraySlice operands); - - // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. - ComputationDataHandle CustomCall( - const string& call_target_name, - tensorflow::gtl::ArraySlice operands, - const Shape& shape); - - // Enqueues a pseudo-op to represent host-side computation data-dependencies. - // During code generation, host send and receive operations will be generated - // to transfer |operands| to the host and a single result of |shape| back to - // the device. Host send/recv operations are emitted using |channel_name|. - // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO - // instruction scheduling. - ComputationDataHandle HostCompute( - tensorflow::gtl::ArraySlice operands, - const string& channel_name, int64 cost_estimate_ns, const Shape& shape); - - // The following methods enqueue element-wise binary arithmetic operations - // onto the computation. The shapes of the operands have to match unless one - // of the operands is a scalar, or an explicit broadcast dimension is given - // (see g3doc for more details). - - // Enqueues a complex compose instruction onto the computation. - ComputationDataHandle Complex( - const ComputationDataHandle& real, const ComputationDataHandle& imag, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a complex conjugate instruction onto the computation. - ComputationDataHandle Conj(const ComputationDataHandle& operand); - - // Enqueues an add instruction onto the computation. - ComputationDataHandle Add( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a subtract instruction onto the computation. - ComputationDataHandle Sub( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a multiply instruction onto the computation. - ComputationDataHandle Mul( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a divide instruction onto the computation. - ComputationDataHandle Div( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a remainder instruction onto the computation. - ComputationDataHandle Rem( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a max instruction onto the computation. - ComputationDataHandle Max( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues a min instruction onto the computation. - ComputationDataHandle Min( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Element-wise logical operators - ComputationDataHandle And( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Or( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Xor( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - ComputationDataHandle Not(const ComputationDataHandle& operand); - - ComputationDataHandle ShiftLeft( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightArithmetic( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle ShiftRightLogical( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Reduces an array among the provided dimensions, given "computation" as a - // reduction operator. - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); - - // Convenience wrapper around the above that reduces all the dimensions in the - // operand shape. - ComputationDataHandle ReduceAll(const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const Computation& computation); - - // Enqueues a windowed reduce instruction onto the computation. - ComputationDataHandle ReduceWindow( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding); - - // As ReduceWindow(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, const Computation& computation, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding); - - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); - - // Enqueues an operation that scatters the `source` array to the selected - // indices of each window. - ComputationDataHandle SelectAndScatter( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, Padding padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // As SelectAndScatter(), but the padding is given in the format - // returned by MakePadding(). - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const Computation& select, - tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const Computation& scatter); - - // Enqueues an abs instruction onto the computation. - ComputationDataHandle Abs(const ComputationDataHandle& operand); - - // Enqueues a atan2 instruction onto the computation. - ComputationDataHandle Atan2( - const ComputationDataHandle& y, const ComputationDataHandle& x, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an exp instruction onto the computation. - ComputationDataHandle Exp(const ComputationDataHandle& operand); - - // Enqueues a floor instruction onto the computation. - ComputationDataHandle Floor(const ComputationDataHandle& operand); - - // Enqueues a ceil instruction onto the computation. - ComputationDataHandle Ceil(const ComputationDataHandle& operand); - - // Enqueues a round instruction onto the computation, rounding to nearest even - // with half-way cases rounding away from zero. - ComputationDataHandle Round(const ComputationDataHandle& operand); - - // Enqueues an log instruction (natural logarithm) onto the computation. - ComputationDataHandle Log(const ComputationDataHandle& operand); - - // Enqueues a sign instruction onto the computation. - ComputationDataHandle Sign(const ComputationDataHandle& operand); - - // Enqueues a cosine instruction onto the computation. - ComputationDataHandle Cos(const ComputationDataHandle& operand); - - // Enqueues a sine instruction onto the computation. - ComputationDataHandle Sin(const ComputationDataHandle& operand); - - // Enqueues a tanh instruction onto the computation. - ComputationDataHandle Tanh(const ComputationDataHandle& operand); - - // Enqueues a real-part instruction onto the computation. - ComputationDataHandle Real(const ComputationDataHandle& operand); - - // Enqueues an imaginary-part instruction onto the computation. - ComputationDataHandle Imag(const ComputationDataHandle& operand); - - // Enqueues a float32 sqrt instruction onto the computation. - // (float32 is specified as there is an implicit float32 0.5f constant - // exponent). - ComputationDataHandle SqrtF32(const ComputationDataHandle& operand); - - // Enqueues a float32 square instruction onto the computation. - // (float32 is specified as there is an implicit float32 2.0f constant - // exponent). - ComputationDataHandle SquareF32(const ComputationDataHandle& operand); - - // Enqueues a lhs^rhs computation onto the computation. - ComputationDataHandle Pow( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - - // Enqueues an operator that tests if the operand's values are finite, i.e., - // not Inf or NaN. Defined only for floating-point types. Returns an array of - // booleans with the same shape where entries are true iff the corresponding - // entry was NaN. - ComputationDataHandle IsFinite(const ComputationDataHandle& operand); - - // Enqueues a convert instruction onto the computation that changes the - // element type of the operand array to primitive_type. - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a no-op instruction onto the computation that changes - // the element type of the operand array to primitive_type. The - // bit-widths of the source and destination element types must be - // identical. - ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); - - // Enqueues a float32 reciprocal instruction onto the computation. - // (float32 is specified as there is an implicit float32 -1.0f constant - // exponent). - // - // TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the - // shape of the operand. - ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand); - - // Enqueues a negate instruction onto the computation. - ComputationDataHandle Neg(const ComputationDataHandle& operand); - - // Enqueues a transpose instruction onto the computation. - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); - - // Enqueues a reverse instruction onto the computation. The order of the - // elements in the given dimensions is reversed (i.e., the element at index i - // is moved to index dimension_size - 1 - i). - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); - - // Enqueues a sort (as increasing order) instruction onto the computation. - ComputationDataHandle Sort(const ComputationDataHandle& operand); - - // Enqueues a clamp instruction onto the computation. - ComputationDataHandle Clamp(const ComputationDataHandle& min, - const ComputationDataHandle& operand, - const ComputationDataHandle& max); - - // Enqueues a map instruction onto the computation. - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const Computation& computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands = {}); - - // Enqueues a N(mu, sigma) random number generation instruction onto the - // computation. - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); - - // Enqueues a U(a, b) random number generation instruction onto the - // computation. Returns values in the semi-open interval [a, b). - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); - - // Enqueues a while node onto the computation. - ComputationDataHandle While(const Computation& condition, - const Computation& body, - const ComputationDataHandle& init); - - // Enqueues a conditional node onto the computation. - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const Computation& true_computation, - const ComputationDataHandle& false_operand, - const Computation& false_computation); - - // Enqueues a ReducePrecision node onto the computation. - ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, - const int exponent_bits, - const int mantissa_bits); - - // Enqueues a Gather node onto the computation. - ComputationDataHandle Gather( - const ComputationDataHandle& input, - const ComputationDataHandle& gather_indices, - const GatherDimensionNumbers& dimension_numbers, - tensorflow::gtl::ArraySlice window_bounds); - - // Enqueues a Send node onto the computation, to send the given operand to - // a Recv instruction that shares the same channel handle. - void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); - - // Enqueues a Recv node onto the computation. The data comes from a Send - // instruction that shares the same channel handle and its shape must - // be the same as the given shape. - ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with index greater than or equal to - // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. - // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a - // compile-time constant without evaluating the computation. - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters = 0); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // Returns a tuple (normalized, batch_mean, batch_var) where `normalized` - // is the normalized result and batch_mean and batch_var are the mean and - // variance, respectively, across batch for the operand. - ComputationDataHandle BatchNormTraining(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& offset, - float epsilon, int64 feature_index); - - // Normalizes operand across spatial and batch dimensions for each feature. - // - // `BatchNormInference` is equivalent to calling `BatchNormTraining` without - // computing `mean` and `variance` for each batch inside the operation. It - // uses the input `mean` and `variance` instead as estimated values. The - // purpose of this op is to reduce latency in inference, hence the name - // `BatchNormInference`. - // - // The output has the same shape as `operand`, and contains the normalized - // values for each batch. - ComputationDataHandle BatchNormInference( - const ComputationDataHandle& operand, const ComputationDataHandle& scale, - const ComputationDataHandle& offset, const ComputationDataHandle& mean, - const ComputationDataHandle& variance, float epsilon, - int64 feature_index); - - // Calculates the gradients of a batch norm op. - // - // The inputs `batch_mean` and `batch_var` represent the mean and variance - // across the batch. - // - // Returns a tuple of three elements: - // - grad_operand: Gradient with respect to input `operand` - // - grad_offset: Gradient with respect to input `offset` - // - grad_scale: Gradient with respect to input `scale` - ComputationDataHandle BatchNormGrad(const ComputationDataHandle& operand, - const ComputationDataHandle& scale, - const ComputationDataHandle& batch_mean, - const ComputationDataHandle& batch_var, - const ComputationDataHandle& grad_output, - float epsilon, int64 feature_index); - - // Computes the value of a constant indicated by a - // ComputationDataHandle using a non-optimized interpreter on the host. - // - // The operand must be from the computation currently being built - - // i.e., returned from this builder with no intervening call to - // Build(). This happens to currently work regardless of that, but - // that may stop working at any time. - // - // The operand must represent a constant value, which in this case - // means that it must not statically depend on any parameter of the - // computation that is being built other then the ones specified on the - // parameter list. The parameters in the list will be indexed by their - // parameter id property so the number of parameters specified should be at - // least as many as the largest used parameter index. - // - // `IsConstant` can be used to test whether a computation is a compile-time - // constant without evaluation it. `ComputeConstant` only succeeds for - // computations where `IsConstant` returns true. - // - // This functionality can be useful when translating a computation - // into XLA where something that looked dynamic is required by - // XLA to be specified as a constant. E.g. the source - // computation (outside of XLA) may include a dynamic - // computation of the shape of something and ComputeConstant lets - // you determine what the value of that computation is in the case - // where the value can be determined at compile time. - // - // If output_layout is non-null, then the output of the computation - // will be stored using that layout. - StatusOr> ComputeConstant( - const ComputationDataHandle& operand, - const Layout* output_layout = nullptr, - tensorflow::gtl::ArraySlice parameters = {}); - - // Returns a new ComputationBuilder whose resultant Computation is used only - // by this ComputationBuilder. The sub-ComputationBuilder has the same - // die_immediately_on_error behavior as the parent. - std::unique_ptr CreateSubBuilder( - const string& computation_name); - - // Modifies the computation being built so that executions of it - // will return the value associated with operand, rather than the - // last expression enqueued on the ComputationBuilder. Any subsequent - // operations added to the ComputationBuilder will not have any effect unless - // SetReturnValue is called again. - Status SetReturnValue(const ComputationDataHandle& operand); - - // Builds the computation with the requested operations, or returns a non-ok - // status. - StatusOr Build(); - - // Builds the computation with the requested operations, or notes an error in - // the parent ComputationBuilder and returns an empty computation if building - // failed. This function is intended to be used where the returned - // Computation is only used by the parent ComputationBuilder and hence further - // operation on the returned Computation will simply be error'ed out if an - // error occurred while building this computation. If the built computation is - // to be used by a ComputationBuilder other than the parent ComputationBuilder - // then Build() should be used instead. - Computation BuildAndNoteError(); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // ComputationDataHandle and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - Status first_error() const { return first_error_; } - - private: - // Limited checking of convolution parameters. Returns false on - // error. - bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers); - - // The parent ComputationBuilder of a sub-ComputationBuilder. The - // parent_builder_ will be the nullptr if not a sub-ComputationBuilder. - ComputationBuilder* parent_builder_{nullptr}; - - // Helper function for creating a Window proto from user-supplied - // data. Returns true if the user-supplied data was valid. - bool MakeWindow(tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice window_strides, - tensorflow::gtl::ArraySlice> padding, - tensorflow::gtl::ArraySlice lhs_dilation, - tensorflow::gtl::ArraySlice rhs_dilation, - Window* window); - - // Internal helper method that does the building for an arbitrary unary op. - ComputationDataHandle UnaryOp(UnaryOperation unop, - const ComputationDataHandle& operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. - ComputationDataHandle BinaryOp( - BinaryOperation binop, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); - - // Internal helper method that does the building for an arbitrary ternary op. - ComputationDataHandle TernaryOp(TernaryOperation triop, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - const ComputationDataHandle& ehs); - - // Internal helper method that does the building for a random number generator - // of a given distribution with an explicitly specified shape. - ComputationDataHandle RngOp( - RandomDistribution distribution, - tensorflow::gtl::ArraySlice parameters, - const Shape& shape); - - // Populates computation_ with a valid object or returns a failing status. - // This is used before any given operation is enqueued. - Status PrepareComputation(); - - // Notes that the error occurred by: - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to Build()) - // * dying if die_immediately_on_error_ is true - void NoteError(const Status& error); - - // Helper function that runs the given op_request, filling in op_response. - // Before the op is run, PrepareComputation is called, and common fields in - // the op_request are filled in. - Status RunOp(OpRequest* op_request, OpResponse* op_response); - - // Helper function that calls RunOp and calls NoteError on failures. - void RunOpAndNoteError(OpRequest* op_request); - - // Helper function that calls RunOp and either returns the output computation - // data handle (on success) or a vacuous computation data handle (on failure). - ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); - - // Helper function that implements GetShape without noting errors. This makes - // it easier to ensure the real GetShape will note errors on every error path. - StatusOr> GetShapeWithoutNoteError( - const ComputationDataHandle& operand); - - string name_; // Name to use for the built computation. - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tensorflow::SavedStackTrace first_error_backtrace_; - - // The computation that operations are enqueued onto. - Computation computation_; - - // The client that the computation is created in. Not owned. - Client* client_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - tensorflow::gtl::optional sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); -}; - -template -ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantLiteral(*Literal::CreateR0(value)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1( - tensorflow::gtl::ArraySlice values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, - NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(literal); -} - -inline ComputationDataHandle ComputationBuilder::ConstantR1( - const tensorflow::core::Bitmap& values) { - return ConstantLiteral(*Literal::CreateR1(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2( - std::initializer_list> values) { - return ConstantLiteral(*Literal::CreateR2(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( - const Array& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantFromArray( - const Array& values) { - return ConstantLiteral(*Literal::CreateFromArray(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateFromArrayWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( - const Array2D& values) { - return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return ConstantLiteral( - *Literal::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( - const Array3D& values) { - return ConstantFromArray(values); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); -} - -template -ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( - const Array4D& values) { - return ConstantFromArray(values); -} - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class ScopedShardingAssignment { - public: - ScopedShardingAssignment(xla::ComputationBuilder* builder, - tensorflow::gtl::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const tensorflow::gtl::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::ComputationBuilder* const builder_; - tensorflow::gtl::optional prev_sharding_; - - TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc index 40f59eaa68e..2986d406001 100644 --- a/tensorflow/compiler/xla/client/global_data.cc +++ b/tensorflow/compiler/xla/client/global_data.cc @@ -31,7 +31,7 @@ GlobalData::~GlobalData() { *request.mutable_data() = handle_; UnregisterResponse response; VLOG(1) << "requesting to unregister " << handle_.ShortDebugString(); - tensorflow::Status s = parent_->Unregister(&request, &response); + Status s = parent_->Unregister(&request, &response); VLOG(1) << "done with request"; if (!s.ok()) { diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 59c4a53c05a..d49d959a6c8 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -22,8 +22,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", @@ -43,9 +41,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 63df449e0b3..a1d34796ccf 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -17,7 +17,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,28 +28,6 @@ limitations under the License. namespace xla { namespace { -using InstructionGenerator = - ComputationDataHandle (*)(ComputationBuilder*, const ComputationDataHandle&, - const ComputationDataHandle&); - -Computation CreateScalarComputation(const string& name, PrimitiveType type, - ComputationBuilder* builder, - InstructionGenerator generator) { - std::unique_ptr b; - if (type == PRED) { - b = builder->CreateSubBuilder(name); - } else { - b = builder->CreateSubBuilder( - tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type))); - } - - const Shape scalar = ShapeUtil::MakeShape(type, {}); - auto lhs = b->Parameter(0, scalar, "lhs"); - auto rhs = b->Parameter(1, scalar, "rhs"); - generator(b.get(), lhs, rhs); - return b->BuildAndNoteError(); -} - using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&); XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, @@ -71,71 +50,6 @@ XlaComputation CreateScalarComputation(const string& name, PrimitiveType type, } // namespace -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "add", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Add(lhs, rhs); }); -} - -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "mul", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Mul(lhs, rhs); }); -} - -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "ge", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Ge(lhs, rhs); }); -} - -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "max", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Max(lhs, rhs); }); -} - -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder) { - return CreateScalarComputation( - "min", type, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); -} - -Computation CreateScalarAndComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "and", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); -} - -Computation CreateScalarOrComputation(ComputationBuilder* builder) { - return CreateScalarComputation( - "or", PRED, builder, - [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); -} - -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder) { - auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarOrComputation(builder); - TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, - builder->GetShape(predicates)); - std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); - std::iota(all_dimensions.begin(), all_dimensions.end(), 0); - return builder->Reduce(predicates, f, logical_or, all_dimensions); -} - XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder) { return CreateScalarComputation( diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f4d3fc80159..64b6b7d6335 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -18,83 +18,38 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -// Creates a scalar add computation and returns it. -Computation CreateScalarAddComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar multiply computation and returns it. -Computation CreateScalarMultiplyComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar ge computation and returns it. -Computation CreateScalarGeComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar max computation and returns it. -Computation CreateScalarMaxComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar min computation and returns it. -Computation CreateScalarMinComputation(PrimitiveType type, - ComputationBuilder* builder); - -// Creates a scalar logical AND computation and returns it. -Computation CreateScalarAndComputation(ComputationBuilder* builder); - -// Creates a scalar logical OR computation and returns it. -Computation CreateScalarOrComputation(ComputationBuilder* builder); - -// Returns whether any predicate in "predicates" is set. -// -// Note: if predicates is zero-sized, Any() vacuously returns false. -StatusOr Any(const ComputationDataHandle& predicates, - ComputationBuilder* builder); - -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar add computation and returns it. XlaComputation CreateScalarAddComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar multiply computation and returns it. XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar ge computation and returns it. XlaComputation CreateScalarGeComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar max computation and returns it. XlaComputation CreateScalarMaxComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar min computation and returns it. XlaComputation CreateScalarMinComputation(PrimitiveType type, XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// + // Creates a scalar logical AND computation and returns it. XlaComputation CreateScalarAndComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Creates a scalar logical OR computation and returns it. XlaComputation CreateScalarOrComputation(XlaBuilder* builder); -// TODO(b/74197823): This is a part of a NOT YET ready refactor. -// // Returns whether any predicate in "predicates" is set. // // Note: if predicates is zero-sized, Any() vacuously returns false. diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 311dc4bdd72..3380af9f303 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -15,8 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -46,16 +45,14 @@ int64 DataSizeOfShape(const Shape& shape) { return total_size; } -// Create a ComputationDataHandle for an op what generates fake data with the -// given shape. -ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, - ComputationBuilder* builder) { +// Creates a XlaOp for an op what generates fake data with the given shape. +XlaOp BuildFakeDataOpOnDevice(const Shape& shape, XlaBuilder* builder) { if (ShapeUtil::IsArray(shape)) { return builder->Broadcast( builder->ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); } - std::vector parts; + std::vector parts; for (const Shape& s : shape.tuple_shapes()) { parts.push_back(BuildFakeDataOpOnDevice(s, builder)); } @@ -64,11 +61,10 @@ ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { - ComputationBuilder b( - client, + XlaBuilder b( tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); BuildFakeDataOpOnDevice(shape, &b); - Computation computation = b.Build().ConsumeValueOrDie(); + XlaComputation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; @@ -96,21 +92,6 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, return MakeFakeDataViaDeviceOrDie(shape, client); } -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client) { - auto program_shape = - client->GetComputationShape(computation).ConsumeValueOrDie(); - - // For every (unbound) parameter that the computation wants, we manufacture - // some arbitrary data so that we can invoke the computation. - std::vector> fake_arguments; - for (const Shape& parameter : program_shape->parameters()) { - fake_arguments.push_back(MakeFakeDataOrDie(parameter, client)); - } - - return fake_arguments; -} - std::vector> MakeFakeArgumentsOrDie( const XlaComputation& computation, Client* client) { CHECK(computation.proto().has_program_shape()) diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index 1dc2622972d..dc613099e2b 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -33,12 +32,6 @@ namespace xla { std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client); -// Returns vector of GlobalData handles of fake data (created using -// MakeFakeDataOrDie) that are correctly shaped arguments for the given -// computation. -std::vector> MakeFakeArgumentsOrDie( - const Computation& computation, Client* client); - // Returns vector of GlobalData handles of fake data (created using // MakeFakeDataOrDie) that are correctly shaped arguments for the given // xla computation. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 30594243dcf..a7c55c6b2b7 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -24,8 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/status_macros.h" -namespace se = ::perftools::gputools; - using xla::source_map_util::InvalidParameterArgument; namespace xla { @@ -50,30 +48,52 @@ LocalExecutable::LocalExecutable(std::unique_ptr executable, << "Must have a valid device ordinal that the executable was built for."; } -tensorflow::Status LocalExecutable::ValidateExecutionOptions( +Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend) { - const ComputationLayout& computation_layout = - executable_->module_config().entry_computation_layout(); + const ComputationLayout& host_computation_layout = + executable_->module_config().host_entry_computation_layout(); + const ComputationLayout& device_computation_layout = + executable_->module_config().device_entry_computation_layout(); // Check argument number, shapes, and layouts. - if (arguments.size() != computation_layout.parameter_count()) { + if (arguments.size() != host_computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %zu", - computation_layout.parameter_count(), arguments.size()); + host_computation_layout.parameter_count(), arguments.size()); + } + if (arguments.size() != device_computation_layout.parameter_count()) { + return InvalidArgument( + "invalid number of arguments for computation: expected %d, got %zu", + device_computation_layout.parameter_count(), arguments.size()); } for (int i = 0; i < arguments.size(); ++i) { - if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( + if (!host_computation_layout.parameter_layout(i).MatchesLayoutInShape( arguments[i]->on_host_shape())) { return InvalidParameterArgument( executable_.get(), i, - "Argument does not match shape or layout of computation parameter " + "Argument does not match host shape or layout of computation " + "parameter " "%d: want %s, got %s", i, - ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) + ShapeUtil::HumanString( + host_computation_layout.parameter_layout(i).shape()) .c_str(), ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } + if (!device_computation_layout.parameter_layout(i).MatchesLayoutInShape( + arguments[i]->on_device_shape())) { + return InvalidParameterArgument( + executable_.get(), i, + "Argument does not match device shape or layout of computation " + "parameter " + "%d: want %s, got %s", + i, + ShapeUtil::HumanString( + device_computation_layout.parameter_layout(i).shape()) + .c_str(), + ShapeUtil::HumanString(arguments[i]->on_device_shape()).c_str()); + } } if (run_options.stream() != nullptr) { @@ -136,7 +156,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( return Status::OK(); } -StatusOr> LocalExecutable::Run( +StatusOr LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, ExecutableRunOptions run_options) { TF_RETURN_IF_ERROR( @@ -168,31 +188,26 @@ StatusOr> LocalExecutable::Run( if (executable_->dumping()) { return ExecuteAndDump(&service_options, arguments); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr result, - executable_->ExecuteOnStreamWrapper( - &service_options, run_options.execution_profile(), arguments)); - - return MakeUnique(std::move(*result), - run_options.allocator()); + return executable_->ExecuteOnStreamWrapper( + &service_options, run_options.execution_profile(), arguments); } -StatusOr> LocalExecutable::ExecuteAndDump( +StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { executable_->session_module()->set_execution_platform( backend_->platform()->Name()); TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); TF_ASSIGN_OR_RETURN( - std::unique_ptr result, + ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(result.get(), executable_->session_module())); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module())); TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); - return ScopedShapedBuffer::MakeScoped(result.get(), run_options->allocator()); + return std::move(result); } -tensorflow::Status LocalExecutable::RecordArguments( +Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, SessionModule* session_module) { session_module->clear_arguments(); @@ -204,8 +219,8 @@ tensorflow::Status LocalExecutable::RecordArguments( return Status::OK(); } -tensorflow::Status LocalExecutable::RecordResult( - const ShapedBuffer* result, SessionModule* session_module) { +Status LocalExecutable::RecordResult(const ShapedBuffer* result, + SessionModule* session_module) { session_module->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); @@ -246,25 +261,6 @@ Backend* LocalClient::mutable_backend() { return local_service_->mutable_backend(); } -StatusOr> LocalClient::Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options) { - ExecutableBuildOptions updated_options = options; - if (options.device_ordinal() == -1) { - updated_options.set_device_ordinal(default_device_ordinal()); - VLOG(3) << "Set device ordinal to default value of: " - << updated_options.device_ordinal(); - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - updated_options)); - return WrapUnique(new LocalExecutable(std::move(executable), - local_service_->mutable_backend(), - updated_options)); -} - StatusOr> LocalClient::Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -283,9 +279,9 @@ StatusOr> LocalClient::Compile( updated_options)); } -StatusOr> -LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, - DeviceMemoryAllocator* allocator) { +StatusOr LocalClient::LiteralToShapedBuffer( + const Literal& literal, int device_ordinal, + DeviceMemoryAllocator* allocator) { if (allocator == nullptr) { allocator = backend().memory_allocator(); } @@ -295,7 +291,7 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, literal, *scoped_buffer)); + executor, literal, scoped_buffer)); return std::move(scoped_buffer); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 98ee7c62c94..d63d4ec7f37 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -19,8 +19,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -38,19 +38,10 @@ class LocalExecutable { public: // Run the compiled computation with the given arguments and options and // return the result. - StatusOr> Run( + StatusOr Run( const tensorflow::gtl::ArraySlice arguments, ExecutableRunOptions run_options); - // Return the layout (contained in a shape) of the result produced by the - // computation. - const Shape& result_layout() const { - return executable_->module_config() - .entry_computation_layout() - .result_layout() - .shape(); - } - // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -67,25 +58,25 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. - tensorflow::Status ValidateExecutionOptions( + Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. - StatusOr> ExecuteAndDump( + StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); // Records the arguments used to invoke the computation in a SessionModule // proto. - tensorflow::Status RecordArguments( + Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, SessionModule* session_module); // Records the result of the computation in a SessionModule proto. - tensorflow::Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, + SessionModule* session_module); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -116,17 +107,8 @@ class LocalClient : public Client { LocalClient(const LocalClient&) = delete; void operator=(const LocalClient&) = delete; - // Build and return a LocalExecutable object. The executable is compiled using - // the given argument layouts and options. - StatusOr> Compile( - const Computation& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -136,7 +118,7 @@ class LocalClient : public Client { // ScopedShapedBuffer. If non-null the given memory allocator is used for // device memory allocation. If null, the default memory allocator for the // device is used. - StatusOr> LiteralToShapedBuffer( + StatusOr LiteralToShapedBuffer( const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator = nullptr); @@ -167,7 +149,7 @@ class LocalClient : public Client { StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); // Returns the platform that the underlying service targets. - perftools::gputools::Platform* platform() const; + se::Platform* platform() const; // Returns the number of devices on the system of the service platform // type. Not all devices may be supported by the service (see diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 31fa1241ee4..0d6e207971e 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -31,9 +31,9 @@ cc_library( hdrs = ["xla_computation.h"], deps = [ "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 7ccdc2ded2c..ae506317c2e 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -57,16 +57,6 @@ bool CanBeRoot(HloOpcode opcode) { } } -StatusOr> GetOperandShapes( - tensorflow::gtl::ArraySlice operands) { - std::vector operand_shapes; - for (const XlaOp& operand : operands) { - TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape()); - operand_shapes.push_back(shape); - } - return operand_shapes; -} - } // namespace StatusOr XlaBuilder::GetShape(const XlaOp& op) const { @@ -76,12 +66,14 @@ StatusOr XlaBuilder::GetShape(const XlaOp& op) const { return instr->shape(); } -StatusOr XlaOp::GetShape() const { - if (builder_ == nullptr) { - return InvalidArgument( - "cannot GetShape for an invalid XlaOp with handle %lld", handle()); +StatusOr> XlaBuilder::GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const { + std::vector operand_shapes; + for (const XlaOp& operand : operands) { + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + operand_shapes.push_back(shape); } - return builder_->GetShape(*this); + return operand_shapes; } XlaBuilder::XlaBuilder(const string& computation_name) @@ -286,7 +278,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, const XlaOp& operand) { TF_RETURN_IF_ERROR(first_error_); - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); CHECK(ShapeUtil::IsScalar(operand_shape) || ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)); @@ -325,7 +317,7 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferUnaryOpShape(unop, operand_shape)); return AddInstruction(std::move(instr), unop, {operand}); @@ -337,8 +329,8 @@ XlaOp XlaBuilder::BinaryOp( tensorflow::gtl::ArraySlice broadcast_dimensions) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferBinaryOpShape( binop, lhs_shape, rhs_shape, broadcast_dimensions)); @@ -374,12 +366,12 @@ XlaOp XlaBuilder::BinaryOp( updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs; } - TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, updated_lhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) { TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(instr.shape(), updated_lhs)); } - TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, updated_rhs.GetShape()); + TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs)); if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) { TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(instr.shape(), updated_rhs)); @@ -393,9 +385,9 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs, const XlaOp& ehs) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape()); - TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, ehs.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); + TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); + TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs)); TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferTernaryOpShape( triop, lhs_shape, rhs_shape, ehs_shape)); @@ -437,7 +429,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs, return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions); } -XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) { +XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; *instr.mutable_shape() = literal.shape(); @@ -485,7 +477,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape, XlaOp XlaBuilder::Broadcast( const XlaOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( const Shape& shape, ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes)); @@ -633,7 +625,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN(const Shape& shape, ShapeInference::InferReshapeShape( operand_shape, dimensions, new_sizes)); @@ -647,7 +639,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand, XlaOp XlaBuilder::Reshape(const XlaOp& operand, tensorflow::gtl::ArraySlice new_sizes) { return NoteErrorOrReturn([&]() -> StatusOr { - TF_ASSIGN_OR_RETURN(auto shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand)); std::vector dimensions(shape.dimensions_size()); std::iota(dimensions.begin(), dimensions.end(), 0); return Reshape(operand, dimensions, new_sizes); @@ -1002,7 +994,7 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type, const tensorflow::gtl::ArraySlice fft_length) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferFftShape(operand_shape, fft_type, fft_length)); @@ -1173,6 +1165,10 @@ XlaOp XlaBuilder::Exp(const XlaOp& operand) { return UnaryOp(HloOpcode::kExp, operand); } +XlaOp XlaBuilder::Expm1(const XlaOp& operand) { + return UnaryOp(HloOpcode::kExpm1, operand); +} + XlaOp XlaBuilder::Floor(const XlaOp& operand) { return UnaryOp(HloOpcode::kFloor, operand); } @@ -1189,10 +1185,18 @@ XlaOp XlaBuilder::Log(const XlaOp& operand) { return UnaryOp(HloOpcode::kLog, operand); } +XlaOp XlaBuilder::Log1p(const XlaOp& operand) { + return UnaryOp(HloOpcode::kLog1p, operand); +} + XlaOp XlaBuilder::Sign(const XlaOp& operand) { return UnaryOp(HloOpcode::kSign, operand); } +XlaOp XlaBuilder::Clz(const XlaOp& operand) { + return UnaryOp(HloOpcode::kClz, operand); +} + XlaOp XlaBuilder::Cos(const XlaOp& operand) { return UnaryOp(HloOpcode::kCos, operand); } @@ -1221,7 +1225,7 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand, tensorflow::gtl::ArraySlice permutation) { return NoteErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape& operand_shape, operand.GetShape()); + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferTransposeShape(operand_shape, permutation)); @@ -1944,11 +1948,18 @@ StatusOr XlaBuilder::LookUpInstruction( const XlaOp& op) const { TF_RETURN_IF_ERROR(first_error_); + if (op.builder_ == nullptr) { + return InvalidArgument( + "invalid XlaOp with handle %lld; the builder of this op is freed", + op.handle()); + } if (op.builder_ != this) { - return InvalidArgument("invalid XlaOp with handle %lld", op.handle()); + return InvalidArgument( + "XlaOp with handle %lld is built by builder '%s', but is trying to use " + "it in builder '%s'", + op.handle(), op.builder_->name().c_str(), this->name().c_str()); } - TF_RET_CHECK(op.builder_ == this); if (op.handle() >= instructions_.size() || op.handle() < 0) { return InvalidArgument("no XlaOp value %lld", op.handle()); } diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 1f7c731064d..2b3013a91c4 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(b/74197823): Replace computation_builder.h with this file. -// -// This is NOT YET ready to use. - #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_CLIENT_XLA_BUILDER_H_ @@ -48,20 +44,32 @@ class XlaBuilder; // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the // instruction as an operand. -// -// TODO(b/74197823): Replace xla::ComputationDataHandle with this one. class XlaOp { public: XlaOp() : handle_(0), builder_(nullptr) {} ~XlaOp() {} - StatusOr GetShape() const; + const XlaBuilder* builder() const { return builder_; } + + bool operator==(const XlaOp& rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; + } + + bool operator!=(const XlaOp& rhs) const { + return handle_ != rhs.handle_ || builder_ != rhs.builder_; + } + + friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) { + out << op.handle(); + return out; + } private: XlaOp(int64 handle, XlaBuilder* builder) : handle_(handle), builder_(builder) {} int64 handle() const { return handle_; } + friend class XlaBuilder; int64 handle_; @@ -71,8 +79,6 @@ class XlaOp { // A convenient interface for building up computations. // // Thread-compatible. -// -// TODO(b/74197823): Replace xla::ComputationBuilder with this one. class XlaBuilder { public: // computation_name: name to use for the built computation. @@ -123,7 +129,7 @@ class XlaBuilder { // Enqueues a constant with the value of the given literal onto the // computation. - XlaOp ConstantLiteral(const Literal& literal); + XlaOp ConstantLiteral(const LiteralSlice& literal); // Enqueues a constant onto the computation. Methods are templated on the // native host type (NativeT) which corresponds to a specific XLA @@ -555,6 +561,9 @@ class XlaBuilder { // Enqueues an exp instruction onto the computation. XlaOp Exp(const XlaOp& operand); + // Enqueues an expm1 instruction onto the computation. + XlaOp Expm1(const XlaOp& operand); + // Enqueues a floor instruction onto the computation. XlaOp Floor(const XlaOp& operand); @@ -568,9 +577,15 @@ class XlaBuilder { // Enqueues an log instruction (natural logarithm) onto the computation. XlaOp Log(const XlaOp& operand); + // Enqueues an log1p instruction (log(x+1)) onto the computation. + XlaOp Log1p(const XlaOp& operand); + // Enqueues a sign instruction onto the computation. XlaOp Sign(const XlaOp& operand); + // Enqueues a count leading zeros instruction onto the computation. + XlaOp Clz(const XlaOp& operand); + // Enqueues a cosine instruction onto the computation. XlaOp Cos(const XlaOp& operand); @@ -828,6 +843,10 @@ class XlaBuilder { // computation and fills the root_id in the pointer. StatusOr GetProgramShape(int64* root_id) const; + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + tensorflow::gtl::ArraySlice operands) const; + // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful // operation such as `RngNormal` or `Infeed`. The visitor walks the @@ -962,8 +981,6 @@ XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D& values) { // RAII-style object: sets the current sharding assignment in builder on // construction, and sets back to the previous assignment on destruction. -// -// TODO(b/74197823): This is a part of a NOT YET ready refactor. class XlaScopedShardingAssignment { public: XlaScopedShardingAssignment(xla::XlaBuilder* builder, diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc index ce984564d01..2df3ea3af0d 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder_test.cc @@ -76,7 +76,7 @@ TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) { auto y = b.Parameter(1, y_shape, "y"); auto add = b.Add(x, y, /*broadcast_dimensions=*/{0, 1}); - TF_ASSERT_OK_AND_ASSIGN(auto add_shape, add.GetShape()); + TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add)); EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape)); TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b)); @@ -188,8 +188,10 @@ TEST_F(XlaBuilderTest, OperandFromWrongBuilder) { builder.Add(p0, p0); auto statusor = builder.Build(); ASSERT_FALSE(statusor.ok()); - EXPECT_THAT(statusor.status().error_message(), - HasSubstr("Do not add XlaOp from builder b1 to builder main")); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "built by builder 'b1', but is trying to use it in builder 'main'")); } TEST_F(XlaBuilderTest, ReshapeDefaultOrder) { diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc index a6752c60102..72e3935696e 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc @@ -17,7 +17,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { @@ -26,4 +28,13 @@ StatusOr XlaComputation::GetProgramShape() const { return proto_.program_shape(); } +StatusOr> XlaComputation::Snapshot() const { + if (IsNull()) { + return InvalidArgument("Computation is invalid."); + } + auto session = MakeUnique(); + *session->mutable_hlo()->mutable_hlo_module() = proto_; + return std::move(session); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h index 7ad212aa24c..0ffba208b1f 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h @@ -25,8 +25,6 @@ limitations under the License. namespace xla { // The computation graph that the user builds up with the XlaBuilder. -// -// TODO(b/74197823): Replace xla::Computation with this one. class XlaComputation { public: XlaComputation() : unique_id_(-1) {} @@ -48,6 +46,10 @@ class XlaComputation { const HloModuleProto& proto() const { return proto_; } + // Requests that we snapshot the computation into a serializable protocol + // buffer form. + StatusOr> Snapshot() const; + // Returns true if this object is a null Computation. bool IsNull() const { return unique_id_ == -1; } diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h index 23a622b1ad0..1a51fdee680 100644 --- a/tensorflow/compiler/xla/device_util.h +++ b/tensorflow/compiler/xla/device_util.h @@ -29,7 +29,7 @@ namespace xla { // Returns a string that represents the device in terms of platform and ordinal; // e.g. the first CUDA device will be "cuda:0" -string DeviceIdentifier(perftools::gputools::StreamExecutor* stream_exec) { +string DeviceIdentifier(se::StreamExecutor* stream_exec) { return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":", stream_exec->device_ordinal()); } diff --git a/tensorflow/compiler/xla/error_spec.h b/tensorflow/compiler/xla/error_spec.h new file mode 100644 index 00000000000..a1463aa1594 --- /dev/null +++ b/tensorflow/compiler/xla/error_spec.h @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ +#define TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ + +namespace xla { + +// Structure describing permissible absolute and relative error bounds. +struct ErrorSpec { + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} + + float abs; // Absolute error bound. + float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ERROR_SPEC_H_ diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 1700c977189..a472747bd17 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -36,26 +36,15 @@ DeviceMemoryAllocator* ExecutableRunOptions::allocator() const { } ExecutableRunOptions& ExecutableRunOptions::set_stream( - perftools::gputools::Stream* stream) { + stream_executor::Stream* stream) { stream_ = stream; return *this; } -perftools::gputools::Stream* ExecutableRunOptions::stream() const { +stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } -ExecutableRunOptions& ExecutableRunOptions::set_inter_op_thread_pool( - tensorflow::thread::ThreadPool* inter_op_thread_pool) { - inter_op_thread_pool_ = inter_op_thread_pool; - return *this; -} - -tensorflow::thread::ThreadPool* ExecutableRunOptions::inter_op_thread_pool() - const { - return inter_op_thread_pool_; -} - ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 2c1d9ffff10..416131be006 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -16,26 +16,27 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ -// Intentionally forward declared so that ExecutableRunOptions can be linked +// Pulls in the ::stream_executor -> ::xla::se namespace alias. +#include "tensorflow/compiler/xla/types.h" + +// These classes are forward declared so that ExecutableRunOptions can be linked // into an XLA-compiled binary without having to link all of the pointed-to // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't // need to be linked). -namespace perftools { -namespace gputools { +namespace stream_executor { class Stream; class Platform; -} -} +} // namespace stream_executor namespace tensorflow { namespace thread { class ThreadPool; -} -} +} // namespace thread +} // namespace tensorflow namespace Eigen { struct ThreadPoolDevice; -} +} // namespace Eigen namespace xla { @@ -61,14 +62,8 @@ class ExecutableRunOptions { // If set, this is the stream to run the computation on. The platform of the // stream must match the platform the executable was built for. A value of // nullptr indicates the option has not been set. - ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream); - perftools::gputools::Stream* stream() const; - - // Sets the thread pool on which to run parallel CPU backend - // computations. Does not take ownership. - ExecutableRunOptions& set_inter_op_thread_pool( - tensorflow::thread::ThreadPool* inter_op_thread_pool); - tensorflow::thread::ThreadPool* inter_op_thread_pool() const; + ExecutableRunOptions& set_stream(stream_executor::Stream* stream); + stream_executor::Stream* stream() const; // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. @@ -91,8 +86,7 @@ class ExecutableRunOptions { DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; DeviceAssignment* device_assignment_ = nullptr; - perftools::gputools::Stream* stream_ = nullptr; - tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; + stream_executor::Stream* stream_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index fdc4bbdd8b1..a76fdcda250 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -139,8 +140,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::SetToDefaultLayout(program_shape->mutable_result()); } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape( - const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutInShape(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape. if (shape.has_layout()) { @@ -149,12 +149,12 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : shape.tuple_shapes()) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } - return tensorflow::Status::OK(); + return Status::OK(); } else if (ShapeUtil::IsOpaque(shape)) { if (shape.has_layout()) { return InvalidArgument("opaque should not have a layout field"); } - return tensorflow::Status::OK(); + return Status::OK(); } else { // Array shape. if (!shape.has_layout()) { @@ -165,14 +165,14 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } -/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape( - const Layout& layout, const Shape& shape) { +/* static */ Status LayoutUtil::ValidateLayoutForShape(const Layout& layout, + const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } if (ShapeUtil::IsOpaque(shape)) { - return tensorflow::Status::OK(); + return Status::OK(); } if (layout.format() == INVALID_FORMAT) { @@ -224,7 +224,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } } - return tensorflow::Status::OK(); + return Status::OK(); } /* static */ void LayoutUtil::ClearLayout(Shape* shape) { @@ -383,7 +383,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { namespace { // Internal helper for recursively copying layouts. -tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { +Status CopyLayoutInternal(const Shape& src, Shape* dst) { if (ShapeUtil::IsTuple(src) != ShapeUtil::IsTuple(*dst)) { return InvalidArgument( "cannot copy layout from shape: shape structure differs"); @@ -410,14 +410,13 @@ tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) { dst->clear_layout(); } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace /* static */ -tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, - Shape* dst) { +Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } @@ -465,4 +464,25 @@ std::ostream& operator<<(std::ostream& out, const Layout& layout) { return out; } +/*static*/ size_t LayoutUtil::Hash(const Layout& layout) { + using tensorflow::hash; + using tensorflow::Hash64Combine; + + size_t hash_value = hash()(layout.format()); + + for (int64 minor_to_major : layout.minor_to_major()) { + hash_value = Hash64Combine(hash_value, hash()(minor_to_major)); + } + + for (int64 padded_dim : layout.padded_dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(padded_dim)); + } + + hash_value = + Hash64Combine(hash_value, hash()(layout.padding_value())); + hash_value = Hash64Combine(hash_value, layout.max_sparse_elements()); + + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6c54eb2201b..d3d6a2cc940 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -20,9 +20,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -61,12 +61,12 @@ class LayoutUtil { static void SetToDefaultLayout(ProgramShape* program_shape); // Validates that the layout within the given shape is correct. - static tensorflow::Status ValidateLayoutInShape(const Shape& shape); + static Status ValidateLayoutInShape(const Shape& shape); // Validates that the provided layout satisfies invariants for the given // shape. - static tensorflow::Status ValidateLayoutForShape(const Layout& layout, - const Shape& shape); + static Status ValidateLayoutForShape(const Layout& layout, + const Shape& shape); // Clears the layout in the given Shape. After this function is called, // HasLayout will return false for the shape. @@ -179,8 +179,7 @@ class LayoutUtil { // tuples. 'src' and 'dst' need not be compatible but the two shapes must // have the same tuple structure (if any) and arrays must have the same // rank. within the shapes must have the same number of dimensions. - static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src, - Shape* dst); + static Status CopyLayoutBetweenShapes(const Shape& src, Shape* dst); // Returns true if the layouts of lhs and rhs are equal, false // otherwise. Recursively compares layouts of tuples. @@ -195,6 +194,9 @@ class LayoutUtil { static bool AreDimensionsConsecutive(const Layout& layout, tensorflow::gtl::ArraySlice dims); + // Compute a hash for `layout`. + static size_t Hash(const Layout& layout); + private: TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 70ae95bf473..f42fb92359f 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -43,10 +43,16 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { #ifdef INTEL_MKL flags->set_xla_cpu_use_mkl_dnn(true); #endif // INTEL_MKL - flags->set_xla_gpu_max_kernel_unroll_factor(1); + flags->set_xla_gpu_max_kernel_unroll_factor(4); // Set cudnn batchnorm off by default; it does not provide a performance win // on average. flags->set_xla_gpu_use_cudnn_batchnorm(false); + + // Run all GPU work on one stream by default. Using multiple streams + // increases memory usage and we lack strong motivating benchmarks for tuning + // the heuristics needed to decide when to run on multiple streams. See + // b/77879207. + flags->set_xla_gpu_disable_multi_streaming(true); } // Allocates flag_values and flag_objects; this function must not be called more diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc new file mode 100644 index 00000000000..3696fdbe12e --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -0,0 +1,739 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/literal_comparison.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" + +using tensorflow::strings::Appendf; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + +namespace xla { +namespace literal_comparison { +namespace { + +// Helper function for comparing a floating point type, FloatT, bitwise equal +// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT +// -- on miscompare, a nice error message is given in the AssertionFailure. +template +Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { + auto ulhs = tensorflow::bit_cast(lhs); + auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); + if (ulhs != urhs) { + return InvalidArgument( + "floating values are not bitwise-equal; and equality testing " + "was requested: %s=%g=%a vs %s=%g=%a", + StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double, + StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double); + } + return Status::OK(); +} + +// Templated comparator that specializes for float equality comparison with the +// bitwise helper above (this is the un-specialized fallback, to just use the +// default gunit implementation). +template +Status CompareEqual(NativeT lhs, NativeT rhs) { + if (lhs == rhs) { + return Status::OK(); + } + return InvalidArgument("Expected equality of these values:\n %s\n %s", + StrCat(lhs).c_str(), StrCat(rhs).c_str()); +} + +// Specializations for floating types that do bitwise comparisons when equality +// comparison is requested. +template <> +Status CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(Eigen::half lhs, Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(float lhs, float rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(double lhs, double rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> +Status CompareEqual(complex64 lhs, complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res.ok()) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} + +// A recursive function which iterates through every index of expected and +// actual literal and compares their values elementwise. Returns true if all +// elements are equal. +template +Status Equal(LiteralSlice expected, LiteralSlice actual, + tensorflow::gtl::MutableArraySlice multi_index, + int64 dimension) { + if (dimension == expected.shape().dimensions_size()) { + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); + return CompareEqual(expected_value, actual_value); + } + + Status result; + for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { + multi_index[dimension] = i; + result.Update(Equal(expected, actual, multi_index, dimension + 1)); + } + return result; +} + +// Gets the total element count. For tuples, this is not the count of tuple +// elements, but the sum of elements of each tuple element. +int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } +} + +// Returns whether the actual and expected values are mismatched with respect to +// nans. 'relaxed_nans' is interpreted as in xla::ErrorSpec. +template +bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } +} + +template <> +bool NanMismatch(complex64 expected, complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); +} + +template <> +bool NanMismatch(half expected, half actual, bool relaxed_nans) { + return NanMismatch(static_cast(expected), + static_cast(actual), relaxed_nans); +} + +// Converts the given floating-point value to a string. +template +string FpValueToString(NativeT value) { + return Printf("%8.4g", static_cast(value)); +} + +template <> +string FpValueToString(complex64 value) { + return Printf("%8.4g + %8.4fi", value.real(), value.imag()); +} + +// Returns the absolute value of the given floating point value. This function +// is used instead of std::abs directly in order to allow type-dependent +// implementations for NearComparator. +template +float FpAbsoluteValue(NativeT value) { + return std::abs(value); +} + +template <> +float FpAbsoluteValue(bfloat16 value) { + return FpAbsoluteValue(static_cast(value)); +} + +template <> +float FpAbsoluteValue(half value) { + return FpAbsoluteValue(static_cast(value)); +} + +// Helper class for comparing floating-point literals within an error bound. +template +class NearComparator { + public: + // Compares the two array literals elementwise and returns a comparison + // result. The comparison is ok() if all actual and expected elements are + // within the given error bound. In case of error, the status contains a + // detailed message about the discrepancy. + static Status Compare(const LiteralSlice& expected, + const LiteralSlice& actual, ErrorSpec error, + bool detailed_message, + const MiscompareCallback& miscompare_callback) { + NearComparator comparator(expected, actual, error, + detailed_message, miscompare_callback); + return comparator.Run(); + } + + private: + // Data structure encapsulating metadata about a single element mismatch. + struct Mismatch { + NativeT actual; + NativeT expected; + float rel_error; + float abs_error; + + // The linear index of the failure within the shape. This linear index is + // from the 'actual' literal. + int64 linear_index; + + bool operator<(const Mismatch& other) const { + return rel_error < other.rel_error; + } + + string ToString(const Shape& shape) const { + return Printf( + "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g", + FpValueToString(actual).c_str(), FpValueToString(expected).c_str(), + Literal::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + rel_error, abs_error); + } + }; + + NearComparator(const LiteralSlice& expected, const LiteralSlice& actual, + ErrorSpec error, bool detailed_message, + const MiscompareCallback& miscompare_callback) + : expected_(expected), + actual_(actual), + error_(error), + detailed_message_(detailed_message), + miscompare_callback_(miscompare_callback), + abs_value_buckets_(kAbsValueBucketBounds.size() - 1, {0, 0}), + abs_error_buckets_(kErrorBucketBounds.size(), 0), + rel_error_buckets_(kErrorBucketBounds.size(), 0) {} + + // Runs the comparison between expected and actual literals. + Status Run() { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, ToStringTruncated(expected_)); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, ToStringTruncated(actual_)); + + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + TF_RETURN_IF_ERROR(EqualShapes(expected_.shape(), actual_.shape())); + if (!ShapeUtil::IsArray(expected_.shape())) { + return InvalidArgument("Expected array shape; got %s.", + ShapeUtil::HumanString(expected_.shape()).c_str()); + } + + mismatches_ = Literal(ShapeUtil::ChangeElementType(actual_.shape(), PRED)); + mismatches_.PopulateWithValue(false); + + CompareLiterals(); + + if (num_mismatches_ == 0) { + return Status::OK(); + } else if (!VLOG_IS_ON(1) && miscompare_callback_ != nullptr) { + miscompare_callback_(expected_, actual_, mismatches_); + } + return InvalidArgument("%s", ErrorMessage().c_str()); + } + + // Insert the given absolute value into the absolute value bucket vector. The + // bounds of the buckets are given by kAbsValueBucketBounds. + void UpdateAbsValueBucket(NativeT value, bool is_mismatch) { + // Adjust the bucket containing the absolute values of the 'actual' + // elements. + const float abs_value = FpAbsoluteValue(value); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + if (i == abs_value_buckets_.size() - 1 || + (abs_value >= kAbsValueBucketBounds[i] && + abs_value < kAbsValueBucketBounds[i + 1])) { + // The first value of the pair is the count of elements in the bucket, + // the second is the count of mismatches in the bucket. + abs_value_buckets_[i].first++; + if (is_mismatch) { + abs_value_buckets_[i].second++; + } + return; + } + } + } + + // Insert the given error into the given error bucket vector. + void UpdateErrorBucket( + float error, tensorflow::gtl::MutableArraySlice error_buckets) { + CHECK_EQ(error_buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < error_buckets.size(); ++i) { + if (error >= kErrorBucketBounds[i]) { + error_buckets[i]++; + } + } + } + + // Compares the two given elements from the expected and actual literals at + // the given literal_index and keeps track of various mismatch statistics. + void CompareValues(NativeT expected, NativeT actual, int64 linear_index) { + const bool is_nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + float abs_error; + float rel_error; + if (actual == expected) { + abs_error = 0; + rel_error = 0; + } else if (is_nan_mismatch) { + num_nan_mismatches_++; + // A nan mismatch is considered to have infinite error. rel_error is used + // for sorting a std::set of the top mismatchs, and a nan value here will + // result in undefined behavior because nan's do not satisfy the strict + // weak ordering requirement of std containers. + abs_error = std::numeric_limits::infinity(); + rel_error = std::numeric_limits::infinity(); + } else { + abs_error = FpAbsoluteValue(actual - expected); + rel_error = abs_error / FpAbsoluteValue(expected); + } + const bool is_abs_mismatch = abs_error > error_.abs; + const bool is_rel_mismatch = rel_error > error_.rel; + const bool is_mismatch = + is_nan_mismatch || (is_abs_mismatch && is_rel_mismatch); + + // Update the error of the relative bucket only if the *absolute* error + // bound is exceeded and vice versa. + if (is_abs_mismatch) { + num_abs_mismatches_++; + UpdateErrorBucket(rel_error, &rel_error_buckets_); + } + if (is_rel_mismatch) { + num_rel_mismatches_++; + UpdateErrorBucket(abs_error, &abs_error_buckets_); + } + + UpdateAbsValueBucket(actual, is_mismatch); + + if (!is_mismatch) { + return; + } + + num_mismatches_++; + + // Keep track of the kTopRelativeErrorCount relative error mismatches. + if (top_rel_mismatches_.size() < kTopRelativeErrorCount || + rel_error > top_rel_mismatches_.begin()->rel_error) { + Mismatch mismatch = {actual, expected, rel_error, abs_error, + linear_index}; + top_rel_mismatches_.insert(mismatch); + if (top_rel_mismatches_.size() > kTopRelativeErrorCount) { + top_rel_mismatches_.erase(top_rel_mismatches_.begin()); + } + } + + mismatches_.data()[linear_index] = true; + } + + // Compares the two literals elementwise. + void CompareLiterals() { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual_.shape().layout(), + expected_.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected_.data(); + tensorflow::gtl::ArraySlice actual_data = + actual_.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + CompareValues(expected_data[i], actual_data[i], i); + } + return; + } + std::vector multi_index(ShapeUtil::Rank(actual_.shape()), 0); + CompareLiteralsSlow(0, &multi_index); + } + + // Slow path for CompareLiterals when 'actual' and 'expected' literals have + // different layouts. In this case, multidimensional indices are constructed + // and indexed for each element. + void CompareLiteralsSlow(int64 dimension, std::vector* multi_index) { + if (dimension == multi_index->size()) { + CompareValues(expected_.Get(*multi_index), + actual_.Get(*multi_index), + IndexUtil::MultidimensionalIndexToLinearIndex( + actual_.shape(), *multi_index)); + } else { + for (int64 i = 0; i < expected_.shape().dimensions(dimension); ++i) { + (*multi_index)[dimension] = i; + CompareLiteralsSlow(dimension + 1, multi_index); + } + } + } + + // Returns an error message string with a detailed breakdown of the + // mismatches. Called after calling Run(). + string ErrorMessage() { + string out; + int64 element_count = ShapeUtil::ElementsIn(actual_.shape()); + + auto percent_string = [](float a, float b) { + float pct = b == 0.0 ? 0.0 : 100.0 * a / b; + return Printf("%0.4f%%", pct); + }; + + Appendf(&out, + "\nMismatch count %lld (%s) in shape %s (%lld elements), abs bound " + "%g, rel bound %g\n", + num_mismatches_, + percent_string(num_mismatches_, element_count).c_str(), + ShapeUtil::HumanString(actual_.shape()).c_str(), + ShapeUtil::ElementsIn(actual_.shape()), error_.abs, error_.rel); + if (num_nan_mismatches_ > 0) { + StrAppend(&out, "nan mismatches ", num_nan_mismatches_, "\n"); + } + Appendf(&out, "Top relative error mismatches:\n"); + for (auto it = top_rel_mismatches_.rbegin(); + it != top_rel_mismatches_.rend(); ++it) { + StrAppend(&out, " ", it->ToString(actual_.shape()).c_str(), "\n"); + } + + if (!detailed_message_) { + return out; + } + + StrAppend(&out, "Absolute magnitude breakdown of actual values:\n"); + CHECK_EQ(abs_value_buckets_.size() + 1, kAbsValueBucketBounds.size()); + for (int i = 0; i < abs_value_buckets_.size(); ++i) { + const int64 bucket_size = abs_value_buckets_[i].first; + const int64 bucket_mismatches = abs_value_buckets_[i].second; + string mismatch_str = bucket_mismatches > 0 + ? Printf(", mismatches %lld", bucket_mismatches) + : ""; + Appendf(&out, " %-6g <= x < %-6g : %7lld (%9s)%s\n", + kAbsValueBucketBounds[i], kAbsValueBucketBounds[i + 1], + bucket_size, percent_string(bucket_size, element_count).c_str(), + mismatch_str.c_str()); + } + + auto print_accum_buckets = [&](const string& header, int64 total, + tensorflow::gtl::ArraySlice buckets) { + StrAppend(&out, header, ":\n"); + Appendf(&out, " < %-6g : %7lld (%s)\n", kErrorBucketBounds[0], + total - buckets[0], + percent_string(total - buckets[0], total).c_str()); + CHECK_EQ(buckets.size(), kErrorBucketBounds.size()); + for (int i = 0; i < kErrorBucketBounds.size(); ++i) { + Appendf(&out, " >= %-6g : %7lld (%s)\n", kErrorBucketBounds[i], + buckets[i], percent_string(buckets[i], total).c_str()); + } + }; + Appendf(&out, "Elements exceeding abs error bound %g: %lld (%s)\n", + error_.abs, num_abs_mismatches_, + percent_string(num_abs_mismatches_, element_count).c_str()); + print_accum_buckets( + "Relative error breakdown of elements exceeding abs error bound", + num_abs_mismatches_, rel_error_buckets_); + Appendf(&out, "Elements exceeding rel error bound %g: %lld (%s)\n", + error_.rel, num_rel_mismatches_, + percent_string(num_rel_mismatches_, element_count).c_str()); + print_accum_buckets( + "Absolute error breakdown of elements exceeding rel error bound", + num_rel_mismatches_, abs_error_buckets_); + return out; + } + + // 'actual' and 'expected' literals being compared. + LiteralSlice expected_; + LiteralSlice actual_; + + // The error bounds of the comparison. + ErrorSpec error_; + + // Whether to include detailed breakdown of mismatches in the error message. + bool detailed_message_; + + // Callback to invoke on miscompare. + MiscompareCallback miscompare_callback_; + + // Number of element element mismatches encountered so far. + int64 num_mismatches_ = 0; + + // Number of elements with a nan mismatch. + int64 num_nan_mismatches_ = 0; + + // Number of elements which exceed the absolute/relative error bound. + int64 num_abs_mismatches_ = 0; + int64 num_rel_mismatches_ = 0; + + // A Literal containing which elements did not match in the expected and + // actual literals. mismatches_ contains PREDs and is of the same sizes as + // the comparison literals. + Literal mismatches_; + + // The number of mismatches to report in the output, sorted by relative error + // magnitude. + static constexpr int64 kTopRelativeErrorCount = 5; + + // The set of mismatches with the largest relative error. The size of this set + // is bounded by kTopRelativeErrorCount. + std::multiset top_rel_mismatches_; + + // Actual values are bucketed by absolute value. kAbsValueBucketBounds is the + // bounds of these buckets. abs_value_buckets_ contains a pair for each + // bucket: the element count and failure count. + static constexpr std::array kAbsValueBucketBounds = { + 0.0, 0.0001, 0.001, 0.01, 0.1, 1, std::numeric_limits::infinity()}; + std::vector> abs_value_buckets_; + + // Buckets for relative and absolute errors. The relative error buckets only + // contains those elements which exceed the *absolute* error bound, and vice + // versa. This makes it easy to see the effect of adjusting the relative (or + // absolute) error bound on the success of the comparison. kErrorBucketBounds + // are the lower bounds of the buckets in both vectors. The error buckets are + // a cumulative distribution so an error value may appear in more than one + // bucket. For example an error value of 0.003 may appear in the buckets + // bounded by 0.01, 0.1, and 1.0. + static constexpr std::array kErrorBucketBounds = {0.0001, 0.001, + 0.01, 0.1, 1}; + std::vector abs_error_buckets_; + std::vector rel_error_buckets_; +}; + +template +constexpr std::array NearComparator::kAbsValueBucketBounds; +template +constexpr std::array NearComparator::kErrorBucketBounds; + +// Helper function for comparing two literals for nearness. Handles tuple-shapes +// via recursion. shape_index is the ShapeIndex of expected (or actual) +// currently being compared. +Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback, + const ShapeIndex& shape_index) { + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + + if (ShapeUtil::IsTuple(expected.shape())) { + Status return_status; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + const auto expected_element = LiteralSlice(expected, {i}); + const auto actual_element = LiteralSlice(actual, {i}); + ShapeIndex element_index = shape_index; + element_index.push_back(i); + Status res = + NearHelper(expected_element, actual_element, error, detailed_message, + miscompare_callback, element_index); + if (!res.ok()) { + string err_message = Printf("\nArray at shape index %s%s", + element_index.ToString().c_str(), + res.error_message().c_str()); + if (return_status.ok()) { + return_status = res; + } else { + return_status = AppendStatus(return_status, res.error_message()); + } + } + } + if (!return_status.ok() && shape_index.empty()) { + // Emit a top-level error message containing the top-level shape in case + // of mismatch. + int64 total_elements = RecursiveElementCount(actual.shape()); + return_status = InvalidArgument( + "\nMismatches in shape %s (%lld elements):\n%s", + ShapeUtil::HumanString(actual.shape()).c_str(), total_elements, + return_status.error_message().c_str()); + } + return return_status; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + switch (expected.shape().element_type()) { + case BF16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F16: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F32: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case F64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + case C64: + return NearComparator::Compare( + expected, actual, error, detailed_message, miscompare_callback); + break; + default: + LOG(FATAL) << "Unsupported primitive type in near comparator: " + << PrimitiveType_Name(expected.shape().element_type()) + << ". Must be floating-point type."; + } + } + + // Non-floating point literal. + return literal_comparison::Equal(expected, actual); +} + +} // namespace + +Status EqualShapes(const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return InvalidArgument("tupleness-mismatch! want: %s got %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (ShapeUtil::IsTuple(expected)) { + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return InvalidArgument( + "want tuple element count: %lld got tuple element count: %lld", + ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + } + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + Status result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result.ok()) { + return AppendStatus(result, StrCat("mismatch in tuple index", i)); + } + } + } else { + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return InvalidArgument("want rank of %s got rank of %s", + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + if (expected.element_type() != actual.element_type()) { + return InvalidArgument( + "mismatch in primitive type %s vs %s", + PrimitiveType_Name(expected.element_type()).c_str(), + PrimitiveType_Name(actual.element_type()).c_str()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return InvalidArgument("want dimensions_size %d got dimensions_size %d", + expected.dimensions_size(), + actual.dimensions_size()); + } + for (int i = 0; i < expected.dimensions_size(); ++i) { + if (expected.dimensions(i) != actual.dimensions(i)) { + return InvalidArgument( + "mismatch in dimension #%d expected: %s actual: %s", i, + ShapeUtil::HumanString(expected).c_str(), + ShapeUtil::HumanString(actual).c_str()); + } + } + } + return Status::OK(); +} + +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { + VLOG(1) << "expected:"; + XLA_VLOG_LINES(1, expected.ToString()); + VLOG(1) << "actual:"; + XLA_VLOG_LINES(1, actual.ToString()); + + TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape())); + std::vector multi_index(expected.shape().dimensions_size(), 0); + Status result; + switch (expected.shape().element_type()) { + case PRED: + result = Equal(expected, actual, &multi_index, 0); + break; + case U8: + result = Equal(expected, actual, &multi_index, 0); + break; + case S32: + result = Equal(expected, actual, &multi_index, 0); + break; + case S64: + result = Equal(expected, actual, &multi_index, 0); + break; + case U32: + result = Equal(expected, actual, &multi_index, 0); + break; + case U64: + result = Equal(expected, actual, &multi_index, 0); + break; + case BF16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F16: + result = Equal(expected, actual, &multi_index, 0); + break; + case F32: + result = Equal(expected, actual, &multi_index, 0); + break; + case F64: + result = Equal(expected, actual, &multi_index, 0); + break; + case C64: + result = Equal(expected, actual, &multi_index, 0); + break; + case TUPLE: { + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + result.Update( + Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}))); + } + break; + } + default: + LOG(FATAL) + << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " + << PrimitiveType_Name(expected.shape().element_type()); + } + + if (result.ok()) { + return Status::OK(); + } + + return AppendStatus(result, + tensorflow::strings::Printf("expected: %s\nactual: %s", + expected.ToString().c_str(), + actual.ToString().c_str())); +} + +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback) { + return NearHelper(expected, actual, error, detailed_message, + miscompare_callback, + /*shape_index=*/{}); +} + +string ToStringTruncated(const LiteralSlice& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; +} + +} // namespace literal_comparison +} // namespace xla diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h new file mode 100644 index 00000000000..00a13e36193 --- /dev/null +++ b/tensorflow/compiler/xla/literal_comparison.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library for comparing literals without taking a dependency on testing +// libraries. + +#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ +#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ + +#include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace literal_comparison { + +// Returns ok if the given shapes have the same rank, dimension sizes, and +// primitive types. +Status EqualShapes(const Shape& expected, const Shape& actual); + +// Returns ok if the expected and actual literals are (bitwise) equal for all +// elements in the literal. Also, asserts that the rank, dimensions sizes, and +// primitive type are equal. +Status Equal(const LiteralSlice& expected, const LiteralSlice& actual); + +using MiscompareCallback = + std::function; + +// Inspects whether the expected and actual literals are within the given error +// bound for all elements. Also, inspects whether the rank, dimensions sizes, +// and dimension bounds are equivalent. +// +// Tuples are matched recursively. +// +// When comparing tensors of non-floating-point type, this inspects for exact +// equality, ignoring the ErrorSpec. +// +// If the shape of the literals is neither a complex/floating-point tensor nor a +// tuple which contains a complex/floating-point tensor, Near() is equivalent to +// Equal(). We don't raise an error in this case, because we want to allow +// callers to call Near() even if they have no preconceptions about the shapes +// being compared. +// +// If detailed_message is true, then the error message in the assertion result +// will contain a more detailed breakdown of mismatches. +Status Near(const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error, bool detailed_message, + const MiscompareCallback& miscompare_callback); + +// Calling ToString on a literal with over 100 million elements takes around +// 3 minutes. The utility of printing a literal with >1000 elements is +// questionable, especially when writing the Literal proto to disk is orders +// of magnitude faster. +string ToStringTruncated(const LiteralSlice& literal); + +} // namespace literal_comparison +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index c315b4ff300..4c560767dc6 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -44,8 +45,16 @@ namespace { constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; -// Converts between little and big endian, assuming elements in the array are 16 -// bits long. +// Converts between little and big endian. +// +// Precondition: size % 2 == 0 (elements in the array are 16 bits long) +void ConvertEndianShort(string* bytes) { + CHECK_EQ(bytes->size() / 2, 0); + for (int64 i = 0; i < bytes->size(); i += 2) { + std::swap((*bytes)[i], (*bytes)[i + 1]); + } +} + void ConvertEndianShort(char* bytes, int64 size) { CHECK_EQ(size / 2, 0); for (int64 i = 0; i < size; i += 2) { @@ -53,8 +62,49 @@ void ConvertEndianShort(char* bytes, int64 size) { } } +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(LiteralSlice literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + } // namespace +LiteralBase::~LiteralBase() {} + std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); return out; @@ -86,99 +136,89 @@ Literal::StrideConfig::StrideConfig( Literal::Literal(const Shape& shape) : Literal(shape, /*allocate_arrays=*/true) {} -Literal::Literal(const Shape& shape, bool allocate_arrays) - : shape_(shape), pieces_(shape), owns_buffers_(true) { - CHECK(LayoutUtil::HasLayout(shape)); - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; +void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { + if (ShapeUtil::IsTuple(shape)) { + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - const Shape& subshape = piece.subshape(); - if (ShapeUtil::IsArray(subshape)) { - if (allocate_arrays) { - if (LayoutUtil::IsSparseArray(subshape)) { - // For sparse arrays, the buffer must be of the size of the maximum - // number of sparse elements possible. - const int64 max_sparse_elements = - LayoutUtil::MaxSparseElements(subshape.layout()); - piece.set_buffer( - new char[max_sparse_elements * ShapeUtil::ByteSizeOfPrimitiveType( - subshape.element_type())]); - piece.set_sparse_indices(new SparseIndexArray( - max_sparse_elements, ShapeUtil::Rank(subshape))); - } else { - piece.set_buffer(new char[piece.size_bytes()]); - } + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); + + SetPiece(subshape, &child_piece, allocate_arrays); + + piece->emplace_back(std::move(child_piece)); + } + } else { + CHECK(ShapeUtil::IsArray(shape)); + if (allocate_arrays) { + if (LayoutUtil::IsSparseArray(shape)) { + // For sparse arrays, the buffer must be of the size of the maximum + // number of sparse elements possible. + const int64 max_sparse_elements = + LayoutUtil::MaxSparseElements(shape.layout()); + piece->set_buffer( + new char[max_sparse_elements * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type())]); + piece->set_sparse_indices( + new SparseIndexArray(max_sparse_elements, ShapeUtil::Rank(shape))); } else { - piece.set_buffer(nullptr); + piece->set_buffer(new char[piece->size_bytes()]); } } } } -Literal::~Literal() { DeallocateBuffers(); } +Literal::Literal(const Shape& shape, bool allocate_arrays) + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(LayoutUtil::HasLayout(*shape_)); + root_piece_ = new Piece(); + root_piece_->set_subshape(shape_.get()); + CHECK(&root_piece_->subshape() == shape_.get()); + + SetPiece(*shape_, root_piece_, allocate_arrays); +} + +Literal::~Literal() { + if (root_piece_ != nullptr) { + DeallocateBuffers(); + delete root_piece_; + } +} void Literal::DeallocateBuffers() { - if (owns_buffers_) { - for (auto& pair : pieces_) { - Piece& piece = pair.second; - if (piece.buffer() != nullptr) { - delete[] piece.buffer(); - delete piece.sparse_indices(); - } - } - } + root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (piece->buffer() != nullptr) { + delete[] piece->buffer(); + delete piece->sparse_indices(); + } + }); } -Literal::Literal(Literal&& other) { - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; - - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); -} +Literal::Literal(Literal&& other) : LiteralBase() { *this = std::move(other); } Literal& Literal::operator=(Literal&& other) { - DeallocateBuffers(); - shape_ = std::move(other.shape_); - pieces_ = std::move(other.pieces_); - // We need to iterate through the pieces to set the subshape pointer - // properly. It must refer to subshapes within shape_. - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); - } - owns_buffers_ = other.owns_buffers_; + DCHECK(&other.root_piece_->subshape() == other.shape_.get()); + using std::swap; + swap(shape_, other.shape_); + swap(root_piece_, other.root_piece_); + DCHECK(&root_piece_->subshape() == shape_.get()); - other.shape_ = ShapeUtil::MakeNil(); - other.pieces_ = ShapeTree(other.shape_); - other.piece({}).set_subshape(&other.shape_); return *this; } -std::unique_ptr Literal::CreateFromShape(const Shape& shape) { +std::unique_ptr LiteralBase::CreateFromShape(const Shape& shape) { auto literal = MakeUnique(shape); - for (auto& pair : literal->pieces_) { - Piece& piece = pair.second; - if (ShapeUtil::IsArray(piece.subshape())) { - memset(piece.untyped_data(), 0, piece.size_bytes()); - } - } + literal->root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* piece) { + if (ShapeUtil::IsArray(piece->subshape())) { + memset(piece->untyped_data(), 0, piece->size_bytes()); + } + }); return literal; } -const SparseIndexArray* Literal::sparse_indices( +const SparseIndexArray* LiteralBase::sparse_indices( const ShapeIndex& shape_index) const { return piece(shape_index).sparse_indices(); } @@ -193,9 +233,19 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ std::unique_ptr Literal::ConvertBF16ToF32( + const LiteralSlice& bf16_literal) { + return ConvertType(bf16_literal); +} + +/* static */ std::unique_ptr Literal::ConvertF32ToBF16( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + template Status Literal::CopySliceFromInternal( - const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); @@ -255,7 +305,7 @@ Status Literal::CopySliceFromInternal( return Status::OK(); } -Status Literal::CopyElementFrom(const Literal& src_literal, +Status Literal::CopyElementFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_index, tensorflow::gtl::ArraySlice dest_index) { DCHECK_EQ(shape().element_type(), src_literal.shape().element_type()); @@ -284,22 +334,21 @@ std::vector Literal::DecomposeTuple() { elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), /*allocate_arrays=*/false)); Literal& element = elements.back(); - for (auto& pair : element.pieces_) { - const ShapeIndex& index = pair.first; - Piece& dest_piece = pair.second; - ShapeIndex src_index = {i}; - for (int64 j : index) { - src_index.push_back(j); - } - Piece& src_piece = piece(src_index); + element.root_piece_->ForEachMutableSubpiece( + [&](const ShapeIndex& index, Piece* dest_piece) { + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); - // Move the respective buffer and sparse indices over to the element - // Literal. - dest_piece.set_buffer(src_piece.buffer()); - src_piece.set_buffer(nullptr); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - src_piece.set_sparse_indices(nullptr); - } + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece->set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece->set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + }); } // Set this literal to be nil-shaped. *this = Literal(); @@ -342,7 +391,9 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFrom(const Literal::Piece& src) { +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { + CHECK(subshape_ != nullptr); + CHECK(src.subshape_ != nullptr); if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); @@ -379,7 +430,7 @@ Status Literal::Piece::CopyFrom(const Literal::Piece& src) { return Status::OK(); } -Status Literal::CopyFrom(const Literal& src_literal, +Status Literal::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, const ShapeIndex& src_shape_index) { const Shape& dest_subshape = @@ -392,36 +443,32 @@ Status Literal::CopyFrom(const Literal& src_literal, ShapeUtil::HumanString(dest_subshape).c_str(), ShapeUtil::HumanString(src_subshape).c_str()); } + return root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + if (!ShapeUtil::IsArray(piece->subshape())) { + return Status::OK(); + } - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - - // Determine if this index is in the part of this literal that we want to - // copy over from src_literal. - bool in_subtree_to_copy = true; - for (int i = 0; i < dest_shape_index.size(); ++i) { - if (index[i] != dest_shape_index[i]) { - in_subtree_to_copy = false; - break; - } - } - if (!in_subtree_to_copy) { - continue; - } - - // Construct the index of the corresponding piece in the source literal. - ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { - src_piece_index.push_back(index[i]); - } - - TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); - } - return Status::OK(); + // Determine if this index is in the part of this literal that we want + // to copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + return Status::OK(); + } + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + return Status::OK(); + }); } Status Literal::MoveFrom(Literal&& src_literal, @@ -435,37 +482,32 @@ Status Literal::MoveFrom(Literal&& src_literal, ShapeUtil::HumanString(src_literal.shape()).c_str()); } - if (!(owns_buffers_ && src_literal.owns_buffers_)) { - return InvalidArgument( - "Source and destination literals must both own their buffers (ie, not " - "be views)"); - } + src_literal.root_piece_->ForEachSubpiece( + [&](const ShapeIndex& src_index, const Piece& src_piece) { + if (!ShapeUtil::IsArray(src_piece.subshape())) { + return; + } - for (auto& pair : src_literal.pieces_) { - const ShapeIndex& src_index = pair.first; - Piece& src_piece = pair.second; - if (!ShapeUtil::IsArray(src_piece.subshape())) { - continue; - } + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + }); - ShapeIndex dest_index = dest_shape_index; - for (int64 i : src_index) { - dest_index.push_back(i); - } - Piece& dest_piece = piece(dest_index); - delete[] dest_piece.buffer(); - dest_piece.set_buffer(src_piece.buffer()); - delete dest_piece.sparse_indices(); - dest_piece.set_sparse_indices(src_piece.sparse_indices()); - } + src_literal.shape_ = MakeUnique(ShapeUtil::MakeNil()); + delete src_literal.root_piece_; + src_literal.root_piece_ = new LiteralBase::Piece(); + src_literal.root_piece_->set_subshape(src_literal.shape_.get()); - src_literal.shape_ = ShapeUtil::MakeNil(); - src_literal.pieces_ = ShapeTree(src_literal.shape_); - src_literal.piece({}).set_subshape(&src_literal.shape_); return Status::OK(); } -Status Literal::CopySliceFrom(const Literal& src_literal, +Status Literal::CopySliceFrom(const LiteralSlice& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size) { @@ -734,7 +776,7 @@ void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { return CreateR2FromArray2D(*value); } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { // Create new shape with 'new_layout' set at the given shape index. Shape new_shape = shape(); @@ -746,7 +788,7 @@ std::unique_ptr Literal::Relayout( return result; } -std::unique_ptr Literal::Relayout( +std::unique_ptr LiteralBase::Relayout( const Shape& shape_with_layout) const { CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) @@ -765,7 +807,7 @@ std::unique_ptr Literal::Relayout( return result; } -StatusOr> Literal::Reshape( +StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); @@ -779,7 +821,8 @@ StatusOr> Literal::Reshape( } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); + *output->mutable_shape_do_not_use() = + ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -793,7 +836,79 @@ StatusOr> Literal::Reshape( return std::move(output); } -std::unique_ptr Literal::Transpose( +/* static */ std::unique_ptr Literal::ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal) { + int64 new_num_elements = 1; + for (int64 i = 0; i < new_dimensions.size(); ++i) { + new_num_elements *= new_dimensions[i]; + } + CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); + + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); + + // Create a new shape with the given minor-to-major layout. This shape is used + // solely for converting linear address to multi-dimensional addresses when + // writing elements to the new literal. + Shape shape_with_layout = new_literal->shape(); + *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); + + // Copy data into new literal, element-by-element. + for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { + std::vector from_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); + std::vector to_multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); + switch (literal.shape().element_type()) { + case PRED: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U8: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case U64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case S64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F32: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case F64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; + default: + LOG(FATAL) << "Unhandled primitive element type: " + << PrimitiveType_Name(literal.shape().element_type()); + } + } + + return new_literal; +} + +std::unique_ptr LiteralBase::Transpose( tensorflow::gtl::ArraySlice permutation) const { CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) @@ -824,15 +939,14 @@ std::unique_ptr Literal::Transpose( for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } - std::unique_ptr new_literal = CreateFromShape(permuted_shape); - DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), + auto new_literal = MakeUnique(permuted_shape); + DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), - root_piece().size_bytes()); + std::memcpy(new_literal->untyped_data(), untyped_data(), size_bytes()); return new_literal; } -std::unique_ptr Literal::Slice( +std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; @@ -900,20 +1014,20 @@ std::unique_ptr Literal::Slice( } } -Literal Literal::Clone() const { +Literal LiteralBase::Clone() const { Literal result(shape()); TF_CHECK_OK(result.CopyFrom(*this)); return result; } -std::unique_ptr Literal::CloneToUnique() const { +std::unique_ptr LiteralBase::CloneToUnique() const { auto result = MakeUnique(shape()); TF_CHECK_OK(result->CopyFrom(*this)); return result; } -string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +string LiteralBase::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsDenseArray(subshape)); switch (subshape.element_type()) { @@ -953,8 +1067,8 @@ string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, } } -string Literal::GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +string LiteralBase::GetSparseElementAsString( + int64 sparse_element_number, const ShapeIndex& shape_index) const { const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); CHECK(LayoutUtil::IsSparseArray(subshape)); switch (subshape.element_type()) { @@ -1008,7 +1122,7 @@ string Literal::GetSparseElementAsString(int64 sparse_element_number, } } -StatusOr Literal::GetIntegralAsS64( +StatusOr LiteralBase::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { @@ -1031,6 +1145,27 @@ StatusOr Literal::GetIntegralAsS64( } } +size_t LiteralBase::Hash() const { + using tensorflow::Hash64; + using tensorflow::Hash64Combine; + + size_t hash_value = ShapeUtil::Hash(shape()); + + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsTuple(subshape)) { + return; + } + + CHECK(LayoutUtil::IsDense(subshape.layout())); + hash_value = Hash64Combine( + hash_value, Hash64(static_cast(untyped_data(index)), + size_bytes(index))); + }); + + return hash_value; +} + Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, int64 value) { CHECK(LayoutUtil::IsDenseArray(shape())); @@ -1061,7 +1196,7 @@ Status Literal::SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, return Status::OK(); } -tensorflow::gtl::ArraySlice Literal::GetSparseIndex( +tensorflow::gtl::ArraySlice LiteralBase::GetSparseIndex( int64 sparse_element_number, const ShapeIndex& shape_index) const { const Piece& p = piece(shape_index); CHECK_GE(sparse_element_number, 0); @@ -1073,10 +1208,10 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } -Literal Literal::GetFirstScalarLiteral() const { - CHECK(ShapeUtil::IsArray(shape_)); - CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); - switch (shape_.element_type()) { +Literal LiteralBase::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_GT(ShapeUtil::ElementsIn(shape()), 0); + switch (shape().element_type()) { case PRED: return std::move(*Literal::CreateR0(GetFirstElement())); // 8 bit types. @@ -1112,11 +1247,11 @@ Literal Literal::GetFirstScalarLiteral() const { case U64: return std::move(*Literal::CreateR0(GetFirstElement())); default: - LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); } } -void Literal::Piece::SortSparseElements() { +void LiteralBase::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: SortSparseElementsInternal(); @@ -1167,7 +1302,7 @@ void Literal::Piece::SortSparseElements() { } template -void Literal::Piece::SortSparseElementsInternal() { +void LiteralBase::Piece::SortSparseElementsInternal() { CHECK(LayoutUtil::IsSparseArray(subshape())); int64 num_elements = sparse_indices()->index_count(); auto values = data(); @@ -1178,9 +1313,11 @@ void Literal::Piece::SortSparseElementsInternal() { namespace { -void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, +void ToStringHelper(const LiteralBase& literal, const ShapeIndex& shape_index, bool print_layout, std::vector* pieces) { const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); + CHECK(LayoutUtil::HasLayout(literal.shape())); + CHECK(LayoutUtil::HasLayout(subshape)); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -1339,13 +1476,14 @@ void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, } // namespace -int64 Literal::sparse_element_count() const { +int64 LiteralBase::sparse_element_count() const { CHECK(LayoutUtil::IsSparseArray(shape())); return sparse_indices()->index_count(); } -string Literal::ToString(bool print_layout) const { +string LiteralBase::ToString(bool print_layout) const { std::vector pieces; + CHECK(LayoutUtil::HasLayout(this->shape())); ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } @@ -1353,7 +1491,7 @@ string Literal::ToString(bool print_layout) const { /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; - for (const Literal* element : elements) { + for (const auto* element : elements) { element_shapes.push_back(element->shape()); } auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); @@ -1363,6 +1501,19 @@ string Literal::ToString(bool print_layout) const { return literal; } +/* static */ std::unique_ptr Literal::MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements) { + std::vector element_shapes; + for (const auto& element : elements) { + element_shapes.push_back(element.shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(elements[i], /*dest_shape_index=*/{i})); + } + return literal; +} + /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { std::vector element_shapes; @@ -1378,7 +1529,7 @@ string Literal::ToString(bool print_layout) const { return literal; } -void Literal::EachCellAsString( +void LiteralBase::EachCellAsString( const std::function indices, const string& value)>& per_cell) const { if (ShapeUtil::HasZeroElements(shape())) { @@ -1394,7 +1545,7 @@ void Literal::EachCellAsString( namespace { template std::unique_ptr ConvertBetweenNativeTypesWithConverter( - const Literal& src_literal, const ConverterType& converter) { + const LiteralBase& src_literal, const ConverterType& converter) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( src_literal.shape(), @@ -1410,7 +1561,8 @@ std::unique_ptr ConvertBetweenNativeTypesWithConverter( } template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { +std::unique_ptr ConvertBetweenNativeTypes( + const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return static_cast(src); }; return ConvertBetweenNativeTypesWithConverter( src_literal, converter); @@ -1419,7 +1571,7 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) == sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { auto converter = [](NativeSrcT src) { return tensorflow::bit_cast(src); }; @@ -1434,12 +1586,12 @@ BitcastBetweenNativeTypes(const Literal& src_literal) { template typename std::enable_if<(sizeof(NativeSrcT) != sizeof(NativeDestT)), std::unique_ptr>::type -BitcastBetweenNativeTypes(const Literal& src_literal) { +BitcastBetweenNativeTypes(const LiteralBase& src_literal) { LOG(FATAL) << "Invalid bitcast between types of different sizes."; } template -std::unique_ptr ConvertToC64(const Literal& src_literal) { +std::unique_ptr ConvertToC64(const LiteralBase& src_literal) { CHECK(ShapeUtil::IsArray(src_literal.shape())); auto result_literal = MakeUnique( ShapeUtil::ChangeElementType(src_literal.shape(), C64)); @@ -1457,7 +1609,7 @@ std::unique_ptr ConvertToC64(const Literal& src_literal) { } template -std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, +std::unique_ptr ConvertIfTypesMatch(const LiteralBase& src_literal, bool bitcast) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); if (bitcast) { @@ -1477,7 +1629,7 @@ std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal, template StatusOr> ConvertIfDestTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type, + const LiteralBase& src_literal, PrimitiveType primitive_dest_type, bool bitcast) { switch (primitive_dest_type) { #define CONVERT_IF_TYPES_MATCH(type) \ @@ -1512,7 +1664,8 @@ StatusOr> ConvertIfDestTypeMatches( } StatusOr> ConvertSwitch( - const Literal& literal, PrimitiveType primitive_dest_type, bool bitcast) { + const LiteralBase& literal, PrimitiveType primitive_dest_type, + bool bitcast) { TF_RET_CHECK(ShapeUtil::IsArray(literal.shape())); if (literal.shape().element_type() == primitive_dest_type) { return literal.CloneToUnique(); @@ -1546,12 +1699,12 @@ StatusOr> ConvertSwitch( } // namespace -StatusOr> Literal::Convert( +StatusOr> LiteralBase::Convert( PrimitiveType primitive_dest_type) const { return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/false); } -StatusOr> Literal::BitcastConvert( +StatusOr> LiteralBase::BitcastConvert( PrimitiveType primitive_dest_type) const { if (primitive_util::BitWidth(shape().element_type()) != primitive_util::BitWidth(primitive_dest_type)) { @@ -1566,7 +1719,7 @@ StatusOr> Literal::BitcastConvert( return ConvertSwitch(*this, primitive_dest_type, /*bitcast=*/true); } -StatusOr> Literal::ConvertToShape( +StatusOr> LiteralBase::ConvertToShape( const Shape& dest_shape, bool round_f32_to_bf16) const { if (!ShapeUtil::IsTuple(dest_shape)) { if (round_f32_to_bf16 && shape().element_type() == F32 && @@ -1581,7 +1734,7 @@ StatusOr> Literal::ConvertToShape( } std::vector elements; for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { - auto element = LiteralView::Create(*this, {i}); + auto element = LiteralSlice(*this, {i}); TF_ASSIGN_OR_RETURN( auto new_element, element.ConvertToShape(ShapeUtil::GetSubshape(dest_shape, {i}))); @@ -1593,8 +1746,8 @@ StatusOr> Literal::ConvertToShape( } template -bool Literal::Piece::EqualElementsInternal( - const Literal::Piece& other, std::vector* multi_index) const { +bool LiteralBase::Piece::EqualElementsInternal( + const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == ShapeUtil::Rank(subshape())) { return (Get(*multi_index) == other.Get(*multi_index)); } @@ -1608,7 +1761,7 @@ bool Literal::Piece::EqualElementsInternal( return true; } -bool Literal::Piece::EqualElements(const Literal::Piece& other) const { +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); std::vector multi_index; @@ -1636,28 +1789,28 @@ bool Literal::Piece::EqualElements(const Literal::Piece& other) const { case C64: return EqualElementsInternal(other, &multi_index); default: - LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + LOG(FATAL) << "Unimplemented: LiteralBase::Piece::EqualElements for type " << PrimitiveType_Name(subshape().element_type()); } } -bool Literal::operator==(const Literal& other) const { +bool LiteralBase::operator==(const LiteralBase& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } - const Piece& other_piece = other.piece(index); - if (!piece.EqualElements(other_piece)) { - return false; - } - } - return true; + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; + } + return true; + }); } namespace { @@ -1675,11 +1828,11 @@ static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, } // namespace -bool Literal::IsAll(int8 value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; +bool LiteralBase::IsAll(int8 value) const { + return root_piece().ForEachSubpieceWithBool([&](const ShapeIndex& index, + const Piece& piece) { if (!ShapeUtil::IsArray(piece.subshape())) { - continue; + return true; } auto piece_is_all = [&]() { @@ -1732,41 +1885,41 @@ bool Literal::IsAll(int8 value) const { if (!piece_is_all()) { return false; } - } - return true; + return true; + }); } -bool Literal::IsAllFloat(float value) const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFloat(float value) const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - auto piece_is_all = [&]() { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(piece.data(), value); - case F64: - return AllElementsEqualValue(piece.data(), value); - case F16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - case BF16: - return AllElementsEqualValue(piece.data(), - static_cast(value)); - default: + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue( + piece.data(), static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; - } - }; - if (!piece_is_all()) { - return false; - } - } - return true; + } + return true; + }); } -bool Literal::IsAllComplex(complex64 value) const { +bool LiteralBase::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: return AllElementsEqualValue(root_piece().data(), @@ -1776,93 +1929,93 @@ bool Literal::IsAllComplex(complex64 value) const { } } -bool Literal::IsAllFirst() const { - for (const auto& pair : pieces_) { - const Piece& piece = pair.second; - if (!ShapeUtil::IsArray(piece.subshape())) { - continue; - } +bool LiteralBase::IsAllFirst() const { + return root_piece().ForEachSubpieceWithBool( + [&](const ShapeIndex& index, const Piece& piece) { + if (!ShapeUtil::IsArray(piece.subshape())) { + return true; + } - // Empty shapes are not all the first element since there is no first - // element. - if (ShapeUtil::HasZeroElements(piece.subshape())) { - return false; - } - auto piece_is_all = [&]() { - switch (piece.subshape().element_type()) { - case PRED: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 8 bit types - case S8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U8: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 16 bit types - case BF16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U16: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 32 bit types - case F32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S32: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - // 64 bit types - case C64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case F64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case S64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - case U64: { - auto data = piece.data(); - return AllElementsEqualValue(data, data[0]); - } - default: + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { return false; - } - }; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; - if (!piece_is_all()) { - return false; - } - } - return true; + if (!piece_is_all()) { + return false; + } + return true; + }); } -bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { +bool LiteralBase::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: @@ -1904,7 +2057,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest, } // namespace -void Literal::Piece::WriteToProto(LiteralProto* proto) const { +void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_shape() = subshape(); switch (subshape().element_type()) { case PRED: @@ -1930,16 +2083,14 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { *proto->mutable_f16s() = string( reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto->mutable_f16s()->data()), - proto->f16s().size()); + ConvertEndianShort(proto->mutable_f16s()); } break; case BF16: *proto->mutable_bf16s() = string( reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto->mutable_bf16s()->data()), - proto->bf16s().size()); + ConvertEndianShort(proto->mutable_bf16s()); } break; case F32: @@ -1962,12 +2113,12 @@ void Literal::Piece::WriteToProto(LiteralProto* proto) const { } } -const void* Literal::Piece::untyped_data() const { +const void* LiteralBase::Piece::untyped_data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } -void* Literal::Piece::untyped_data() { +void* LiteralBase::Piece::untyped_data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); return buffer(); } @@ -1988,7 +2139,7 @@ Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, } // namespace -Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { +Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { // These conditions should have been checked in Literal::CreateFromProto. TF_RET_CHECK(proto.has_shape()); TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); @@ -2055,21 +2206,19 @@ Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { return Status::OK(); } -LiteralProto Literal::ToProto() const { +LiteralProto LiteralBase::ToProto() const { LiteralProto proto; - for (const auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - const Piece& piece = pair.second; - - LiteralProto* proto_piece = &proto; - for (int64 i : index) { - while (proto_piece->tuple_literals_size() <= i) { - proto_piece->add_tuple_literals(); - } - proto_piece = proto_piece->mutable_tuple_literals(i); - } - piece.WriteToProto(proto_piece); - } + root_piece().ForEachSubpiece( + [&](const ShapeIndex& index, const Piece& piece) { + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + }); if (LayoutUtil::IsSparseArray(shape())) { CopyToRepeatedField(proto.mutable_sparse_indices(), @@ -2091,33 +2240,40 @@ StatusOr> Literal::CreateFromProto( auto literal = MakeUnique(proto.shape()); - for (auto& pair : literal->pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - const LiteralProto* proto_element = &proto; - for (int64 i : index) { - TF_RET_CHECK(i < proto_element->tuple_literals_size()); - proto_element = &proto_element->tuple_literals(i); - } + TF_RETURN_IF_ERROR(literal->root_piece_->ForEachMutableSubpieceWithStatus( + [&](const ShapeIndex& index, Piece* piece) { + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } - if (ShapeUtil::IsTuple(piece.subshape())) { - if (proto_element->tuple_literals_size() != - ShapeUtil::TupleElementCount(piece.subshape())) { - return InvalidArgument( - "Expected %lld tuple elements in LiteralProto, has %d", - ShapeUtil::TupleElementCount(piece.subshape()), - proto_element->tuple_literals_size()); - } - continue; - } + if (ShapeUtil::IsTuple(piece->subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece->subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece->subshape()), + proto_element->tuple_literals_size()); + } + return Status::OK(); + } + + CHECK(ShapeUtil::IsArray(piece->subshape())); + TF_RETURN_IF_ERROR(piece->CopyFromProto(*proto_element)); + + return Status::OK(); + })); - TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); - TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); - } return std::move(literal); } -const void* Literal::untyped_data(const ShapeIndex& shape_index) const { +/* static */ string Literal::MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index) { + return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}"); +} + +const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const { return piece(shape_index).untyped_data(); } @@ -2125,11 +2281,11 @@ void* Literal::untyped_data(const ShapeIndex& shape_index) { return piece(shape_index).untyped_data(); } -int64 Literal::size_bytes(const ShapeIndex& shape_index) const { +int64 LiteralBase::size_bytes(const ShapeIndex& shape_index) const { return piece(shape_index).size_bytes(); } -string Literal::GetR1U8AsString() const { +string LiteralBase::GetR1U8AsString() const { CHECK(ShapeUtil::IsArray(shape())); CHECK_EQ(ShapeUtil::Rank(shape()), 1); CHECK_EQ(shape().element_type(), U8); @@ -2137,51 +2293,55 @@ string Literal::GetR1U8AsString() const { ShapeUtil::ElementsIn(shape())); } -/* static */ const LiteralView LiteralView::Create( - const Literal& literal, const ShapeIndex& view_root) { - return LiteralView(literal, view_root); -} +void BorrowingLiteral::BuildPieceSubtree(const Shape& shape, Piece* piece) { + CHECK(ShapeUtil::IsTuple(shape)); + for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& subshape = shape.tuple_shapes(i); -LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { - shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); - pieces_ = ShapeTree(shape_); - owns_buffers_ = false; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; + auto child_piece = Piece(); + child_piece.set_subshape(&subshape); - ShapeIndex src_index = view_root; - for (int64 i : index) { - src_index.push_back(i); + if (ShapeUtil::IsTuple(subshape)) { + BuildPieceSubtree(subshape, &child_piece); } - const Piece& src_piece = literal.piece(src_index); - piece.set_buffer(src_piece.buffer()); - piece.set_sparse_indices(src_piece.sparse_indices()); - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + + piece->emplace_back(std::move(child_piece)); } } -LiteralView::~LiteralView() {} +LiteralSlice::LiteralSlice(const LiteralBase& literal) + : LiteralBase(), root_piece_(&literal.root_piece()) {} -LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } +LiteralSlice::LiteralSlice(const LiteralBase& literal, + const ShapeIndex& view_root) + : LiteralBase(), root_piece_(&literal.piece(view_root)) {} -LiteralView& LiteralView::operator=(const LiteralView& other) { - CopyFrom(other); - return *this; +BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) + : LiteralBase(), shape_(shape) { + CHECK(ShapeUtil::IsArray(shape_)); + CHECK_NE(src_buf_ptr, nullptr); + CHECK(LayoutUtil::HasLayout(shape_)); + + root_piece_ = Piece(); + root_piece_.set_buffer(const_cast(src_buf_ptr)); + root_piece_.set_subshape(&shape_); } -void LiteralView::CopyFrom(const LiteralView& other) { - // We can't use the default copy-constructor/copy-assignment because - // Piece::subshape_ points to subshapes within the Shape of the owning - // Literal/LiteralView. - shape_ = other.shape(); - pieces_ = other.pieces_; - for (auto& pair : pieces_) { - const ShapeIndex& index = pair.first; - Piece& piece = pair.second; - piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); +BorrowingLiteral::BorrowingLiteral( + tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) + : LiteralBase(), shape_(shape) { + CHECK(ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsNestedTuple(shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + root_piece_ = Piece(); + root_piece_.set_subshape(&shape_); + BuildPieceSubtree(shape_, &root_piece_); + + for (int i = 0; i < src_buf_ptrs.size(); ++i) { + const auto& src_shape = shape_.tuple_shapes(i); + CHECK(ShapeUtil::IsArray(src_shape)); + root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } - owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 8aa19222dc4..609dc7a3aca 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -52,237 +51,209 @@ limitations under the License. namespace xla { -// Class representing literal values in XLA. -// -// TODO(b/67651157): The methods in this class should be reduced to a minimal -// set of methods which construct Literals and accessors methods. Other methods -// which perform computation on Literals (Reshape, Slice, etc) should be moved -// elsewhere, and perhaps combined with evaluator code which operates on -// Literals. -class Literal { +// Forward declare Literal and LiteralSlice class to be used by the creation +// methods in the base class. +class Literal; +class LiteralSlice; + +// Abstract base class for literals. +class LiteralBase { public: - Literal() : Literal(ShapeUtil::MakeNil()) {} - - // Create a literal of the given shape. The literal is allocated sufficient - // memory to hold the shape. Memory is uninitialized. - explicit Literal(const Shape& shape); - virtual ~Literal(); - - // Literals are moveable, but not copyable. To copy a literal use - // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies - // of literals which can be expensive. - Literal(const Literal& other) = delete; - Literal& operator=(const Literal& other) = delete; - Literal(Literal&& other); - Literal& operator=(Literal&& other); + virtual ~LiteralBase() = 0; // Literals are equal if they have compatible shapes and the same data // values. Layout is not compared. - bool operator==(const Literal& other) const; - bool operator!=(const Literal& other) const { return !(*this == other); } + bool operator==(const LiteralBase& other) const; + bool operator!=(const LiteralBase& other) const { return !(*this == other); } - // Serialize to and from a proto. - static StatusOr> CreateFromProto( - const LiteralProto& proto); + // Returns the shape of the literal. + const Shape& shape() const { return root_piece().subshape(); } + + // Serialize to proto. LiteralProto ToProto() const; - // Return the shape of the literal. - const Shape& shape() const { return shape_; } - - // TODO(b/67651157): Remove this accessor. Literal users should not be able to - // mutate the shape as this can produce malformed Literals. - Shape* mutable_shape_do_not_use() { return &shape_; } - - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. See primitive_util.h for the mapping from - // XLA type to native type. + // Returns an ArraySlice of the array for this literal for the given NativeT + // (e.g., float). CHECKs if the subshape of the literal at the given + // ShapeIndex is not array. See primitive_util.h for the mapping from XLA type + // to native type. template tensorflow::gtl::ArraySlice data( const ShapeIndex& shape_index = {}) const; - template - tensorflow::gtl::MutableArraySlice data( - const ShapeIndex& shape_index = {}); - // Returns a pointer to the sparse index array. Returns nullptr if the literal - // is not a sparse array. + // Returns a const pointer to the sparse index array. Returns nullptr if the + // literal is not a sparse array. const SparseIndexArray* sparse_indices( const ShapeIndex& shape_index = {}) const; - SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); - // Returns a pointer to (or size of) the underlying buffer holding the array - // at the given shape index. CHECKs if the subshape of the literal at the - // given ShapeIndex is not array. + // Returns a const pointer to (or size of) the underlying buffer holding the + // array at the given shape index. CHECKs if the subshape of the literal at + // the given ShapeIndex is not array. const void* untyped_data(const ShapeIndex& shape_index = {}) const; - void* untyped_data(const ShapeIndex& shape_index = {}); int64 size_bytes(const ShapeIndex& shape_index = {}) const; - // Creates a new literal of a given rank. To minimize ambiguity (for users - // and the compiler) these CreateR[0-2] methods should explicitly specify the - // native type. For example: - // - // CreateR1({1.0, 42.0}); - // CreateR2({{1, 2}, {3, 4}}); - // - // The variants not ending with WithLayout use the default XLA layout for the - // literal's linear representation in memory. - template - static std::unique_ptr CreateR0(NativeT value); - template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values); - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values); - template - static std::unique_ptr CreateR2( - std::initializer_list> values); - template - static std::unique_ptr CreateR2WithLayout( - std::initializer_list> values, - const Layout& layout); - template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values); - template - static std::unique_ptr CreateR3WithLayout( - std::initializer_list< - std::initializer_list>> - values, - const Layout& layout); - template - static std::unique_ptr CreateR4( - std::initializer_list>>> - values); - template - static std::unique_ptr CreateR4WithLayout( - std::initializer_list>>> - values, - const Layout& layout); - // Returns this literal's data as a string. This literal must be a rank-1 U8 // array. string GetR1U8AsString() const; - // Creates a literal with a sparse layout and the given indices and values. - // The shape is initialized from the given dimensions. The minor dimension of - // the indices array must equal the rank of the shape (i.e. size of the - // dimensions array). The major dimension of the indices array must equal the - // number of elements in the values array. The maximum number of elements in - // the array is taken from the max_indices() value of the index array. - // - // XLA assumes that sparse literals are in sorted order for all operations. If - // the `sort` argument is true, then the indices and values will be sorted - // while copying them into the literal. If you have ensured that the indices - // and values are already sorted, then you may set the `sort` argument to - // false to skip the sorting step. - // - // For example: - // - // CreateSparse( - // {12, 12, 12}, - // SparseIndexArray(10, 3, - // Array2D{ - // {0, 1, 2}, - // {3, 4, 5}, - // {6, 7, 8}, - // {9, 10, 11}, - // }), - // {1.0, 2.0 3.0, 4.0}) - // - // This creates an array with shape F64[12,12,12]sparse{10}, that has the - // following non-zero values: - // - // [0, 1, 2]: 1.0 - // [3, 4, 5]: 2.0 - // [6, 7, 8]: 3.0 - // [9, 10, 11]: 4.0 - // + // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. + string ToString(bool print_layout = false) const; + + // Gets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. template - static std::unique_ptr CreateSparse( - tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, bool sort = true); - - // Populates a literal with a sparse layout with the given indices and values. - // Each index in the indices array is CHECKed against the dimensions in the - // literal's shape. If sort is true, then the indices and values will be - // sorted. If sort is false, then the indices and values are assumed to - // already be in sorted order. See CreateSparse for an example of how data - // are populated. + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + // Overloads of Get for array literals. CHECKs if the literal is not + // array-shaped and dense. template - void PopulateSparse(SparseIndexArray indices, - tensorflow::gtl::ArraySlice values, - bool sort = true); + NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape); + // Returns the element value at index (0, ..., 0), however many zeroes are + // required for that index. + template + NativeT GetFirstElement() const; - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions); + // As Get(), but determines the correct type and converts the value + // into text. + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + // As Get(), but determines the correct type and converts the value into + // int64. This literal must be an array. + StatusOr GetIntegralAsS64( + tensorflow::gtl::ArraySlice multi_index) const; - // Copy values from 'src_literal' rooted at 'src_shape_index' into this - // literal rooted at 'dest_shape_index'. The subshape of this literal rooted - // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. - Status CopyFrom(const Literal& src_literal, - const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - // Similar to CopyFrom, but with move semantincs. The subshape of this literal - // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' - // (layouts and shapes must match), but need not be arrays. The memory - // allocated in this literal for the subshape at dest_shape_index is - // deallocated, and the respective buffers are replaced with those in - // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). - Status MoveFrom(Literal&& src_literal, - const ShapeIndex& dest_shape_index = {}); + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + template + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; - // Copies the values from src_literal, starting at src_base shape indexes, - // to this literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // The src_literal and this literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - // Note: if either src_literal or this literal contains dimensions with zero - // element, then copy_size must be 0 in these dimensions while the - // corresponding base indices being 0. - // This literal and 'src_literal' must be arrays. - Status CopySliceFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + // Invokes the "per cell" callback for each element in the provided + // literal with the element's indices and a string representation of + // the element's value. + // + // This function is useful if you want a polymorphic representation + // of the tensor's elements (turning it to a string for something + // like representation in a protobuf). + // + // This literal must have a dense layout. + void EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const; + template + void EachCell(std::function indices, + NativeT value)> + per_cell) const; - // Copies one element from src_literal[src_index] to (*this)[dest_index]. - Status CopyElementFrom(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_index, - tensorflow::gtl::ArraySlice dest_index); + // Returns whether every element in this literal is equal to value. + // + // value is an int8 because we expect this to be called with small + // compile-time constants (0, -1, etc.) and so that whatever value you pass + // can be represented exactly by floating-point types as small as 16 bits. + // + // If value doesn't fit in this literal's type, returns false. Values of 1/0 + // are considered equal to true/false; other values are not considered equal + // to true. Also if this literal is not array-shaped false is returned. + bool IsAll(int8 value) const; - // Returns a vector containing the tuple elements of this Literal as separate - // Literals. This Literal must be tuple-shaped and can be a nested tuple. The - // elements are moved into the new Literals; no data is copied. Upon return - // this Literal is set to a nil shape (empty tuple) - std::vector DecomposeTuple(); + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular floating-point number. + // + // If the literal is not a floating-point value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for values that can be expressed precisely as a float, + // e.g. -0.5. Also if this literal is not array-shaped false is returned. + bool IsAllFloat(float value) const; - // This operation is the inverse of DecomposeTuple. The given elements are - // moved into the tuple elements of a new tuple-shaped Literal which is - // returned. Upon return, each of the Literals in 'elements' is set to a nil - // shape (empty tuple). - static Literal MoveIntoTuple( - tensorflow::gtl::MutableArraySlice elements); + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. + bool IsAllComplex(complex64 value) const; - // Creates a new value that has the equivalent value as this literal, but - // conforms to new_layout; e.g. a literal matrix that was in {0, 1} - // minor-to-major dimension layout can be re-layed-out as {1, 0} + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + + // Returns whether this literal is zero at the specified index. This literal + // must be an array with a dense layout. + bool IsZero(tensorflow::gtl::ArraySlice indices) const; + + // Returns the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Returns the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + // Compute a hash for this literal. This literal must not be a sparse tensor + // or a tuple containing a sparse tensor. + size_t Hash() const; + + // Converts this literal to the given shape. Returns an error is the + // conversion is not possible. + // + // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding + // instead of truncation; otherwise, truncation is used. + // + // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes + // the default behavior. + StatusOr> ConvertToShape( + const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + + // Converts this literal to another primitive type using a bitcast + // conversion. The to and from primitive types must have the same bit + // width. Returns an error if the conversion is not possible. This literal + // must be array-shaped. + StatusOr> BitcastConvert( + PrimitiveType primitive_dest_type) const; + + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. This literal must be array-shaped. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; + + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + + // Clones the underlying buffers into a new Literal, or new + // std::unique_ptr. + Literal Clone() const; + std::unique_ptr CloneToUnique() const; + + // TODO(b/67651157): The methods below which perform computation on Literals + // (Reshape, Slice, etc) should be moved elsewhere, and perhaps combined with + // evaluator code which operates on Literals. + // + // Creates a new value that has the equivalent value as this + // literal, but conforms to new_layout; e.g. a literal matrix that was in {0, + // 1} minor-to-major dimension layout can be re-layed-out as {1, 0} // minor-to-major dimension layout and the value in the cell at any given // logical index (i0, i1) will be the same. // @@ -333,44 +304,510 @@ class Literal { template std::unique_ptr Replicate(int64 times) const; - // Converts this literal to another primitive type using - // static_cast<>. Returns an error if the conversion is not possible. This - // literal must be array-shaped. - StatusOr> Convert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to another primitive type using a bitcast - // conversion. The to and from primitive types must have the same bit - // width. Returns an error if the conversion is not possible. This literal - // must be array-shaped. - StatusOr> BitcastConvert( - PrimitiveType primitive_dest_type) const; - - // Converts this literal to the given shape. Returns an error is the - // conversion is not possible. + // Creates a new Literal object with the shape specified as parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). // - // round_f32_to_bf16: if true, converting F32 elements to BF16 uses rounding - // instead of truncation; otherwise, truncation is used. + // Note: It's an antipattern to use this method then immediately call + // Literal::Populate on the result (since that results in zero initialization, + // then reinitialization. Conside if a call to MakeUnique(shape), + // followed by the call to Literal::Populate can be used instead. + static std::unique_ptr CreateFromShape(const Shape& shape); + + protected: + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Returns the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Returns the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } + + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } + + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } + + // Returns the number of elements in this piece's array. + int64 element_count() const { + // If this is a sparse array, use the number of elements represented by + // the indices in the associated SparseIndexArray. + return LayoutUtil::IsSparseArray(subshape()) + ? sparse_indices()->index_count() + : ShapeUtil::ElementsIn(subshape()); + } + + // Returns the child piece at 'index' of this piece. + Piece& child(int64 index) { return children_[index]; } + + // Adds a child piece to this piece's children. + void emplace_back(Piece child_piece) { + children_.emplace_back(std::move(child_piece)); + } + + // Returns the size of children pieces of this piece. + int64 children_size() { return children_.size(); } + + // Visitor functions that recursively traverses the piece and calls the + // given function at each child piece. The function has the type: + // void (const ShapeIndex& index, const Piece& piece) + template + void ForEachSubpiece(const Fn& func) const { + ShapeIndex index; + return ForEachHelper( + [&func](const ShapeIndex& index, const Piece& piece) { + func(index, piece); + return Status::OK(); + }, + *this, &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, const Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachSubpieceWithStatus(const Fn& func) const { + ShapeIndex index; + return ForEachHelper(func, *this, &index); + } + // Same as above, but the function has the type: + // Bool (const ShapeIndex& index, const Piece& piece) + // The first non-true return value is returned by the function. + template + bool ForEachSubpieceWithBool(const Fn& func) const { + ShapeIndex index; + return ForEachHelperBool(func, *this, &index); + } + // Same as above, but the function has the type: + // Void (const ShapeIndex& index, Piece& piece) + template + void ForEachMutableSubpiece(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + [&func](const ShapeIndex& index, Piece* piece) { + func(index, piece); + return Status::OK(); + }, + const_cast(this), &index) + .IgnoreError(); + } + // Same as above, but the function has the type: + // Status (const ShapeIndex& index, Piece& piece) + // The first non-OK return value is returned by the function. + template + Status ForEachMutableSubpieceWithStatus(const Fn& func) { + ShapeIndex index; + return ForEachMutableHelper( + func, const_cast(this), &index); + } + + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; + + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; + + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); + + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); + + // Sorts the elements in a sparse array. + void SortSparseElements(); + + private: + // Helpers for traversing the piece via ForEachSubpiece rooted at 'index'. + // The first non-OK (or non-true) value is returned by the function. + // The callable 'func' has the same signature as described above in + // ForEachSubpiece*. + template + Status ForEachHelper(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR(ForEachHelper(func, piece.children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + template + bool ForEachHelperBool(const Fn& func, const Piece& piece, + ShapeIndex* index) const { + if (!func(*index, piece)) { + return false; + } + for (int64 i = 0; i < piece.children_.size(); ++i) { + index->push_back(i); + if (!ForEachHelperBool(func, piece.children_[i], index)) { + return false; + } + index->pop_back(); + } + return true; + } + template + Status ForEachMutableHelper(const Fn& func, Piece* piece, + ShapeIndex* index) { + TF_RETURN_IF_ERROR(func(*index, piece)); + for (int64 i = 0; i < piece->children_.size(); ++i) { + index->push_back(i); + TF_RETURN_IF_ERROR( + ForEachMutableHelper(func, &piece->children_[i], index)); + index->pop_back(); + } + return Status::OK(); + } + + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; + + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); + + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; + + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; + + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + + // Children pieces for tuple shaped pieces. + std::vector children_ = {}; + }; // class Piece + + const Piece& piece(const ShapeIndex& shape_index) const { + Piece* piece = &const_cast(root_piece()); + for (const auto i : shape_index) { + DCHECK_GE(i, 0); + DCHECK_LT(i, piece->children_size()); + piece = &piece->child(i); + } + return *piece; + } + + // Returns the piece at the root of the shape. + virtual const Piece& root_piece() const = 0; + + // LiteralSlice and Literal must access Pieces of other Literals. + friend class Literal; + friend class LiteralSlice; + friend class BorrowingLiteral; +}; + +// Class representing literal values in XLA. +// +// The underlying buffer and shape is always owned by this class. +class Literal : public LiteralBase { + public: + Literal() : Literal(ShapeUtil::MakeNil()) {} + + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); + + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); + Literal& operator=(Literal&& other); + + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return shape_.get(); } + + // Returns a MutableArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::data; + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to the underlying buffer holding the array at the given + // shape index. CHECKs if the subshape of the literal at the given ShapeIndex + // is not array. + void* untyped_data(const ShapeIndex& shape_index = {}); + // Unhide const method from parent class. + using LiteralBase::untyped_data; + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const LiteralSlice& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + + // Copies the values from src_literal, starting at src_base shape indexes, + // to this literal, starting at dest_base, where the copy size in each + // dimension is specified by copy_size. + // The src_literal and this literal must have the same primitive type, + // src_base+copy_size must fit the source literal dimensions, as well as + // dest_base+copy_size must fit the destination literal dimensions. + // Note: if either src_literal or this literal contains dimensions with zero + // element, then copy_size must be 0 in these dimensions while the + // corresponding base indices being 0. + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Copies one element from src_literal[src_index] to (*this)[dest_index]. + Status CopyElementFrom(const LiteralSlice& src_literal, + tensorflow::gtl::ArraySlice src_index, + tensorflow::gtl::ArraySlice dest_index); + + // Sets an element in the literal at the given index. The multi_index is + // CHECKed against the dimension sizes. + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); + // Overloads of Set for array literals. CHECKs if the literal is not + // array-shaped and dense. + template + void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. + template + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); + + // As Set(), but truncates `value` to the literal element type before storing. + // This literal must be an array. + Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, + int64 value); + + // Populate this literal with the given values. Examples: // - // TODO(b/69266521): remove the round_to_bfloat16 flag when rounding becomes - // the default behavior. - StatusOr> ConvertToShape( - const Shape& dest_shape, bool round_f32_to_bf16 = false) const; + // // Populate with floats. + // Array2D float_values = ... + // literal.PopulateR2FromArray2D(values); + // + // // Populate with int32s. + // literal.PopulateR2({{1, 2}, {3, 4}}); + // + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. + template + void PopulateR1(tensorflow::gtl::ArraySlice values); + void PopulateR1(const tensorflow::core::Bitmap& values); + template + void PopulateR2(std::initializer_list> values); + template + void PopulateFromArray(const Array& values); + template + void PopulateR2FromArray2D(const Array2D& values); + template + void PopulateR3FromArray3D(const Array3D& values); + template + void PopulateR4FromArray4D(const Array4D& values); + + // Populates literal values by calling the generator function for every cell + // in this literal object. + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. + template + Status Populate(const FnType& generator); + + // A parallel version of Populate(). This can be used if the generator is + // thread-safe and the values for the shape's different elements are + // independent. + template + Status PopulateParallel(const FnType& generator); + + // Fills this literal with the given value. + template + void PopulateWithValue(NativeT value); + + // Factory methods below. + // + + // Serialize from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); + + // Creates a new literal of a given rank. To minimize ambiguity (for users + // and the compiler) these CreateR[0-2] methods should explicitly specify the + // native type. For example: + // + // CreateR1({1.0, 42.0}); + // CreateR2({{1, 2}, {3, 4}}); + // + // The variants not ending with WithLayout use the default XLA layout for the + // literal's linear representation in memory. + template + static std::unique_ptr CreateR0(NativeT value); + template + static std::unique_ptr CreateR1( + tensorflow::gtl::ArraySlice values); + static std::unique_ptr CreateR1( + const tensorflow::core::Bitmap& values); + template + static std::unique_ptr CreateR2( + std::initializer_list> values); + template + static std::unique_ptr CreateR2WithLayout( + std::initializer_list> values, + const Layout& layout); + template + static std::unique_ptr CreateR3( + std::initializer_list< + std::initializer_list>> + values); + template + static std::unique_ptr CreateR3WithLayout( + std::initializer_list< + std::initializer_list>> + values, + const Layout& layout); + template + static std::unique_ptr CreateR4( + std::initializer_list>>> + values); + template + static std::unique_ptr CreateR4WithLayout( + std::initializer_list>>> + values, + const Layout& layout); + + // Creates a literal with a sparse layout and the given indices and values. + // The shape is initialized from the given dimensions. The minor dimension of + // the indices array must equal the rank of the shape (i.e. size of the + // dimensions array). The major dimension of the indices array must equal the + // number of elements in the values array. The maximum number of elements in + // the array is taken from the max_indices() value of the index array. + // + // XLA assumes that sparse literals are in sorted order for all operations. If + // the `sort` argument is true, then the indices and values will be sorted + // while copying them into the literal. If you have ensured that the indices + // and values are already sorted, then you may set the `sort` argument to + // false to skip the sorting step. + // + // For example: + // + // CreateSparse( + // {12, 12, 12}, + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }), + // {1.0, 2.0 3.0, 4.0}) + // + // This creates an array with shape F64[12,12,12]sparse{10}, that has the + // following non-zero values: + // + // [0, 1, 2]: 1.0 + // [3, 4, 5]: 2.0 + // [6, 7, 8]: 3.0 + // [9, 10, 11]: 4.0 + // + template + static std::unique_ptr CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort = true); // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); - // Creates a literal of the given shape where each element is `value`. template static std::unique_ptr CreateFullWithDescendingLayout( @@ -425,79 +862,6 @@ class Literal { std::initializer_list> values, int64 projection_p, int64 projection_z); - // Clones this literal into a new Literal, or new std::unique_ptr. - Literal Clone() const; - std::unique_ptr CloneToUnique() const; - - // Gets or sets an element in the literal at the given index. The multi_index - // is CHECKed against the dimension sizes. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index, NativeT value); - - // Overloads of Get and Set for array literals. CHECKs if the literal is not - // array-shaped and dense. - template - NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; - template - void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - - // Returns the multi-index of the element in a sparse literal at the given - // sparse element number. The sparse element number is the position with in - // the sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - tensorflow::gtl::ArraySlice GetSparseIndex( - int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; - - // Returns the value of the element in a sparse literal at the given sparse - // element number. The sparse element number is the position with in the - // sparse array's list of (index, value) pairs, and is checked against the - // total number of (index, value) pairs in the sparse array. - template - NativeT GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // Appends the given element to the literal. If the elements are not appended - // in sorted order, then SortSparseElements should be called before calling - // other methods. This literal must have a sparse layout. - template - void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, - NativeT value, const ShapeIndex& shape_index = {}); - - // Sorts the elements in a sparse array. - void SortSparseElements(const ShapeIndex& shape_index = {}); - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - NativeT GetFirstElement() const; - - // Returns a literal scalar representing the first element. - Literal GetFirstScalarLiteral() const; - - // As Get(), but determines the correct type and converts the value - // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index = {}) const; - - // As GetSparseElement(), but determines the correct type and converts the - // value into text. - string GetSparseElementAsString(int64 sparse_element_number, - const ShapeIndex& shape_index = {}) const; - - // As Get(), but determines the correct type and converts the value into - // int64. This literal must be an array. - StatusOr GetIntegralAsS64( - tensorflow::gtl::ArraySlice multi_index) const; - - // As Set(), but truncates `value` to the literal element type before storing. - // This literal must be an array. - Status SetIntegralAsS64(tensorflow::gtl::ArraySlice multi_index, - int64 value); - // Returns an identity matrix (rank 2) with the given row and column count. template static std::unique_ptr MakeIdentityR2(int64 size); @@ -507,6 +871,9 @@ class Literal { static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); + static std::unique_ptr MakeTupleFromSlices( + tensorflow::gtl::ArraySlice elements); + // As above, but intended to be invoked with move semantics; i.e. // // std::vector> elements = ...; @@ -538,136 +905,104 @@ class Literal { return MakeTupleOwned(std::move(v)); } - // Returns a string representation of the literal value. - // Warning: this function can take minutes for multi-million element Literals. - string ToString(bool print_layout = false) const; + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - // - // This literal must have a dense layout. - void EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const; - template - void EachCell(std::function indices, - NativeT value)> - per_cell) const; + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); - // Populate this literal with the given values. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // literal.PopulateR2FromArray2D(values); - // - // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); - // - // The shape and element type of this literal must match given values. For - // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 - // array of S32. - template - void PopulateR1(tensorflow::gtl::ArraySlice values); - void PopulateR1(const tensorflow::core::Bitmap& values); - template - void PopulateR2(std::initializer_list> values); - template - void PopulateFromArray(const Array& values); - template - void PopulateR2FromArray2D(const Array2D& values); - template - void PopulateR3FromArray3D(const Array3D& values); - template - void PopulateR4FromArray4D(const Array4D& values); + // Creates a new Literal object with its values havings the primitive_type + // type, and with dimensions defined by the dimensions parameter. + // The content of the literal values is the default value of the primitive + // type of literal itself (0 for numeric types, and false for predicates). + static std::unique_ptr CreateFromDimensions( + PrimitiveType primitive_type, + tensorflow::gtl::ArraySlice dimensions); - // Populates literal values by calling the generator function for every cell - // in this literal object. - // - // generator must be a callable of the type - // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. - // - // This literal must have a dense layout. - template - Status Populate(const FnType& generator); + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32( + const LiteralSlice& bf16_literal); - // A parallel version of Populate(). This can be used if the generator is - // thread-safe and the values for the shape's different elements are - // independent. - template - Status PopulateParallel(const FnType& generator); + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16( + const LiteralSlice& f32_literal); - // Fills this literal with the given value. - template - void PopulateWithValue(NativeT value); + // Creates a literal with a new shape with the given new dimensions using the + // data in the given input literal. For reshaping purposes the (flat) data + // buffer of the input literal is assumed to have the given minor_to_major + // layout order. + static std::unique_ptr ReshapeSlice( + tensorflow::gtl::ArraySlice new_dimensions, + tensorflow::gtl::ArraySlice minor_to_major, + const LiteralSlice& literal); - // Returns whether every element in this literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in this literal's type, returns false. Values of 1/0 - // are considered equal to true/false; other values are not considered equal - // to true. Also if this literal is not array-shaped false is returned. - bool IsAll(int8 value) const; + // Creates a literal with the supplied shape, and uses the provided value + // generator to populate the literal's values. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. Also if this literal is not array-shaped false is returned. - bool IsAllFloat(float value) const; + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation, and using the engine as entropy generator. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, typename E, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev); - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular complex number. - // - // If the literal is not a complex value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for complex values that can be expressed precisely as - // float pairs e.g. (-0.5, 1.0). - // - // This literal must have a dense layout. - bool IsAllComplex(complex64 value) const; + // Creates a literal with the supplied shape, and initializes the literal + // values using a normal distribution with given mean and stddev standard + // deviation. + // Returns the new literal object, or an error Status if failed. + template < + PrimitiveType type, + typename T = typename primitive_util::PrimitiveTypeToNative::type> + static StatusOr> CreateRandomLiteral( + const Shape& shape, T mean, T stddev); - // Literal consists entirely of the first element of the literal. - bool IsAllFirst() const; + // + // End of factory methods. - // Returns whether this literal is zero at the specified index. This literal - // must be an array with a dense layout. - bool IsZero(tensorflow::gtl::ArraySlice indices) const; + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will + // be returned for a 2-dimensional index with dimension 0 index equal to 7, + // dimension 1 equal to 8. + static string MultiIndexAsString( + tensorflow::gtl::ArraySlice multi_index); - // Return the count of the elements in the array at the given shape index in - // this literal. - int64 element_count(const ShapeIndex& index = {}) const { - return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + private: + // Recursively sets the subshapes and buffers of all subpieces rooted at + // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in + // the shape. + void SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays); + + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return const_cast(LiteralBase::piece(shape_index)); } - // Return the count of the elements in the sparse array at the given shape - // index in this literal, which will be no larger than - // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). - int64 sparse_element_count() const; - - protected: - // 'allocate_arrays' indicates whether to allocate memory for the arrays in - // the shape. If false, buffer pointers inside of the Literal::Pieces are set - // to nullptr. - Literal(const Shape& shape, bool allocate_arrays); + Piece& root_piece() const override { return *root_piece_; }; // Internal template helper for the Literal::CopySliceFrom(), matching its // arguments one by one. template - Status CopySliceFromInternal(const Literal& src_literal, + Status CopySliceFromInternal(const LiteralBase& src_literal, tensorflow::gtl::ArraySlice src_base, tensorflow::gtl::ArraySlice dest_base, tensorflow::gtl::ArraySlice copy_size); @@ -695,162 +1030,69 @@ class Literal { int64 minor_loop_size = 1; }; - // A data structure representing a subshape at a particular ShapeIndex within - // the literal. For array-shaped ShapeIndexes, this data structure holds the - // pointer to the memory allocated for the array data. - class Piece { - public: - // Return the buffer holding the array data for this piece as an array - // slice. This piece must be array-shaped. - template - tensorflow::gtl::ArraySlice data() const; - template - tensorflow::gtl::MutableArraySlice data(); + // Literal class always owns the shape. The parent class borrows this shape. + std::unique_ptr shape_; - // Return the buffer holding the array data for this piece as a void*. This - // piece must be array-shaped. - void* untyped_data(); - const void* untyped_data() const; - - // Gets or sets an element in the array at the given index. The multi_index - // is CHECKed against the dimension sizes of the array. This piece must be - // array-shaped. - template - NativeT Get(tensorflow::gtl::ArraySlice index) const; - template - void Set(tensorflow::gtl::ArraySlice index, NativeT value); - - // Gets/sets the buffer holding the array data. - char* buffer() const { return buffer_; } - void set_buffer(char* buffer) { buffer_ = buffer; } - - // The array of multi-indices that provide the locations of non-zero - // elements in a sparse array. Only used if - // LayoutUtil::IsSparseArray(shape()) is true. - SparseIndexArray* sparse_indices() const { return sparse_indices_; } - void set_sparse_indices(SparseIndexArray* sparse_indices) { - sparse_indices_ = sparse_indices; - } - - // Gets or sets the subshape of this piece. This reference points to a - // subshape within the shape in the containing Literal (Literal::shape_). - const Shape& subshape() const { return *subshape_; } - void set_subshape(const Shape* subshape) { subshape_ = subshape; } - - // Returns the size in bytes of the buffer holding the array data. - int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } - - // Returns the number of elements in this piece's array. - int64 element_count() const { - // If this is a sparse array, use the number of elements represented by - // the indices in the associated SparseIndexArray. - return LayoutUtil::IsSparseArray(subshape()) - ? sparse_indices()->index_count() - : ShapeUtil::ElementsIn(subshape()); - } - - // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); - - // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. - bool EqualElements(const Piece& other) const; - - // Writes the shape and data (if array-shaped) into the given proto. - void WriteToProto(LiteralProto* proto) const; - - // Copies the data from the given proto into this piece. The shape of this - // piece must be equal (not just compatible) to the shape of the proto. - Status CopyFromProto(const LiteralProto& proto); - - // Sorts the elements in a sparse array. - void SortSparseElements(); - - private: - // Recursive helper for EqualElements. - template - bool EqualElementsInternal(const Piece& other, - std::vector* multi_index) const; - - // Helper for SortSparseElements that has the element type as a template - // parameter. - template - void SortSparseElementsInternal(); - - // For array-shaped pieces, this is the buffer holding the literal data. - char* buffer_ = nullptr; - - // For sparse arrays, this is the array of indices. - SparseIndexArray* sparse_indices_ = nullptr; - - // The shape of piece. This points into the shape of the containing Literal - // (Literal::shape_). - const Shape* subshape_ = nullptr; - }; - - // Returns the piece at the given ShapeIndex. - Piece& piece(const ShapeIndex& shape_index) { - return *pieces_.mutable_element(shape_index); - } - const Piece& piece(const ShapeIndex& shape_index) const { - return pieces_.element(shape_index); - } - - // Returns the piece at the root of the shape (empty ShapeIndex). - Piece& root_piece() { return piece({}); } - const Piece& root_piece() const { return piece({}); } - - // Deallocate the buffers held by this literal (if the literal owns the - // buffer). - void DeallocateBuffers(); + Piece* root_piece_ = nullptr; // Implementation details shared between Populate() and PopulateParallel() template Status PopulateInternal(const FnType& generator, bool parallel); - Shape shape_; - ShapeTree pieces_; - - // Whether the buffers held in pieces_ are owned by this Literal. - bool owns_buffers_; - - // LiteralView must access and manipulate Pieces of other Literals. - friend class LiteralView; -}; // namespace xla + // Deallocate the buffers held by this literal. + void DeallocateBuffers(); + friend class LiteralBase; +}; std::ostream& operator<<(std::ostream& out, const Literal& literal); -// A read-only view of a Literal. A LiteralView contains pointers to buffers -// owned by the viewed Literal. -// -// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable -// and mutable) similar to (Mutable)ArraySlice. -class LiteralView : public Literal { +// A read-only view of a Literal. A LiteralSlice contains pointers to shape and +// literal buffers always owned by others. +class LiteralSlice : public LiteralBase { public: - // Create and return a view of the given literal rooted at the given shape - // index within the given literal. A factory is used rather than a public - // constructor because only const LiteralViews are supported. It's still - // possible to create non-const LiteralViews via the copy constructors, but - // the factory method makes it a bit less likely. Implementing literal slices - // will fix this undesirable situation (b/71550060). - static const LiteralView Create(const Literal& literal, - const ShapeIndex& view_root = {}); + LiteralSlice() : LiteralBase() {} - LiteralView(const LiteralView& other); - LiteralView& operator=(const LiteralView& other); - - virtual ~LiteralView(); + // Implicit conversion constructors. + LiteralSlice(const LiteralBase& literal); + LiteralSlice(const LiteralBase& literal, const ShapeIndex& view_root); private: - LiteralView(const Literal& literal, const ShapeIndex& view_root); + const Piece& root_piece() const override { return *root_piece_; }; - // Helper for the copy constructor and copy assignment operator. - void CopyFrom(const LiteralView& other); + const Piece* root_piece_; // Not owned. +}; + +// A read-only Literal where the underlying buffers are never owned by this +// class. +class BorrowingLiteral : public LiteralBase { + public: + BorrowingLiteral() : LiteralBase() {} + + // 'src_buf_ptr' is not owned by this class and must outlive the + // lifetime of this class. It points to an appropirately sized buffer with + // data interpretered as indicated by 'shape'. + // This constructor is only used for array shapes. + BorrowingLiteral(const char* src_buf_ptr, const Shape& shape); + // Similar as above, except to be used for constructing non-nested tuples. + BorrowingLiteral(tensorflow::gtl::ArraySlice src_buf_ptrs, + const Shape& shape); + // TODO(b/79707221): adding constructors for nested tuples as well. + + private: + // Recursively builds the subtree for the given piece and sets the subshapes + // of the given piece with the given shape. + void BuildPieceSubtree(const Shape& shape, Piece* piece); + + // Accessor for the root piece of this literal. + const Piece& root_piece() const override { return root_piece_; }; + Piece root_piece_; + + // Shape of this literal. + const Shape shape_; }; template -tensorflow::gtl::ArraySlice Literal::Piece::data() const { +tensorflow::gtl::ArraySlice LiteralBase::Piece::data() const { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -863,7 +1105,7 @@ tensorflow::gtl::ArraySlice Literal::Piece::data() const { } template -tensorflow::gtl::MutableArraySlice Literal::Piece::data() { +tensorflow::gtl::MutableArraySlice LiteralBase::Piece::data() { CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); CHECK_EQ(subshape().element_type(), primitive_util::NativeToPrimitiveType()) @@ -876,7 +1118,7 @@ tensorflow::gtl::MutableArraySlice Literal::Piece::data() { } template -NativeT Literal::Piece::Get( +NativeT LiteralBase::Piece::Get( tensorflow::gtl::ArraySlice multi_index) const { CHECK(LayoutUtil::IsDenseArray(subshape())); return data()[IndexUtil::MultidimensionalIndexToLinearIndex( @@ -884,15 +1126,15 @@ NativeT Literal::Piece::Get( } template -void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { +void LiteralBase::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { CHECK(LayoutUtil::IsDenseArray(subshape())); data()[IndexUtil::MultidimensionalIndexToLinearIndex( subshape(), multi_index)] = value; } template -tensorflow::gtl::ArraySlice Literal::data( +tensorflow::gtl::ArraySlice LiteralBase::data( const ShapeIndex& shape_index) const { return piece(shape_index).data(); } @@ -904,13 +1146,13 @@ tensorflow::gtl::MutableArraySlice Literal::data( } template -inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, - const ShapeIndex& shape_index) const { +inline NativeT LiteralBase::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { return piece(shape_index).Get(multi_index); } template -inline NativeT Literal::Get( +inline NativeT LiteralBase::Get( tensorflow::gtl::ArraySlice multi_index) const { return root_piece().Get(multi_index); } @@ -1157,13 +1399,13 @@ template } template -NativeT Literal::GetFirstElement() const { +NativeT LiteralBase::GetFirstElement() const { return data().at(0); } template -NativeT Literal::GetSparseElement(int64 sparse_element_number, - const ShapeIndex& shape_index) const { +NativeT LiteralBase::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { CHECK( LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); return data(shape_index)[sparse_element_number]; @@ -1196,7 +1438,7 @@ template } template -void Literal::EachCell( +void LiteralBase::EachCell( std::function indices, NativeT value)> per_cell) const { @@ -1372,7 +1614,7 @@ template } template -std::unique_ptr Literal::Replicate(int64 times) const { +std::unique_ptr LiteralBase::Replicate(int64 times) const { DimensionVector bounds = {times}; bounds.reserve(shape().dimensions_size() + 1); for (int64 bound : shape().dimensions()) { @@ -1407,6 +1649,38 @@ std::unique_ptr Literal::Replicate(int64 times) const { return literal; } +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, + const std::function)>& generator) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + TF_RET_CHECK(shape.element_type() == type); + auto literal = MakeUnique(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { + return generator(indexes); + })); + return std::move(literal); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, E* engine, T mean, T stddev) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + std::normal_distribution generator(mean, stddev); + return CreateRandomLiteral( + shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { + return generator(*engine); + }); +} + +template +/* static */ StatusOr> Literal::CreateRandomLiteral( + const Shape& shape, T mean, T stddev) { + std::minstd_rand0 engine; + return CreateRandomLiteral(shape, &engine, mean, stddev); +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 61046784e05..77f979a0d70 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/layout_util.h" @@ -974,7 +975,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { Literal::CreateR1({2.0, 4.0}).get(), &nil_literal}); - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); @@ -985,7 +986,7 @@ TEST_F(LiteralUtilTest, CopyFromTuples) { /*src_shape_index=*/{})); // The matrix element should be unchanged. - EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(*matrix, LiteralSlice(*nested_tuple, {0})); // The tuple element should have been copied from 'tuple'. EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); @@ -1065,7 +1066,7 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1107,7 +1108,7 @@ TEST_F(LiteralUtilTest, PopulateParallel) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); auto generator = [&](ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. @@ -1373,36 +1374,36 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, LiteralViewTest) { +TEST_F(LiteralUtilTest, LiteralSliceTest) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); Literal nil(ShapeUtil::MakeNil()); - EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); - EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); - EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); - EXPECT_EQ(LiteralView::Create(nil, {}), nil); + EXPECT_EQ(LiteralSlice(*scalar, {}), *scalar); + EXPECT_EQ(LiteralSlice(*matrix, {}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralSlice(nil, {}), nil); - EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); - EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + EXPECT_EQ(LiteralSlice(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralSlice(*tuple, {1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); - EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralSlice(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralSlice(*nested_tuple, {1}), *scalar); } -TEST_F(LiteralUtilTest, MutatingLiteralView) { +TEST_F(LiteralUtilTest, MutatingLiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); // Verify that changing the underlying data beneath the view changes the // data of the view itself. - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); EXPECT_EQ( nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), 1.0f); @@ -1418,19 +1419,57 @@ TEST_F(LiteralUtilTest, MutatingLiteralView) { 555.0f); } -TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { +TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - const auto nested_tuple_view = LiteralView::Create(*nested_tuple); - const auto tuple_view = - LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); - const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + const auto nested_tuple_view = LiteralSlice(*nested_tuple); + const auto tuple_view = LiteralSlice(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralSlice(tuple_view, /*view_root=*/{1}); EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { + std::vector int64_values = {1, 2, 3}; + const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); + + BorrowingLiteral literal(reinterpret_cast(int64_values.data()), + literal_shape); + + EXPECT_EQ(literal.Get({0}), 1); + EXPECT_EQ(literal.Get({1}), 2); + EXPECT_EQ(literal.Get({2}), 3); +} + +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { + std::vector one_two_three = {1, 2, 3}; + const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); + + std::vector hundred = {100}; + const Shape hundred_shape = ShapeUtil::MakeShape(S64, {1}); + + std::vector src_buf_ptrs; + src_buf_ptrs.emplace_back( + reinterpret_cast(one_two_three.data())); + src_buf_ptrs.emplace_back(reinterpret_cast(hundred.data())); + auto literal_tuple = BorrowingLiteral( + src_buf_ptrs, + ShapeUtil::MakeTupleShape({one_two_three_shape, hundred_shape})); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{0}), + 1); + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{0}, /*shape_index=*/{1}), + 100); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{1}, /*shape_index=*/{0}), + 2); + + EXPECT_EQ(literal_tuple.Get(/*multi_index=*/{2}, /*shape_index=*/{0}), + 3); +} + TEST_F(LiteralUtilTest, LiteralMove) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); @@ -1533,11 +1572,11 @@ TEST_F(LiteralUtilTest, LiteralMoveAssignment) { EXPECT_EQ(literal.Get({1, 1}), 4.0); } -TEST_F(LiteralUtilTest, LiteralViewCopy) { +TEST_F(LiteralUtilTest, LiteralSliceCopy) { std::unique_ptr matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - const auto matrix_view = LiteralView::Create(*matrix); - LiteralView matrix_view_copy(matrix_view); + const auto matrix_view = LiteralSlice(*matrix); + LiteralSlice matrix_view_copy(matrix_view); EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 8db8c6f3de8..3c74e070da5 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -86,11 +86,10 @@ const typename Collection::value_type::second_type& FindOrDefault( // Inserts the key-value pair into the collection. Dies if key was already // present. -template -void InsertOrDie(Collection* const collection, - const typename Collection::value_type::first_type& key, - const typename Collection::value_type::second_type& data) { - auto p = collection->insert(std::make_pair(key, data)); +template +void InsertOrDie(Collection* const collection, Key&& key, Value&& value) { + auto p = collection->insert( + std::make_pair(std::forward(key), std::forward(value))); CHECK(p.second) << "duplicate key: " << key; } @@ -101,9 +100,10 @@ bool ContainsKey(const Collection& collection, const Key& key) { } // Inserts `value` into `set`. Dies if it was already present. -template -void InsertOrDie(Set* const set, const typename Set::value_type& value) { - CHECK(set->insert(value).second) << "duplicate value: " << value; +template +void InsertOrDie(Set* const set, Value&& value) { + CHECK(set->insert(std::forward(value)).second) + << "duplicate value: " << value; } } // namespace xla diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h index c58c19db2ca..bfcdfc62f95 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/ptr_util.h @@ -28,26 +28,8 @@ limitations under the License. #include "tensorflow/core/util/ptr_util.h" namespace xla { - -template -std::unique_ptr WrapUnique(T* ptr) { - return tensorflow::WrapUnique(ptr); -} - -template -typename tensorflow::helper::MakeUniqueResult::scalar MakeUnique( - Args&&... args) { - return tensorflow::MakeUnique(std::forward(args)...); -} - -// Overload for array of unknown bound. -// The allocation of arrays needs to use the array form of new, -// and cannot take element constructor arguments. -template -typename tensorflow::helper::MakeUniqueResult::array MakeUnique(size_t n) { - return tensorflow::MakeUnique(n); -} - +using tensorflow::MakeUnique; +using tensorflow::WrapUnique; } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 0517a5502e6..932cce943f7 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -20,6 +20,7 @@ py_test( srcs = ["xla_client_test.py"], main = "xla_client_test.py", srcs_version = "PY2AND3", + tags = ["no_oss"], deps = [ ":xla_client", "//tensorflow/python:platform_test", @@ -48,9 +49,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 2bacc6a9142..cb4dc1782b6 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/python/local_computation_builder.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/platform/default/thread_annotations.h" +#include "tensorflow/core/platform/thread_annotations.h" namespace xla { @@ -89,40 +90,40 @@ StatusOr> TransferFromOutfeedLocalReplica( return client->TransferFromOutfeedLocal(shape, device_ordinal); } -LocalShapedBuffer::LocalShapedBuffer( - std::unique_ptr shaped_buffer) +LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer) : shaped_buffer_(std::move(shaped_buffer)) {} -const std::unique_ptr& LocalShapedBuffer::shaped_buffer() - const { - return shaped_buffer_; +const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { + return &shaped_buffer_; } -static StatusOr> ToBuffer( - LocalClient* client, int device_ordinal, const Literal& arg) { +static StatusOr ToBuffer(LocalClient* client, + int device_ordinal, + const Literal& arg) { return client->LiteralToShapedBuffer(arg, device_ordinal, client->backend().memory_allocator()); } /* static */ -LocalShapedBuffer* LocalShapedBuffer::FromLiteral( +StatusOr LocalShapedBuffer::FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout) { LocalClient* client = GetOrCreateLocalClient(); - std::unique_ptr buf; - if (shape_with_layout) { - std::unique_ptr relaid = - argument.Relayout(shape_with_layout.value()); - buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie(); - } else { - buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); - } - return new LocalShapedBuffer(std::move(buf)); + StatusOr buf = [&] { + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + return ToBuffer(client, /*device_ordinal=*/0, *relaid); + } + return ToBuffer(client, /*device_ordinal=*/0, argument); + }(); + TF_RETURN_IF_ERROR(buf.status()); + return new LocalShapedBuffer(std::move(buf).ValueOrDie()); } -std::unique_ptr LocalShapedBuffer::ToLiteral() const { +StatusOr> LocalShapedBuffer::ToLiteral() const { LocalClient* client = GetOrCreateLocalClient(); - return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); + return client->ShapedBufferToLiteral(*shaped_buffer()); } CompiledLocalComputation::CompiledLocalComputation( @@ -158,14 +159,14 @@ StatusOr> CompiledLocalComputation::Execute( << device_ordinal; // Transfer arguments in - std::vector> scoped_buffers; + std::vector scoped_buffers; scoped_buffers.reserve(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { const Literal& argument = arguments[i]; const tensorflow::gtl::optional& shape_with_layout = shapes_with_layout[i]; - StatusOr> pushed; + StatusOr pushed; if (shape_with_layout) { std::unique_ptr relaid = argument.Relayout(shape_with_layout.value()); @@ -185,7 +186,7 @@ StatusOr> CompiledLocalComputation::Execute( std::vector argument_buffers; argument_buffers.reserve(scoped_buffers.size()); for (auto& buffer : scoped_buffers) { - argument_buffers.push_back(buffer.get()); + argument_buffers.push_back(&buffer); } DeviceAssignment device_assignment = @@ -197,12 +198,10 @@ StatusOr> CompiledLocalComputation::Execute( ExecutableRunOptions options; options.set_device_ordinal(device_ordinal); options.set_allocator(client->backend().memory_allocator()); - options.set_inter_op_thread_pool( - client->backend().inter_op_thread_pool()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - StatusOr> result_buffer_status = + StatusOr result_buffer_status = executable_->Run(argument_buffers, options); if (!result_buffer_status.ok()) { results[replica] = result_buffer_status.status(); @@ -210,8 +209,8 @@ StatusOr> CompiledLocalComputation::Execute( } // Transfer result out - results[replica] = - client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie()); + results[replica] = client->ShapedBufferToLiteral( + std::move(result_buffer_status).ValueOrDie()); }); } } @@ -236,22 +235,21 @@ LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( std::vector argument_buffers; argument_buffers.reserve(argument_handles.size()); for (auto& handle : argument_handles) { - argument_buffers.push_back(handle->shaped_buffer().get()); + argument_buffers.push_back(handle->shaped_buffer()); } // Execute ExecutableRunOptions options; options.set_allocator(client->backend().memory_allocator()); - options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); options.set_intra_op_thread_pool( client->backend().eigen_intra_op_thread_pool_device()); - std::unique_ptr result_buffer = + ScopedShapedBuffer result_buffer = executable_->Run(argument_buffers, options).ConsumeValueOrDie(); return new LocalShapedBuffer(std::move(result_buffer)); } -LocalComputation::LocalComputation(Computation computation) +LocalComputation::LocalComputation(XlaComputation computation) : computation_(std::move(computation)) {} StatusOr LocalComputation::Compile( @@ -274,7 +272,7 @@ StatusOr LocalComputation::Compile( return new CompiledLocalComputation(std::move(local_executable)); } -const Computation& LocalComputation::computation() const { +const XlaComputation& LocalComputation::computation() const { return computation_; } @@ -284,8 +282,12 @@ StatusOr LocalComputation::GetReturnValueShape() const { return std::move(*program_shape.mutable_result()); } +LocalOp::LocalOp(const XlaOp& op) : op_(op) {} + +const XlaOp& LocalOp::op() const { return op_; } + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) - : builder_(GetOrCreateLocalClient(), computation_name) {} + : builder_(computation_name) {} void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { builder_.SetOpMetadata(metadata); @@ -294,19 +296,21 @@ void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } StatusOr LocalComputationBuilder::Build() { - TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build()); return new LocalComputation(std::move(computation)); } -ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, - const Shape& shape, - const string& name) { +LocalOp LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { return builder_.Parameter(parameter_number, shape, name); } std::unique_ptr LocalComputationBuilder::GetShape( - const ComputationDataHandle& operand) { - return builder_.GetShape(operand).ConsumeValueOrDie(); + const LocalOp& operand) { + auto result = MakeUnique(); + *result = builder_.GetShape(operand.op()).ValueOrDie(); + return result; } StatusOr LocalComputationBuilder::GetReturnValueShape() { @@ -314,222 +318,236 @@ StatusOr LocalComputationBuilder::GetReturnValueShape() { return program_shape.result(); } -ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { +LocalOp LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } -void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, +void LocalComputationBuilder::Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config) { - builder_.Outfeed(operand, shape, outfeed_config); + builder_.Outfeed(operand.op(), shape, outfeed_config); } -ComputationDataHandle LocalComputationBuilder::ConstantLiteral( - const Literal& literal) { +LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) { return builder_.ConstantLiteral(literal); } -ComputationDataHandle LocalComputationBuilder::Broadcast( - const ComputationDataHandle& operand, +LocalOp LocalComputationBuilder::Broadcast( + const LocalOp& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - return builder_.Broadcast(operand, broadcast_sizes); + return builder_.Broadcast(operand.op(), broadcast_sizes); } -ComputationDataHandle LocalComputationBuilder::Pad( - const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config) { - return builder_.Pad(operand, padding_value, padding_config); +LocalOp LocalComputationBuilder::Pad(const LocalOp& operand, + const LocalOp& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand.op(), padding_value.op(), padding_config); } -ComputationDataHandle LocalComputationBuilder::Reshape( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, +LocalOp LocalComputationBuilder::Reshape( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - return builder_.Reshape(operand, dimensions, new_sizes); + return builder_.Reshape(operand.op(), dimensions, new_sizes); } -ComputationDataHandle LocalComputationBuilder::Collapse( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Collapse(operand, dimensions); +LocalOp LocalComputationBuilder::Collapse( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( - const ComputationDataHandle& operand) { - return builder_.CrossReplicaSum(operand); +LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) { + return builder_.CrossReplicaSum(operand.op()); } -ComputationDataHandle LocalComputationBuilder::Slice( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, +LocalOp LocalComputationBuilder::Slice( + const LocalOp& operand, tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - return builder_.Slice(operand, start_indices, limit_indices, strides); + return builder_.Slice(operand.op(), start_indices, limit_indices, strides); } -ComputationDataHandle LocalComputationBuilder::SliceInDim( - const ComputationDataHandle& operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno) { - return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand, + int64 start_index, + int64 limit_index, int64 stride, + int64 dimno) { + return builder_.SliceInDim(operand.op(), start_index, limit_index, stride, + dimno); } -ComputationDataHandle LocalComputationBuilder::DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, +LocalOp LocalComputationBuilder::DynamicSlice( + const LocalOp& operand, const LocalOp& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - return builder_.DynamicSlice(operand, start_indices, slice_sizes); + return builder_.DynamicSlice(operand.op(), start_indices.op(), slice_sizes); } -ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices) { - return builder_.DynamicUpdateSlice(operand, update, start_indices); +LocalOp LocalComputationBuilder::DynamicUpdateSlice( + const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices) { + return builder_.DynamicUpdateSlice(operand.op(), update.op(), + start_indices.op()); } -ComputationDataHandle LocalComputationBuilder::ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension) { - return builder_.ConcatInDim(operands, dimension); +LocalOp LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, int64 dimension) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.ConcatInDim(xla_ops, dimension); } -ComputationDataHandle -LocalComputationBuilder::SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, +LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter) { + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter) { return builder_.SelectAndScatterWithGeneralPadding( - operand, select.computation(), window_dimensions, window_strides, padding, - source, init_value, scatter.computation()); + operand.op(), select.computation(), window_dimensions, window_strides, + padding, source.op(), init_value.op(), scatter.computation()); } -ComputationDataHandle LocalComputationBuilder::Tuple( - tensorflow::gtl::ArraySlice elements) { - return builder_.Tuple(elements); +LocalOp LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + std::vector xla_ops; + xla_ops.reserve(elements.size()); + for (const auto& op : elements) { + xla_ops.push_back(op.op()); + } + + return builder_.Tuple(xla_ops); } -ComputationDataHandle LocalComputationBuilder::GetTupleElement( - const ComputationDataHandle& tuple_data, int64 index) { - return builder_.GetTupleElement(tuple_data, index); +LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data, + int64 index) { + return builder_.GetTupleElement(tuple_data.op(), index); } -ComputationDataHandle LocalComputationBuilder::Dot( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return builder_.Dot(lhs, rhs); +LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) { + return builder_.Dot(lhs.op(), rhs.op()); } -ComputationDataHandle LocalComputationBuilder::DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::DotGeneral( + const LocalOp& lhs, const LocalOp& rhs, const DotDimensionNumbers& dimension_numbers) { - return builder_.DotGeneral(lhs, rhs, dimension_numbers); + return builder_.DotGeneral(lhs.op(), rhs.op(), dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, +LocalOp LocalComputationBuilder::ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers) { - return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation, + return builder_.ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, + padding, lhs_dilation, rhs_dilation, dimension_numbers); } -ComputationDataHandle LocalComputationBuilder::ConvertElementType( - const ComputationDataHandle& operand, PrimitiveType new_element_type) { - return builder_.ConvertElementType(operand, new_element_type); +LocalOp LocalComputationBuilder::ConvertElementType( + const LocalOp& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand.op(), new_element_type); } -ComputationDataHandle LocalComputationBuilder::Call( +LocalOp LocalComputationBuilder::Call( const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands) { - return builder_.Call(local_computation.computation(), operands); + tensorflow::gtl::ArraySlice operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + return builder_.Call(local_computation.computation(), xla_ops); } -ComputationDataHandle LocalComputationBuilder::Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation) { - return builder_.Transpose(operand, permutation); +LocalOp LocalComputationBuilder::Transpose( + const LocalOp& operand, tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand.op(), permutation); } -ComputationDataHandle LocalComputationBuilder::Rev( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions) { - return builder_.Rev(operand, dimensions); +LocalOp LocalComputationBuilder::Rev( + const LocalOp& operand, tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand.op(), dimensions); } -ComputationDataHandle LocalComputationBuilder::Map( - tensorflow::gtl::ArraySlice operands, +LocalOp LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands) { - return builder_.Map(operands, local_computation.computation(), dimensions, - static_operands); + tensorflow::gtl::ArraySlice static_operands) { + std::vector xla_ops; + xla_ops.reserve(operands.size()); + for (const auto& op : operands) { + xla_ops.push_back(op.op()); + } + + std::vector static_xla_ops; + static_xla_ops.reserve(static_operands.size()); + for (const auto& op : static_operands) { + static_xla_ops.push_back(op.op()); + } + + return builder_.Map(xla_ops, local_computation.computation(), dimensions, + static_xla_ops); } -ComputationDataHandle LocalComputationBuilder::Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::Reduce( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - return builder_.Reduce(operand, init_value, local_computation.computation(), - dimensions_to_reduce); + return builder_.Reduce(operand.op(), init_value.op(), + local_computation.computation(), dimensions_to_reduce); } -ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, +LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { return builder_.ReduceWindowWithGeneralPadding( - operand, init_value, local_computation.computation(), window_dimensions, - window_strides, padding); + operand.op(), init_value.op(), local_computation.computation(), + window_dimensions, window_strides, padding); } -ComputationDataHandle LocalComputationBuilder::RngNormal( - const ComputationDataHandle& mu, const ComputationDataHandle& sigma, - const Shape& shape) { - return builder_.RngNormal(mu, sigma, shape); +LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, + const LocalOp& sigma, + const Shape& shape) { + return builder_.RngNormal(mu.op(), sigma.op(), shape); } -ComputationDataHandle LocalComputationBuilder::RngUniform( - const ComputationDataHandle& a, const ComputationDataHandle& b, - const Shape& shape) { - return builder_.RngUniform(a, b, shape); +LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b, + const Shape& shape) { + return builder_.RngUniform(a.op(), b.op(), shape); } -ComputationDataHandle LocalComputationBuilder::While( - const LocalComputation& condition, const LocalComputation& body, - const ComputationDataHandle& init) { - return builder_.While(condition.computation(), body.computation(), init); +LocalOp LocalComputationBuilder::While(const LocalComputation& condition, + const LocalComputation& body, + const LocalOp& init) { + return builder_.While(condition.computation(), body.computation(), init.op()); } -ComputationDataHandle LocalComputationBuilder::Conditional( - const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, +LocalOp LocalComputationBuilder::Conditional( + const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, const LocalOp& false_operand, const LocalComputation& false_computation) { - return builder_.Conditional(predicate, true_operand, - true_computation.computation(), false_operand, - false_computation.computation()); + return builder_.Conditional( + predicate.op(), true_operand.op(), true_computation.computation(), + false_operand.op(), false_computation.computation()); } -StatusOr LocalComputationBuilder::IsConstant( - const ComputationDataHandle& operand, int64 num_parameters) { - return builder_.IsConstant(operand, num_parameters); +StatusOr LocalComputationBuilder::IsConstant(const LocalOp& operand) { + return builder_.IsConstant(operand.op()); } -StatusOr> LocalComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters) { - return builder_.ComputeConstant(operand, output_layout, parameters); +StatusOr LocalComputationBuilder::BuildConstantSubGraph( + const LocalOp& operand) { + TF_ASSIGN_OR_RETURN(XlaComputation computation, + builder_.BuildConstantSubGraph(operand.op())); + return new LocalComputation(std::move(computation)); } #define _FORWARD(method_name, return_sig, args_sig, args) \ @@ -537,23 +555,19 @@ StatusOr> LocalComputationBuilder::ComputeConstant( return builder_.method_name args; \ } -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand), (operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op())) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions), \ - (lhs, rhs, broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs.op(), rhs.op(), broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs), \ - (lhs, rhs, ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \ + (lhs.op(), rhs.op(), ehs.op())) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 31046e60f11..a06b85b4ea2 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -17,9 +17,10 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -59,15 +60,17 @@ StatusOr > TransferFromOutfeedLocalReplica( // client. class LocalShapedBuffer { public: - static LocalShapedBuffer* FromLiteral( + static StatusOr FromLiteral( const Literal& argument, const tensorflow::gtl::optional& shape_with_layout); - LocalShapedBuffer(std::unique_ptr shaped_buffer); - const std::unique_ptr& shaped_buffer() const; - std::unique_ptr ToLiteral() const; + + LocalShapedBuffer(ScopedShapedBuffer shaped_buffer); + const ScopedShapedBuffer* shaped_buffer() const; + + StatusOr > ToLiteral() const; private: - std::unique_ptr shaped_buffer_; + ScopedShapedBuffer shaped_buffer_; }; // Wraps a LocalExecutable produced by compiling a @@ -95,25 +98,37 @@ class CompiledLocalComputation { std::unique_ptr executable_; }; -// Wraps a Computation produced by a LocalComputationBuilder. The +// Wraps a XlaComputation produced by a LocalComputationBuilder. The // Compile method compiles the computation to a (local) executable via // the client library's local client. This class is intended to be // made available to Python via SWIG. class LocalComputation { public: - LocalComputation(Computation computation); + LocalComputation(XlaComputation computation); StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); - const Computation& computation() const; + const XlaComputation& computation() const; // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; private: - Computation computation_; + XlaComputation computation_; +}; + +// Wraps a XlaOp produced by a LocalComputationBuilder. This class is intended +// to be made available to Python via SWIG. +class LocalOp { + public: + LocalOp(const XlaOp& op); + + const XlaOp& op() const; + + private: + XlaOp op_; }; // Wraps the ComputationBuilder API in order to: @@ -133,166 +148,137 @@ class LocalComputationBuilder { // Returns an owned LocalComputation to the caller on success. StatusOr Build(); - ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, - const string& name); + LocalOp Parameter(int64 parameter_number, const Shape& shape, + const string& name); - std::unique_ptr GetShape(const ComputationDataHandle& operand); + std::unique_ptr GetShape(const LocalOp& operand); // Returns the shape of the current return value for the computation. StatusOr GetReturnValueShape(); - ComputationDataHandle Infeed(const Shape& shape); + LocalOp Infeed(const Shape& shape); - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + void Outfeed(const LocalOp& operand, const Shape& shape, const string& outfeed_config); - ComputationDataHandle ConstantLiteral(const Literal& literal); + LocalOp ConstantLiteral(const Literal& literal); - ComputationDataHandle Broadcast( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice broadcast_sizes); + LocalOp Broadcast(const LocalOp& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); - ComputationDataHandle Pad(const ComputationDataHandle& operand, - const ComputationDataHandle& padding_value, - const PaddingConfig& padding_config); + LocalOp Pad(const LocalOp& operand, const LocalOp& padding_value, + const PaddingConfig& padding_config); - ComputationDataHandle Reshape(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice new_sizes); + LocalOp Reshape(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); - ComputationDataHandle Collapse(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Collapse(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + LocalOp CrossReplicaSum(const LocalOp& operand); - ComputationDataHandle Slice(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices, - tensorflow::gtl::ArraySlice strides); + LocalOp Slice(const LocalOp& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); - ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, - int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + LocalOp SliceInDim(const LocalOp& operand, int64 start_index, + int64 limit_index, int64 stride, int64 dimno); - ComputationDataHandle DynamicSlice( - const ComputationDataHandle& operand, - const ComputationDataHandle& start_indices, - tensorflow::gtl::ArraySlice slice_sizes); + LocalOp DynamicSlice(const LocalOp& operand, const LocalOp& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); - ComputationDataHandle DynamicUpdateSlice( - const ComputationDataHandle& operand, const ComputationDataHandle& update, - const ComputationDataHandle& start_indices); + LocalOp DynamicUpdateSlice(const LocalOp& operand, const LocalOp& update, + const LocalOp& start_indices); - ComputationDataHandle ConcatInDim( - tensorflow::gtl::ArraySlice operands, - int64 dimension); + LocalOp ConcatInDim(tensorflow::gtl::ArraySlice operands, + int64 dimension); - ComputationDataHandle SelectAndScatterWithGeneralPadding( - const ComputationDataHandle& operand, const LocalComputation& select, + LocalOp SelectAndScatterWithGeneralPadding( + const LocalOp& operand, const LocalComputation& select, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, - const ComputationDataHandle& source, - const ComputationDataHandle& init_value, const LocalComputation& scatter); + const LocalOp& source, const LocalOp& init_value, + const LocalComputation& scatter); - ComputationDataHandle Tuple( - tensorflow::gtl::ArraySlice elements); + LocalOp Tuple(tensorflow::gtl::ArraySlice elements); - ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, - int64 index); + LocalOp GetTupleElement(const LocalOp& tuple_data, int64 index); - ComputationDataHandle Dot(const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs); + LocalOp Dot(const LocalOp& lhs, const LocalOp& rhs); - ComputationDataHandle DotGeneral( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, - const DotDimensionNumbers& dimension_numbers); + LocalOp DotGeneral(const LocalOp& lhs, const LocalOp& rhs, + const DotDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvGeneralDilated( - const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + LocalOp ConvGeneralDilated( + const LocalOp& lhs, const LocalOp& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding, tensorflow::gtl::ArraySlice lhs_dilation, tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, - PrimitiveType new_element_type); + LocalOp ConvertElementType(const LocalOp& operand, + PrimitiveType new_element_type); - ComputationDataHandle Call( - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice operands); + LocalOp Call(const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); - ComputationDataHandle Transpose( - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice permutation); + LocalOp Transpose(const LocalOp& operand, + tensorflow::gtl::ArraySlice permutation); - ComputationDataHandle Rev(const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice dimensions); + LocalOp Rev(const LocalOp& operand, + tensorflow::gtl::ArraySlice dimensions); - ComputationDataHandle Map( - tensorflow::gtl::ArraySlice operands, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions, - tensorflow::gtl::ArraySlice static_operands); + LocalOp Map(tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); - ComputationDataHandle Reduce( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, - const LocalComputation& local_computation, - tensorflow::gtl::ArraySlice dimensions_to_reduce); + LocalOp Reduce(const LocalOp& operand, const LocalOp& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); - ComputationDataHandle ReduceWindowWithGeneralPadding( - const ComputationDataHandle& operand, - const ComputationDataHandle& init_value, + LocalOp ReduceWindowWithGeneralPadding( + const LocalOp& operand, const LocalOp& init_value, const LocalComputation& local_computation, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice > padding); - ComputationDataHandle RngNormal(const ComputationDataHandle& mu, - const ComputationDataHandle& sigma, - const Shape& shape); + LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, + const Shape& shape); - ComputationDataHandle RngUniform(const ComputationDataHandle& a, - const ComputationDataHandle& b, - const Shape& shape); + LocalOp RngUniform(const LocalOp& a, const LocalOp& b, const Shape& shape); - ComputationDataHandle While(const LocalComputation& condition, - const LocalComputation& body, - const ComputationDataHandle& init); + LocalOp While(const LocalComputation& condition, const LocalComputation& body, + const LocalOp& init); - ComputationDataHandle Conditional(const ComputationDataHandle& predicate, - const ComputationDataHandle& true_operand, - const LocalComputation& true_computation, - const ComputationDataHandle& false_operand, - const LocalComputation& false_computation); + LocalOp Conditional(const LocalOp& predicate, const LocalOp& true_operand, + const LocalComputation& true_computation, + const LocalOp& false_operand, + const LocalComputation& false_computation); - StatusOr IsConstant(const ComputationDataHandle& operand, - int64 num_parameters); + StatusOr IsConstant(const LocalOp& operand); - StatusOr > ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout, - tensorflow::gtl::ArraySlice parameters); + StatusOr BuildConstantSubGraph(const LocalOp& operand); #define _FORWARD(method_name, return_sig, args_sig) \ return_sig method_name args_sig; -#define _FORWARD_UNOP(method_name) \ - _FORWARD(method_name, ComputationDataHandle, \ - (const ComputationDataHandle& operand)) +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, LocalOp, (const LocalOp& operand)) -#define _FORWARD_BINOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - tensorflow::gtl::ArraySlice broadcast_dimensions)) +#define _FORWARD_BINOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) -#define _FORWARD_TRIOP(method_name) \ - _FORWARD( \ - method_name, ComputationDataHandle, \ - (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ - const ComputationDataHandle& ehs)) +#define _FORWARD_TRIOP(method_name) \ + _FORWARD(method_name, LocalOp, \ + (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs)) _FORWARD_TRIOP(Select) _FORWARD_TRIOP(Clamp) @@ -336,7 +322,7 @@ class LocalComputationBuilder { #undef _FORWARD_TRIOP private: - ComputationBuilder builder_; + XlaBuilder builder_; }; // Functions for freeing resources from the Python side. diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index ac792e8189b..04c56bbba95 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -22,9 +22,8 @@ limitations under the License. // // C++ Python // -------------------------------------+--------------------------------------- -// ComputationDataHandle <-> int // ArraySlice <- sequence of int -// ArraySlice <- sequence of int +// ArraySlice <- sequence of LocalOp // Literal <-> (nested tuple of) numpy ndarray // std::vector <- sequence of (nested tuple of) ndarray // Shape -> pair holding (dtype, dimensions) @@ -91,12 +90,9 @@ limitations under the License. // One central reason for the Python-side indirection is that the // Python-side objects produced by the typemaps in this file are // further packaged up by xla_client before being passed on. For -// instance, xla_client wraps the long produced for a C++ -// ComputationDataHandle in a Python ComputationDataHandle proto, -// rather than exposing a raw long outside of the client. Similarly, -// the Python pair produced for a C++ Shape is further wrapped in a -// Python class (xla_client.Shape) so as not to expose the raw pair -// externally. +// instance, the Python pair produced for a C++ Shape is further +// wrapped in a Python class (xla_client.Shape) so as not to expose +// the raw pair externally. // // Other SWIG object wrappers (e.g. of LocalComputation) are further // wrapped by xla_client in order to set up a custom destructor that @@ -124,6 +120,7 @@ using namespace xla; using namespace xla::swig; namespace xla { + namespace swig { bool GetIntAttr(PyObject* o, const char* field, int64* result) { @@ -177,21 +174,6 @@ bool HandleStringAttribute(PyObject* o, tensorflow::ImportNumpy(); %} -// ComputationDataHandle - -%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { - const int64 handle = numpy::PyIntOrPyLongToLong($input); - if (handle == -1 && PyErr_Occurred()) { - SWIG_fail; - } - temp.set_handle(handle); - $1 = &temp; -} - -%typemap(out) ComputationDataHandle { - $result = numpy::LongToPyIntOrPyLong($1.handle()); -} - %typemap(out) StatusOr { if ($1.ok()) { auto* value = $1.ValueOrDie(); @@ -205,6 +187,19 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBuffer*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); @@ -288,33 +283,23 @@ tensorflow::ImportNumpy(); $1 = temps; } -// ComputationDataHandle +// ArraySlice -%typemap(in) tensorflow::gtl::ArraySlice - (std::vector temps) { +%typemap(in) tensorflow::gtl::ArraySlice( + std::vector temps) { if (!PySequence_Check($input)) { PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); SWIG_fail; } const int size = PySequence_Size($input); - temps.resize(size); for (int i = 0; i < size; ++i) { PyObject* o = PySequence_GetItem($input, i); - PyObject* py_int = numpy::PyNumberToPyInt(o); - if (!py_int) { - PyErr_SetString( - PyExc_TypeError, - "Argument sequence element cannot be converted to int"); + LocalOp* op; + if ((SWIG_ConvertPtr(o, (void**)&op, $descriptor(xla::swig::LocalOp*), + SWIG_POINTER_EXCEPTION)) == -1) { SWIG_fail; } - const int64 handle = numpy::PyIntOrPyLongToLong(py_int); - if (handle == -1 && PyErr_Occurred()) { - Py_DECREF(py_int); - Py_DECREF(o); - SWIG_fail; - } - temps[i].set_handle(handle); - Py_DECREF(py_int); + temps.push_back(*op); Py_DECREF(o); } $1 = temps; @@ -921,6 +906,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc index eec48479c92..68648a3a176 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.cc +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -181,16 +181,6 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { PyObjectCppRepr(o).c_str()); }; - auto get_attr = [o, &error](const string& field) -> StatusOr { - PyObject* result = - PyObject_GetAttrString(o, const_cast(field.c_str())); - if (result == nullptr) { - return error(tensorflow::strings::StrCat( - "Failed to get attribute of Shape object:", field)); - } - return result; - }; - auto call_method = [o, &error](const string& method) -> StatusOr { PyObject* result = PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); @@ -202,12 +192,16 @@ StatusOr XlaShapeFromPyShape(PyObject* o) { }; PyObject* np_type; - TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype")); + TF_ASSIGN_OR_RETURN(np_type, call_method("numpy_dtype")); if (np_type->ob_type != &PyArrayDescr_Type) { - return error("Shape attribute np_dtype is not an integer numpy dtype"); + return error( + "Return value of shape method numpy_dtype " + "is not an integer numpy dtype"); } if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { - return error("Shape attribute np_dtype is not a valid integer numpy dtype"); + return error( + "Return value of shape method numpy_dtype " + "is not a valid integer numpy dtype"); } const PrimitiveType element_type = NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); @@ -346,13 +340,13 @@ StatusOr OpMetadataFromPyObject(PyObject* o) { return result; } -PyObject* PyObjectFromXlaLiteral(const Literal& literal) { +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { int num_elements = ShapeUtil::TupleElementCount(literal.shape()); PyObject* tuple = PyTuple_New(num_elements); for (int i = 0; i < num_elements; i++) { - PyTuple_SET_ITEM( - tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + PyTuple_SET_ITEM(tuple, i, + PyObjectFromXlaLiteral(LiteralSlice(literal, {i}))); } return tuple; } else { @@ -437,7 +431,7 @@ Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, return Status::OK(); } -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array) { switch (np_type) { case NPY_BOOL: diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h index 9656cb1c31c..64f0aae0f97 100644 --- a/tensorflow/compiler/xla/python/numpy_bridge.h +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -74,7 +74,7 @@ StatusOr OpMetadataFromPyObject(PyObject* o); // array data. // // The return value is a new reference. -PyObject* PyObjectFromXlaLiteral(const Literal& literal); +PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal); // Converts a Numpy ndarray or a nested Python tuple thereof to a // corresponding XLA literal. @@ -90,7 +90,7 @@ StatusOr > XlaLiteralFromPyObject(PyObject* o); Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, Literal* literal); -void CopyLiteralToNumpyArray(int np_type, const Literal& literal, +void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal, PyArrayObject* py_array); template @@ -101,7 +101,8 @@ void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { } template -void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { +void CopyLiteralToNumpyArray(const LiteralSlice& literal, + PyArrayObject* py_array) { NativeT* dest = static_cast(PyArray_DATA(py_array)); auto source = literal.data(); std::copy(source.begin(), source.end(), dest); diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 9c81f6439d0..1d5b75d1bee 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -166,14 +166,14 @@ class LocalBuffer(object): self._delete = c_api.DeleteLocalShapedBuffer @staticmethod - def from_py(npval, layout_fn=None): - npval = require_numpy_array_layout(npval) + def from_pyval(pyval, layout_fn=None): + pyval = require_numpy_array_layout(pyval) if layout_fn: - shape = Shape.from_numpy(npval) + shape = Shape.from_pyval(pyval) shape = shape.map_leaves(layout_fn) else: shape = None - return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(pyval, shape)) def to_py(self): return self.c_local_shaped_buffer.ToLiteral() @@ -191,53 +191,104 @@ class LocalBuffer(object): class Shape(object): - """XLA shape. + """Represents an XLA shape. - Represents an XLA shape by a corresponding Python/Numpy type and a - list of dimensions, which are themselves Shapes in case this one - represents an XLA tuple. + A shape is either an array shape, having rank-many integer + dimensions and an element type (represented by a Numpy dtype), or it + is a tuple shape, having a shape for every tuple component: + + type shape = + TupleShape of shape list + | ArrayShape of { dimensions: int list; element_type: dtype } + + Callers are expected to instantiate this class only via the static + constructors: tuple_shape, array_shape, and from_pyval. """ - def __init__(self, np_dtype, dimensions, minor_to_major=None): + @staticmethod + def tuple_shape(tuple_shapes): + """Construct a tuple shape.""" + if (not isinstance(tuple_shapes, (tuple, list)) or + not all(isinstance(t, Shape) for t in tuple_shapes)): + raise TypeError('tuple_shapes must be a tuple of Shapes') + return Shape(tuple_shapes, tuple) + + @staticmethod + def array_shape(element_type, dimensions, minor_to_major=None): + """Construct an array shape.""" + if (not isinstance(dimensions, tuple) or + not all(isinstance(i, int) for i in dimensions)): + dimensions = tuple(int(i) for i in dimensions) + return Shape(dimensions, np.dtype(element_type), + minor_to_major=minor_to_major) + + @staticmethod + def from_pyval(pyval): + def convert(pyval): + if isinstance(pyval, tuple): + return Shape.tuple_shape(tuple(convert(elt) for elt in pyval)) + else: + pyval = require_numpy_array_layout(pyval) + return Shape.array_shape(pyval.dtype, np.shape(pyval)) + return convert(pyval) + + def __init__(self, dimensions, dtype, minor_to_major=None): assert isinstance(dimensions, tuple) - self.np_dtype = np_dtype self._dimensions = dimensions + self._dtype = dtype + self._is_tuple = dtype == tuple self._minor_to_major = minor_to_major self._check_minor_to_major() def __eq__(self, other): # pylint: disable=protected-access - return (self.np_dtype == other.np_dtype and + return (self._dtype == other._dtype and self._dimensions == other._dimensions and self._minor_to_major == other._minor_to_major) def __repr__(self): - return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' - 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, - self._minor_to_major) - - def element_type(self): - return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + return ('xla_client.Shape(_dtype={!r}, _dimensions={!r}, ' + '_is_tuple={!r}), _minor_to_major={!r}').format( + self._dtype, self._dimensions, self._is_tuple, + self._minor_to_major) def is_tuple(self): - return self.element_type() == xla_data_pb2.TUPLE + return self._is_tuple - def dimensions(self): - if self.is_tuple(): - raise ValueError('Tuple shape has no dimensions') - return self._dimensions - - def minor_to_major(self): - return self._minor_to_major + def is_array(self): + return not self._is_tuple def tuple_shapes(self): if not self.is_tuple(): - raise ValueError('Shape is not a tuple shape') + raise ValueError('not a tuple shape') + return self._dimensions + + def numpy_dtype(self): + """Like element_type(), but returns dtype('O') in case of a tuple shape.""" + if self.is_tuple(): + return np.dtype(np.object) + else: + return self.element_type() + + def xla_element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.numpy_dtype())] + + def element_type(self): + if not self.is_array(): + raise ValueError('not an array shape') + return self._dtype + + def dimensions(self): + if not self.is_array(): + raise ValueError('not an array shape') return self._dimensions def rank(self): return len(self.dimensions()) + def minor_to_major(self): + return self._minor_to_major + def map_leaves(self, f): """Map f over each leaf-level array subshape. @@ -250,7 +301,7 @@ class Shape(object): """ if self.is_tuple(): children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) - return Shape(np.dtype('O'), children) + return Shape.tuple_shape(children) else: mapped = f(self) return self if mapped is None else mapped @@ -264,44 +315,24 @@ class Shape(object): assert sorted(mtm) == range(len(mtm)), self def update_minor_to_major(self, minor_to_major): + if not self.is_array(): + raise ValueError('not an array shape') if not isinstance(minor_to_major, tuple): raise TypeError('minor_to_major must be a tuple') - updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) + updated = Shape.array_shape( + self.element_type(), self.dimensions(), minor_to_major) updated._check_minor_to_major() # pylint: disable=protected-access return updated - @staticmethod - def from_numpy(npval): - - def convert(npval): - if isinstance(npval, tuple): - return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) - else: - return Shape(npval.dtype, np.shape(npval)) - - return convert(require_numpy_array_layout(npval)) - def _wrap_shape(shape_info): dtype, dims = shape_info element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] if element_type == xla_data_pb2.TUPLE: - dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) - return Shape(dtype, dims) - - -def _wrap_data_handle(handle): - cdh = xla_data_pb2.ComputationDataHandle() - cdh.handle = handle - return cdh - - -def _unwrap_data_handle(handle_proto): - return handle_proto.handle - - -def _unwrap_data_handles(handle_protos): - return [_unwrap_data_handle(cdh) for cdh in handle_protos] + shapes = tuple(_wrap_shape(subshape_info) for subshape_info in dims) + return Shape.tuple_shape(shapes) + else: + return Shape.array_shape(dtype, dims) def require_numpy_array_layout(value): @@ -420,7 +451,7 @@ class LocalComputation(object): compile_options=None, layout_fn=None): return self.Compile( - argument_shapes=[Shape.from_numpy(arg) for arg in arguments], + argument_shapes=[Shape.from_pyval(arg) for arg in arguments], compile_options=compile_options, layout_fn=layout_fn) @@ -428,7 +459,7 @@ class LocalComputation(object): """Execute with Python values as arguments and return value.""" if not self.is_compiled: raise ValueError('Cannot execute an uncompiled local XLA computation.') - argument_shapes = [Shape.from_numpy(arg) for arg in arguments] + argument_shapes = [Shape.from_pyval(arg) for arg in arguments] if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes @@ -490,9 +521,9 @@ class ComputationBuilder(object): queue for subsequent use in the computation. Returns: - A ComputationDataHandle message. + A LocalOp. """ - return _wrap_data_handle(self._client.Infeed(shape)) + return self._client.Infeed(shape) def Outfeed(self, operand): """Enqueues an outfeed op onto the computation. @@ -500,9 +531,7 @@ class ComputationBuilder(object): Outfeed operations enqueue data, using the given operand, onto the XLA outfeed queue for subsequent dequeue via the client API. """ - self._client.Outfeed( - _unwrap_data_handle(operand), self.GetShape(operand), - ''.encode('utf-8')) + self._client.Outfeed(operand, self.GetShape(operand), ''.encode('utf-8')) def Constant(self, value): """Enqueues a constant op onto the computation. @@ -512,10 +541,10 @@ class ComputationBuilder(object): to one of the supported types. Returns: - A ComputationDataHandle message. + A LocalOp. """ value = require_numpy_array_layout(value) - return _wrap_data_handle(self._client.ConstantLiteral(value)) + return self._client.ConstantLiteral(value) def ConstantF32Scalar(self, value): """Convenience method to enqueue a scalar F32 constant op. @@ -524,7 +553,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float32)) @@ -535,7 +564,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.float64)) @@ -546,7 +575,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int32)) @@ -557,7 +586,7 @@ class ComputationBuilder(object): value: a floating-point number. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.int64)) @@ -568,7 +597,7 @@ class ComputationBuilder(object): value: a boolean value. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.Constant(np.array(value, dtype=np.bool)) @@ -584,15 +613,14 @@ class ComputationBuilder(object): parameters, use it for *all* parameters to avoid clashes. Returns: - A ComputationDataHandle message. + A LocalOp. """ if name is None: name = '' if parameter_num is None: parameter_num = next(self._parameter_numbering) - return _wrap_data_handle( - self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + return self._client.Parameter(parameter_num, shape, name.encode('utf8')) def ParameterFromNumpy(self, value, name=None, parameter_num=None): """Enqueues a Parameter op onto the computation. @@ -604,23 +632,22 @@ class ComputationBuilder(object): parameter_num: as in ParameterWithShape. Returns: - A ComputationDataHandle message. + A LocalOp. """ return self.ParameterWithShape( - Shape.from_numpy(value), name=name, parameter_num=parameter_num) + Shape.from_pyval(value), name=name, parameter_num=parameter_num) def Broadcast(self, operand, sizes): """Enqueues a broadcast operation onto the computation. Args: - operand: the operand ComputationDataHandle to broadcast. + operand: the operand LocalOp to broadcast. sizes: an iterable of broadcast sizes. Returns: - A ComputationDataHandle representing the added broadcast op. + A LocalOp representing the added broadcast op. """ - return _wrap_data_handle( - self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + return self._client.Broadcast(operand, sizes) def Concatenate(self, operands, dimension): """Enqueues a concatenate operation onto the computation. @@ -630,10 +657,9 @@ class ComputationBuilder(object): dimension: the dimension in which to perform the concatenation. Returns: - A ComputationDataHandle representing the added concatenate op. + A LocalOp representing the added concatenate op. """ - return _wrap_data_handle( - self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + return self._client.ConcatInDim(operands, dimension) def ConvertElementType(self, operand, new_element_type): """Enqueues an element type conversion operation onto the computation. @@ -643,14 +669,12 @@ class ComputationBuilder(object): new_element_type: the target primitive type. Returns: - A ComputationDataHandle representing the added conversion op. + A LocalOp representing the added conversion op. """ - return _wrap_data_handle( - self._client.ConvertElementType( - _unwrap_data_handle(operand), new_element_type)) + return self._client.ConvertElementType(operand, new_element_type) def GetShape(self, operand): - return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + return _wrap_shape(self._client.GetShape(operand)) def GetReturnValueShape(self): return _wrap_shape(self._client.GetReturnValueShape()) @@ -662,40 +686,35 @@ class ComputationBuilder(object): """Enqueues a Pad operation onto the computation. Args: - operand: ComputationDataHandle representing the array to pad. - padding_value: ComputationDataHandle representing the scalar pad value. + operand: LocalOp representing the array to pad. + padding_value: LocalOp representing the scalar pad value. padding_config: either an xla_data_pb2.PaddingConfig or a list of integer triples (edge_padding_low, edge_padding_high, interior_padding) representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added Pad op. + A LocalOp representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) - return _wrap_data_handle( - self._client.Pad(_unwrap_data_handle(operand), - _unwrap_data_handle(padding_value), - padding_config)) + return self._client.Pad(operand, padding_value, padding_config) def Reshape(self, operand, dimensions, new_sizes): """Enqueues a reshape op onto the computation. Args: - operand: ComputationDataHandle representing the array to be reshaped. + operand: LocalOp representing the array to be reshaped. dimensions: sequence of integers encoding the order in which dimensions are collapsed or None, in which case dimensions are flattened in order. new_sizes: sequence of integers encoding the new dimension sizes (shape). Returns: - A ComputationDataHandle representing the added Reshape op. + A LocalOp representing the added Reshape op. """ if dimensions is None: ndim = len(self.GetShape(operand).dimensions()) dimensions = tuple(range(ndim)) - return _wrap_data_handle( - self._client.Reshape( - _unwrap_data_handle(operand), dimensions, new_sizes)) + return self._client.Reshape(operand, dimensions, new_sizes) def CrossReplicaSum(self, operand): """CrossReplicaSum op. @@ -704,67 +723,56 @@ class ComputationBuilder(object): operand: the operand to sum across replica instances. Returns: - A ComputationDataHandle that has the sum of the value among all replicas. + A LocalOp that has the sum of the value among all replicas. """ - return _wrap_data_handle( - self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + return self._client.CrossReplicaSum(operand) def Collapse(self, operand, dimensions): """Collapse op.""" - return _wrap_data_handle( - self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + return self._client.Collapse(operand, dimensions) def Trans(self, operand): """Specialized matrix transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + return self._client.Transpose(operand, [1, 0]) def Transpose(self, operand, permutation): """Transpose op.""" - return _wrap_data_handle( - self._client.Transpose(_unwrap_data_handle(operand), permutation)) + return self._client.Transpose(operand, permutation) def Rev(self, operand, dimensions): """Rev op.""" - return _wrap_data_handle( - self._client.Rev(_unwrap_data_handle(operand), dimensions)) + return self._client.Rev(operand, dimensions) def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin """Clamp op.""" - return _wrap_data_handle( - self._client.Clamp(_unwrap_data_handle(min), - _unwrap_data_handle(operand), - _unwrap_data_handle(max))) + return self._client.Clamp(min, operand, max) def SelectAndScatter(self, operand, select, window_dimensions, window_strides, padding, source, init_value, scatter): """Select and scatter op, used by the gradient of ReduceWindow. Args: - operand: ComputationDataHandle for array of dimension N and type T over + operand: LocalOp for array of dimension N and type T over which the windows slide. select: Computation of type (T, T) -> Pred to apply to the elements of each window to indicate which element is selected. window_dimensions: sequence of N integers for dimensions of the window. window_strides: sequence of N integers for the strides of the window. padding: PaddingType representing either 'SAME' or 'VALID ' padding. - source: ComputationDataHandle for array of type T with values to scatter. - init_value: ComputationDataHandle of scalar type T for initial out value. + source: LocalOp for array of type T with values to scatter. + init_value: LocalOp of scalar type T for initial out value. scatter: Computation of type (T, T) -> T to apply to each scatter source element with its destination element. Returns: - A ComputationDataHandle representing the added SelectAndScatter op. + A LocalOp representing the added SelectAndScatter op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.SelectAndScatterWithGeneralPadding( - _unwrap_data_handle(operand), select.c_local_computation, - window_dimensions, window_strides, pads, - _unwrap_data_handle(source), _unwrap_data_handle(init_value), - scatter.c_local_computation)) + return self._client.SelectAndScatterWithGeneralPadding( + operand, select.c_local_computation, window_dimensions, window_strides, + pads, source, init_value, scatter.c_local_computation) def Select(self, pred, on_true, on_false): """Element-wise selection op. @@ -772,17 +780,13 @@ class ComputationBuilder(object): Constructs an output array from elements of two input arrays, based on the values of a predicate array. """ - return _wrap_data_handle( - self._client.Select( - _unwrap_data_handle(pred), - _unwrap_data_handle(on_true), - _unwrap_data_handle(on_false))) + return self._client.Select(pred, on_true, on_false) def Slice(self, operand, start_indices, limit_indices, strides=None): """Enqueues a slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_indices: iterable of N integers containing the starting indices of the slice for each dimension. limit_indices: iterable of N integers containing the ending indices @@ -791,207 +795,177 @@ class ComputationBuilder(object): each dimension. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ if strides is None: start_indices = list(start_indices) strides = [1] * len(start_indices) - return _wrap_data_handle( - self._client.Slice( - _unwrap_data_handle(operand), start_indices, limit_indices, - strides)) + return self._client.Slice(operand, start_indices, limit_indices, strides) def SliceInDim(self, operand, start_index, limit_index, stride, dimno): """Enqueues a slice-in-dimension operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. + operand: LocalOp for the N dimensional array to be sliced. start_index: an integer containing the start index of the slice. limit_index: an integer containing the end index of the slice. stride: an integer containing the stride size for the slice. dimno: an integer indicating the dimension along which to slice. Returns: - A ComputationDataHandle representing the added Slice op. + A LocalOp representing the added Slice op. """ - return _wrap_data_handle( - self._client.SliceInDim( - _unwrap_data_handle(operand), start_index, limit_index, stride, - dimno)) + return self._client.SliceInDim(operand, start_index, limit_index, stride, + dimno) def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be sliced. - start_indices: ComputationDataHandle for the 1D array of N integers + operand: LocalOp for the N dimensional array to be sliced. + start_indices: LocalOp for the 1D array of N integers containing the starting indices of the slice. slice_sizes: iterable of N integers containing the slice sizes in each dimension. Returns: - A ComputationDataHandle representing the added DynamicSlice op. + A LocalOp representing the added DynamicSlice op. """ - return _wrap_data_handle( - self._client.DynamicSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(start_indices), - slice_sizes)) + return self._client.DynamicSlice(operand, start_indices, slice_sizes) def DynamicUpdateSlice(self, operand, update, start_indices): """Enqueues a dynamic update slice operation onto the computation. Args: - operand: ComputationDataHandle for the N dimensional array to be updated. + operand: LocalOp for the N dimensional array to be updated. update: N dimensional array comprising the slice update. start_indices: Rank-1 array of N integers comprising the starting indices of the slice along each dimension. Returns: - A ComputationDataHandle representing the added DynamicUpdateSlice op. + A LocalOp representing the added DynamicUpdateSlice op. """ - return _wrap_data_handle( - self._client.DynamicUpdateSlice( - _unwrap_data_handle(operand), - _unwrap_data_handle(update), - _unwrap_data_handle(start_indices))) + return self._client.DynamicUpdateSlice(operand, update, start_indices) def Tuple(self, *ops): """Enqueues a tuple operation onto the computation. Args: - ops: a sequence of tuple operands (each a ComputationDataHandle). + ops: a sequence of tuple operands (each a LocalOp). Returns: - A ComputationDataHandle representing the added Tuple op. + A LocalOp representing the added Tuple op. """ - return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + return self._client.Tuple(ops) def GetTupleElement(self, tup, index): """Enqueues a 'get tuple element' operation onto the computation. Args: - tup: the tuple operand (a ComputationDataHandle). + tup: the tuple operand (a LocalOp). index: numeric index to select from the tuple. Returns: - A ComputationDataHandle representing the added GetTupleElement op. + A LocalOp representing the added GetTupleElement op. """ - return _wrap_data_handle( - self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + return self._client.GetTupleElement(tup, index) def Call(self, computation_to_apply, operands): """Enqueues a call operation onto the computation. Args: computation_to_apply: a Computation object. - operands: an iterable of ComputationDataHandle. The number and types of + operands: an iterable of LocalOp. The number and types of operands must match the arity of computation_to_apply. Returns: - A ComputationDataHandle representing the added call op. + A LocalOp representing the added call op. """ - return _wrap_data_handle( - self._client.Call(computation_to_apply.c_local_computation, - _unwrap_data_handles(operands))) + return self._client.Call(computation_to_apply.c_local_computation, operands) def Map(self, operands, computation_to_apply, dimensions, static_operands=()): """Enqueues a map operation onto the computation. Args: - operands: an iterable of ComputationDataHandle. + operands: an iterable of LocalOp. computation_to_apply: a Computation object. dimensions: dimensions over which to apply map the function. static_operands: auxiliary arguments passed to the applied computation. Returns: - A ComputationDataHandle representing the added Map op. + A LocalOp representing the added Map op. """ - return _wrap_data_handle( - self._client.Map( - _unwrap_data_handles(operands), - computation_to_apply.c_local_computation, - dimensions, - _unwrap_data_handles(static_operands))) + return self._client.Map(operands, computation_to_apply.c_local_computation, + dimensions, static_operands) def Reduce(self, operand, init_value, computation_to_apply, dimensions): """Enqueues a reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a Computation object - binary reduction function. dimensions: sequence of dimensions (integers) to reduce on. Returns: - A ComputationDataHandle representing the added Reduce op. + A LocalOp representing the added Reduce op. """ - return _wrap_data_handle( - self._client.Reduce( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - dimensions)) + return self._client.Reduce(operand, init_value, + computation_to_apply.c_local_computation, + dimensions) def ReduceWindow(self, operand, init_value, computation_to_apply, window_dimensions, window_strides, padding): """Enqueues a windowed reduction operation onto the computation. Args: - operand: reduction operand (ComputationDataHandle). - init_value: reduction initial value (ComputationDataHandle). + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). computation_to_apply: a binary reduction function (Computation). window_dimensions: dimensions of window (sequence of integers). window_strides: strides for window (sequence of integers). padding: PaddingType representing either 'SAME' or 'VALID' padding. Returns: - A ComputationDataHandle representing the added ReduceWindow op. + A LocalOp representing the added ReduceWindow op. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(operand).dimensions(), window_dimensions, window_strides) - return _wrap_data_handle( - self._client.ReduceWindowWithGeneralPadding( - _unwrap_data_handle(operand), - _unwrap_data_handle(init_value), - computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads)) + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. Args: - mu: A ComputationDataHandle to an F32 scalar specifying the mean. - sigma: A ComputationDataHandle to an F32 scalar specifying the standard + mu: A LocalOp to an F32 scalar specifying the mean. + sigma: A LocalOp to an F32 scalar specifying the standard deviation. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of F32 values. + Returns: a LocalOp to the generated array of F32 values. """ - shape = Shape(self.GetShape(mu).np_dtype, dims) - return _wrap_data_handle( - self._client.RngNormal( - _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + shape = Shape.array_shape(self.GetShape(mu).element_type(), dims) + return self._client.RngNormal(mu, sigma, shape) def RngUniform(self, a, b, dims): """Enqueues an RngUniform operation onto the computation. Args: - a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + a: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of b) specifying the low end of the interval [a, b) over which values are generated. - b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + b: a LocalOp to an F32, S32, or U32 scalar (consistent with the type of a) specifying the high end of the interval [a, b) over which values are generated. dims: A 1D array-like of nonnegative integers specifying the dimensions. - Returns: a ComputationDataHandle to the generated array of values with the + Returns: a LocalOp to the generated array of values with the same numeric type (F32, S32, or U32) as the arguments a and b. """ - shape = Shape(self.GetShape(a).np_dtype, dims) - return _wrap_data_handle( - self._client.RngUniform( - _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + shape = Shape.array_shape(self.GetShape(a).element_type(), dims) + return self._client.RngUniform(a, b, shape) def While(self, cond, body, init): """Enqueues a While operation onto the computation. @@ -999,112 +973,105 @@ class ComputationBuilder(object): Args: cond: a Computation for the loop condition, which has type T -> PRED body: a Computation for the loop body, which has type T -> T - init: a ComputationDataHandle for the initial parameter, which has type T + init: a LocalOp for the initial parameter, which has type T - Returns: a ComputationDataHandle representing the While operation. + Returns: a LocalOp representing the While operation. """ - return _wrap_data_handle( - self._client.While(cond.c_local_computation, - body.c_local_computation, - _unwrap_data_handle(init))) + return self._client.While(cond.c_local_computation, + body.c_local_computation, init) def Conditional(self, pred, true_operand, true_computation, false_operand, false_computation): """Enqueues a Conditional operation onto the computation. Args: - predicate: a ComputationDataHandle to test, which has scalar type PRED - true_operand: a ComputationDataHandle of type T_0 + predicate: a LocalOp to test, which has scalar type PRED + true_operand: a LocalOp of type T_0 true_computation: a Computation to apply to true_operand, type T_0 -> S false_operand: a ComputationDatahandle of type T_1 false_computation: a Computation to apply to false_operand, type T_1 -> S - Returns: a ComputationDataHandle representing the Conditional operation. + Returns: a LocalOp representing the Conditional operation. """ - return _wrap_data_handle( - self._client.Conditional( - _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), - true_computation.c_local_computation, - _unwrap_data_handle(false_operand), - false_computation.c_local_computation)) + return self._client.Conditional( + pred, true_operand, true_computation.c_local_computation, false_operand, + false_computation.c_local_computation) - def IsConstant(self, operand, num_parameters=0): - """Enqueues an IsConstant operation onto the computation. + def IsConstant(self, operand): + """Checks whether the given operand is a compile-time constant. Args: operand: a ComputationDataHandle to test. - num_parameters: optional int, number of computation parameters to treat as - constant (default 0). Returns: bool indicating whether `operand` is a compile-time constant, - meaning its value does not depend on parameters with index greater than or - equal to `num_parameters`. + meaning its value does not depend on any parametersor, or on stateful + operators such as `RngNormal` or `Infeed`. """ - return self._client.IsConstant(_unwrap_data_handle(operand), num_parameters) + return self._client.IsConstant(operand) + + def BuildConstantSubGraph(self, operand): + """Builds a constant sub graph. + + Args: + operand: a LocalOp to test. + Returns: a LocalComputation that is rooted on the given `operand` which is a + compile-time constant. + """ + return self._client.BuildConstantSubGraph(operand) def Dot(self, lhs, rhs): """Enqueues a dot operation onto the computation. Args: - lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. - rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + lhs: LocalOp for the rank 1 or rank 2 left-hand-side array. + rhs: LocalOp for the rank 1 or rank 2 right-hand-side array. - Returns: a ComputationDataHandle representing the Dot operation. + Returns: a LocalOp representing the Dot operation. """ - return _wrap_data_handle( - self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + return self._client.Dot(lhs, rhs) def DotGeneral(self, lhs, rhs, dimension_numbers): """Enqueues a general dot operation onto the computation. Args: - lhs: ComputationDataHandle for the left-hand-side array. - rhs: ComputationDataHandle for the right-hand-side array. + lhs: LocalOp for the left-hand-side array. + rhs: LocalOp for the right-hand-side array. dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of integers representing the dimensions to treat as contracting dimensions and batch dimensions on each input operand. - Returns: a ComputationDataHandle representing the DotGeneral operation. + Returns: a LocalOp representing the DotGeneral operation. """ if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) - return _wrap_data_handle( - self._client.DotGeneral( - _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), - dimension_numbers)) + return self._client.DotGeneral(lhs, rhs, dimension_numbers) def Conv(self, lhs, rhs, window_strides, padding): """Enqueues a Conv operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. - Returns: a ComputationDataHandle representing the Conv operation. + Returns: a LocalOp representing the Conv operation. """ pads = _convert_padding_type_to_pad_values( padding, self.GetShape(lhs).dimensions()[2:], self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - pads, - (), - (), - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), + (), dimension_numbers) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: - lhs: ComputationDataHandle for the rank N+2 array of inputs. - rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of kernel strides. padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. @@ -1114,14 +1081,9 @@ class ComputationBuilder(object): A ComputationdataHandle representing the added ConvWithGeneralPadding op. """ dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) - return _wrap_data_handle( - self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), - _unwrap_data_handle(rhs), - window_strides, - padding, - lhs_dilation, - rhs_dilation, - dimension_numbers)) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1151,15 +1113,14 @@ def _forward_methods_to_local_builder(): """Generate a forwarding method that wraps/unwraps data handles.""" def forward(self, *args, **kwargs): - unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + arg_list = list(args) - if is_binop and len(unwrapped_args) < 3: - unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + if is_binop and len(arg_list) < 3: + arg_list.append(kwargs.get('broadcast_dimensions', ())) - return _wrap_data_handle( - target_method( - self._client, # pylint: disable=protected-access - *unwrapped_args)) + return target_method( + self._client, # pylint: disable=protected-access + *arg_list) return forward diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index d97264ea640..c073c02040e 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -319,7 +319,7 @@ class LocalBufferTest(LocalComputationTest): def _Execute(self, c, arguments): compiled_c = c.Build().CompileWithExampleArguments(arguments) - arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments] + arg_buffers = [xla_client.LocalBuffer.from_pyval(arg) for arg in arguments] result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) return result_buffer.to_py() @@ -350,7 +350,7 @@ class LocalBufferTest(LocalComputationTest): c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) arg = NumpyArrayF32(1.11) compiled_c = c.Build().CompileWithExampleArguments([arg]) - arg_buffer = xla_client.LocalBuffer.from_py(arg) + arg_buffer = xla_client.LocalBuffer.from_pyval(arg) arg_buffer.delete() with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) @@ -1160,7 +1160,6 @@ class EmbeddedComputationsTest(LocalComputationTest): self._ExecuteAndCompareClose( c, expected=np.sum(input_array, axis=tuple(dims))) - _ReduceAndTest(0) _ReduceAndTest(0) _ReduceAndTest(0, 1) _ReduceAndTest(0, 2) @@ -1288,7 +1287,7 @@ class EmbeddedComputationsTest(LocalComputationTest): def testInfeedS32Values(self): to_infeed = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - c.Infeed(xla_client.Shape.from_numpy(to_infeed[0])) + c.Infeed(xla_client.Shape.from_pyval(to_infeed[0])) compiled_c = c.Build().CompileWithExampleArguments() for item in to_infeed: xla_client.transfer_to_infeed(item) @@ -1300,7 +1299,7 @@ class EmbeddedComputationsTest(LocalComputationTest): def testInfeedThenOutfeedS32(self): to_round_trip = NumpyArrayS32([1, 2, 3, 4]) c = self._NewComputation() - x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0])) + x = c.Infeed(xla_client.Shape.from_pyval(to_round_trip[0])) c.Outfeed(x) compiled_c = c.Build().CompileWithExampleArguments() @@ -1310,7 +1309,7 @@ class EmbeddedComputationsTest(LocalComputationTest): execution.start() xla_client.transfer_to_infeed(want) got = xla_client.transfer_from_outfeed( - xla_client.Shape.from_numpy(to_round_trip[0])) + xla_client.Shape.from_pyval(to_round_trip[0])) execution.join() self.assertEqual(want, got) diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ad3a28e1193..c289c84cff7 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -90,7 +90,7 @@ std::unique_ptr> MatmulArray2DImpl( Padding padding) { return ConvArray3DGeneralDimensionsDilated( lhs, rhs, kernel_stride, padding, 1, 1, - ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); + XlaBuilder::CreateDefaultConvDimensionNumbers(1)); } /*static*/ std::unique_ptr> @@ -140,7 +140,7 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( std::pair kernel_stride, Padding padding) { return ConvArray4DGeneralDimensions( lhs, rhs, kernel_stride, padding, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); } /* static */ std::unique_ptr> @@ -572,7 +572,8 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module("ReferenceUtil"); + HloModuleConfig config; + HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); HloEvaluator evaluator; diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 28d6a8c3fe8..2698ba7d79e 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -330,13 +330,14 @@ class ReferenceUtil { return result; } - // Slices with modulo-wrapping. + // Slices with index clamping template - static std::vector ModSlice1D(const tensorflow::gtl::ArraySlice& input, - int64 start, int64 size) { + static std::vector ClampSlice1D( + const tensorflow::gtl::ArraySlice& input, int64 start, int64 size) { + start = std::min(std::max(0, start), input.size() - size); std::vector result; for (int64 i = 0; i < size; ++i) { - result.push_back(input[(start + i) % input.size()]); + result.push_back(input[(start + i)]); } return result; } diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 977f8637873..0d56a9a477b 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -55,7 +55,7 @@ tf_cc_test( deps = [ ":grpc_stub", "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index b559ee4b5a3..313f11a9a95 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "grpc++/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_stub.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/lib/io/path.h" @@ -84,7 +84,7 @@ TEST_F(GRPCClientTestBase, ItsAlive) { } TEST_F(GRPCClientTestBase, AxpyTenValues) { - ComputationBuilder builder(client_.get(), "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); @@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) { TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer( computation, {}, nullptr)); - LiteralTestUtil::ExpectNear(*expected_literal, *result_literal, - ErrorSpec(0.0001)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal, + ErrorSpec(0.0001))); } } // namespace diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 414829d6e76..5f4dc6bd08f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -20,15 +20,15 @@ limitations under the License. namespace xla { /* static */ StatusOr> GRPCService::NewService( - perftools::gputools::Platform* platform) { + se::Platform* platform) { std::unique_ptr grpc_service(new GRPCService()); TF_ASSIGN_OR_RETURN(grpc_service->service_, ::xla::Service::NewService(platform)); return std::move(grpc_service); } -::grpc::Status DelegateRPC(std::function op) { - tensorflow::Status s = op(); +::grpc::Status DelegateRPC(std::function op) { + Status s = op(); return tensorflow::ToGrpcStatus(s); } @@ -75,6 +75,13 @@ namespace xla { [this, arg, result]() { return service_->Execute(arg, result); }); } +::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) { + return DelegateRPC( + [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); +} + ::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) { diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 7c9e484517e..50f02796f2d 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -29,7 +29,7 @@ class GRPCService : public grpc::XlaService::Service { // that the service should target. If platform is null then the default // platform is used. static StatusOr> NewService( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); ::grpc::Status Computation(::grpc::ServerContext* context, const ComputationRequest* arg, @@ -54,6 +54,10 @@ class GRPCService : public grpc::XlaService::Service { const ExecuteRequest* arg, ExecuteResponse* result) override; + ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, + const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; + ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, const ExecuteAsyncRequest* arg, ExecuteAsyncResponse* result) override; diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index e1f2b0abe39..620ac6cec4f 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -20,53 +20,49 @@ namespace xla { GRPCStub::~GRPCStub() = default; -tensorflow::Status MakeRPC( +Status MakeRPC( const std::function<::grpc::Status(::grpc::ClientContext*)>& rpc_method) { ::grpc::ClientContext context; ::grpc::Status s = rpc_method(&context); return tensorflow::FromGrpcStatus(s); } -tensorflow::Status GRPCStub::TransferToClient( - const TransferToClientRequest* request, - TransferToClientResponse* response) { +Status GRPCStub::TransferToClient(const TransferToClientRequest* request, + TransferToClientResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToClient(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToServer( - const TransferToServerRequest* request, - TransferToServerResponse* response) { +Status GRPCStub::TransferToServer(const TransferToServerRequest* request, + TransferToServerResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToServer(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferToInfeed( - const TransferToInfeedRequest* request, - TransferToInfeedResponse* response) { +Status GRPCStub::TransferToInfeed(const TransferToInfeedRequest* request, + TransferToInfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferToInfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::TransferFromOutfeed( - const TransferFromOutfeedRequest* request, - TransferFromOutfeedResponse* response) { +Status GRPCStub::TransferFromOutfeed(const TransferFromOutfeedRequest* request, + TransferFromOutfeedResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->TransferFromOutfeed(context, *request, response); }); } -tensorflow::Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, - ResetDeviceResponse* response) { +Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, + ResetDeviceResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ResetDevice(context, *request, response); }); } -tensorflow::Status GRPCStub::LoadComputationSnapshot( +Status GRPCStub::LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -74,28 +70,28 @@ tensorflow::Status GRPCStub::LoadComputationSnapshot( }); } -tensorflow::Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { +Status GRPCStub::Execute(const ExecuteRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Execute(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) { +Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteGraph(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteParallel( - const ExecuteParallelRequest* request, ExecuteParallelResponse* response) { +Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, + ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteParallel(context, *request, response); }); } -tensorflow::Status GRPCStub::ExecuteGraphParallel( +Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -103,38 +99,35 @@ tensorflow::Status GRPCStub::ExecuteGraphParallel( }); } -tensorflow::Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { +Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, + ExecuteAsyncResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ExecuteAsync(context, *request, response); }); } -tensorflow::Status GRPCStub::WaitForExecution( - const WaitForExecutionRequest* request, - WaitForExecutionResponse* response) { +Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, + WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->WaitForExecution(context, *request, response); }); } -tensorflow::Status GRPCStub::DeconstructTuple( - const DeconstructTupleRequest* request, - DeconstructTupleResponse* response) { +Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, + DeconstructTupleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->DeconstructTuple(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationStats( - const ComputationStatsRequest* request, - ComputationStatsResponse* response) { +Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, + ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetComputationStats(context, *request, response); }); } -tensorflow::Status GRPCStub::GetComputationGraphStats( +Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -142,81 +135,77 @@ tensorflow::Status GRPCStub::GetComputationGraphStats( }); } -tensorflow::Status GRPCStub::GetComputationShape( - const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { +Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, + GetComputationShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetComputationShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetShape(const GetShapeRequest* request, - GetShapeResponse* response) { +Status GRPCStub::GetShape(const GetShapeRequest* request, + GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetShape(context, *request, response); }); } -tensorflow::Status GRPCStub::GetDeviceHandles( - const GetDeviceHandlesRequest* request, - GetDeviceHandlesResponse* response) { +Status GRPCStub::GetDeviceHandles(const GetDeviceHandlesRequest* request, + GetDeviceHandlesResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetDeviceHandles(context, *request, response); }); } -tensorflow::Status GRPCStub::CreateChannelHandle( - const CreateChannelHandleRequest* request, - CreateChannelHandleResponse* response) { +Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, + CreateChannelHandleResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateChannelHandle(context, *request, response); }); } // Methods used by ComputationBuilder. -tensorflow::Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { +Status GRPCStub::Computation(const ComputationRequest* request, + ComputationResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Computation(context, *request, response); }); } -tensorflow::Status GRPCStub::Op(const OpRequest* request, - OpResponse* response) { +Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->CreateOp(context, *request, response); }); } -tensorflow::Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { +Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, + GetLocalShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->GetLocalShape(context, *request, response); }); } -tensorflow::Status GRPCStub::SetReturnValue( - const SetReturnValueRequest* request, SetReturnValueResponse* responses) { +Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, + SetReturnValueResponse* responses) { return MakeRPC([this, request, responses](::grpc::ClientContext* context) { return grpc_stub_->SetReturnValue(context, *request, responses); }); } -tensorflow::Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { +Status GRPCStub::IsConstant(const IsConstantRequest* request, + IsConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->IsConstant(context, *request, response); }); } -tensorflow::Status GRPCStub::ComputeConstant( - const ComputeConstantRequest* request, ComputeConstantResponse* response) { +Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, + ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->ComputeConstant(context, *request, response); }); } -tensorflow::Status GRPCStub::ComputeConstantGraph( +Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -225,17 +214,16 @@ tensorflow::Status GRPCStub::ComputeConstantGraph( } // Methods used by Computation. -tensorflow::Status GRPCStub::SnapshotComputation( - const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { +Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, + SnapshotComputationResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->SnapshotComputation(context, *request, response); }); } // Methods used by GlobalData. -tensorflow::Status GRPCStub::Unregister(const UnregisterRequest* request, - UnregisterResponse* response) { +Status GRPCStub::Unregister(const UnregisterRequest* request, + UnregisterResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { return grpc_stub_->Unregister(context, *request, response); }); diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index fd9810d4f1a..5906d45769b 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -28,105 +28,90 @@ class GRPCStub : public ServiceInterface { explicit GRPCStub(grpc::XlaService::Stub* stub) : grpc_stub_(stub) {} ~GRPCStub() override; - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; - tensorflow::Status LoadComputationSnapshot( + Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) override; - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* request, - ExecuteResponse* response) override; + Status ExecuteGraph(const ExecuteGraphRequest* request, + ExecuteResponse* response) override; - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* request, - ExecuteParallelResponse* response) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, + ExecuteParallelResponse* response) override; - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* request, - ComputationStatsResponse* response) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, + ComputationStatsResponse* response) override; - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; + Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Methods used by ComputationBuilder. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; + Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status Op(const OpRequest* arg, OpResponse* result) override; + Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Methods used by Computation. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; + Status SnapshotComputation(const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) override; // Methods used by GlobalData. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; grpc::XlaService::Stub* service() { return grpc_stub_; } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ddc099807d3..0a50f00919a 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -12,6 +12,7 @@ package_group( ], ) +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") @@ -26,6 +27,7 @@ xla_proto_library( xla_proto_library( name = "hlo_proto", srcs = ["hlo.proto"], + visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) @@ -200,7 +202,22 @@ tf_cc_test( cc_library( name = "hlo_evaluator", - srcs = ["hlo_evaluator.cc"], + srcs = [ + "hlo_evaluator.cc", + "hlo_evaluator_typed_visitor.h", + "hlo_evaluator_typed_visitor_bfloat16.cc", + "hlo_evaluator_typed_visitor_bool.cc", + "hlo_evaluator_typed_visitor_complex64.cc", + "hlo_evaluator_typed_visitor_double.cc", + "hlo_evaluator_typed_visitor_float.cc", + "hlo_evaluator_typed_visitor_half.cc", + "hlo_evaluator_typed_visitor_int32.cc", + "hlo_evaluator_typed_visitor_int64.cc", + "hlo_evaluator_typed_visitor_int8.cc", + "hlo_evaluator_typed_visitor_uint32.cc", + "hlo_evaluator_typed_visitor_uint64.cc", + "hlo_evaluator_typed_visitor_uint8.cc", + ], hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", @@ -233,7 +250,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -359,6 +376,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -369,6 +387,7 @@ tf_cc_test( ":hlo_matchers", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -755,6 +774,7 @@ cc_library( ":hlo", ":hlo_execution_profile", ":hlo_graph_dumper", + ":hlo_proto", ":pool", ":session_proto", ":shaped_buffer", @@ -778,6 +798,7 @@ cc_library( srcs = ["compiler.cc"], hdrs = ["compiler.h"], deps = [ + ":buffer_value", ":executable", ":hlo", ":hlo_module_config", @@ -834,7 +855,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -915,33 +935,6 @@ tf_cc_test( ], ) -cc_library( - name = "liveness_util", - srcs = ["liveness_util.cc"], - hdrs = ["liveness_util.h"], - deps = [ - ":hlo", - ":hlo_dataflow_analysis", - ":logical_buffer", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - ], -) - -tf_cc_test( - name = "liveness_util_test", - srcs = ["liveness_util_test.cc"], - deps = [ - ":hlo", - ":liveness_util", - ":tuple_points_to_analysis", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - ], -) - cc_library( name = "buffer_liveness", srcs = [ @@ -953,7 +946,6 @@ cc_library( deps = [ ":hlo", ":hlo_ordering", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -990,6 +982,7 @@ cc_library( ], deps = [ ":buffer_liveness", + ":buffer_value_containers", ":heap_simulator", ":hlo", ":hlo_proto", @@ -1012,6 +1005,7 @@ tf_cc_test( srcs = ["buffer_assignment_test.cc"], deps = [ ":buffer_assignment", + ":buffer_value", ":call_graph", ":computation_tracker", ":copy_insertion", @@ -1044,7 +1038,6 @@ cc_library( ":hlo_dataflow_analysis", ":hlo_proto", ":hlo_value", - ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1077,11 +1070,11 @@ cc_library( srcs = ["heap_simulator.cc"], hdrs = ["heap_simulator.h"], deps = [ + ":buffer_value", + ":buffer_value_containers", ":hlo", ":hlo_ordering", ":hlo_proto", - ":liveness_util", - ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1093,10 +1086,11 @@ tf_cc_test( name = "heap_simulator_test", srcs = ["heap_simulator_test.cc"], deps = [ + ":buffer_value", ":heap_simulator", ":hlo", ":hlo_ordering", - ":logical_buffer", + ":hlo_value", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:status_macros", @@ -1161,6 +1155,7 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ + ":buffer_value", ":hlo", ":hlo_ordering", ":hlo_scheduling", @@ -1204,6 +1199,7 @@ tf_cc_test( ":instruction_fusion", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1220,6 +1216,23 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_creation_utils_test", + srcs = ["hlo_creation_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_evaluator", + ":hlo_matchers", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "batchnorm_expander", srcs = ["batchnorm_expander.cc"], @@ -1227,13 +1240,11 @@ cc_library( deps = [ ":hlo", ":hlo_pass", - ":hlo_query", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], @@ -1283,6 +1294,7 @@ cc_library( ":hlo_creation_utils", ":hlo_pass", ":hlo_query", + ":pattern_matcher", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1309,6 +1321,43 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "batch_dot_simplification", + srcs = ["batch_dot_simplification.cc"], + hdrs = ["batch_dot_simplification.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "batch_dot_simplification_test", + srcs = ["batch_dot_simplification_test.cc"], + deps = [ + ":batch_dot_simplification", + ":hlo", + ":hlo_matchers", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", @@ -1643,10 +1692,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1678,6 +1727,8 @@ tf_cc_test( ":hlo_execution_profile", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", ], ) @@ -1727,11 +1778,38 @@ tf_cc_test( ], ) +cc_library( + name = "buffer_value", + srcs = ["buffer_value.cc"], + hdrs = ["buffer_value.h"], + deps = [ + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "buffer_value_containers", + hdrs = ["buffer_value_containers.h"], + deps = [ + ":buffer_value", + ":logical_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + cc_library( name = "logical_buffer", srcs = ["logical_buffer.cc"], hdrs = ["logical_buffer.h"], deps = [ + ":buffer_value", ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:shape_util", @@ -1747,6 +1825,7 @@ cc_library( srcs = ["hlo_value.cc"], hdrs = ["hlo_value.h"], deps = [ + ":buffer_value", ":hlo", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -1799,6 +1878,44 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_value", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo", + ":hlo_liveness_analysis", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_buffer", srcs = ["hlo_buffer.cc"], @@ -1935,10 +2052,12 @@ cc_library( deps = [ ":computation_layout", ":hlo", + ":hlo_dce", ":hlo_graph_dumper", ":hlo_pass", ":logical_buffer", ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1962,7 +2081,6 @@ cc_library( ":hlo_graph_dumper", ":hlo_ordering", ":hlo_pass", - ":liveness_util", ":logical_buffer", ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", @@ -2009,11 +2127,30 @@ cc_library( ], ) +cc_library( + name = "hlo_module_dce", + srcs = ["hlo_module_dce.cc"], + hdrs = ["hlo_module_dce.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_liveness_analysis", + ":hlo_pass", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_library( name = "hlo_verifier", srcs = ["hlo_verifier.cc"], hdrs = ["hlo_verifier.h"], deps = [ + ":hlo", ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", @@ -2043,13 +2180,13 @@ cc_library( hdrs = ["hlo_rematerialization.h"], deps = [ ":buffer_liveness", + ":buffer_value", ":call_graph", ":flatten_call_graph", ":hlo", ":hlo_dce", ":hlo_ordering", ":hlo_scheduling", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -2097,6 +2234,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_module_dce_test", + srcs = ["hlo_module_dce_test.cc"], + deps = [ + ":hlo", + ":hlo_module_dce", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], @@ -2188,6 +2346,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2254,8 +2413,14 @@ tf_cc_test( cc_library( name = "device_memory_allocator", - srcs = ["device_memory_allocator.cc"], - hdrs = ["device_memory_allocator.h"], + srcs = [ + "device_memory_allocator.cc", + "owning_device_memory.cc", + ], + hdrs = [ + "device_memory_allocator.h", + "owning_device_memory.h", + ], deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -2290,6 +2455,24 @@ cc_library( ], ) +xla_test( + name = "elemental_ir_emitter_test", + srcs = ["elemental_ir_emitter_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + cc_library( name = "hlo_module_config", srcs = ["hlo_module_config.cc"], @@ -2361,7 +2544,6 @@ tf_cc_test( srcs = ["hlo_tfgraph_builder_test.cc"], deps = [ ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:protos_all_cc", @@ -2414,6 +2596,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", ], ) @@ -2423,6 +2606,7 @@ tf_cc_test( srcs = ["transpose_folding_test.cc"], deps = [ ":hlo", + ":hlo_matchers", ":shape_inference", ":transpose_folding", "//tensorflow/compiler/xla:literal_util", @@ -2430,10 +2614,11 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2466,7 +2651,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -2575,7 +2760,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", - "@com_google_absl//absl/memory", ], ) @@ -2664,6 +2848,33 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_constant_sinking", + srcs = ["while_loop_constant_sinking.cc"], + hdrs = ["while_loop_constant_sinking.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_constant_sinking_test", + srcs = ["while_loop_constant_sinking_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_constant_sinking", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) + cc_library( name = "despecializer", srcs = ["despecializer.cc"], @@ -2690,3 +2901,30 @@ cc_library( "//tensorflow/core:lib", ], ) + +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", + ], +) + +tf_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":hlo_matchers", + ":indexed_array_analysis", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 6cb1bd56695..f732ed8f398 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -44,8 +45,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { + namespace { +namespace m = match; + // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -88,25 +92,6 @@ bool ReshapeIsBitcast( valid_bitcast_callback(operand->shape(), reshape->shape()); } -// Adds a scalar computation to the module to enable optimizations with dot -// converting into reduction. -HloComputation* CreateScalarBinaryComputation(HloModule* module, - PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - HloComputation* scalar_computation = - module->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_computation; -} -} // namespace - // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain // algebraic expressions to simplified forms. Note: This only supports // simplifications that simply look at the operands of an instruction. For the @@ -215,8 +200,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { HloInstruction* zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloComputation* AddReduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + HloComputation* AddReduce_computation = GetOrCreateScalarAddComputation(); Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); return computation_->AddInstruction(HloInstruction::CreateReduce( shape, hlo, zero, {dim}, AddReduce_computation)); @@ -286,6 +270,26 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + StatusOr OptimizeDotOfGather(HloInstruction* dot); + + HloComputation* GetOrCreateScalarAddComputation() { + if (scalar_add_computation_) { + return scalar_add_computation_; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(F32, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation_ = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation_; + } + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -304,8 +308,13 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable convolution simplification on platforms where it causes a slowdown. bool enable_conv_simplification_; + + // Cached computation for adding two scalar F32. + HloComputation* scalar_add_computation_ = nullptr; }; +} // namespace + bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, @@ -350,8 +359,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( } Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { - auto lhs = add->mutable_operand(0); - auto rhs = add->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(add, m::Add(m::Op(&lhs), m::Op(&rhs)))); + // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { @@ -366,7 +376,7 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // Canonicalization: Put constants on the right. This makes the reassociation // rules below simpler. VLOG(10) << "trying transform [Const + A => A + Const]"; - if (lhs->IsConstant() && !rhs->IsConstant()) { + if (Match(add, m::Add(m::Constant(), m::NonConstant()))) { return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); @@ -379,16 +389,13 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { // (A + C1) + (B + C2) => A + B + (C1 + C2). // VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; - if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && - !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { - auto* c1 = lhs->mutable_operand(1); - auto* c2 = rhs; - + HloInstruction *a, *c1, *c2; + if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2)))) { TF_ASSIGN_OR_RETURN(auto* sum_of_constants, MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); return ReplaceWithNewInstruction( - add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, - lhs->mutable_operand(0), + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, sum_of_constants)); } @@ -397,11 +404,11 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { // If a bitcast feeds a bitcast, make it a single bitcast. - if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) { + HloInstruction* op; + if (Match(bitcast, m::Bitcast(m::Bitcast(m::Op(&op))))) { return ReplaceWithNewInstruction( - bitcast, HloInstruction::CreateUnary( - bitcast->shape(), HloOpcode::kBitcast, - bitcast->mutable_operand(0)->mutable_operand(0))); + bitcast, + HloInstruction::CreateUnary(bitcast->shape(), HloOpcode::kBitcast, op)); } // All bitcasts can be eliminated (assuming layout constraints are // satisified). @@ -418,11 +425,10 @@ Status AlgebraicSimplifierVisitor::HandleBitcastConvert( Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (copy->operand(0)->opcode() == HloOpcode::kCopy) { + HloInstruction* op; + if (Match(copy, m::Copy(m::Copy(m::Op(&op))))) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary( - copy->shape(), HloOpcode::kCopy, - copy->mutable_operand(0)->mutable_operand(0))); + copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, op)); } // All copies can be eliminated (assuming layout constraints are satisified). ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); @@ -462,12 +468,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } else if (operands.size() == 2) { // A binary concat with a broadcasted scalar as an operand can be converted // into a pad which is simpler to fold into other operations. - bool is_effective_low_pad = - operands[0]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[0]->operand(0)->shape()); - bool is_effective_high_pad = - operands[1]->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(operands[1]->operand(0)->shape()); + bool is_effective_low_pad = Match( + operands[0], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); + bool is_effective_high_pad = Match( + operands[1], m::Broadcast(m::Op().WithShape(m::Shape().IsScalar()))); if (!is_effective_low_pad && !is_effective_high_pad) { return Status::OK(); } @@ -499,13 +503,13 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( } static HloInstruction* BuildTupleConstant(HloComputation* computation, - const Literal& literal) { + const LiteralSlice& literal) { if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { elems.push_back( - BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); + BuildTupleConstant(computation, LiteralSlice(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { @@ -537,8 +541,8 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { } Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { - auto lhs = sub->mutable_operand(0); - auto rhs = sub->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { @@ -547,7 +551,7 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { // Canonicalize subtraction of a constant to addition. VLOG(10) << "trying transform [A - Const => A + (-Const)]"; - if (rhs->IsConstant() && !lhs->IsConstant()) { + if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { HloInstruction* negative_const = computation_->AddInstruction( HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); return ReplaceWithNewInstruction( @@ -559,56 +563,53 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { } Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { - auto lhs = divide->mutable_operand(0); - auto rhs = divide->mutable_operand(1); + Shape* shape; + HloInstruction *a, *b, *c, *d; + CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b)))); // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); - if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { + if (IsAll(b, 1) && ReplaceInstructionIfSameShape(divide, a)) { return Status::OK(); } // exp(A)/exp(B) => exp(A-B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b))) + .WithShape(m::Shape(&shape)))) { VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString(); - HloInstruction* subtract = - computation_->AddInstruction(HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + HloInstruction* subtract = computation_->AddInstruction( + HloInstruction::CreateBinary(*shape, HloOpcode::kSubtract, a, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, - subtract)); + divide, HloInstruction::CreateUnary(*shape, HloOpcode::kExp, subtract)); } // A/exp(B) => A*exp(-B) - if (rhs->opcode() == HloOpcode::kExp) { + if (Match(divide, m::Divide(m::Op(&a), m::Exp(m::Op(&b))))) { VLOG(10) << "transform [A/exp(B) => A*exp(-B)]: " << divide->ToString(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(0))); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(divide->shape(), HloOpcode::kNegate, b)); HloInstruction* new_exp = computation_->AddInstruction( HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp, negate)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_exp)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, new_exp)); } // A/pow(B,C) => A*pow(B,-C) - if (rhs->opcode() == HloOpcode::kPower) { + if (Match(divide, m::Divide(m::Op(&a), m::Power(m::Op(&b), m::Op(&c))))) { VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); // The output shape of the created negate operator should be the same as the // input. - const Shape& negate_shape = rhs->operand(1)->shape(); - HloInstruction* negate = - computation_->AddInstruction(HloInstruction::CreateUnary( - negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); + const Shape& negate_shape = c->shape(); + HloInstruction* negate = computation_->AddInstruction( + HloInstruction::CreateUnary(negate_shape, HloOpcode::kNegate, c)); // And the power operator should retain the output shape of the old one. - const Shape& new_power_shape = rhs->shape(); - HloInstruction* new_power = computation_->AddInstruction( - HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, - rhs->mutable_operand(0), negate)); + const Shape& new_power_shape = b->shape(); + HloInstruction* new_power = + computation_->AddInstruction(HloInstruction::CreateBinary( + new_power_shape, HloOpcode::kPower, b, negate)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, new_power)); + divide->shape(), HloOpcode::kMultiply, a, new_power)); } // Simplifying integral division would produce unexpected results. @@ -620,28 +621,24 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // // (Backends can do this transformation, but generally only if the constant is // a scalar.) - if (lhs->opcode() != HloOpcode::kConstant && - rhs->opcode() == HloOpcode::kConstant) { + if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { HloInstruction* one = computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::One(lhs->shape().element_type()).CloneToUnique())); - HloInstruction* inverse = - computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kDivide, one, rhs)); + Literal::One(a->shape().element_type()).CloneToUnique())); + HloInstruction* inverse = computation_->AddInstruction( + HloInstruction::CreateBinary(b->shape(), HloOpcode::kDivide, one, b)); return ReplaceWithNewInstruction( - divide, HloInstruction::CreateBinary( - divide->shape(), HloOpcode::kMultiply, lhs, inverse)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kMultiply, a, inverse)); } // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) - if (lhs->opcode() == HloOpcode::kDivide && - rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(auto a_times_d, MakeBinaryHlo(HloOpcode::kMultiply, - lhs->mutable_operand(0), - rhs->mutable_operand(1))); - TF_ASSIGN_OR_RETURN(auto b_times_c, MakeBinaryHlo(HloOpcode::kMultiply, - lhs->mutable_operand(1), - rhs->mutable_operand(0))); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), + m::Divide(m::Op(&c), m::Op(&d))))) { + TF_ASSIGN_OR_RETURN(auto a_times_d, + MakeBinaryHlo(HloOpcode::kMultiply, a, d)); + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kDivide, a_times_d, b_times_c)); @@ -649,24 +646,21 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { } // (A / B) / C => A / (B * C) - if (lhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN( - auto b_times_c, - MakeBinaryHlo(HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Op(&c)))) { + TF_ASSIGN_OR_RETURN(auto b_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, b, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - lhs->mutable_operand(0), b_times_c)); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a, b_times_c)); } // A / (B / C) => (A*C) / B - if (rhs->opcode() == HloOpcode::kDivide) { - TF_ASSIGN_OR_RETURN(auto a_times_c, MakeBinaryHlo(HloOpcode::kMultiply, lhs, - rhs->mutable_operand(1))); + if (Match(divide, m::Divide(m::Op(&a), m::Divide(m::Op(&b), m::Op(&c))))) { + TF_ASSIGN_OR_RETURN(auto a_times_c, + MakeBinaryHlo(HloOpcode::kMultiply, a, c)); return ReplaceWithNewInstruction( - divide, - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kDivide, - a_times_c, rhs->mutable_operand(0))); + divide, HloInstruction::CreateBinary(divide->shape(), + HloOpcode::kDivide, a_times_c, b)); } return Status::OK(); @@ -674,8 +668,8 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( HloInstruction* dot) { - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); int64 lhs_collapsing_dim = dot->dot_dimension_numbers().lhs_contracting_dimensions(0); if (lhs->IsRank2Transpose()) { @@ -792,8 +786,8 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); - HloInstruction* lhs = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, @@ -922,9 +916,137 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( return add_result; } +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.rhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0 || + dnums.rhs_batch_dimensions_size() != 0 || + dot->shape().dimensions_size() != 2) { // dot output 2D + VLOG(10) << "DotOfGather: Can only optimize 2D, non-batch dot operations."; + return nullptr; + } + + // Optimize either dot(DS(ctA), ctB)) or dot(ctB, DS(ctA)). + // Currently a Gather is a DynamicSlice. + auto is_dynamic_slice_constant_combination = + [](HloInstruction* a, HloInstruction* b, int a_contracting_dimension) { + // First operand is a DynamicSlice(Constant). + if (a->opcode() != HloOpcode::kDynamicSlice) { + return false; + } + auto* dynamic_slice_op = a->operand(0); + if (dynamic_slice_op->opcode() != HloOpcode::kConstant) { + return false; + } + // Second operand is a Constant. + if (b->opcode() != HloOpcode::kConstant) { + return false; + } + // The DynamicSlice output is a vector. + const Shape& dynamic_slice_shape = a->shape(); + if (dynamic_slice_shape.dimensions(1 - a_contracting_dimension) != 1) { + return false; + } + // Constant size is the same before and after slice in the contracting + // dimension, otherwise we either must precompute for all possible slice + // indices or dot is invalid. + const Shape& dynamic_slice_op_shape = dynamic_slice_op->shape(); + if (dynamic_slice_op_shape.dimensions(a_contracting_dimension) != + dynamic_slice_shape.dimensions(a_contracting_dimension)) { + return false; + } + return true; + }; + + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + int rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + + if (!is_dynamic_slice_constant_combination( + lhs, rhs, /*a_contracting_dimension=*/lhs_contracting_dimension) && + !is_dynamic_slice_constant_combination( + rhs, lhs, /*a_contracting_dimension=*/rhs_contracting_dimension)) { + VLOG(10) << "DotOfGather: Can only optimize dot(DS(ctA), ctB)) or " + "dot(ctB, DS(ctA)), where the two constants have equal " + "contracting dimensions."; + return nullptr; + } + + // LHS is DynamicSlice: + // input: dot(DS(ctA), ctB)) + // where DS(ctA) = DS({M x K}, {start, 0}, {1, K}) and ctB = {K x N}. + // => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {start, 0}, {1, N}) => {1 x N}. + + // RHS is DynamicSlice: + // input: dot(ctA, DS(ctB)) + // where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, start}, {K, 1}). + // => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. + // output: DS(dot(ctA, ctB)) + // => output dimensions: DS ({M x N}, {0, start}, {M, 1}) => {M x 1}. + + bool lhs_is_dynamic_slice = lhs->opcode() == HloOpcode::kDynamicSlice; + + // ctA: + HloInstruction* left_operand = + lhs_is_dynamic_slice ? lhs->mutable_operand(0) : lhs; + // ctB: + HloInstruction* right_operand = + lhs_is_dynamic_slice ? rhs : rhs->mutable_operand(0); + // Build ctA x ctB. + const int m = left_operand->shape().dimensions(1 - lhs_contracting_dimension); + const int n = + right_operand->shape().dimensions(1 - rhs_contracting_dimension); + auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); + auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( + memoized_shape, left_operand, right_operand, dnums)); + // Get pair {start, 0} or {0, start}. + HloInstruction* original_start_indices = + lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); + // Position of start: + int index_of_non_zero_start = lhs_is_dynamic_slice + ? 1 - lhs_contracting_dimension + : 1 - rhs_contracting_dimension; + // Position of zero: + int index_of_zero_start = 1 - index_of_non_zero_start; + + // Slice out start and 0 components and reorder if necessary. + auto indices_type = original_start_indices->shape().element_type(); + Shape s_shape = ShapeUtil::MakeShape(indices_type, {1}); + Shape d_shape = ShapeUtil::MakeShape(indices_type, {2}); + HloInstruction* non_zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_non_zero_start}, + {index_of_non_zero_start + 1}, {1})); + HloInstruction* zero_start = + computation_->AddInstruction(HloInstruction::CreateSlice( + s_shape, original_start_indices, {index_of_zero_start}, + {index_of_zero_start + 1}, {1})); + HloInstruction* new_start_indices = + lhs_is_dynamic_slice + ? computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {non_zero_start, zero_start}, 0)) + : computation_->AddInstruction(HloInstruction::CreateConcatenate( + d_shape, {zero_start, non_zero_start}, 0)); + + // Build DynamicSlice(ctA x ctB). + const int new_slice_m = lhs_is_dynamic_slice ? 1 : m; + const int new_slice_n = lhs_is_dynamic_slice ? n : 1; + auto* memoized_lookup = + computation_->AddInstruction(HloInstruction::CreateDynamicSlice( + dot->shape(), memoized_inst, new_start_indices, + {new_slice_m, new_slice_n})); + + return memoized_lookup; +} + Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(dot, m::Dot(m::Op(&lhs), m::Op(&rhs)))); // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or // below. @@ -951,6 +1073,17 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { return ReplaceInstruction(dot, dot_of_concat_optimized); } + // Simplify dot(ConstA, Gather(Index, ConstB)) to: + // Gather(Index, dot*(ConstA, ConstB)), where dot* is an appropriately + // batched version of dot. + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_gather_optimized, + OptimizeDotOfGather(dot)); + if (dot_of_gather_optimized) { + VLOG(10) << "Replaced dot(constA, gather(i, constB)) with " + "gather(i, dot*(constA, constB))"; + return ReplaceInstruction(dot, dot_of_gather_optimized); + } + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { TF_ASSIGN_OR_RETURN(bool did_strength_reduction, HandleDotStrengthReduction(dot)); @@ -976,8 +1109,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { } Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { - auto lhs = multiply->mutable_operand(0); - auto rhs = multiply->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(multiply, m::Multiply(m::Op(&lhs), m::Op(&rhs)))); // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { @@ -990,10 +1123,9 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } // exp(A) * exp(B) => exp(A+B) - if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) { + if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( - multiply->shape(), HloOpcode::kAdd, lhs->mutable_operand(0), - rhs->mutable_operand(0))); + multiply->shape(), HloOpcode::kAdd, lhs, rhs)); return ReplaceWithNewInstruction( multiply, HloInstruction::CreateUnary(multiply->shape(), HloOpcode::kExp, add)); @@ -1004,20 +1136,19 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); - auto operand = log->mutable_operand(0); - if (operand->opcode() == HloOpcode::kExp && - ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { + HloInstruction *a, *b; + if (Match(log, m::Log(m::Exp(m::Op(&a)))) && + ReplaceInstructionIfSameShape(log, a)) { return Status::OK(); } // ln(pow(A,B)) => B*ln(A) - if (operand->opcode() == HloOpcode::kPower) { - auto new_log = computation_->AddInstruction(HloInstruction::CreateUnary( - log->shape(), HloOpcode::kLog, operand->mutable_operand(0))); + if (Match(log, m::Log(m::Power(m::Op(&a), m::Op(&b))))) { + auto new_log = computation_->AddInstruction( + HloInstruction::CreateUnary(log->shape(), HloOpcode::kLog, a)); return ReplaceWithNewInstruction( - log, - HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, - new_log, operand->mutable_operand(1))); + log, HloInstruction::CreateBinary(log->shape(), HloOpcode::kMultiply, + new_log, b)); } return Status::OK(); @@ -1120,7 +1251,8 @@ bool OutputIsSubsetOfOperandElements(HloInstruction* instruction, } // namespace Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { - auto operand = broadcast->mutable_operand(0); + HloInstruction* operand; + CHECK(Match(broadcast, m::Broadcast(m::Op(&operand)))); auto dims = broadcast->dimensions(); // A degenerate broadcast of a reshape that does not change the number of // elements can be replaced by a reshape. @@ -1231,30 +1363,28 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { // Complex(Real(c), Imag(c)) -> c Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { - auto real = complex->mutable_operand(0); - auto imag = complex->mutable_operand(1); - if (real->opcode() == HloOpcode::kReal && - imag->opcode() == HloOpcode::kImag && - real->operand(0) == imag->operand(0)) { - return ReplaceInstruction(complex, real->mutable_operand(0)); + HloInstruction *c0, *c1; + if (Match(complex, m::Complex(m::Real(m::Op(&c0)), m::Imag(m::Op(&c1)))) && + c0 == c1) { + return ReplaceInstruction(complex, c0); } return Status::OK(); } // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { - auto operand = real->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(real, operand->mutable_operand(0)); + HloInstruction* op; + if (Match(real, m::Real(m::Complex(m::Op(&op), m::Op())))) { + return ReplaceInstruction(real, op); } return Status::OK(); } // Imag(Complex(r, i)) -> i Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { - auto operand = imag->mutable_operand(0); - if (operand->opcode() == HloOpcode::kComplex) { - return ReplaceInstruction(imag, operand->mutable_operand(1)); + HloInstruction* op; + if (Match(imag, m::Imag(m::Complex(m::Op(), m::Op(&op))))) { + return ReplaceInstruction(imag, op); } return Status::OK(); } @@ -1351,8 +1481,8 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); - auto lhs = power->mutable_operand(0); - auto rhs = power->mutable_operand(1); + HloInstruction *lhs, *rhs; + CHECK(Match(power, m::Power(m::Op(&lhs), m::Op(&rhs)))); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( Literal::One(power->shape().element_type()).CloneToUnique()); @@ -1372,9 +1502,10 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { } // pow(exp(A),B) => exp(A*B) - if (lhs->opcode() == HloOpcode::kExp) { + HloInstruction *a, *b; + if (Match(power, m::Power(m::Exp(m::Op(&a)), m::Op(&b)))) { auto a_times_b = computation_->AddInstruction(HloInstruction::CreateBinary( - power->shape(), HloOpcode::kMultiply, lhs->operands()[0], rhs)); + power->shape(), HloOpcode::kMultiply, a, b)); return ReplaceWithNewInstruction( power, HloInstruction::CreateUnary(power->shape(), HloOpcode::kExp, a_times_b)); @@ -1424,7 +1555,6 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -// TODO(b/74536353): do this simplification for BroadcastDimOne as well. StatusOr AlgebraicSimplifierVisitor:: TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( HloInstruction* reshape_or_broadcast) { @@ -1707,7 +1837,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { HloInstruction::CreateReshape(reduce->shape(), arg)); return ReplaceWithNewInstruction( reduce, HloInstruction::CreateMap(reduce->shape(), - {reshape, init_value}, function)); + {init_value, reshape}, function)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 20c549562d5..4e082877c77 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -1699,14 +1700,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1732,8 +1733,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1751,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1766,14 +1767,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1789,14 +1790,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1924,12 +1925,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(b.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - if (!simplifier.Run(&module).ValueOrDie()) { + if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } auto* root = computation->root_instruction(); @@ -2044,15 +2045,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Minimum(param0, min_value), max_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2074,15 +2075,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2105,15 +2106,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2135,15 +2136,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2167,8 +2168,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2176,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2201,8 +2202,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); @@ -2211,10 +2212,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2242,8 +2243,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); @@ -2251,7 +2252,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2260,7 +2261,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2289,7 +2290,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2312,15 +2313,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2341,7 +2342,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2374,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2397,15 +2398,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2431,12 +2432,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); @@ -2962,5 +2963,208 @@ TEST_F(AlgebraicSimplifierTest, DynamicUpdateSliceZeroUpdate) { INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); + +struct DotOfGatherTestSpec { + int64 m; + int64 k; + int64 n; + int s; // start index for dynamic slice on the non-contracting dimension + int64 lcd; // left contracting dimension + int64 rcd; // right contracting dimension + bool neg; // is negative testcase +}; + +class DotOfGatherSimplificationTest + : public HloVerifiedTestBase, + public ::testing::WithParamInterface {}; + +// input: dot(DS(ctA), ctB)) +// where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. +// => input dimensions: dot({1 x K}, {K x N}) => {1 x N}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {s, 0}, {1, N}) => {1 x N}. +TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.m); + + // For negative tests, increase k of the dynamic slice argument to prevent the + // optimization (constants ctA, ctB must have equal contracting dimensions). + int64 k_increase = spec.neg ? 5 : 0; + int64 lhs_rows = (spec.lcd == 0) ? (spec.k + k_increase) : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : (spec.k + k_increase); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + int32 start_row = (spec.lcd == 0) ? 0 : spec.s; + int32 start_col = (spec.lcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.lcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.lcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, lhs, start_indices, {slice_row_size, slice_col_size})); + + int64 rhs_rows = (spec.rcd == 0) ? spec.k : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : spec.k; + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = 1; + int64 dot_col_size = spec.n; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +// input: dot(ctA, DS(ctB)) +// where ctA = {M x K} and DS(ctB) = DS({K x N}, {0, s}, {K, 1}). +// => input dimensions: dot({M x K}, {K x 1}) => {M x 1}. +// output: DS(dot(ctA, ctB)) +// => output dimensions: DS ({M x N}, {0, s}, {M, 1}) => {M x 1}. +TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfGatherTestSpec spec = GetParam(); + + ASSERT_LE(spec.s, spec.n); + + int64 lhs_rows = (spec.lcd == 0) ? spec.k : spec.m; + int64 lhs_cols = (spec.lcd == 0) ? spec.m : spec.k; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {lhs_rows, lhs_cols}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/lhs_rows, + /*cols=*/lhs_cols))); + + // For negative tests increase k of the dynamic slice argument to prevent the + // optimization + int64 k_increase = spec.neg ? 5 : 0; + int64 rhs_rows = (spec.rcd == 0) ? (spec.k + k_increase) : spec.n; + int64 rhs_cols = (spec.rcd == 0) ? spec.n : (spec.k + k_increase); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {rhs_rows, rhs_cols}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/rhs_rows, + /*cols=*/rhs_cols))); + + int32 start_row = (spec.rcd == 0) ? 0 : spec.s; + int32 start_col = (spec.rcd == 0) ? spec.s : 0; + const auto start_indices = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({start_row, start_col}))); + int64 slice_row_size = (spec.rcd == 0) ? spec.k : 1; + int64 slice_col_size = (spec.rcd == 0) ? 1 : spec.k; + Shape ds_shape = ShapeUtil::MakeShape(F32, {slice_row_size, slice_col_size}); + auto* ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ds_shape, rhs, start_indices, {slice_row_size, slice_col_size})); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(spec.lcd); + dot_dnums.add_rhs_contracting_dimensions(spec.rcd); + + int64 dot_row_size = spec.m; + int64 dot_col_size = 1; + Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + if (spec.neg) { + EXPECT_NE(computation->root_instruction()->opcode(), + HloOpcode::kDynamicSlice); + } else { + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Dot(op::Constant(), op::Constant()), + op::Concatenate())); + } +} + +std::vector DotOfGatherPositiveNegativeTests() { + std::vector positives = { + // "Classical dot", i.e. matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/0, + /*neg=*/false}, + // Note: testing for m=1 and n=1 is unnecessary, as this optimizes to + // dot(ct, ct) before DotOfGather optimization kicks in. + // Contract on rows: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/0, + /*neg=*/false}, + // Reverse matrix multiply: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/0, /*rcd=*/1, + /*neg=*/false}, + // Contract on columns: + {/*m=*/10, /*k=*/10, /*n=*/5, /*s=*/0, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/20, /*k=*/20, /*n=*/3, /*s=*/2, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + {/*m=*/10, /*k=*/3, /*n=*/10, /*s=*/9, /*lcd=*/1, /*rcd=*/1, + /*neg=*/false}, + }; + std::vector all; + for (int i = 0; i < positives.size(); i++) { + DotOfGatherTestSpec positive_test = positives[i]; + all.push_back(positive_test); + DotOfGatherTestSpec negative_test = positive_test; + negative_test.neg = true; + all.push_back(negative_test); + } + return all; +} + +INSTANTIATE_TEST_CASE_P( + DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, + ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 4f819a743c4..95b4cb6d2e6 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -31,52 +31,68 @@ limitations under the License. namespace xla { StatusOr AllocationTracker::Register( - std::unique_ptr shaped_buffer, const string& tag) { + ScopedShapedBuffer shaped_buffer, const string& tag) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - std::vector> replicated_buffers; + std::vector replicated_buffers; replicated_buffers.emplace_back(std::move(shaped_buffer)); return RegisterInternal(std::move(replicated_buffers), tag); } StatusOr AllocationTracker::RegisterReplicatedBuffers( - std::vector> replicated_buffers, - const string& tag) { + std::vector replicated_buffers, const string& tag) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "RegisterReplicatedBuffers"; return RegisterInternal(std::move(replicated_buffers), tag); } +// ReleaseIfScopedShapedBuffer lets RegisterInternal(b) call +// b.release() if b is a ScopedShapedBuffer, or otherwise pass b through +// unmodified. +static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; } +static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) { + return b.release(); +} + +template StatusOr AllocationTracker::RegisterInternal( - std::vector> replicated_buffers, - const string& tag) { + std::vector replicated_buffers, const string& tag) { + static_assert(std::is_same::value || + std::is_same::value, + "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer."); VLOG(2) << "RegisterInternal(" << "tag: \"" << tag << "\" with " << replicated_buffers.size() << " shaped_buffers."; for (const auto& shaped_buffer : replicated_buffers) { - VLOG(2) << "shaped_buffer:" << *shaped_buffer; - if (shaped_buffer->platform() != backend_->platform()) { + VLOG(2) << "shaped_buffer:" << shaped_buffer; + if (shaped_buffer.platform() != backend_->platform()) { return InvalidArgument( "AllocationTracker for platform %s cannot register buffer from " "platform %s", backend_->platform()->Name().c_str(), - shaped_buffer->platform()->Name().c_str()); + shaped_buffer.platform()->Name().c_str()); } } int64 handle = next_handle_++; for (auto& shaped_buffer : replicated_buffers) { std::vector shape_indices; - ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), - [this, &shape_indices](const Shape& /*subshape*/, - const ShapeIndex& index) { - shape_indices.push_back(index); - }); + ShapeUtil::ForEachSubshape( + shaped_buffer.on_device_shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + shape_indices.push_back(index); + }); + // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns + // them. for (const ShapeIndex& index : shape_indices) { - AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), - shaped_buffer->device_ordinal()); + AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index), + shaped_buffer.device_ordinal()); } - handle_to_shaped_buffers_[handle].emplace_back(std::move(shaped_buffer)); + // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer + // into a regular ShapedBuffer, which is stored in + // handle_to_shaped_buffers_. + handle_to_shaped_buffers_[handle].emplace_back(MakeUnique( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } GlobalDataHandle result; @@ -85,7 +101,7 @@ StatusOr AllocationTracker::RegisterInternal( return result; } -tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { +Status AllocationTracker::Unregister(const GlobalDataHandle& data) { tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Unregister(" << "handle: " << data.handle() << ")"; @@ -103,10 +119,6 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { shaped_buffer->device_ordinal())); } } - return Reset(data); -} - -Status AllocationTracker::Reset(const GlobalDataHandle& data) { // Keep a nullptr as a tombstone for unregistered handles. This enables // better error messages. That is, "handle has been deallocated" versus // "handle does not exist". @@ -118,7 +130,7 @@ Status AllocationTracker::Reset(const GlobalDataHandle& data) { for (auto& shaped_buffer : it->second) { shaped_buffer.reset(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( @@ -146,14 +158,14 @@ StatusOr> AllocationTracker::DeconstructTuple( for (int i = 0; i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); ++i) { - auto element_buffer = MakeUnique( + auto element_buffer = ShapedBuffer( ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), shaped_buffer->platform(), shaped_buffer->device_ordinal()); - element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), - /*index=*/{}); - std::vector> replicated_buffers; - replicated_buffers.emplace_back(std::move(element_buffer)); + element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}), + /*index=*/{}); + std::vector replicated_buffers; + replicated_buffers.push_back(std::move(element_buffer)); TF_ASSIGN_OR_RETURN( GlobalDataHandle element_handle, RegisterInternal(std::move(replicated_buffers), "deconstructed tuple")); @@ -204,32 +216,33 @@ StatusOr> AllocationTracker::ResolveInternal( } void AllocationTracker::AddAllocationOrIncrementRefCount( - perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + se::DeviceMemoryBase device_memory, int device_ordinal) { AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); if (it == allocation_map.end()) { - allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, - /*ref_count=*/1}; + allocation_map[device_memory.opaque()] = { + OwningDeviceMemory(device_memory, device_ordinal, + backend_->memory_allocator()), + /*ref_count=*/1}; } else { it->second.ref_count++; } } -Status AllocationTracker::DecrementRefCount( - perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { +Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory, + int device_ordinal) { AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; auto it = allocation_map.find(device_memory.opaque()); TF_RET_CHECK(it != allocation_map.end()); Allocation& allocation = it->second; TF_RET_CHECK(allocation.ref_count >= 1); if (allocation.ref_count == 1) { - TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( - device_ordinal, &device_memory)); + allocation.device_memory.Free(); allocation_map.erase(it); } else { allocation.ref_count--; } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index 038aee8541b..a7d8927cf7e 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -45,14 +45,13 @@ class AllocationTracker { // Registers a shaped buffer of device memory, and returns a corresponding // handle that can be used for talking to XLA clients. The given shaped buffer // will be treated as the buffer corresponding to the only replica. - StatusOr Register( - std::unique_ptr shaped_buffer, const string& tag); + StatusOr Register(ScopedShapedBuffer shaped_buffer, + const string& tag); // Registers a vector of shaped buffers of device memory, one per replica, and // returns a corresponding handle that can be used for talking to XLA clients. StatusOr RegisterReplicatedBuffers( - std::vector> replicated_buffers, - const string& tag); + std::vector replicated_buffers, const string& tag); // Unregister the allocation for the given data handle. Status Unregister(const GlobalDataHandle& data); @@ -77,10 +76,7 @@ class AllocationTracker { // Data structure encapsulating single memory allocation on the device. struct Allocation { // The pointer to this allocation. - perftools::gputools::DeviceMemoryBase device_memory; - - // The device that the memory is allocated on. - int device_ordinal; + OwningDeviceMemory device_memory; // This is the number of times this memory allocation is referred to by // registered data handles. @@ -88,28 +84,28 @@ class AllocationTracker { }; // Internal helper which resolves the given GlobalDataHandle to a - // ShapedBuffer. + // list of ScopedShapedBuffers. StatusOr> ResolveInternal( const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Internal helper which registers a vector of shaped buffers, one per - // replica. + // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If + // it's ShapedBuffer, all of the given buffers must already be tracked by this + // object -- presumably this is a call from DeconstructTuple. + template StatusOr RegisterInternal( - std::vector> replicated_buffers, - const string& tag) EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Resets the shaped buffers corresponding to the given handle. - Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + std::vector replicated_buffers, const string& tag) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Adds the given device address to the allocation tracker, or if it already - // exists, then increment it's reference count. - void AddAllocationOrIncrementRefCount( - perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) + // exists, then increment its reference count. + void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory, + int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Decrements the reference count of the given device memory. Then, if it is // zero, deallocate the memory. - Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory, + Status DecrementRefCount(se::DeviceMemoryBase device_memory, int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); // A map from device memory opaque value to allocation. One such map is @@ -127,11 +123,29 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - tensorflow::gtl::FlatMap opaque_to_allocation_map_ + // + // This is not a TF FlatMap because (currently) FlatMap (and therefore + // AllocationMap) is not movable. + std::unordered_map opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the // buffers for different replicas. + // + // The ShapedBuffers in this map's vectors need to be unique_ptrs, because our + // public API returns pointers to them. We expect the concrete class to be + // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is + // handled by opaque_to_allocation_map_. + // + // The elements of the vectors need to be unique_ptrs because we return + // pointers to them. (In theory we could use std::list or something instead, + // but we also want to be able to null out these elements.) + // + // The reason that the elements can't be unique_ptrs is + // the existence of DeconstructTuple(). This function allows us to create a + // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then + // free'd when both the view *and* the original tuple are Unregistered. This + // refcounting is managed in opaque_to_allocation_map_. tensorflow::gtl::FlatMap>> handle_to_shaped_buffers_ GUARDED_BY(mutex_); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 05f2d062784..349b32451a6 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -31,24 +31,20 @@ limitations under the License. #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { -BackendOptions& BackendOptions::set_platform( - perftools::gputools::Platform* platform) { +BackendOptions& BackendOptions::set_platform(se::Platform* platform) { platform_ = platform; return *this; } -perftools::gputools::Platform* BackendOptions::platform() const { - return platform_; -} +se::Platform* BackendOptions::platform() const { return platform_; } BackendOptions& BackendOptions::set_intra_op_parallelism_threads( int num_threads) { @@ -77,7 +73,7 @@ struct Backend::EigenThreadPoolWrapper { /* static */ StatusOr> Backend::CreateBackend( const BackendOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); @@ -121,7 +117,7 @@ StatusOr Backend::BorrowStream( } Backend::Backend( - perftools::gputools::Platform* platform, Compiler* compiler, + se::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads) @@ -142,9 +138,6 @@ Backend::Backend( << "Service found no devices for backend " << platform_->Name() << '.'; if (platform->id() == se::host::kHostPlatformId) { - inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( - tensorflow::Env::Default(), "xla_inter_op", - tensorflow::port::NumSchedulableCPUs())); const int num_threads = intra_op_parallelism_threads > 0 ? intra_op_parallelism_threads : tensorflow::port::NumSchedulableCPUs(); @@ -159,10 +152,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { - return inter_op_thread_pool_.get(); -} - const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device() const { if (intra_op_thread_pool_wrapper_ == nullptr) { @@ -178,7 +167,7 @@ tensorflow::thread::ThreadPool* Backend::eigen_intra_op_thread_pool() const { return intra_op_thread_pool_wrapper_->pool.get(); } -StatusOr Backend::stream_executor( +StatusOr Backend::stream_executor( int device_ordinal) const { if (device_ordinal < 0 || device_ordinal > stream_executors_.back()->device_ordinal()) { @@ -201,9 +190,9 @@ StatusOr Backend::devices_equivalent(int device_ordinal_a, // bit crude but works for GPUs which is the important case where we compile // an executable for one GPU and want to know if it will run (well) on // another. - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_a, + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_a, stream_executor(device_ordinal_a)); - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_b, + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor_b, stream_executor(device_ordinal_b)); return (executor_a->GetDeviceDescription().name() == executor_b->GetDeviceDescription().name()); diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index b5ca483b727..6546602473e 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -44,8 +44,8 @@ namespace xla { class BackendOptions { public: // Set the platform backing the backend, or nullptr for the default platform. - BackendOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; + BackendOptions& set_platform(se::Platform* platform); + se::Platform* platform() const; // Sets the thread pool size for parallel execution of an individual operator. // The default value of -1 will result in initializing the thread pool with @@ -54,7 +54,7 @@ class BackendOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_ = nullptr; + se::Platform* platform_ = nullptr; int intra_op_parallelism_threads_ = -1; }; @@ -66,7 +66,7 @@ class BackendOptions { // StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie(); class Backend { public: - using StreamPtr = Pool::SmartPtr; + using StreamPtr = Pool::SmartPtr; // Creates a new backend. static StatusOr> CreateBackend( @@ -79,7 +79,7 @@ class Backend { ~Backend(); // Accessors for the various objects. - perftools::gputools::Platform* platform() const { return platform_; } + se::Platform* platform() const { return platform_; } Compiler* compiler() const { return compiler_; } DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_.get(); @@ -96,19 +96,17 @@ class Backend { // Returns stream executors of all supported devices for this backend. The // executors are ordered by the device ordinal. - const std::vector& stream_executors() - const { + const std::vector& stream_executors() const { return stream_executors_; } // Returns the stream executor for the given device ordinal. - StatusOr stream_executor( - int device_ordinal) const; + StatusOr stream_executor(int device_ordinal) const; // Returns the stream executor for the default device ordinal. This stream // executor can only be used when the number of computations is 1 (replication // can be > 1). - perftools::gputools::StreamExecutor* default_stream_executor() const { + se::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; } @@ -117,8 +115,7 @@ class Backend { // internal pool, or by constructing/initializating it, and returns the result // to the caller. StatusOr BorrowStream(int device_ordinal); - StatusOr BorrowStream( - perftools::gputools::StreamExecutor* executor); + StatusOr BorrowStream(se::StreamExecutor* executor); // Returns a function to borrow a stream, as `BorrowStream` above does. // Purely for convenience, the caller could rather make this anonymous @@ -143,10 +140,6 @@ class Backend { // be equivalent to an executable compiled for the other. StatusOr devices_equivalent(int device_ordinal_a, int device_ordinal_b); - // For the host platform, returns the threadpool to use when scheduling - // parallel operators. For other platforms, returns NULL. - tensorflow::thread::ThreadPool* inter_op_thread_pool() const; - // For the host platform, returns the configured eigen threadpool device to be // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; @@ -157,36 +150,30 @@ class Backend { private: struct EigenThreadPoolWrapper; - Backend(perftools::gputools::Platform* platform, Compiler* compiler, - tensorflow::gtl::ArraySlice - stream_executors, + Backend(se::Platform* platform, Compiler* compiler, + tensorflow::gtl::ArraySlice stream_executors, TransferManager* transfer_manager, ComputationPlacer* computation_placer, int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; - perftools::gputools::Platform* platform_; + se::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; ComputationPlacer* computation_placer_; // Vector of stream executors. stream_executors_[0] is the default executor. - std::vector stream_executors_; + std::vector stream_executors_; tensorflow::mutex mu_; // Mapping from stream executor to stream pools, used by `BorrowStream` above. - std::map> - stream_pools_ GUARDED_BY(mu_); + std::map> stream_pools_ GUARDED_BY(mu_); // The default memory allocator to use. std::unique_ptr memory_allocator_; - // For the CPU backend, a threadpool for scheduling parallel operators. - std::unique_ptr inter_op_thread_pool_; - // For the CPU backend, an Eigen threadpool device for use by Eigen code. std::unique_ptr intra_op_thread_pool_wrapper_; }; diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc new file mode 100644 index 00000000000..2099916509a --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" + +namespace xla { +StatusOr +BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot) { + const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); + HloInstruction *lhs = batch_dot->mutable_operand(0), + *rhs = batch_dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + + std::vector degenerate_dims; + for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) { + if (lhs_shape.dimensions(batch_dim) == 1) { + degenerate_dims.push_back(batch_dim); + } + } + + if (degenerate_dims.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN(HloInstruction * new_lhs, + ElideDegenerateDims(lhs, degenerate_dims)); + TF_ASSIGN_OR_RETURN(HloInstruction * new_rhs, + ElideDegenerateDims(rhs, degenerate_dims)); + + DotDimensionNumbers new_dim_numbers = dim_numbers; + new_dim_numbers.clear_lhs_batch_dimensions(); + new_dim_numbers.clear_rhs_batch_dimensions(); + + for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() - + degenerate_dims.size(); + i < e; i++) { + new_dim_numbers.add_lhs_batch_dimensions(i); + new_dim_numbers.add_rhs_batch_dimensions(i); + } + + new_dim_numbers.set_lhs_contracting_dimensions( + 0, + new_dim_numbers.lhs_contracting_dimensions(0) - degenerate_dims.size()); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, + new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + + TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, + MakeReshapeHlo(batch_dot->shape(), new_dot)); + + VLOG(2) << "Replaced " << batch_dot->ToString() << " with " + << new_dot->ToString(); + + TF_RETURN_IF_ERROR( + batch_dot->parent()->ReplaceInstruction(batch_dot, new_dot_reshaped)); + + return true; +} + +tensorflow::StringPiece BatchDotSimplification::name() const { + return "batch-dot-simplification"; +} + +StatusOr BatchDotSimplification::Run(HloModule* module) { + bool changed = false; + std::vector dot_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); + } + for (HloInstruction* dot_instr : dot_instrs) { + TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, + ElideDegenerateBatchDimensionFromBatchDot(dot_instr)); + changed |= elided_batch_dim_from_one; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h new file mode 100644 index 00000000000..c0ca8d8ebac --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// Simplifies batch dot operations. +// +// Normally these would live in the algebraic simplifier, but we want to run +// this to fixpoint (this pass reaches fixed point in one execution) before we +// run the DotDecomposer. +class BatchDotSimplification : public HloPassInterface { + public: + StatusOr Run(HloModule* module) override; + tensorflow::StringPiece name() const override; + + private: + StatusOr ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc new file mode 100644 index 00000000000..38f1a5d3a64 --- /dev/null +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class BatchDotSimplificationTest : public HloVerifiedTestBase {}; + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3] parameter(1) + ROOT dot = f32[1,9] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideSingleDegenerateBatchDotDim_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[1,9,3] parameter(0) + b = f32[1,3,7] parameter(1) + ROOT dot = f32[1,9,7] dot(a, b), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorVector) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,3] parameter(1) + ROOT dot = f32[9,1,7,1] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/2))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_VectorMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,3] parameter(0) + b = f32[9,1,7,1,20,3] parameter(1) + ROOT dot = f32[9,1,7,1,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={4}, rhs_contracting_dims={5} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/2, /*rhs_contracting_dim=*/3))); +} + +TEST_F(BatchDotSimplificationTest, + ElideMultipleDegenerateBatchDotDims_MatrixMatrix) { + const string hlo_text = R"( +HloModule BatchDot + +main { + a = f32[9,1,7,1,19,3] parameter(0) + b = f32[9,1,7,1,3,20] parameter(1) + ROOT dot = f32[9,1,7,1,19,20] dot(a, b), lhs_batch_dims={0,1,2,3}, rhs_batch_dims={0,1,2,3}, lhs_contracting_dims={5}, rhs_contracting_dims={4} +} +)"; + + ParseAndVerifyModule(hlo_text); + BatchDotSimplification pass; + ASSERT_TRUE(pass.Run(&module()).ValueOrDie()); + + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, + op::Reshape(op::Dot( + op::Reshape(op::Parameter(0)), op::Reshape(op::Parameter(1)), + /*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/2))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 38086bd7e12..96e02b82b97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -15,35 +15,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" -#include #include -#include -#include #include #include #include -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { @@ -80,17 +77,25 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, - HloOpcode opcode) { - HloComputation::Builder b("scalar_computation"); - auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); - auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); - auto scalar_op = b.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), - opcode, scalar_lhs, scalar_rhs)); - return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + HloComputation* GetOrCreateScalarAddComputation( + PrimitiveType primitive_type) { + HloComputation** scalar_add_computation = + &scalar_add_computations_[primitive_type]; + if (*scalar_add_computation) { + return *scalar_add_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(primitive_type, {}); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + *scalar_add_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return *scalar_add_computation; } // Current HloComputation instance the BatchNormExpander is @@ -105,6 +110,10 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Whether rewrite has occurred. bool changed_ = false; + // Cached computations for adding two scalars. + tensorflow::gtl::FlatMap + scalar_add_computations_; + // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. // Returns the Status representing the result of the replace operation. @@ -129,6 +138,8 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { } }; +} // namespace + bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, @@ -199,7 +210,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // X^2. auto operand_squared = add(HloInstruction::CreateBinary( @@ -500,7 +511,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(ptype, HloOpcode::kAdd); + GetOrCreateScalarAddComputation(ptype); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index c26d2feef58..ed0746980f8 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -33,7 +33,7 @@ BFloat16Propagation::BFloat16Propagation( const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} -void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( +void BFloat16Propagation::DetermineFusionComputationPrecision( HloInstruction* fusion) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) { @@ -48,15 +48,13 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( auto root = fusion->fused_instructions_computation()->root_instruction(); // Adjust root's element types according to the fusion's output shape. - ShapeUtil::ForEachMutableSubshape( - root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32) { + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { return; } - if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() == - BF16) { - subshape->set_element_type(BF16); - changed_ = true; + if (OutputTypeAfterChange(fusion, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(root, index, BF16); VLOG(2) << "Fused root " << root->ToString() << " at shape index " << index << " changed to BF16 precision for fusion " << fusion->ToString(); @@ -67,13 +65,101 @@ void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( auto insts = fusion->fused_instructions_computation()->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert( + computations_visited_in_backward_pass_.insert( fusion->fused_instructions_computation()); + + RevertIfFusionInternalBF16Changes(fusion); } -void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( +void BFloat16Propagation::RevertIfFusionInternalBF16Changes( + HloInstruction* fusion) { + auto has_changes = [this](HloInstruction* inst) { + auto it = changes_to_bf16_.find(inst); + return it != changes_to_bf16_.end() && !it->second.empty(); + }; + + auto root = fusion->fused_instructions_computation()->root_instruction(); + tensorflow::gtl::FlatSet changed_root_buffers; + + auto root_changes_it = changes_to_bf16_.find(root); + if (root_changes_it != changes_to_bf16_.end()) { + for (const auto& index : root_changes_it->second) { + for (const HloValue* value : + dataflow_->GetValueSet(root, index).values()) { + changed_root_buffers.insert(value); + } + } + } + + auto aliases_changed_root_buffer = + [this, &changed_root_buffers](const HloInstruction* inst) { + bool aliasing = false; + ShapeUtil::ForEachSubshape( + inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (aliasing) { + // Skip if aliasing is already found. + return; + } + // Only F32 buffers are considered for changing to BF16 in this + // pass. + if (subshape.element_type() != F32) { + return; + } + for (const HloValue* value : + dataflow_->GetValueSet(inst, index).values()) { + if (ContainsKey(changed_root_buffers, value)) { + aliasing = true; + break; + } + } + }); + return aliasing; + }; + + for (auto inst : + fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kParameter) { + continue; + } + if (aliases_changed_root_buffer(inst)) { + continue; + } + if (inst->opcode() == HloOpcode::kFusion) { + bool parameter_reverted = false; + for (int64 i = 0; i < inst->operand_count(); ++i) { + if (has_changes(inst->mutable_operand(i))) { + // Changes on the operand have not been reverted. + continue; + } + auto* fused_parameter = inst->fused_parameter(i); + if (has_changes(fused_parameter)) { + changes_to_bf16_.erase(fused_parameter); + parameter_reverted = true; + } + } + if (parameter_reverted) { + RevertIfFusionInternalBF16Changes(inst); + } + } + if (!has_changes(inst)) { + continue; + } + bool revert_changes = true; + for (auto operand : inst->operands()) { + if (has_changes(operand)) { + revert_changes = false; + break; + } + } + if (revert_changes) { + changes_to_bf16_.erase(inst); + } + } +} + +void BFloat16Propagation::DetermineWhileComputationsPrecision( HloInstruction* while_hlo) { CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile); @@ -86,16 +172,14 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( auto body_root = body->root_instruction(); HloComputation* condition = while_hlo->while_condition(); - ShapeUtil::ForEachMutableSubshape( - body_root->mutable_shape(), - [this, while_hlo, body_root](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32) { + ShapeUtil::ForEachSubshape( + body_root->shape(), [this, while_hlo, body_root]( + const Shape& subshape, const ShapeIndex& index) { + if (subshape.element_type() != F32) { return; } - if (ShapeUtil::GetSubshape(while_hlo->shape(), index).element_type() == - BF16) { - subshape->set_element_type(BF16); - changed_ = true; + if (OutputTypeAfterChange(while_hlo, index) == BF16) { + AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16); VLOG(2) << "While body root " << body_root->ToString() << " at shape index " << index << " changed to BF16 precision for while " @@ -106,30 +190,30 @@ void BFloat16Propagation::DetermineAndMutateWhileComputationsPrecision( auto body_insts = body->MakeInstructionPostOrder(); for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert(body); + computations_visited_in_backward_pass_.insert(body); auto condition_insts = condition->MakeInstructionPostOrder(); for (auto inst_it = condition_insts.rbegin(); inst_it != condition_insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false); } - computations_visited_in_mutation_pass_.insert(condition); + computations_visited_in_backward_pass_.insert(condition); } bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const { - auto value_set = dataflow_->GetValueSet(&hlo, index); + auto& value_set = dataflow_->GetValueSet(&hlo, index); for (const HloValue* value : value_set.values()) { if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { return false; } - if (value->shape().element_type() == BF16) { + if (ValueTypeAfterChange(value) == BF16) { continue; } for (const HloUse& use : value->uses()) { - if (!ContainsKey(instructions_visited_in_mutation_pass_, + if (!ContainsKey(instructions_visited_in_backward_pass_, use.instruction)) { // We don't know yet whether use.instruction will consume BF16 since it // hasn't been visited. Although we visit instructions in reverse @@ -145,26 +229,23 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, // precision, or a called computation's parameters have been changed to // BF16 for fusions or whiles. if (use.instruction->opcode() == HloOpcode::kFusion) { - const auto* fused_parameter = + auto* fused_parameter = use.instruction->fused_parameter(use.operand_number); - if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) { return false; } continue; } else if (use.instruction->opcode() == HloOpcode::kWhile) { - const auto* cond_parameter = + auto* cond_parameter = use.instruction->while_condition()->parameter_instruction( use.operand_number); - if (ShapeUtil::GetSubshape(cond_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) { return false; } - const auto* body_parameter = + auto* body_parameter = use.instruction->while_body()->parameter_instruction( use.operand_number); - if (ShapeUtil::GetSubshape(body_parameter->shape(), use.operand_index) - .element_type() != BF16) { + if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) { return false; } continue; @@ -174,19 +255,20 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, continue; } // If the op propagates precision and it outputs a BF16, then it's OK to - // supply BF16 also as the input. In the backward mutation pass, the users - // shapes should have already been processed. + // supply BF16 also as the input. In the backward pass, the users shapes + // should have already been processed. PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; if (use.instruction->opcode() == HloOpcode::kTuple || (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && ShapeUtil::IsTuple(use.instruction->shape()))) { - user_output_type = ShapeUtil::GetSubshape( - ShapeUtil::GetSubshape(use.instruction->shape(), - {use.operand_number}), - use.operand_index) - .element_type(); + ShapeIndex use_output_index{use.operand_number}; + for (int64 i : use.operand_index) { + use_output_index.push_back(i); + } + user_output_type = + OutputTypeAfterChange(use.instruction, use_output_index); } else { - user_output_type = use.instruction->shape().element_type(); + user_output_type = OutputTypeAfterChange(use.instruction, {}); } if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( *use.instruction, use.operand_number) && @@ -199,8 +281,8 @@ bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, return true; } -void BFloat16Propagation::DetermineAndMutateInstructionPrecision( - HloInstruction* hlo, bool skip_parameters) { +void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo, + bool skip_parameters) { // We handle any fusion computation or while body/condition after the // instruction is handled, because we need to know the output shape of a // fusion or while before propagating inside its computations. @@ -209,12 +291,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( [this, hlo, &postpone_processing_called_computations] { if (!postpone_processing_called_computations) { if (hlo->opcode() == HloOpcode::kFusion) { - DetermineAndMutateFusionComputationPrecision(hlo); + DetermineFusionComputationPrecision(hlo); } else if (hlo->opcode() == HloOpcode::kWhile) { - DetermineAndMutateWhileComputationsPrecision(hlo); + DetermineWhileComputationsPrecision(hlo); } } - instructions_visited_in_mutation_pass_.insert(hlo); + instructions_visited_in_backward_pass_.insert(hlo); }); if (hlo->opcode() == HloOpcode::kWhile && @@ -245,9 +327,9 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( CHECK(hlo->parent() != nullptr); if (hlo == hlo->parent()->root_instruction()) { if (!hlo->parent()->IsFusionComputation()) { - ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape, + ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */, const ShapeIndex& index) { - if (subshape.element_type() != F32) { + if (OutputTypeAfterChange(hlo, index) != F32) { return; } for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { @@ -269,13 +351,12 @@ void BFloat16Propagation::DetermineAndMutateInstructionPrecision( return; } - ShapeUtil::ForEachMutableSubshape( - hlo->mutable_shape(), - [hlo, this](Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() == F32 && + ShapeUtil::ForEachSubshape( + hlo->shape(), + [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) { + if (OutputTypeAfterChange(hlo, index) == F32 && AllUsersConsumeBF16(*hlo, index)) { - subshape->set_element_type(BF16); - changed_ = true; + AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16); VLOG(2) << "HloInstruction output at shape index " << index << " changed to BF16 precision: " << hlo->ToString(); } @@ -308,26 +389,24 @@ void BFloat16Propagation::AdjustCalledComputationParameters( CHECK_EQ(operands.size(), computation->num_parameters()); for (int64 i = 0; i < operands.size(); ++i) { auto parameter = computation->parameter_instruction(i); - ShapeUtil::ForEachMutableSubshape( - parameter->mutable_shape(), - [this, i, hlo, &operands, parameter](Shape* subshape, + ShapeUtil::ForEachSubshape( + parameter->shape(), + [this, i, hlo, &operands, parameter](const Shape& /* subshape */, const ShapeIndex& index) { if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { return; } PrimitiveType operand_type = - ShapeUtil::GetSubshape(operands[i]->shape(), index) - .element_type(); - if (subshape->element_type() == operand_type) { + OutputTypeAfterChange(operands[i], index); + if (OutputTypeAfterChange(parameter, index) == operand_type) { return; } - CHECK(operand_type == F32 || operand_type == BF16); - subshape->set_element_type(operand_type); - changed_ = true; + AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type); VLOG(2) << "Called computation parameter " << parameter->ToString() << " at shape index " << index - << " adjusted to match operand in HLO " - << hlo->ToString(); + << " adjusted to " + << (operand_type == BF16 ? "BF16" : "F32") + << " to match operand in HLO " << hlo->ToString(); }); } }; @@ -348,52 +427,48 @@ void BFloat16Propagation::AdjustCalledComputationParameters( void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { auto adjust_computation = [this, hlo](HloComputation* computation, - const Shape& output_shape) { + HloInstruction* output) { // Adjust root. HloInstruction* root = computation->root_instruction(); - ShapeUtil::ForEachMutableSubshape( - root->mutable_shape(), [this, hlo, root, &output_shape]( - Shape* subshape, const ShapeIndex& index) { - if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { - return; - } - const PrimitiveType output_type = - ShapeUtil::GetSubshape(output_shape, index).element_type(); - if (subshape->element_type() == output_type) { - return; - } - CHECK(output_type == F32 || output_type == BF16); - subshape->set_element_type(output_type); - // It's possible that output_type is F32, but the root instruction's - // type is BF16; e.g., a fusion node's output was changed to BF16 - // initially but then adjusted back to F32, and the fusion computation - // is now being adjusted after the fusion node. - if (output_type == F32) { - for (const auto* value : - dataflow_->GetValueSet(root, index).values()) { - // We rely on the fact that this adjustment works in reverse - // topological order so that called computation will be - // processed later. Adding the value to - // values_that_must_be_kept_as_f32_ will ensure the - // correctness of the adjustment for HLOs that will be - // processed later. - values_that_must_be_kept_as_f32_.insert(value); - } - } - changed_ = true; - VLOG(2) << "Called computation root " << root->ToString() - << " at shape index " << index - << " adjusted to match output shape of " << hlo->ToString(); - }); + ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output]( + const Shape& /* subshape */, + const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) { + return; + } + const PrimitiveType output_type = OutputTypeAfterChange(output, index); + if (OutputTypeAfterChange(root, index) == output_type) { + return; + } + AddToOrRemoveFromBF16ChangeSet(root, index, output_type); + // It's possible that output_type is F32, but the root instruction's + // type is BF16; e.g., a fusion node's output was changed to BF16 + // initially but then adjusted back to F32, and the fusion computation + // is now being adjusted after the fusion node. + if (output_type == F32) { + for (const auto* value : dataflow_->GetValueSet(root, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order so that called computation will be + // processed later. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the + // correctness of the adjustment for HLOs that will be + // processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + VLOG(2) << "Called computation root " << root->ToString() + << " at shape index " << index << " adjusted to " + << (output_type == BF16 ? "BF16" : "F32") + << " to match output shape of " << hlo->ToString(); + }); }; switch (hlo->opcode()) { case HloOpcode::kFusion: - adjust_computation(hlo->fused_instructions_computation(), hlo->shape()); + adjust_computation(hlo->fused_instructions_computation(), hlo); break; case HloOpcode::kWhile: - adjust_computation(hlo->while_condition(), hlo->shape()); - adjust_computation(hlo->while_body(), hlo->shape()); + adjust_computation(hlo->while_body(), hlo); break; default: break; @@ -410,16 +485,19 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; auto adjust_hlo_output = [this, hlo, ¶meter_changed]( - Shape* subshape, const ShapeIndex& index) { - if (subshape->element_type() != F32 && subshape->element_type() != BF16) { + const Shape& /* subshape */, + const ShapeIndex& index) { + auto output_type = OutputTypeAfterChange(hlo, index); + if (output_type != F32 && output_type != BF16) { return; } PrimitiveType type = BF16; for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { - if (value->shape().element_type() == BF16) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { continue; } - CHECK_EQ(value->shape().element_type(), F32); + CHECK_EQ(value_type, F32); type = F32; break; } @@ -438,16 +516,17 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( values_that_must_be_kept_as_f32_.insert(value); } } - if (type != subshape->element_type()) { - subshape->set_element_type(type); + if (type != output_type) { + AddToOrRemoveFromBF16ChangeSet(hlo, index, type); VLOG(2) << "HloInstruction output at shape index " << index - << " adjusted to " << *subshape << ": " << hlo->ToString(); + << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": " + << hlo->ToString(); if (hlo->opcode() == HloOpcode::kParameter) { parameter_changed = true; } } }; - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_hlo_output); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); AdjustCalledComputationRoot(hlo); if (hlo->opcode() == HloOpcode::kWhile) { // We need to run on the while body and condition repeatedly until a fixed @@ -464,8 +543,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), &visited_in_while)) { visited_in_while.clear(); - ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), - adjust_hlo_output); + ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output); AdjustCalledComputationRoot(hlo); } visited_computations->insert(visited_in_while.begin(), @@ -479,7 +557,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( return parameter_changed; } -Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( +void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { std::list computations_topological_order = module->MakeComputationPostOrder(); @@ -491,7 +569,9 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( } ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved); } +} +Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) { // We could have changed a fusion computation's root shape to have a different // precision than the fusion node's output, if the fusion root does not // define a buffer (e.g., a tuple). Now we add conversions after such fusion @@ -518,7 +598,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( // (2) after adding conversion // (3) after tuple simplifier and DCE. bool needs_tuple_simplifier = false; - for (auto computation : computations_topological_order) { + for (auto computation : module->MakeComputationPostOrder()) { auto insts = computation->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { auto hlo = *inst_it; @@ -588,7 +668,14 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); } } + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + } + return Status::OK(); +} +Status BFloat16Propagation::ResolveConvertedConstants(HloModule* module) { // We may have converted some constants from F32 to BF16, so adjust the // constant literals in such cases. We do this here instead of when the // constant node's is changed because 1) the HloInstruction interface does not @@ -599,8 +686,7 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( // can avoid repeated conversions. // // TODO(b/73833576): Consider resetting literal in HloInstruction. - bool needs_dce = needs_tuple_simplifier; - for (auto computation : computations_topological_order) { + for (auto computation : module->MakeComputationPostOrder()) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConstant) { continue; @@ -613,23 +699,13 @@ Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( auto new_constant = computation->AddInstruction( HloInstruction::CreateConstant(std::move(converted_literal))); TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant)); - needs_dce = true; } } } - - if (needs_tuple_simplifier) { - TupleSimplifier tuple_simplifier; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - } - if (needs_dce) { - HloDCE dce; - TF_RETURN_IF_ERROR(dce.Run(module).status()); - } return Status::OK(); } -Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { +Status BFloat16Propagation::SkipNoopConversions(HloModule* module) { for (auto computation : module->computations()) { for (auto hlo : computation->MakeInstructionPostOrder()) { if (hlo->opcode() != HloOpcode::kConvert) { @@ -644,7 +720,6 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { if (is_root) { computation->set_root_instruction(source); } - TF_RETURN_IF_ERROR(computation->RemoveInstructionAndUnusedOperands(hlo)); } } return Status::OK(); @@ -653,8 +728,18 @@ Status BFloat16Propagation::RemoveNoopConversions(HloModule* module) { // The algorithm first does a forward pass (parameters to root) to determine a // set of instructions to consider using bfloat16, then does a backward pass to // determine the precisions of those instructions according to the need of -// their users. +// their users. During the backward pass, the potential changes are stored in +// changes_to_bf16_ which are subject to further adjustments then applied to the +// HLOs. StatusOr BFloat16Propagation::Run(HloModule* module) { + consider_using_bfloat16_.clear(); + instructions_visited_in_backward_pass_.clear(); + computations_visited_in_backward_pass_.clear(); + values_that_must_be_kept_as_f32_.clear(); + caller_counts_.clear(); + changes_to_bf16_.clear(); + changed_ = false; + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); std::list computations_topological_order = @@ -687,8 +772,24 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { } auto insts = (*comp_it)->MakeInstructionPostOrder(); for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { - DetermineAndMutateInstructionPrecision(*inst_it, - /*skip_parameters=*/true); + DetermineInstructionPrecision(*inst_it, + /*skip_parameters=*/true); + } + } + + // It's possible that an instruction does not define a buffer, but the + // defining instruction's shape has changed. So we need to adjust the output + // shapes of instructions according to the HLO values they refer to. + ResolveInconsistencyOfAliasingBuffers(module); + + // Apply the changes in changes_to_bf16_. + for (auto& change : changes_to_bf16_) { + auto shape = change.first->mutable_shape(); + for (const auto& index : change.second) { + auto subshape = ShapeUtil::GetMutableSubshape(shape, index); + CHECK_EQ(subshape->element_type(), F32); + subshape->set_element_type(BF16); + changed_ = true; } } @@ -696,15 +797,56 @@ StatusOr BFloat16Propagation::Run(HloModule* module) { return false; } - // It's possible that an instruction does not define a buffer, but the - // defining instruction's shape has changed. So we need to adjust the output - // shapes of instructions according to the HLO values they refer to. - TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module)); + TF_RETURN_IF_ERROR(ResolveConvertedConstants(module)); // This pass could have turned an F32 -> BF16 conversion to a no-op (BF16 -> - // BF16), so we remove them now. - TF_RETURN_IF_ERROR(RemoveNoopConversions(module)); + // BF16), so we skip them now. + TF_RETURN_IF_ERROR(SkipNoopConversions(module)); + + { + // We may have dead HLOs after ResolveInconsistentFusions, + // ResolveConvertedConstants and SkipNoopConversions. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } return true; } +PrimitiveType BFloat16Propagation::OutputTypeAfterChange( + HloInstruction* hlo, const ShapeIndex& index) const { + PrimitiveType type_on_hlo = + ShapeUtil::GetSubshape(hlo->shape(), index).element_type(); + if (type_on_hlo != F32) { + return type_on_hlo; + } + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return type_on_hlo; + } + return ContainsKey(it->second, index) ? BF16 : F32; +} + +PrimitiveType BFloat16Propagation::ValueTypeAfterChange( + const HloValue* value) const { + auto hlo = value->defining_instruction(); + const auto& position = value->defining_position(); + return OutputTypeAfterChange(hlo, position.index); +} + +void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet( + HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) { + if (target_type == BF16) { + auto& entry = changes_to_bf16_[hlo]; + entry.insert(index); + } else { + CHECK_EQ(target_type, F32); + auto it = changes_to_bf16_.find(hlo); + if (it == changes_to_bf16_.end()) { + return; + } + it->second.erase(index); + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1744e9db90a..de0355ddfca 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -85,30 +86,39 @@ class BFloat16Propagation : public HloPassInterface { tensorflow::gtl::FlatSet consider_using_bfloat16_; // *************************** - // Functions called and state produced by the backward mutation pass (from - // root to parameters). + // Functions called and state produced by the backward pass (from root to + // parameters) that finds opportunities to use BF16. - // Determines the precision for the given instruction in the mutation pass. - void DetermineAndMutateInstructionPrecision(HloInstruction* hlo, - bool skip_parameters); + // Determines the precision for the given instruction in the + // opportunity-finding pass. + void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); - // Special handling in the mutation pass for fusion computations. + // Special handling in the opportunity-finding pass for fusion computations. // // Precondition: hlo->opcode() == kFusion - void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion); + void DetermineFusionComputationPrecision(HloInstruction* fusion); - // Special handling in the mutation pass for while computations. + // Reverts changes to BF16 that will not propagate outside a fusion + // computation. This avoids BF16 casts overhead inside a fusion which won't + // save memory bandwidth. + // + // Precondition: hlo->opcode() == kFusion + void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); + + // Special handling in the opportunity-finding pass for while computations. // // Precondition: hlo->opcode() == kWhile - void DetermineAndMutateWhileComputationsPrecision(HloInstruction* while_hlo); + void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); - // The set of HloInstructions that have been visited in the mutation pass. + // The set of HloInstructions that have been visited in the + // opportunity-finding pass. tensorflow::gtl::FlatSet - instructions_visited_in_mutation_pass_; + instructions_visited_in_backward_pass_; - // The set of HloComputations that have been visited in the mutation pass. + // The set of HloComputations that have been visited in the + // opportunity-finding pass. tensorflow::gtl::FlatSet - computations_visited_in_mutation_pass_; + computations_visited_in_backward_pass_; // *************************** // Functions called by the final inconsistency resolving pass. @@ -116,7 +126,7 @@ class BFloat16Propagation : public HloPassInterface { // Adjusts the output shapes of HloInstructions such that if two // HloInstructions have aliasing buffers in their outputs, they must have the // same precision. - Status ResolveInconsistencyOfAliasingBuffers(HloModule* module); + void ResolveInconsistencyOfAliasingBuffers(HloModule* module); // Resolves inconsistency of aliasing buffers for the given computation, and // recursively runs on a while instruction's condition and body until a fixed @@ -134,9 +144,19 @@ class BFloat16Propagation : public HloPassInterface { void AdjustCalledComputationRoot(HloInstruction* hlo); // *************************** - // Removes no-op conversions (same source and target shapes) that can be - // produced this pass. - Status RemoveNoopConversions(HloModule* module); + // Functions called after changes in changes_to_bf16_ are applied. + + // Resolves inconsistencies introduced by this pass for fusions with + // tuple-type output. + Status ResolveInconsistentFusions(HloModule* module); + + // Converts the literals in kConstant HLOs which have their types changed to + // BF16 by this pass. + Status ResolveConvertedConstants(HloModule* module); + + // Skips no-op conversions (same source and target shapes) that can be + // produced this pass, i.e., replaces them in their uses with their operands. + Status SkipNoopConversions(HloModule* module); // *************************** // Functions called and state used by two or more passes. @@ -146,6 +166,23 @@ class BFloat16Propagation : public HloPassInterface { bool AllUsersConsumeBF16(const HloInstruction& hlo, const ShapeIndex& index) const; + // The output element type of the HLO at the given shape index after changes + // in changes_to_bf16_ are applied. + PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, + const ShapeIndex& index) const; + + // The element type of the HLO value after changes in changes_to_bf16_ are + // applied. + PrimitiveType ValueTypeAfterChange(const HloValue* value) const; + + // If target_type == BF16, adds the HLO at the given index to + // changes_to_bf16_; otherwise, target_type must be F32 and this function + // removes the HLO at the given index from changes_to_bf16_ if it was earlier + // added. + void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, + const ShapeIndex& index, + PrimitiveType target_type); + // The set of F32 HLO values that must be kept in F32. tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; @@ -153,10 +190,28 @@ class BFloat16Propagation : public HloPassInterface { // module. Populated at the beginning of this pass. tensorflow::gtl::FlatMap caller_counts_; + // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which + // are subject to further adjustment, then finally applied to the HLOs. This + // avoids setting changed_ to true but all changes are reverted during + // adjustment. + struct IndexHasher { + int64 operator()(const ShapeIndex& index) const { + int64 hash = 0; + for (int64 i : index) { + hash = tensorflow::Hash64Combine(hash, std::hash()(i)); + } + return hash; + } + }; + tensorflow::gtl::FlatMap> + changes_to_bf16_; + + // Whether the last processed HLO module has been changed by this pass. + bool changed_ = false; + const BFloat16Support* bfloat16_support_; std::unique_ptr dataflow_; - - bool changed_ = false; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 88f83014164..5e1499ee6b6 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) { EXPECT_TRUE(OutputsBF16(dot->operand(1))); EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant); EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(0)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))); - LiteralTestUtil::ExpectEqual( + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)))); + EXPECT_TRUE(LiteralTestUtil::Equal( dot->operand(1)->literal(), - *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))); + *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)))); } // Tests that BF16 can be propagated through nested tuples. @@ -323,6 +323,37 @@ TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { EXPECT_TRUE(OutputsBF16(b_f1)); } +// Tests that changes to BF16 that cannot be propagated outside a fusion are +// discarded. +TEST_F(BFloat16PropagationTest, DiscardFusionInternalBF16Changes) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* dot_f = builder_f.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add_f, add_f)); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + dot_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, comp_f)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + EXPECT_EQ(computation->root_instruction(), fusion); +} + // Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion // outputs are only used by a dot, and 3) one element of the tuple is used by // an add in the fusion computation, then the propagation pass should create a @@ -426,8 +457,62 @@ TEST_F(BFloat16PropagationTest, SelectOverTuples) { EXPECT_TRUE(OutputsBF16(xpose)); } -// Tests that BF16 is propagated properly through while computations. -TEST_F(BFloat16PropagationTest, PropagateThroughWhile) { +// Tests that BF16 is propagated properly through a while computation with +// non-tuple input/output. +TEST_F(BFloat16PropagationTest, PropagateThroughSimpleWhile) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + auto builder_cond = HloComputation::Builder("cond"); + auto cond_param = builder_cond.AddInstruction( + HloInstruction::CreateParameter(0, shape, "cond_param")); + auto cond_dot = builder_cond.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, cond_param, cond_param)); + auto cond_root = builder_cond.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {0, 0}, {1, 1}, {1, 1})), + builder_cond.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {}), cond_dot, {1, 1}, {2, 2}, {1, 1})))); + auto cond = module->AddEmbeddedComputation(builder_cond.Build()); + + auto builder_body = HloComputation::Builder("body"); + auto body_param = builder_body.AddInstruction( + HloInstruction::CreateParameter(0, shape, "body_param")); + auto body_dot = builder_body.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, body_param, body_param)); + auto body = module->AddEmbeddedComputation(builder_body.Build()); + + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(shape, cond, body, add)); + + auto dot = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, while_hlo, while_hlo)); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE( + ShapeUtil::Equal(cond_root->shape(), ShapeUtil::MakeShape(PRED, {}))); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(body_dot)); + EXPECT_TRUE(OutputsBF16(body_param)); + EXPECT_TRUE(OutputsBF16(cond_param)); + EXPECT_FALSE(OutputsBF16(dot)); +} + +// Tests that BF16 is propagated properly through while computations with +// tuple-shaped input/output. +TEST_F(BFloat16PropagationTest, PropagateThroughTupleWhile) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index dbe45e932cd..c0b8bf90392 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -292,112 +293,6 @@ BufferAllocationProto BufferAllocation::ToProto() const { return proto; } -std::pair> -BufferAllocation::ComputePeakMemoryLogicalBuffers() const { - if (HeapTraces().empty()) { - // Just return the largest LogicalBuffer in the allocation. - const LogicalBuffer* largest_buffer = nullptr; - int64 largest_size = 0; - for (const auto& pair : assigned_buffers()) { - const LogicalBuffer* buffer = pair.first; - int64 size = pair.second.size; - if (largest_buffer == nullptr) { - largest_buffer = buffer; - largest_size = size; - continue; - } - // Tie-break with LogicalBuffer::Id so the return value is stable relative - // to changing addresses. - if (size > largest_size || - ((size == largest_size) && (largest_buffer->id() > buffer->id()))) { - largest_buffer = buffer; - largest_size = size; - } - } - CHECK(largest_buffer != nullptr) - << "No logical buffers in allocation: " << ToString(); - return {largest_size, {largest_buffer}}; - } - - // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical - // buffers in this allocation. - tensorflow::gtl::FlatMap - id_to_buffer; - tensorflow::gtl::FlatMap buffer_sizes; - for (const auto& pair : assigned_buffers()) { - const LogicalBuffer* buffer = pair.first; - const OffsetSize& offset_size = pair.second; - id_to_buffer[buffer->id()] = buffer; - buffer_sizes[buffer] = offset_size.size; - } - - // Returns how much the given event increases the total size of live - // buffers. Can be negative. - auto memory_delta = [this, &id_to_buffer, &buffer_sizes]( - const HeapSimulatorTrace::Event& event) -> int64 { - const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); - const int64 buffer_size = buffer_sizes.at(buffer); - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - return buffer_size; - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - // Sharing a buffer does not change the live set size for the purposes of - // the heap simulator. Even though the shared-with buffer may be smaller, - // the entire allocation remains live. - return 0; - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - return -1 * buffer_size; - } - LOG(FATAL) << "Unknown event kind: " << event.kind(); - }; - - int64 total_max_live_size = 0; - std::vector live_buffers_vector; - for (const HeapSimulatorTrace& heap_trace : HeapTraces()) { - // First compute the size of the maximal live set. - int64 max_live_size = 0; - int64 live_size = 0; - for (const auto& event : heap_trace.events()) { - live_size += memory_delta(event); - if (max_live_size < live_size) { - max_live_size = live_size; - } - } - - // Next gather the set of logical buffers live at the earliest point of - // maximal live set size. - tensorflow::gtl::FlatSet live_buffers; - live_size = 0; - for (const auto& event : heap_trace.events()) { - const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); - if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { - InsertOrDie(&live_buffers, buffer); - } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { - // Nothing to do. - } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { - CHECK(ContainsKey(live_buffers, buffer)); - live_buffers.erase(buffer); - } - - live_size += memory_delta(event); - if (live_size == max_live_size) { - break; - } - } - CHECK_EQ(live_size, max_live_size); - total_max_live_size += max_live_size; - - live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), - live_buffers.end()); - } - - // Stabily sort the live buffers. - std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - return {total_max_live_size, live_buffers_vector}; -} - string BufferAllocation::ToString() const { string output; Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); @@ -610,6 +505,7 @@ BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer, BufferAllocation* allocation = NewEmptyAllocation(size, is_thread_local, is_reusable, buffer.color()); AddAssignment(allocation, buffer, /*offset=*/0, size); + allocation->peak_buffers_.push_back(&buffer); return allocation; } @@ -680,6 +576,10 @@ void BufferAssignment::CombineTempAllocations() { CHECK_EQ(temp_allocation.HeapTraces().size(), 1); combined_allocation->AddHeapTrace(temp_allocation.HeapTraces().front()); } + combined_allocation->peak_buffers_.insert( + combined_allocation->peak_buffers_.end(), + temp_allocation.peak_buffers_.begin(), + temp_allocation.peak_buffers_.end()); } // Replace all existing temporary allocations with the new combined // allocations. @@ -800,7 +700,7 @@ BufferAssignmentProto BufferAssignment::ToProto() const { BufferAssignmentProto::BufferAlias* proto_alias = proto.add_buffer_aliases(); LogicalBufferProto::Location proto_alias_location = - LogicalBuffer::ToLocationProto(*alias.instruction(), alias.index()); + BufferValue::ToLocationProto(*alias.instruction(), alias.index()); proto_alias->set_source_buffer_id(buffer.id()); proto_alias->mutable_location()->Swap(&proto_alias_location); } @@ -1184,7 +1084,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1212,7 +1114,9 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); HeapSimulator::Options options; - options.buffers_to_assign = &single_colored_set.second; + BufferValueFlatSet buffer_value_set = + ToBufferValueFlatSet(single_colored_set.second); + options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( @@ -1228,6 +1132,89 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( return Status::OK(); } +namespace { + +// Computes and returns the set of logical buffers live at the point of maximal +// liveness in the given heap trace. LogicalBuffers are (stabily) sorted by id. +std::vector ComputePeakMemoryLogicalBuffers( + const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { + // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical + // buffers in this allocation. + tensorflow::gtl::FlatMap + id_to_buffer; + tensorflow::gtl::FlatMap buffer_sizes; + for (const auto& pair : allocation.assigned_buffers()) { + const LogicalBuffer* buffer = pair.first; + const BufferAllocation::OffsetSize& offset_size = pair.second; + id_to_buffer[buffer->id()] = buffer; + buffer_sizes[buffer] = offset_size.size; + } + + // Returns how much the given event increases the total size of live + // buffers. Can be negative. + auto memory_delta = [&id_to_buffer, &buffer_sizes]( + const HeapSimulatorTrace::Event& event) -> int64 { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + const int64 buffer_size = buffer_sizes.at(buffer); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + return buffer_size; + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Sharing a buffer does not change the live set size for the purposes of + // the heap simulator. Even though the shared-with buffer may be smaller, + // the entire allocation remains live. + return 0; + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + return -1 * buffer_size; + } + LOG(FATAL) << "Unknown event kind: " << event.kind(); + }; + + // First compute the size of the maximal live set. + int64 max_live_size = 0; + int64 live_size = 0; + for (const auto& event : heap_trace.events()) { + live_size += memory_delta(event); + if (max_live_size < live_size) { + max_live_size = live_size; + } + } + + // Next gather the set of logical buffers live at the earliest point of + // maximal live set size. + tensorflow::gtl::FlatSet live_buffers; + live_size = 0; + for (const auto& event : heap_trace.events()) { + const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); + if (event.kind() == HeapSimulatorTrace::Event::ALLOC) { + InsertOrDie(&live_buffers, buffer); + } else if (event.kind() == HeapSimulatorTrace::Event::SHARE_WITH) { + // Nothing to do. + } else if (event.kind() == HeapSimulatorTrace::Event::FREE) { + CHECK(ContainsKey(live_buffers, buffer)); + live_buffers.erase(buffer); + } + + live_size += memory_delta(event); + if (live_size == max_live_size) { + break; + } + } + CHECK_EQ(live_size, max_live_size); + + std::vector live_buffers_vector; + live_buffers_vector.insert(live_buffers_vector.end(), live_buffers.begin(), + live_buffers.end()); + + // Stabily sort the live buffers. + std::sort(live_buffers_vector.begin(), live_buffers_vector.end(), + [](const LogicalBuffer* a, const LogicalBuffer* b) { + return a->id() < b->id(); + }); + return live_buffers_vector; +} + +} // namespace + void BufferAssigner::AssignBuffersFromHeapSimulator( const HeapSimulator::Result& result, BufferAssignment* assignment, LogicalBuffer::Color color) { @@ -1242,10 +1229,15 @@ void BufferAssigner::AssignBuffersFromHeapSimulator( BufferAllocation* allocation = assignment->NewEmptyAllocation( result.heap_size, /*is_thread_local=*/false, /*is_reusable=*/true, color); for (const auto& buffer_chunk : result.chunk_map) { - const LogicalBuffer& buffer = *buffer_chunk.first; + // TODO(lauj) Remove this down_cast after downstream users of + // BufferAllocation::assigned_buffers() are updated to use BufferValue. + const LogicalBuffer& buffer = + *CHECK_NOTNULL(dynamic_cast(buffer_chunk.first)); const HeapSimulator::Chunk& chunk = buffer_chunk.second; assignment->AddAssignment(allocation, buffer, chunk.offset, chunk.size); } + allocation->peak_buffers_ = + ComputePeakMemoryLogicalBuffers(*allocation, result.debug_trace); VLOG(1) << "Ran heap simulation for allocation: " << allocation->ToString(); allocation->AddHeapTrace(result.debug_trace); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 3086d0e2ca0..ad0b0bf7c25 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -206,17 +206,15 @@ class BufferAllocation { return heap_traces_; } - // Compute and return the LogicalBuffers which are live at the point of peak - // memory usage for the given allocation. The point of peak memory usage is - // the point at which the total size of all live logical buffers is - // maximal. If peak memory is reached at multiple points, the set of logical - // buffers live at the earliest maximal point is returned. The vector is - // stabily asserted by LogicalBuffer::Index. - // - // The return value is a pair of total size of the logical buffers at peak, - // and the buffers themselves. - std::pair> - ComputePeakMemoryLogicalBuffers() const; + // Returns the LogicalBuffers which are live at the point of peak memory usage + // for this allocation. The point of peak memory usage is the point at which + // the total size of all live logical buffers is maximal. If peak memory is + // reached at multiple points, the set of logical buffers live at the earliest + // maximal point is returned. The vector is stabily sorted by + // LogicalBuffer::Index. + const std::vector& PeakMemoryLogicalBuffers() const { + return peak_buffers_; + } // Get the number of bytes lost to fragmentation. This is equal to the // difference between the size of the allocation and the size of the maximal @@ -291,6 +289,9 @@ class BufferAllocation { int64 fragmentation_bytes_ = 0; std::vector heap_traces_; + + // Set of buffers live at the point of peak memory usage for this allocation. + std::vector peak_buffers_; }; // Add stream operators for nicer output of CHECK/RET_CHECK failures. @@ -414,10 +415,10 @@ class BufferAssignment { // Only BufferAssigner can build or modify BufferAssignments. friend class BufferAssigner; - explicit BufferAssignment(const HloModule* module, - std::unique_ptr liveness, - LogicalBuffer::SizeFunction buffer_size, - LogicalBuffer::AlignmentFunction color_alignment) + BufferAssignment(const HloModule* module, + std::unique_ptr liveness, + LogicalBuffer::SizeFunction buffer_size, + LogicalBuffer::AlignmentFunction color_alignment) : module_(module), liveness_(std::move(liveness)), buffer_size_(std::move(buffer_size)), diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 513a8785bbd..a4fb0eefaca 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -1519,12 +1520,8 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // single logical buffer should be exactly the logical buffer in that // allocation. const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul); - int64 peak_size; - std::vector peak_buffers; - - std::tie(peak_size, peak_buffers) = - mul_buffer.ComputePeakMemoryLogicalBuffers(); - EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(f32vec100_)); + const std::vector& peak_buffers = + mul_buffer.PeakMemoryLogicalBuffers(); ASSERT_EQ(peak_buffers.size(), 1); EXPECT_EQ(peak_buffers[0]->instruction(), mul); } @@ -1555,6 +1552,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { HloInstruction::CreateConcatenate(concat_shape, {rev, neg}, 0)); // Make the root tiny so no interior nodes can share its buffer. auto root = builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), concat, {0}, {1}, {1})); auto module = CreateNewModule(); @@ -1569,12 +1567,10 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { EXPECT_TRUE(buffer.IsPreallocatedTempBuffer()); ASSERT_EQ(buffer.assigned_buffers().size(), 4); - int64 peak_size; - std::vector peak_buffers; - std::tie(peak_size, peak_buffers) = buffer.ComputePeakMemoryLogicalBuffers(); + const std::vector& peak_buffers = + buffer.PeakMemoryLogicalBuffers(); // The peak live set should be concat and its inputs. - EXPECT_EQ(peak_size, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {400}))); ASSERT_EQ(peak_buffers.size(), 3); std::vector peak_instructions; for (const LogicalBuffer* logical_buffer : peak_buffers) { @@ -1583,6 +1579,69 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); } +TEST_F(BufferAssignmentTest, PeakBuffersWhile) { + auto module = CreateNewModule(); + const Shape shape = ShapeUtil::MakeShape(F32, {123, 123}); + HloComputation* condition; + { + auto b = HloComputation::Builder(TestName() + ".cond"); + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + condition = module->AddEmbeddedComputation(b.Build()); + } + HloComputation* body; + { + auto b = HloComputation::Builder(TestName() + ".body"); + auto param = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + b.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, param)); + body = module->AddEmbeddedComputation(b.Build()); + } + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param)); + auto while_op = builder.AddInstruction( + HloInstruction::CreateWhile(shape, condition, body, copy)); + // This broadcast should get a temporary allocation which is merged with the + // allocation for the while. Peak buffers should include the while and the + // broadcast. + auto bcast = builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {123, 123, 123}), while_op, {0, 1})); + builder.AddInstruction(HloInstruction::CreateReverse( + ShapeUtil::MakeShape(F32, {123, 123, 123}), bcast, {0})); + module->AddEntryComputation(builder.Build()); + + auto buffers = RunBufferAssignment(module.get()); + const BufferAllocation& buffer = GetTopLevelAllocation(*buffers, bcast); + const std::vector& peak_buffers = + buffer.PeakMemoryLogicalBuffers(); + ASSERT_EQ(peak_buffers.size(), 2); + + // The peak buffers should include the broadcast and one of the colocated + // buffers of the while (body param, condition param, body root, or the while + // itself). + const LogicalBuffer* bcast_buffer; + const LogicalBuffer* nonbcast_buffer; + if (peak_buffers[0]->instruction() == bcast) { + bcast_buffer = peak_buffers[0]; + nonbcast_buffer = peak_buffers[1]; + } else { + bcast_buffer = peak_buffers[1]; + nonbcast_buffer = peak_buffers[0]; + } + EXPECT_EQ(bcast_buffer->instruction(), bcast); + EXPECT_TRUE( + nonbcast_buffer->instruction() == copy || + nonbcast_buffer->instruction() == while_op || + nonbcast_buffer->instruction() == body->parameter_instruction(0) || + nonbcast_buffer->instruction() == body->root_instruction() || + nonbcast_buffer->instruction() == condition->parameter_instruction(0)); +} + class WhileBufferAssignmentTest : public HloTestBase { protected: std::unique_ptr BuildWhileConditionComputation( @@ -1626,7 +1685,7 @@ class WhileBufferAssignmentTest : public HloTestBase { .ConsumeValueOrDie(); } - static int64 ByteSizeOf(const LogicalBuffer& buffer) { + static int64 ByteSizeOf(const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), sizeof(void*)); } @@ -1641,7 +1700,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1816,7 +1875,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { }; // Build the entry computation as described in the comment above. - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); @@ -1884,7 +1943,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1929,7 +1988,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1994,7 +2053,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -2073,7 +2132,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { } TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = xla::MakeUnique(TestName()); + auto module = CreateNewModule(); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 37982aaef9e..810d597e730 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +43,7 @@ StatusOr> BufferLiveness::Run( return std::move(liveness); } -tensorflow::Status BufferLiveness::Analyze() { +Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { @@ -71,7 +70,7 @@ tensorflow::Status BufferLiveness::Analyze() { } XLA_VLOG_LINES(3, ToString()); - return tensorflow::Status::OK(); + return Status::OK(); } string BufferLiveness::ToString() const { @@ -105,8 +104,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, - points_to_analysis())) { + if (points_to_analysis().DoesNotUseOperandBuffer(alias.instruction(), + alias.index(), user)) { continue; } if (user != b.instruction() && @@ -132,9 +131,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // the qualifications specified in CanShareOperandBufferWithUser. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && - !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), - b.instruction(), b.index(), - points_to_analysis())) { + !points_to_analysis().CanShareOperandBufferWithUser( + alias.instruction(), alias.index(), b.instruction(), b.index())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 11834a5127e..cdd3cf4032e 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -89,7 +89,7 @@ class BufferLiveness { // Perform buffer liveness analysis. This method must be called prior to // MayInterfere or MaybeLiveOut. - tensorflow::Status Analyze(); + Status Analyze(); // Returns true if the live range of the buffer of 'a' is strictly before the // live range of the buffer of 'b' (they do not overlap). diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc new file mode 100644 index 00000000000..2bc556a9e27 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/buffer_value.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +BufferValue::BufferValue(HloInstruction* instruction, const ShapeIndex& index, + Id id) + : id_(id) { + const Shape& shape = ShapeUtil::GetSubshape(instruction->shape(), index); + is_array_ = ShapeUtil::IsArray(shape); + is_tuple_ = ShapeUtil::IsTuple(shape); +} + +BufferValue::~BufferValue() {} + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer) { + out << buffer.ToString(); + return out; +} + +/*static*/ LogicalBufferProto::Location BufferValue::ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index) { + LogicalBufferProto::Location proto; + proto.set_computation_name(instruction.parent()->name()); + proto.set_instruction_name(instruction.name()); + for (const int64 index_entry : index) { + proto.add_shape_index(index_entry); + } + return proto; +} + +LogicalBufferProto BufferValue::ToProto(const SizeFunction& size_fn) const { + LogicalBufferProto proto; + proto.set_id(id()); + proto.set_size(size_fn(*this)); + LogicalBufferProto::Location proto_location = + ToLocationProto(*instruction(), index()); + proto.mutable_defined_at()->Swap(&proto_location); + if (has_color()) { + proto.set_color(color().value()); + } + return proto; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_value.h b/tensorflow/compiler/xla/service/buffer_value.h new file mode 100644 index 00000000000..f4be16e0843 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value.h @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/int_type.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Abstract class describing a value used by one of the dataflow analyses - +// TuplePointsToAnalysis or HloDataflowAnalysis. +// TODO(b/78906445) Delete this class when TuplePointsToAnalysis is unused. +// +// XLA arrays are trivially a single BufferValue. Tuples are made up of more +// than one BufferValue: an BufferValue for the pointer vector, and an +// BufferValue for each child element. +// +// Every BufferValue is defined by a particular instruction and most +// instructions define only a single BufferValue. Instructions which define a +// single BufferValue include array-shaped instructions such as Add but also +// includes Tuple-shaped instructions such as Tuple. The Tuple instruction +// defines a single BufferValue which is a vector of pointers to the values +// containing the Tuple instruction's operands. Though the result of the Tuple +// instruction includes multiple values only the top-level BufferValue (the +// vector of pointers) is defined by the Tuple instruction. The values +// containing the tuple elements are defined by earlier instructions, usually +// the operands of the Tuple instruction. +// +// Instructions which construct both the tuple *and* the tuple elements define +// more than one BufferValue. This includes (at least) tuple-shaped Constant, +// Parameter, Infeed and While instructions. These tuple-shaped instructions do +// not assemble a tuple from existing BufferValues like the Tuple instruction +// does, but rather define all the BufferValues in the tuple. +// +// Some instructions, such as Bitcast, define no buffers. These instructions +// simply forward buffers from their operands. +// +// The BufferValue object describes which HLO instruction defines a buffer and +// where within that instruction's output shape the buffer is defined. The +// location within the output shape is indicated by BufferValue::index() which +// is defined identically to the index used in ShapeUtil::GetSubshape(). +// Examples: +// +// %add = Add(%foo, %bar) +// %tuple_constant = Constant({1, {42, 43}}) +// +// %add defines a single array-shaped buffer BufferValue(%add, {}) which holds +// the array result of the add operation. The nested-tuple-shaped +// %tuple_constant defines 5 buffers described by the following BufferValue +// objects: +// +// BufferValue(%tuple_constant, {}) // "Top-level" buffer: vector of +// // pointers to BufferValues at +// // indices {0} and {1} +// BufferValue(%tuple_constant, {0}) // Holds value "1" +// BufferValue(%tuple_constant, {1}) // Holds nested tuple: vector of +// // pointers to BufferValues at +// // indices {1, 0} and {1, 1} +// BufferValue(%tuple_constant, {1, 0}) // Holds value "42" +// BufferValue(%tuple_constant, {1, 1}) // Holds value "43" + +class BufferValue { + public: + TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); + + // Id is a unique identifier for the BufferValue to facilitate efficient + // collections of BufferValues with stable iteration order. + using Id = int64; + + // Functions which return the size and alignment of a logical buffer in bytes. + using SizeFunction = std::function; + using AlignmentFunction = std::function; + + virtual ~BufferValue(); + + Id id() const { return id_; } + + // Return the instruction that defines the buffer. + virtual HloInstruction* instruction() const = 0; + + // Return the index within the output of the instruction where the buffer is + // defined. Index used defined as in ShapeUtil::GetSubshape() + virtual const ShapeIndex& index() const = 0; + + // Return the color of the BufferValue. Differently colored buffers can not be + // parts of the same allocation. + Color color() const { + CHECK_NE(color_, kInvalidColor) + << "Should not query the color of a buffer that was never colored"; + return color_; + } + + void set_color(Color color) { + CHECK_NE(color, kInvalidColor) + << "Should not set the color of a buffer to the invalid color"; + color_ = color; + } + + bool has_color() const { return color_ != kInvalidColor; } + + // Return the shape of the buffer. This reference points into the shape field + // of the instruction defining the buffer. Therefore, the returned shape will + // contain the layout of instruction, if any. + virtual const Shape& shape() const = 0; + + // Returns true if this buffer is the top-level output buffer of the defining + // HLO instruction. This is equivalent to index == {}. + bool IsTopLevel() const { return index().empty(); } + + // Whether this buffer contains a tuple. + bool IsTuple() const { return is_tuple_; } + + // Whether this buffer contains an array. + bool IsArray() const { return is_array_; } + + // operator< is required for std::set. + bool operator<(const BufferValue& other) const { return id_ < other.id_; } + + virtual string ToString() const = 0; + + // TODO(lauj) rename LogicalBufferProto to BufferValueProto. + LogicalBufferProto ToProto(const SizeFunction& size_fn) const; + + // Returns the LogicalBufferProto::Location that serializes the given + // instruction and index. + static LogicalBufferProto::Location ToLocationProto( + const HloInstruction& instruction, const ShapeIndex& index); + + const Color kInvalidColor = Color(-1); + + protected: + BufferValue(HloInstruction* instruction, const ShapeIndex& index, Id id); + + private: + // The definining instruction and index are not stored here; they can be found + // in the LogicalBuffer and HloValue subclasses. This class exists only to + // support migrations from TuplePointsToAnalysis to HloDataflowAnalysis, by + // allowing abstract use of LogicalBuffer or HloValue. After those migrations + // are complete, this class should be deleted (b/78906445). Because we plan to + // delete LogicalBuffer and this class, we don't refactor all the shared + // features from LogicalBuffer and HloValue into this class. + Id id_ : 62; + bool is_array_ : 1; + bool is_tuple_ : 1; + Color color_ = kInvalidColor; +}; + +std::ostream& operator<<(std::ostream& out, const BufferValue& buffer); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_H_ diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h new file mode 100644 index 00000000000..305914fca82 --- /dev/null +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ + +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/core/lib/gtl/compactptrset.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Define various containers of BufferValues, and utilities to convert from +// containers of LogicalBuffers to containers of BufferValues. + +using BufferValueCompactPointerSet = + tensorflow::gtl::CompactPointerSet; +template +BufferValueCompactPointerSet ToBufferValueCompactPointerSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueCompactPointerSet output; + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +using BufferValueFlatSet = tensorflow::gtl::FlatSet; +template +BufferValueFlatSet ToBufferValueFlatSet( + const LogicalBufferContainerT& logical_buffer_container) { + BufferValueFlatSet output; + output.reserve(logical_buffer_container.size()); + for (const LogicalBuffer* buffer : logical_buffer_container) { + output.insert(buffer); + } + return output; +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index c83da9eddc8..d39fd7307ae 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -37,7 +37,7 @@ limitations under the License. namespace xla { /* static */ StatusOr> -CompileOnlyService::NewService(perftools::gputools::Platform* platform) { +CompileOnlyService::NewService(se::Platform* platform) { ServiceOptions default_options; default_options.set_platform(platform); return NewService(default_options); @@ -45,7 +45,7 @@ CompileOnlyService::NewService(perftools::gputools::Platform* platform) { /* static */ StatusOr> CompileOnlyService::NewService(const ServiceOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } @@ -61,6 +61,49 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, Compiler* compiler) : Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {} +StatusOr>> +CompileOnlyService::CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options) { + std::vector> hlo_modules; + for (const AotXlaComputationInstance& instance : computations) { + TF_RET_CHECK(instance.computation.has_program_shape()); + + const DebugOptions& debug_options = options.debug_options(); + + // Dump computation proto if flag is set. + const string& directory_path = debug_options.xla_dump_computations_to(); + if (!directory_path.empty()) { + HloSnapshot hlo_snapshot; + *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; + string filename = tensorflow::strings::StrCat( + "computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); + const string& per_host_path = tensorflow::io::JoinPath( + directory_path, tensorflow::port::Hostname()); + + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); + } + + const auto& program_shape = instance.computation.program_shape(); + ExecutionOptions execution_options; + *execution_options.mutable_debug_options() = debug_options; + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(program_shape, instance.argument_layouts, + &execution_options)); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr hlo_module, + HloModule::CreateFromProto(instance.computation, *module_config)); + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); + hlo_modules.push_back(std::move(hlo_module)); + } + + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); +} + StatusOr>> CompileOnlyService::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 9859941c6c1..7f2ce0e8974 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -34,7 +34,7 @@ class CompileOnlyService : public Service { // platform that the service should target. If platform is null then the // default platform is used. static StatusOr> NewService( - perftools::gputools::Platform* platform); + se::Platform* platform); static StatusOr> NewService( const ServiceOptions& options); @@ -53,51 +53,64 @@ class CompileOnlyService : public Service { const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& Options); + // A description of a xla computation to compile using CompileAheadOfTime. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + struct AotXlaComputationInstance { + HloModuleProto computation; + std::vector argument_layouts; + const Shape* result_layout = nullptr; + }; + + // Compiles a list of xla computations for ahead-of-time execution. This is + // intended for use in static compilation. See + // |CompileOnlyClient::CompileAheadOfTime| for additional details. + // + // TODO(b/74197823): This is a part of a NOT YET ready refactor. + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options); + // Override Service methods that require or imply the existence of an // execute backend. Note that this does not include TransferToClient, as // computing constants produces global data that we may wish to transfer. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override { + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override { + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override { + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); } - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override { + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override { + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override { + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override { return Unimplemented( "CompileOnlyService does not support device data transfers."); } - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override { + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 0392d4af48a..8b01a6c4b50 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -23,26 +23,21 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" -namespace se = ::perftools::gputools; - namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); -/* static */ std::map* +/* static */ std::map* Compiler::GetPlatformCompilerFactories() { - static auto* r = - new std::map; + static auto* r = new std::map; return r; } /* static */ -std::map>* +std::map>* Compiler::GetPlatformCompilers() { - static auto* r = new std::map>; + static auto* r = new std::map>; return r; } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index b4b53ae2ed4..a4b59d1ba9b 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -70,7 +71,7 @@ class AotCompilationOptions { virtual ~AotCompilationOptions() = default; // Returns the ID of the platform to which these options apply. - virtual perftools::gputools::Platform::Id PlatformId() const = 0; + virtual se::Platform::Id PlatformId() const = 0; // Optional allocator that may be used for allocating temp space on the device // during compilation. @@ -109,7 +110,7 @@ class Compiler { virtual ~Compiler() {} // Returns the ID of the platform that this compiler targets. - virtual perftools::gputools::Platform::Id PlatformId() const = 0; + virtual se::Platform::Id PlatformId() const = 0; // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. @@ -120,8 +121,7 @@ class Compiler { // algorithm over those buffers, to see which variant is fastest. Any space // allocated should be deallocated before this function returns. virtual StatusOr> RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* executor, + std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for execution on a device given by the executor, @@ -137,8 +137,7 @@ class Compiler { // // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* executor, + std::unique_ptr module, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) = 0; // Compiles a set of HLO modules that can run in parallel, potentially @@ -151,8 +150,7 @@ class Compiler { // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, - std::vector> - stream_exec, + std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for @@ -171,14 +169,12 @@ class Compiler { // be a singleton, so no ownership is transferred. // // Precondition: a platform kind must not be registered more than once. - static void RegisterCompilerFactory( - perftools::gputools::Platform::Id platform_id, - CompilerFactory compiler_factory); + static void RegisterCompilerFactory(se::Platform::Id platform_id, + CompilerFactory compiler_factory); // Returns the compiler singleton pointer if it is available for the given // platform, or an error status if it is not. - static StatusOr GetForPlatform( - const perftools::gputools::Platform* platform); + static StatusOr GetForPlatform(const se::Platform* platform); // Returns a function that computes the size in bytes of the logical // buffer that contains a shape. @@ -186,9 +182,9 @@ class Compiler { // Returns a function that computes the size in bytes of a given // logical buffer. - std::function BufferSizeBytesFunction() { + std::function BufferSizeBytesFunction() { HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); - return [shape_size](const LogicalBuffer& buffer) { + return [shape_size](const BufferValue& buffer) { return shape_size(buffer.shape()); }; } @@ -198,12 +194,12 @@ class Compiler { static tensorflow::mutex platform_compiler_mutex_; // Map from platform kind to compiler factory. - static std::map* + static std::map* GetPlatformCompilerFactories(); // Map from platform kind to compiler instance, if we made one already (based // on the factories above). - static std::map>* + static std::map>* GetPlatformCompilers(); }; diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index d2d4f14fcec..cb61f3da39f 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -23,12 +23,15 @@ limitations under the License. namespace xla { -ComputationLayout::ComputationLayout(const ProgramShape& program_shape) +ComputationLayout::ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts) : result_layout_(program_shape.result()) { for (auto& shape : program_shape.parameters()) { parameter_layouts_.emplace_back(shape); } - SetToDefaultLayout(); + if (ignore_layouts) { + SetToDefaultLayout(); + } } void ComputationLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 80e102411c7..53c3a3f7b73 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -34,8 +34,9 @@ class ComputationLayout { public: // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the - // ProgramShape are ignored. - explicit ComputationLayout(const ProgramShape& program_shape); + // ProgramShape are ignored if ignore_layouts is true. + explicit ComputationLayout(const ProgramShape& program_shape, + bool ignore_layouts = true); // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 657fba6b623..7c1bacff92b 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -32,8 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { @@ -132,11 +130,9 @@ StatusOr ComputationPlacer::AssignDevices( ComputationPlacer::platform_computation_placer_mutex_( tensorflow::LINKER_INITIALIZED); -/* static */ std::map* +/* static */ std::map* ComputationPlacer::GetPlatformComputationPlacers() { - static auto* r = - new std::map; + static auto* r = new std::map; return r; } @@ -147,10 +143,10 @@ static std::unique_ptr CreateComputationPlacer() { } static bool InitModule() { - xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, - &CreateComputationPlacer); - xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, - &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::host::kHostPlatformId, &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 737ccabaa7a..737d00e93ec 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -80,13 +80,13 @@ class ComputationPlacer { // Registers a computation placer creation function for a particular platform. static void RegisterComputationPlacer( - perftools::gputools::Platform::Id platform_id, + se::Platform::Id platform_id, ComputationPlacerCreationFunction creation_function); // Returns the computation placer singleton pointer if it is available for the // given platform, or an error status if it is not. static StatusOr GetForPlatform( - const perftools::gputools::Platform* platform); + const se::Platform* platform); private: // The mutex that guards the platform-to-computation placer map. @@ -101,10 +101,9 @@ class ComputationPlacer { }; // Map from platform kind to computation placer singleton. - static std::map* - GetPlatformComputationPlacers(); + static std::map* GetPlatformComputationPlacers(); - perftools::gputools::Platform::Id platform_id_; + se::Platform::Id platform_id_; TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); }; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index f35de080853..e9ec796121f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -35,7 +35,7 @@ namespace xla { // Tries to replace a conditional with a call operation of the corresponding // computation. If the given conditional has a constant predicate, tries to -// replace it with a call to its true/false computation as appropirate and then +// replace it with a call to its true/false computation as appropriate and then // inline that computation. // // Returns true if it made a change to the graph. @@ -69,7 +69,7 @@ static StatusOr TryRemoveConditional(HloInstruction* conditional) { conditional->shape(), {conditional->mutable_operand(2)}, conditional->false_computation())); } - + conditional->SetupDerivedInstruction(call_op); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(conditional, call_op)); TF_RETURN_IF_ERROR(CallInliner::Inline(call_op).status()); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 40519ecc799..33d8338809d 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -65,7 +64,7 @@ struct SpecialCaseCopyPolicy { // output tuple. bool copy_root_replicated_buffers = false; // If true, insert a copy if a buffer coming from a constant or a parameter - // is found wihtin the output tuple. + // is found within the output tuple. bool copy_parameters_and_constants = false; }; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 246b8028618..bfd85f257fb 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -89,12 +89,10 @@ cc_library( ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", - ":cpu_parallelization_preparation", ":disassembler", ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/xla:literal_util", @@ -105,6 +103,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batch_dot_simplification", "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -127,12 +126,14 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", @@ -171,11 +172,13 @@ cc_library( ":orc_jit_memory_mapper", ":runtime_fp16", ":runtime_conv2d", + ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", + ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", "@llvm//:execution_engine", "@llvm//:core", @@ -232,35 +235,6 @@ cc_library( ], ) -cc_library( - name = "parallel_cpu_executable", - srcs = ["parallel_cpu_executable.cc"], - hdrs = [ - "parallel_cpu_executable.h", - ], - deps = [ - ":cpu_runtime", - ":shape_partition", - ":simple_orc_jit", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:buffer_assignment", - "//tensorflow/compiler/xla/service:device_memory_allocator", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:logical_buffer", - "//tensorflow/compiler/xla/service:shaped_buffer", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - "@llvm//:orc_jit", - ], -) - cc_library( name = "ir_emitter", srcs = [ @@ -324,6 +298,15 @@ cc_library( ], ) +cc_library( + name = "target_machine_features_fake", + testonly = 1, + hdrs = ["target_machine_features_fake.h"], + deps = [ + ":target_machine_features", + ], +) + cc_library( name = "ir_function", srcs = ["ir_function.cc"], @@ -365,6 +348,7 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":ir_emission_utils", ":target_machine_features", ":vector_support_library", "//tensorflow/compiler/xla:shape_util", @@ -394,10 +378,10 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/core:lib", ], ) @@ -437,7 +421,6 @@ cc_library( "//tensorflow/core:lib", "@llvm//:analysis", "@llvm//:core", - "@llvm//:execution_engine", "@llvm//:ipo", "@llvm//:mc", "@llvm//:object", @@ -501,6 +484,27 @@ cc_library( ], ) +cc_library( + name = "runtime_conv2d_mkl", + srcs = [ + "runtime_conv2d_mkl.cc", + ], + hdrs = ["runtime_conv2d_mkl.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + ":runtime_conv2d", + ":runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + "//tensorflow/core/kernels:eigen_helpers", + "//third_party/eigen3", + ] + if_mkl([ + "@mkl_dnn", + "//third_party/mkl:intel_binary_blob", + ]), +) + cc_library( name = "runtime_fft", srcs = [ @@ -513,7 +517,6 @@ cc_library( deps = [ "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:framework", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -575,6 +578,22 @@ cc_library( ], ) +cc_library( + name = "runtime_single_threaded_fft", + srcs = [ + "runtime_fft_impl.h", + "runtime_single_threaded_fft.cc", + ], + hdrs = ["runtime_single_threaded_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_single_threaded_matmul", srcs = ["runtime_single_threaded_matmul.cc"], @@ -633,6 +652,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -661,31 +681,13 @@ cc_library( ], ) -cc_library( - name = "cpu_parallelization_preparation", - srcs = ["cpu_parallelization_preparation.cc"], - hdrs = [ - "cpu_parallelization_preparation.h", - ], - deps = [ - ":ir_emission_utils", - ":parallel_task_assignment", - ":shape_partition", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_cost_analysis", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/core:lib", - ], -) - cc_library( name = "ir_emission_utils", srcs = ["ir_emission_utils.cc"], hdrs = ["ir_emission_utils.h"], deps = [ ":cpu_runtime", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", @@ -698,6 +700,7 @@ tf_cc_test( srcs = ["ir_emission_utils_test.cc"], deps = [ ":ir_emission_utils", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -716,6 +719,7 @@ cc_library( deps = [ ":dot_op_emitter", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", @@ -729,6 +733,7 @@ tf_cc_test( srcs = ["cpu_layout_assignment_test.cc"], deps = [ ":cpu_layout_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -753,6 +758,7 @@ cc_library( deps = [ ":cpu_runtime", ":ir_emission_utils", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", @@ -767,6 +773,7 @@ tf_cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + ":target_machine_features_fake", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", @@ -805,6 +812,7 @@ cc_library( ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", + ":target_machine_features", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", @@ -817,6 +825,7 @@ tf_cc_test( deps = [ ":cpu_executable", ":parallel_task_assignment", + ":target_machine_features_fake", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -939,3 +948,17 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +tf_cc_test( + name = "cpu_eigen_tensor_alignment_test", + size = "small", + srcs = ["cpu_eigen_tensor_alignment_test.cc"], + deps = [ + ":dot_op_emitter", + ":ir_emission_utils", + ":target_machine_features_fake", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 61b2da7a7dc..6a7eb85e3ba 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -25,11 +25,11 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/MC/MCContext.h" #include "llvm/Object/ObjectFile.h" +#include "llvm/Support/SmallVectorMemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" @@ -158,7 +158,7 @@ std::unique_ptr CompilerFunctor::operator()( // Construct ObjectFile from machine code buffer. return std::unique_ptr( - new llvm::ObjectMemoryBuffer(std::move(stream_buffer))); + new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer))); } static std::vector VectorFunctionsForTargetLibraryInfoImpl() { diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 2136aeb3877..0985b9297fe 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -33,7 +33,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { if (hlo->opcode() == HloOpcode::kConvolution && - !PotentiallyImplementedAsEigenConvolution(*hlo)) { + !PotentiallyImplementedAsEigenConvolution(*hlo, + target_machine_features_)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); auto input_batch_dim = dnums.input_batch_dimension(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 9b2c3d82eb6..e6fd1499edd 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,12 +33,19 @@ namespace cpu { // convolutions can run faster. class ConvCanonicalization : public HloPassInterface { public: + explicit ConvCanonicalization( + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) {} + ~ConvCanonicalization() override {} tensorflow::StringPiece name() const override { return "convolution-canonicalization"; } StatusOr Run(HloModule* module) override; + + private: + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 968f53d5c70..375b017b092 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -89,7 +90,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie()); const HloInstruction* output_reshape = entry_computation->root_instruction(); @@ -146,7 +151,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - ConvCanonicalization conv_canonicalization; + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + ConvCanonicalization conv_canonicalization(&target_machine_features); EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie()); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e43777c5e5e..25b18eff20f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" @@ -56,12 +57,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -83,12 +82,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" @@ -100,8 +101,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -namespace se = ::perftools::gputools; - namespace xla { namespace cpu { @@ -122,10 +121,12 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const { CpuAotCompilationResult::CpuAotCompilationResult( ObjectFileData object_file_data, BufferSizes buffer_sizes, - int64 result_buffer_index) + int64 result_buffer_index, + std::unique_ptr hlo_profile_printer_data) : object_file_data_(std::move(object_file_data)), buffer_sizes_(std::move(buffer_sizes)), - result_buffer_index_(result_buffer_index) {} + result_buffer_index_(result_buffer_index), + hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {} CpuAotCompilationResult::~CpuAotCompilationResult() = default; @@ -175,14 +176,13 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: static StatusOr> GetCandidatesForComputation( - HloComputation* computation, + const HloComputation& computation, const std::unordered_map& assigned_indices) { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx, assigned_indices); - TF_RETURN_IF_ERROR( - computation->Accept(&profile_candidates_for_computation)); + TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } @@ -233,7 +233,10 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(); @@ -250,8 +253,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(&target_machine_features); { auto& pass = pipeline.AddPass>("simplification"); @@ -262,7 +266,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); - pipeline.AddPass(); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -274,16 +277,19 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass(); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); } + pipeline.AddPass(); pipeline.AddPass( - [](const HloInstruction& dot, - const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot) + [&target_machine_features]( + const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return PotentiallyImplementedAsEigenDot(dot, target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -291,12 +297,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass(/*is_layout_sensitive=*/false); pipeline.AddPass(); + pipeline.AddPass(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass( - module->mutable_entry_computation_layout()); + module->mutable_device_entry_computation_layout(), + &target_machine_features); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -310,17 +319,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { module->config().intra_op_parallelism_threads() > 0 ? module->config().intra_op_parallelism_threads() : tensorflow::port::NumSchedulableCPUs(); - if (options::CpuParallelBackendRequested(module->config())) { - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); - } else if (!is_aot_compile) { + if (!is_aot_compile) { // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. // Note this is not run for AOT because it would bring in thread pool // and thread synchronization dependencies which would likely increase // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); + pipeline.AddPass( + max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -331,13 +337,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); - if (options::CpuParallelBackendRequested(module->config())) { - // Re-run the outlining, in case any copies were inserted into the entry - // computation. - pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction()); - pipeline.AddPass(); - } pipeline.AddPass(); return pipeline.Run(module).status(); } @@ -437,16 +436,56 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) { return Status::OK(); } +Status CreateHloProfilingArtifacts( + const HloModule& module, + std::unordered_map* + instruction_to_profile_idx, + std::unordered_map* + computation_to_profile_idx, + std::unique_ptr* hlo_profile_index_map, + std::unique_ptr* hlo_profile_printer_data) { + *hlo_profile_index_map = MakeUnique(module); + const HloComputation& entry_computation = *module.entry_computation(); + + TF_ASSIGN_OR_RETURN( + *instruction_to_profile_idx, + CollectProfileCandidates::GetCandidatesForComputation( + entry_computation, + (*hlo_profile_index_map)->instruction_to_profile_idx())); + + auto shape_size_bytes = [](const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return static_cast(sizeof(void*)); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + + HloCostAnalysis cost_analysis(shape_size_bytes); + TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis)); + *hlo_profile_printer_data = + CreateHloProfilePrinterData(**hlo_profile_index_map, cost_analysis); + *computation_to_profile_idx = + (*hlo_profile_index_map)->computation_to_profile_idx(); + + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* /*stream_exec*/, + std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, DeviceMemoryAllocator* /*device_allocator*/) { VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + std::unique_ptr jit_target_machine = + SimpleOrcJIT::InferTargetMachineForJIT( + CompilerTargetOptions(module->config()), + CodeGenOptLevel(module->config())); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, + jit_target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -454,8 +493,7 @@ StatusOr> CpuCompiler::RunHloPasses( } StatusOr> CpuCompiler::RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* /*device_allocator*/) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; @@ -493,28 +531,9 @@ StatusOr> CpuCompiler::RunBackend( std::unique_ptr hlo_profile_index_map; std::unique_ptr hlo_profile_printer_data; if (module->config().hlo_profiling_enabled()) { - hlo_profile_index_map = MakeUnique(*module); - - TF_ASSIGN_OR_RETURN( - instruction_to_profile_idx, - CollectProfileCandidates::GetCandidatesForComputation( - entry_computation, - hlo_profile_index_map->instruction_to_profile_idx())); - - auto shape_size_bytes = [](const Shape& shape) { - // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { - return static_cast(sizeof(void*)); - } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - }; - - HloCostAnalysis cost_analysis(shape_size_bytes); - TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); - hlo_profile_printer_data = - CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis); - computation_to_profile_idx = - hlo_profile_index_map->computation_to_profile_idx(); + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); } std::unique_ptr cpu_executable; @@ -526,190 +545,82 @@ StatusOr> CpuCompiler::RunBackend( const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); - if (options::CpuParallelBackendRequested(module->config())) { - VLOG(1) << "Using parallel cpu backend"; + // Select an order for emitting the HLO instructions for each + // computation. Using this sequence enables tighter buffer liveness analysis + // and reduced memory usage (as compared to using DependencyHloOrdering). + TF_ASSIGN_OR_RETURN( + SequentialHloOrdering::HloModuleSequence module_sequence, + CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - // DependencyHloOrdering is used for the parallel emitter because the order - // of HLO instruction execution is not known ahead of time. - // DependencyHloOrdering is the most conservative partial order and only - // uses data dependencies for determining order. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), xla::MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); + // Run buffer analysis on the HLO graph. This analysis figures out which + // temporary buffers are required to run the computation. + TF_ASSIGN_OR_RETURN( + std::unique_ptr assignment, + BufferAssigner::Run( + module.get(), + xla::MakeUnique(module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); + // BufferAssignment::ToString() includes a header, so no need for us to + // print one ourselves. + XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); + if (!xla_dump_optimized_hlo_proto_to.empty()) { + HloProto proto = MakeHloProto(*module, *assignment); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_optimized_hlo_proto_to, module->name())); + } + + // Each computation is a single function. Emit all embedded computations + // before the entry computation. The order of computations returned from + // GetEmbeddedComputations guarantees that a called computation occurs + // before a caller computation. + + LLVMTargetMachineFeatures target_machine_features(jit->target_machine()); + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + &target_machine_features, jit->external_constant_pool()); + + for (auto embedded_computation : + entry_computation->MakeEmbeddedComputationsList()) { + if (embedded_computation->IsFusionComputation()) { + continue; } + TF_RETURN_IF_ERROR( + ir_emitter + .EmitComputation(embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &module_sequence.at(embedded_computation)) + .status()); + } + string function_name_prefix = entry_computation->name().empty() + ? "__compute" + : entry_computation->name(); + TF_ASSIGN_OR_RETURN( + llvm::Function * entry_function, + ir_emitter.EmitComputation(entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &module_sequence.at(entry_computation))); - // If we are using the parallel CPU backend, we need to create map from - // HloInstruction to the corresponding generated function name. - std::map parallel_computations; - std::unordered_map> - aligned_constants; - for (auto instruction : entry_computation->MakeInstructionPostOrder()) { - // Parameters and constants don't get their own computation. - if (instruction->opcode() == HloOpcode::kParameter) { - continue; - } - if (instruction->opcode() == HloOpcode::kConstant) { - // Copy the constant out of the ProtocolBuffer so that we can give it a - // higher alignment. - const void* data = instruction->literal().untyped_data(); - int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); - auto iter = aligned_constants.emplace( - instruction, xla::MakeUnique(size)); - CHECK_EQ(iter.second, true); - unsigned char* aligned_data = iter.first->second.get(); - memcpy(aligned_data, data, size); - continue; - } - // The parallel preparation should have ensured that the top-level - // computation consists solely of Call instructions. - TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall) - << module->ToString(); - HloComputation* to_apply = instruction->to_apply(); - parallel_computations.emplace(to_apply, instruction); - } + string function_name = llvm_ir::AsString(entry_function->getName()); + string ir_module_string; + if (embed_ir_in_executable) { + ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); + } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - std::move(instruction_to_profile_idx), - std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); - std::unique_ptr> function_names( - new HloInstructionMap()); - for (auto embedded_computation : - entry_computation->MakeEmbeddedComputationsList()) { - if (embedded_computation->IsFusionComputation()) { - continue; - } - auto parallel_computation_iter = - parallel_computations.find(embedded_computation); - // All parallel computations are considered to be an entry computation for - // IR generation purposes. - bool computation_is_parallel = - parallel_computation_iter != parallel_computations.end(); - TF_ASSIGN_OR_RETURN( - llvm::Function * ir_function, - ir_emitter.EmitComputation( - embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/computation_is_parallel, - /*instruction_order=*/nullptr)); - // If this computation is parallel, remember it in the function name map. - // This way we know what function to execute when we try to run code for - // the Call instruction. - if (computation_is_parallel) { - HloInstruction* call_instruction = parallel_computation_iter->second; - InsertOrDie(function_names.get(), call_instruction, - llvm_ir::AsString(ir_function->getName())); - } - } + // JIT compile the LLVM IR module to in-memory machine code. + jit->AddModule(std::move(llvm_module)); + cpu_executable.reset(new CpuExecutable( + std::move(jit), std::move(assignment), std::move(module), function_name, + std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); - string ir_module_string; - if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); - } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); - - // JIT compile the LLVM IR module to in-memory machine code. - jit->AddModule(std::move(llvm_module)); - cpu_executable.reset(new ParallelCpuExecutable( - std::move(jit), std::move(assignment), std::move(module), - std::move(function_names), std::move(aligned_constants), - std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); - - if (embed_ir_in_executable) { - static_cast(*cpu_executable) - .set_ir_module_string(ir_module_string); - } - } else { - VLOG(1) << "Using sequential cpu backend"; - - // Select an order for emitting the HLO instructions for each - // computation. Using this sequence enables tighter buffer liveness analysis - // and reduced memory usage (as compared to using DependencyHloOrdering). - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); - - // Run buffer analysis on the HLO graph. This analysis figures out which - // temporary buffers are required to run the computation. - TF_ASSIGN_OR_RETURN( - std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - xla::MakeUnique( - module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. - XLA_VLOG_LINES(2, assignment->ToString()); - - if (!xla_dump_optimized_hlo_proto_to.empty()) { - HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_optimized_hlo_proto_to, module->name())); - } - - // Each computation is a single function. Emit all embedded computations - // before the entry computation. The order of computations returned from - // GetEmbeddedComputations guarantees that a called computation occurs - // before a caller computation. - - IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - std::move(instruction_to_profile_idx), - std::move(computation_to_profile_idx), - jit->target_machine(), jit->external_constant_pool()); - - for (auto embedded_computation : - entry_computation->MakeEmbeddedComputationsList()) { - if (embedded_computation->IsFusionComputation()) { - continue; - } - TF_RETURN_IF_ERROR( - ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) - .status()); - } - string function_name_prefix = entry_computation->name().empty() - ? "__compute" - : entry_computation->name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); - - string function_name = llvm_ir::AsString(entry_function->getName()); - string ir_module_string; - if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); - } - TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); - - XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); - - // JIT compile the LLVM IR module to in-memory machine code. - jit->AddModule(std::move(llvm_module)); - cpu_executable.reset(new CpuExecutable( - std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); - - if (embed_ir_in_executable) { - static_cast(*cpu_executable) - .set_ir_module_string(ir_module_string); - } + if (embed_ir_in_executable) { + static_cast(*cpu_executable) + .set_ir_module_string(ir_module_string); } VLOG(1) << "Compilation finished"; @@ -811,7 +722,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + TF_RETURN_IF_ERROR( + RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get())); VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -840,12 +752,22 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_optimized_hlo_proto_to, module->name())); } + std::unordered_map instruction_to_profile_idx; + std::unordered_map computation_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer_data; + + if (module->config().hlo_profiling_enabled()) { + TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts( + *module, &instruction_to_profile_idx, &computation_to_profile_idx, + &hlo_profile_index_map, &hlo_profile_printer_data)); + } + + LLVMTargetMachineFeatures target_machine_features(target_machine.get()); IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*instruction_to_profile_idx=*/ - std::unordered_map{}, - /*computation_to_profile_idx=*/ - std::unordered_map{}, - target_machine.get(), + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), + &target_machine_features, /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : @@ -886,6 +808,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_RETURN_IF_ERROR(verify_status); } + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(llvm_module)); + Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, @@ -919,7 +843,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, results.emplace_back(MakeUnique( std::move(object_file_data), std::move(buffer_sizes), - result_slice.index())); + result_slice.index(), std::move(hlo_profile_printer_data))); } VLOG(1) << "Compilation finished"; @@ -938,9 +862,9 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { } // namespace xla static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() { - return xla::MakeUnique(); - }); + xla::Compiler::RegisterCompilerFactory( + stream_executor::host::kHostPlatformId, + []() { return xla::MakeUnique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 3498139ab95..e56f9f01134 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -53,7 +54,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions { RelocationModel relocation_model); ~CpuAotCompilationOptions() override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; // The triple used for compilation, similar to clang's -target flag. const string& triple() const { return triple_; } @@ -76,10 +77,16 @@ class CpuAotCompilationOptions : public AotCompilationOptions { class CpuAotCompilationResult : public AotCompilationResult { public: - CpuAotCompilationResult(ObjectFileData object_file_data, - BufferSizes buffer_sizes, int64 result_buffer_index); + CpuAotCompilationResult( + ObjectFileData object_file_data, BufferSizes buffer_sizes, + int64 result_buffer_index, + std::unique_ptr hlo_profile_printer_data); ~CpuAotCompilationResult(); + HloProfilePrinterData* hlo_profile_printer_data() const { + return hlo_profile_printer_data_.get(); + } + const ObjectFileData& object_file_data() const { return object_file_data_; } const BufferSizes& buffer_sizes() const { return buffer_sizes_; } int64 result_buffer_index() const { return result_buffer_index_; } @@ -97,6 +104,10 @@ class CpuAotCompilationResult : public AotCompilationResult { // result of the computation. This buffer should be passed into the output // parameter when calling the compiled computation. const int64 result_buffer_index_; + + // Contains an instance of HloProfilePrinterData if HLO profiling is enabled, + // otherwise is nullptr. + std::unique_ptr hlo_profile_printer_data_; }; // CPU-targeting implementation of the XLA Compiler interface. @@ -112,25 +123,23 @@ class CpuCompiler : public LLVMCompiler { // Bring in // StatusOr>> Compile( // std::vector> modules, - // std::vector> + // std::vector> // stream_execs) using LLVMCompiler::Compile; StatusOr> RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; @@ -140,7 +149,8 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module, bool is_aot_compile); + Status RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc new file mode 100644 index 00000000000..d12fa6bb9ad --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { + +// Test that we don't call into Eigen with tensors too small to be aligned +// reliably. + +class CpuEigenTensorAlignmentTest : public ::testing::Test {}; + +TEST_F(CpuEigenTensorAlignmentTest, EigenDotAlignment) { + string hlo_string = R"( +HloModule DotOperation + +ENTRY DotOperation { + arg0 = f32[5,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[5,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloInstruction* dot = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE( + PotentiallyImplementedAsEigenDot(*dot, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenDot( + *dot, target_machine_with_full_alignment)); +} + +TEST_F(CpuEigenTensorAlignmentTest, EigenConvAlignment) { + string hlo_string = R"( +HloModule ConvOperation + +ENTRY ConvOperation { + arg0 = f32[1,2,1] parameter(0) + arg1 = f32[1,1,1] parameter(1) + ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1}, dim_labels=b0f_0io->b0f +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloInstruction* conv = module->entry_computation()->root_instruction(); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine_with_no_alignment( + [](int64 size) { return 1; }); + + EXPECT_FALSE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_no_alignment)); + + TargetMachineFeaturesWithFakeAlignmentLogic + target_machine_with_full_alignment([](int64 size) { + return TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + + EXPECT_TRUE(PotentiallyImplementedAsEigenConvolution( + *conv, target_machine_with_full_alignment)); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index c053703c352..cf43b74c699 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -45,8 +45,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/host/host_stream.h" -namespace se = ::perftools::gputools; - namespace xla { namespace cpu { @@ -75,7 +73,7 @@ CpuExecutable::CpuExecutable( Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { + std::vector* buffers) { CHECK_EQ(buffers->size(), assignment_->Allocations().size()); VLOG(3) << "Allocating " << assignment_->Allocations().size() << " allocations for module " << module().name(); @@ -203,61 +201,19 @@ Status CpuExecutable::ExecuteComputeFunction( return Status::OK(); } -static void LogLiveAddresses( - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - if (!VLOG_IS_ON(3)) { - return; - } - - CHECK_EQ(buffers.size(), buffers_in_result.size()); - std::vector live_out_buffers; - for (int i = 0; i < buffers.size(); ++i) { - if (buffers_in_result[i]) { - live_out_buffers.push_back(buffers[i].opaque()); - } - } - VLOG(3) << "Live addresses in output marking found " - << live_out_buffers.size() << " addresses:\n" - << tensorflow::str_util::Join( - live_out_buffers, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); -} - -static Status DeallocateTempBuffers( - DeviceMemoryAllocator* allocator, se::Stream* stream, - tensorflow::gtl::ArraySlice buffers, - const std::vector& buffers_in_result) { - // Keep those buffers in the output of the marked live because they are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR( - allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); - } - } - - return Status::OK(); -} - -StatusOr> CpuExecutable::CreateResultShapedBuffer( +StatusOr CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - allocated_buffers, - std::vector* buffers_in_result) { + tensorflow::gtl::MutableArraySlice buffers) { se::Stream* stream = run_options->stream(); - auto result_buffer = MakeUnique( - /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), - stream->parent()->platform(), stream->parent()->device_ordinal()); + ScopedShapedBuffer result_buffer( + /*on_host_shape=*/host_result_shape(), + /*on_device_shape=*/host_result_shape(), run_options->allocator(), + stream->parent()->device_ordinal()); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. - TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + // Move OwningDeviceMemory values which contain the array(s) of the result + // into the respective location in ScopedShapedBuffer which is returned to the + // caller. + TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus( [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); // The points to set is unambiguous so the set should be a @@ -275,16 +231,15 @@ StatusOr> CpuExecutable::CreateResultShapedBuffer( CHECK(!slice.allocation()->is_entry_computation_parameter()); const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + OwningDeviceMemory& buffer = buffers[buffer_index]; CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - (*buffers_in_result)[buffer_index] = true; + *device_memory = buffer.Forget(); return Status::OK(); })); return std::move(result_buffer); } -StatusOr> CpuExecutable::ExecuteOnStream( +StatusOr CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -294,26 +249,24 @@ StatusOr> CpuExecutable::ExecuteOnStream( se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(), + arguments, unowning_buffers, + hlo_execution_profile)); - // Free all buffers not in the result. - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); - - return std::move(result_buffer); + return CreateResultShapedBuffer(run_options, &buffers); } -StatusOr> CpuExecutable::ExecuteAsyncOnStream( +StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { if (hlo_profiling_enabled()) { @@ -322,34 +275,57 @@ StatusOr> CpuExecutable::ExecuteAsyncOnStream( "supported on CPU."); } - auto* host_stream = dynamic_cast( + auto* host_stream = dynamic_cast( run_options->stream()->implementation()); se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - + std::vector buffers(assignment_->Allocations().size()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_buffer, - CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); + std::vector unowning_buffers; + unowning_buffers.reserve(buffers.size()); + for (auto& buffer : buffers) { + unowning_buffers.push_back(buffer.AsDeviceMemoryBase()); + } + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + CreateResultShapedBuffer(run_options, &buffers)); - LogLiveAddresses(buffers, buffers_in_result); + // At this point, `unowning_buffers` contains unowning pointers to all of our + // buffers, and `buffers` contains owning pointers to the non-live-out + // buffers. Enqueue a task which keeps alive the non-live-out buffers. + // + // Logically we want this lambda to capture `buffers` by move, ultimately our + // functor needs to be wrapped in an std::function, and that requires its + // functor to be copyable. Thus we perpitrate the hack of capturing buffers + // "by shared pointer". + // + // We also need to change the types of some of the variables we capture: + // run_options needs to change from a pointer to a value type, and arguments + // needs to change from an ArraySlice into a vector. We use a struct instead + // of a lambda to make this explicit. + struct AsyncRunTask { + CpuExecutable* executable; + ServiceExecutableRunOptions run_options; + std::vector arguments; + std::vector unowning_buffers; + std::shared_ptr> buffers; - host_stream->EnqueueTask([this, run_options, arguments, buffers, - buffers_in_result, memory_allocator, stream]() { - // Failing a CHECK here is not great, but I don't see an obvious way to - // return a failed Status asynchronously. - TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, - buffers, - /*hlo_execution_profile=*/nullptr)); - TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - buffers_in_result)); - }); + void operator()() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(executable->ExecuteComputeFunction( + &run_options.run_options(), arguments, unowning_buffers, + /*hlo_execution_profile=*/nullptr)); + } + }; + host_stream->EnqueueTask(AsyncRunTask{ + this, *run_options, + std::vector(arguments.begin(), arguments.end()), + unowning_buffers, + std::make_shared>(std::move(buffers))}); - return std::move(result_buffer); + return std::move(result); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index d3502b3a03e..8dd47bfb865 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -55,12 +55,12 @@ class CpuExecutable : public Executable { std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr> ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; @@ -90,30 +90,24 @@ class CpuExecutable : public Executable { // assignment. Each vector element corresponds to a particular Index. If // a vector element already contains a non-null DeviceMemoryBase, then no // buffer is assigned for this element. - Status AllocateBuffers( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); + Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator, + int device_ordinal, + std::vector* buffers); // Calls the generated function performing the computation with the given // arguments using the supplied buffers. Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice - buffers, + tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile); - // Create a ShapedBuffer for holding the result of the computation. The - // addresses (DeviceMemoryBases) are set according to buffer assignment. - // 'buffers_in_result' should point to a vector of the same size as - // 'allocated_buffers'. An element in buffers_in_result is set to true if the - // corresponding buffer is live out of the computation (and thus contained in - // the returned ShapedBuffer). - StatusOr> CreateResultShapedBuffer( + // Creates a ScopedShapedBuffer for holding the result of the computation, + // moving buffers out of allocated_buffers and into the result as appropriate. + // The addresses are set according to buffer assignment. + StatusOr CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - allocated_buffers, - std::vector* buffers_in_result); + tensorflow::gtl::MutableArraySlice buffers); // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 0fc5a746bbb..b40d264c03a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -34,6 +34,7 @@ bool CanBeLoopFused(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kConcatenate || hlo.opcode() == HloOpcode::kDynamicSlice || hlo.opcode() == HloOpcode::kDynamicUpdateSlice || + hlo.opcode() == HloOpcode::kGather || hlo.opcode() == HloOpcode::kPad || hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kReverse || diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 6ed1cd31b18..46fe060817b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -157,37 +158,95 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { EXPECT_EQ(dot, computation->root_instruction()); } -TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { - HloComputation::Builder builder(TestName()); - HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 256}), "arg0")); - HloInstruction* arg1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1024, 256}), "arg1")); +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_RHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg1)); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction( - MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[1024,256] parameter(1) + exponential = s32[1024,256] exponential(arg1) + transpose = s32[256,1024] transpose(exponential), dimensions={1,0} + ROOT dot = f32[1,1024] dot(arg0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); TransposeFolding transpose_folding( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { return candidate_operands; }, TransposeFolding::NeverFoldTranspose); - EXPECT_TRUE(transpose_folding.Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); - EXPECT_FALSE(CpuInstructionFusion().Run(module.get()).ValueOrDie()); - EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kFusion); - EXPECT_EQ(computation->root_instruction()->fusion_kind(), - HloInstruction::FusionKind::kTransposeDot); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion_LHS) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[256,1] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[1,256] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/0)); +} + +TEST_F(InstructionFusionTest, + DotOperationFusion_TransposeFusion_LHS_NonDefault) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + transpose = s32[256,1] transpose(arg0), dimensions={1,0} + exponential = s32[256,1024] exponential(arg1) + ROOT dot = f32[1,1024] dot(transpose, exponential), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloComputation* computation = module->entry_computation(); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + TransposeFolding::NeverFoldTranspose); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + ASSERT_TRUE(changed); + ASSERT_THAT(computation->root_instruction(), + op::Dot(op::Parameter(0), op::Exp(op::Parameter(1)), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/0)); } class OpcodeFusionTest : public InstructionFusionTest { @@ -697,6 +756,154 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +struct GatherLoopFusionTestSpec { + string test_name; + string hlo_computation_text; + + static string Name( + const ::testing::TestParamInfo& info) { + return info.param.test_name; + } +}; + +class GatherLoopFusionTest + : public OpcodeFusionTest, + public ::testing::WithParamInterface {}; + +TEST_P(GatherLoopFusionTest, GatherLoopFusion) { + const GatherLoopFusionTestSpec& spec = GetParam(); + string hlo_string = tensorflow::strings::StrCat( + "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kGather, HloOpcode::kAdd, HloOpcode::kBroadcast, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); +} + +std::vector GetGatherLoopFusionTestSpecs() { + std::vector result; + + result.push_back({"FusedTensorFlowGatherV2", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNdMultipleBatchDims", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_0", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedTensorFlowGatherNd_1", R"( +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"}); + + result.push_back({"FusedBatchDynamicSlice", R"( +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"}); + + return result; +} + +INSTANTIATE_TEST_CASE_P(GatherLoopFusionTestInstantiation, GatherLoopFusionTest, + ::testing::ValuesIn(GetGatherLoopFusionTestSpecs()), + GatherLoopFusionTestSpec::Name); } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index e8117377e61..aa872d5ec9e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -100,7 +100,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) { + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) { const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -126,7 +127,8 @@ Status CpuLayoutAssignment::AddBackendConstraints( const HloInstruction* op = instruction->operand(*op_idx); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( ColMajorShape(op->shape()), instruction, *op_idx)); - } else if (PotentiallyImplementedAsEigenDot(*instruction)) { + } else if (PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, // rhs, and output need to be row-major. @@ -139,13 +141,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - // dot is a kDot or a kTransposeDot fusion node. In the latter case, if - // it represents X @ X, it may have just one operand. - if (dot->operand_count() > 1) { - const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); - } + const HloInstruction* rhs_instruction = dot->operand(1); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); @@ -181,7 +179,7 @@ Status CpuLayoutAssignment::AddBackendConstraints( } } } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index c8edbb9e15a..3c4fe68b830 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" @@ -27,12 +28,17 @@ namespace cpu { // layout constraints for operands and results of library calls. class CpuLayoutAssignment : public LayoutAssignment { public: - explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit CpuLayoutAssignment( + ComputationLayout* entry_computation_layout, + const TargetMachineFeatures* target_machine_features) + : LayoutAssignment(entry_computation_layout), + target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 6ba030fff3b..429fc7b7860 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -49,7 +50,12 @@ class CpuLayoutAssignmentTest : public HloTestBase { protected: void AssignLayouts(HloModule* module, ComputationLayout* entry_computation_layout) { - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -311,7 +317,12 @@ static StatusOr RunDotOutputFusion( result.addend_fusion_param = fusion_instruction->operand( fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + cpu::CpuLayoutAssignment layout_assignment(&computation_layout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 09f028463af..f9c51f243c4 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -19,7 +19,6 @@ limitations under the License. namespace { -const char* const kXlaParallelCpuOption = "xla_cpu_parallel"; const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; @@ -30,12 +29,6 @@ namespace xla { namespace cpu { namespace options { -bool CpuParallelBackendRequested(const HloModuleConfig& config) { - const auto& extra_options_map = - config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaParallelCpuOption) > 0; -} - bool OptimizeForSizeRequested(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 6ba0fd24538..be62ff3cc1a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -24,7 +24,6 @@ namespace xla { namespace cpu { namespace options { -bool CpuParallelBackendRequested(const HloModuleConfig& config); bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc deleted file mode 100644 index 662ee609232..00000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ /dev/null @@ -1,192 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" - -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -namespace xla { -namespace cpu { - -StatusOr ParallelizationPreparation::Run(HloModule* module) { - XLA_VLOG_LINES(2, "ParallelizationPreparation ENTRY"); - XLA_VLOG_LINES(2, module->ToString()); - - bool changed = false; - TF_ASSIGN_OR_RETURN(changed, RunParallelTaskAssignment(module)); - - HloComputation* entry_computation = module->entry_computation(); - std::unordered_set outlined; - std::vector instructions_to_outline; - for (HloInstruction* instruction : - entry_computation->MakeInstructionPostOrder()) { - // If the instruction has been outlined, it no longer exists and we must not - // dereference it. - if (outlined.count(instruction) > 0) { - continue; - } - - // Skip parameters and constants, there is nothing to parallelize. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant) { - continue; - } - - // Outline 'instruction' in isolation if it was assigned parallel tasks. - if (OutlineParallelizableInstruction(instruction)) { - outlined.insert(instruction); - changed = true; - continue; - } - - instructions_to_outline.clear(); - HloInstruction* outline_candidate = instruction; - instructions_to_outline.push_back(outline_candidate); - - // Outline sole users with the current instruction. - while (CanOutlineWithUser(outline_candidate)) { - HloInstruction* prior_candidate = outline_candidate; - outline_candidate = *outline_candidate->users().begin(); - if (std::any_of(outline_candidate->operands().begin(), - outline_candidate->operands().end(), - [&](const HloInstruction* operand) { - // Do not consider any candidates which have operands - // other than the prior candidate, constants or - // parameters. Otherwise, we'd increase the fan-in which - // would reduce parallelism. - return operand->opcode() != HloOpcode::kParameter && - operand->opcode() != HloOpcode::kConstant && - operand != prior_candidate; - })) { - break; - } - instructions_to_outline.push_back(outline_candidate); - } - - outlined.insert(instructions_to_outline.begin(), - instructions_to_outline.end()); - - // Optimization to avoid replacing a single existing kCall with another - // kCall that just calls the first one. - if (instructions_to_outline.size() == 1 && - instructions_to_outline[0]->opcode() == HloOpcode::kCall) { - continue; - } - - module->OutlineExpressionFromComputation( - instructions_to_outline, - tensorflow::strings::StrCat("pp_", instruction->name()), - entry_computation); - changed = true; - } - - XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT"); - XLA_VLOG_LINES(2, module->ToString()); - return changed; -} - -StatusOr ParallelizationPreparation::RunParallelTaskAssignment( - HloModule* module) { - VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; - bool changed = false; - // Initialize ParallelTaskAssignment. - ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_, - module); - // Assign parallel tasks to HLOs in entry computation. - HloComputation* computation = module->entry_computation(); - for (auto* instruction : computation->instructions()) { - // Calculate target parallel task count in [1, max_parallelism_]. - const int64 target_parallel_task_count = - parallel_task_assignment.GetTargetParallelTaskCount(instruction); - if (target_parallel_task_count == 1) { - continue; - } - - // Assign feasible dimension partitions (based on actual dimension sizes). - auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) - .Run(target_parallel_task_count); - const int64 total_partition_count = - ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); - if (total_partition_count <= 1) { - // Feasible partition calculation resulting in no partitioning, so skip. - continue; - } - VLOG(2) << "Assigning parallel task count: " << total_partition_count - << " to instruction: " << instruction->name(); - // Map 'instruction' to assigned dimension partitioning. - instruction->set_outer_dimension_partitions(dim_partition_counts); - } - - return changed; -} - -bool ParallelizationPreparation::OutlineParallelizableInstruction( - HloInstruction* instruction) { - if (instruction->outer_dimension_partitions().empty()) { - return false; - } - // Store dimension partition counts before outlining (which clones - // 'instruction'). - std::vector dim_partition_counts = - instruction->outer_dimension_partitions(); - // Outline 'instruction' in its own sub-computation. - HloModule* module = instruction->parent()->parent(); - auto* call = module->OutlineExpressionFromComputation( - {instruction}, tensorflow::strings::StrCat("pp_", instruction->name()), - module->entry_computation()); - // Map previously assigned 'dim_partition_counts' to cloned root instruction. - VLOG(1) << "Outlining parallelizable" - << " caller: " << call->name() - << " callee: " << call->to_apply()->root_instruction()->name(); - call->to_apply()->root_instruction()->set_outer_dimension_partitions( - dim_partition_counts); - return true; -} - -bool ParallelizationPreparation::CanOutlineWithUser( - HloInstruction* instruction) { - if (instruction->users().size() != 1) { - // Do not outline 'instruction' with multiple users. - return false; - } - if (AssignedParallelTasks(instruction) || - AssignedParallelTasks(*instruction->users().begin())) { - // Do not outline if 'instruction' (or user) were assigned parallel tasks. - return false; - } - return true; -} - -bool ParallelizationPreparation::AssignedParallelTasks( - HloInstruction* instruction) { - return !instruction->outer_dimension_partitions().empty() || - (instruction->opcode() == HloOpcode::kCall && - !instruction->to_apply() - ->root_instruction() - ->outer_dimension_partitions() - .empty()); -} - -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h deleted file mode 100644 index 87be758ef5d..00000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ - -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" - -namespace xla { -namespace cpu { - -// This pass prepares an HLO module for parallel execution by transforming -// subgraphs of the top-level computation into embedded computations which can -// be executed in parallel. -// TODO(b/29630486): Currently, it is limited to turning all instructions (which -// are not constants or parameters) in the entry computation into embedded -// computations. However, it could make sense to coarsen the parallelization to -// improve cache locality. Also, we will need to do something to intelligently -// handle While constructs. -class ParallelizationPreparation : public HloPassInterface { - public: - // 'max_parallelism': the maximum parallel task count per instruction. - // 'shape_size': shape size function used by HloCostAnalysis during parallel - // task assignment. - ParallelizationPreparation( - const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_(shape_size) {} - ~ParallelizationPreparation() override {} - - tensorflow::StringPiece name() const override { - return "cpu-parallel-prepare"; - } - - // Run parallel preparation on the given computation. Returns whether the - // computation was changed. - StatusOr Run(HloModule* module) override; - - private: - // Assigns parallel task partitions to conformant instructions in 'module'. - // Returns true on success or error status otherwise. - StatusOr RunParallelTaskAssignment(HloModule* module); - - // Outlines 'instruction' from entry computation, if it had - // been assigned parallel tasks in an earlier pass through the computation. - // Returns true if 'instruction' was successfully outlined, false otherwise. - bool OutlineParallelizableInstruction(HloInstruction* instruction); - - // Returns true if 'instruction' can be outlined into the same sub-computation - // with its single user (parallelizable instructions are not outlined with - // each other). Returns false otherwise. - bool CanOutlineWithUser(HloInstruction* instruction); - - // Returns true if 'instruction' (or the root of the sub-computation that - // 'instruction' calls) has had parallel tasks assigned in earlier pass. - // Returns false otherwise. - bool AssignedParallelTasks(HloInstruction* instruction); - - const int64 max_parallelism_; - const HloCostAnalysis::ShapeSizeFunction shape_size_; -}; - -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 872b0be1f8a..54c52bc08f9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kMKLConvF32SymbolName = "__xla_cpu_runtime_MKLConvF32"; extern const char* const kMKLMatMulF32SymbolName = "__xla_cpu_runtime_MKLMatMulF32"; extern const char* const kMKLMatMulF64SymbolName = @@ -50,6 +51,8 @@ extern const char* const kEigenConvF16SymbolName = extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; +extern const char* const kEigenSingleThreadedFftSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedFft"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index e392e231b4c..aa0e9671230 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -44,6 +44,7 @@ namespace runtime { extern const char* const kEigenMatMulF16SymbolName; extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kMKLConvF32SymbolName; extern const char* const kMKLMatMulF32SymbolName; extern const char* const kMKLMatMulF64SymbolName; extern const char* const kMKLSingleThreadedMatMulF32SymbolName; @@ -51,6 +52,7 @@ extern const char* const kMKLSingleThreadedMatMulF64SymbolName; extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; +extern const char* const kEigenSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index f5e61aef534..d97802ee45d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -34,8 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -90,8 +88,8 @@ CpuTransferManager::CpuTransferManager() : GenericTransferManager(se::host::kHostPlatformId, /*pointer_size=*/sizeof(void*)) {} -Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status CpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); @@ -241,21 +239,20 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( } StatusOr CpuTransferManager::TransferTupleBuffersFromOutfeed( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data) { return TransferBuffersFromOutfeedInternal(executor, buffer_data, /*is_tuple=*/true); } StatusOr CpuTransferManager::TransferArrayBufferFromOutfeed( - perftools::gputools::StreamExecutor* executor, void* destination, - int64 size_bytes) { + se::StreamExecutor* executor, void* destination, int64 size_bytes) { return TransferBuffersFromOutfeedInternal( executor, {{destination, size_bytes}}, /*is_tuple=*/false); } StatusOr CpuTransferManager::TransferBuffersFromOutfeedInternal( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data, bool is_tuple) { std::vector> buffers; @@ -306,8 +303,8 @@ static std::unique_ptr CreateCpuTransferManager() { } static bool InitModule() { - xla::TransferManager::RegisterTransferManager(se::host::kHostPlatformId, - &CreateCpuTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::host::kHostPlatformId, &CreateCpuTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 6c7524d9471..6dfc666f09d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -37,36 +37,35 @@ class CpuTransferManager : public GenericTransferManager { CpuTransferManager(); ~CpuTransferManager() override {} - Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, - const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; - Status TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; private: // Transfers infeed data to device. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. StatusOr TransferBufferToInfeedInternal( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source); + se::StreamExecutor* executor, int64 size, const void* source); // Helper that transfers a tuple of element buffers from the device's outfeed. StatusOr TransferTupleBuffersFromOutfeed( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data); // Helper that transfers an array buffer from the device's outfeed. - StatusOr TransferArrayBufferFromOutfeed( - perftools::gputools::StreamExecutor* executor, void* destination, - int64 size_bytes); + StatusOr TransferArrayBufferFromOutfeed(se::StreamExecutor* executor, + void* destination, + int64 size_bytes); // On success, returns the shape that was transferred from the outfeed -- if // is_tuple is true, the returned shape will be a tuple of the returned shapes // for the given buffers. StatusOr TransferBuffersFromOutfeedInternal( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice> buffer_data, bool is_tuple); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 29afd8ea5f9..5cdfc110aff 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -522,16 +523,16 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // namespace -DotOpEmitter::DotOpEmitter( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, - llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config, - const TargetMachineFeatures& target_machine_features) +DotOpEmitter::DotOpEmitter(const HloInstruction& dot, + const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, + llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), @@ -541,24 +542,22 @@ DotOpEmitter::DotOpEmitter( hlo_module_config_(hlo_module_config), target_machine_features_(target_machine_features) {} -/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, +/* static */ Status DotOpEmitter::EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); - DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, addend_array, - executable_run_options_value, ir_builder, - hlo_module_config, target_machine_features); + DotOpEmitter dot_emitter(dot, target_array, lhs_array, rhs_array, + addend_array, executable_run_options_value, + ir_builder, hlo_module_config, + target_machine_features); return dot_emitter.Emit(); } -bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } - bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (dot_.shape().dimensions_size() != 2) { return false; @@ -580,7 +579,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.m == 1) { bool rhs_effectively_row_major = - transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + mat_mult_dims.rhs_non_canonical ^ !mat_mult_dims.rhs_column_major; if (rhs_effectively_row_major) { k = mat_mult_dims.k; m = mat_mult_dims.n; @@ -596,7 +595,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (mat_mult_dims.n == 1) { bool lhs_effectively_column_major = - transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + mat_mult_dims.lhs_non_canonical ^ mat_mult_dims.lhs_column_major; if (lhs_effectively_column_major) { m = mat_mult_dims.m; k = mat_mult_dims.k; @@ -692,7 +691,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { return true; } -tensorflow::Status DotOpEmitter::Emit() { +Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. // @@ -736,23 +735,17 @@ tensorflow::Status DotOpEmitter::Emit() { CHECK_EQ(addend_array_, nullptr); - if (PotentiallyImplementedAsEigenDot(dot_)) { + if (PotentiallyImplementedAsEigenDot(dot_, target_machine_features_)) { return EmitCallToRuntime(); } // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special // case where the reduction dimension is 0 for both LHS and RHS. This results // in a vector dot product producing a scalar. - int64 lhs_reduction_dimension = 0; - if (ShapeUtil::Rank(lhs_shape) >= 2) { - lhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1); - } - int64 rhs_reduction_dimension = 0; - if (ShapeUtil::Rank(rhs_shape) >= 2) { - rhs_reduction_dimension = - ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2); - } + int64 lhs_reduction_dimension = + dot_.dot_dimension_numbers().lhs_contracting_dimensions(0); + int64 rhs_reduction_dimension = + dot_.dot_dimension_numbers().rhs_contracting_dimensions(0); // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -876,10 +869,10 @@ tensorflow::Status DotOpEmitter::Emit() { // loop. ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitScalarDot() { +Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. llvm::Value* result; llvm::Value* lhs_value = @@ -904,12 +897,10 @@ tensorflow::Status DotOpEmitter::EmitScalarDot() { result = ir_builder_->CreateFMul(lhs_value, rhs_value); } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status DotOpEmitter::EmitCallToRuntime() { - DCHECK(ShapesAreLegalForRuntimeDot()); - +Status DotOpEmitter::EmitCallToRuntime() { // The signature of the Eigen runtime matmul function is: // // (void)(void* run_options, float* out, float* lhs, float* rhs, @@ -990,8 +981,8 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; - bool transpose_lhs = transpose_lhs_; - bool transpose_rhs = transpose_rhs_; + bool transpose_lhs = mat_mult_dims.lhs_non_canonical; + bool transpose_rhs = mat_mult_dims.rhs_non_canonical; if (!mat_mult_dims.lhs_column_major) { std::swap(mat_mult_dims.m, mat_mult_dims.n); @@ -1011,7 +1002,7 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { ir_builder_->getInt64(mat_mult_dims.k), ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); - return tensorflow::Status::OK(); + return Status::OK(); } DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { @@ -1019,12 +1010,16 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { const Shape& lhs_shape = lhs_array_.GetShape(); const Shape& rhs_shape = rhs_array_.GetShape(); + const DotDimensionNumbers& dim_nums = dot_.dot_dimension_numbers(); - return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), - lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), - rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, - LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; + return { + /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)), + /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0)), + /*n=*/rhs_shape.dimensions(1 - dim_nums.rhs_contracting_dimensions(0)), + /*lhs_column_major=*/LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + /*lhs_non_canonical=*/dim_nums.lhs_contracting_dimensions(0) == 0, + /*rhs_column_major=*/LayoutUtil::Minor(rhs_shape.layout(), 0) == 0, + /*rhs_non_canonical=*/dim_nums.rhs_contracting_dimensions(0) == 1}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -1064,18 +1059,39 @@ static bool IsRank2WithNoPadding(const Shape& shape) { // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. -static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { +static bool AreValidGemmShapes( + const Shape& lhs_shape, const Shape& rhs_shape, const Shape& output_shape, + const TargetMachineFeatures& target_machine_features) { // The inputs and the output must // 1) be matrices with no padding, and // 2) have an allowed element type. PrimitiveType output_primitive_type = output_shape.element_type(); - return (output_primitive_type == F32 || output_primitive_type == F16) && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); + if (!(output_primitive_type == F64 || output_primitive_type == F32 || + output_primitive_type == F16)) { + return false; + } + + if (!(IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape))) { + return false; + } + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(lhs_shape) || !is_aligned(rhs_shape) || + !is_aligned(output_shape)) { + return false; + } + + return true; } -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features) { // For certain types of Dot, we can call Eigen if (hlo.opcode() == HloOpcode::kDot) { const Shape& lhs_shape = hlo.operand(0)->shape(); @@ -1092,28 +1108,18 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape(), + target_machine_features)) { + const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers(); // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; } } - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - return false; } diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 9d748eb81f7..a75b8ffcbfc 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -31,7 +31,9 @@ limitations under the License. namespace xla { namespace cpu { -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); +bool PotentiallyImplementedAsEigenDot( + const HloInstruction& hlo, + const TargetMachineFeatures& target_machine_features); // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot @@ -55,17 +57,16 @@ class DotOpEmitter { // dimensions as the result, and the result is computed as `addend_array` + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported // for Matrix-vector products. - static tensorflow::Status EmitDotOperation( - const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, - const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + static Status EmitDotOperation( + const HloInstruction& dot, const llvm_ir::IrArray& target_array, + const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features); private: - DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, const llvm_ir::IrArray& target_array, + DotOpEmitter(const HloInstruction& dot, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, @@ -75,18 +76,18 @@ class DotOpEmitter { const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. - tensorflow::Status Emit(); + Status Emit(); // Emits instructions to perform a scalar dot product (a multiply of the // LHS and RHS) and store the results in the target. - tensorflow::Status EmitScalarDot(); + Status EmitScalarDot(); // Emit an LLVM IR implementation of the dot operation if we can. Returns // true if an LLVM IR implementation was emitted. bool EmitLlvmIrDotIfProfitable(); // Emits a call to the CPU runtime to perform the matrix multiply. - tensorflow::Status EmitCallToRuntime(); + Status EmitCallToRuntime(); // Emits a series of nested loops for iterating over an operand array in the // dot operation. Loops are constructed in major to minor dimension layout @@ -99,10 +100,6 @@ class DotOpEmitter { llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix); - // Our runtime operation requires that all arrays have the same layout, - // no padding, and a rank of two. - bool ShapesAreLegalForRuntimeDot() const; - // Represents the dimensions of a matrix-matrix multiply operation. struct MatMultDims { // The number of rows in the LHS. @@ -115,11 +112,17 @@ class DotOpEmitter { // The number of columns on the RHS. int64 n; - // True if the LHS matrix column major. + // True if the LHS matrix is column major. bool lhs_column_major; - // True if the RHS matrix column major. + // True if the LHS contraction dimension is not 1. + bool lhs_non_canonical; + + // True if the RHS matrix is column major. bool rhs_column_major; + + // True if the RHS contraction dimension is not 0. + bool rhs_non_canonical; }; // Get the MatMultDims instance for the dot product this DotOpEmitter @@ -136,8 +139,6 @@ class DotOpEmitter { } const HloInstruction& dot_; - const bool transpose_lhs_; - const bool transpose_rhs_; const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 99c5e16db70..e97113dfa0f 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -115,7 +115,7 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( for (int i = 0; i < hlo->operand_count(); i++) { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(i))( - ElementwiseSourceIndex(index, *hlo, 0))); + ElementwiseSourceIndex(index, *hlo, i))); operands.push_back(operand_value); } return ir_emitter_->EmitScalarCall(hlo->shape().element_type(), diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index 7dcc4ca7fa0..c5628655915 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -26,13 +26,13 @@ limitations under the License. namespace xla { namespace cpu { -void ExternalConstantPool::Insert(string name, const Literal& literal, +void ExternalConstantPool::Insert(string name, const LiteralSlice& literal, int64 alignment) { CHECK(!ShapeUtil::IsTuple(literal.shape())); CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); CHECK(entries_.find(name) == entries_.end()); - int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + const int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); void* raw_pointer = tensorflow::port::AlignedMalloc( literal_size, std::max(alignment, sizeof(void*))); CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index 8008a56df4d..0677f5f0b58 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -43,7 +43,7 @@ class ExternalConstantPool { // The constant pool copies out the contents of `literal` into a buffer it // owns -- it does not keep pointers to `literal`, or to memory owned by // `literal`. - void Insert(string name, const Literal& literal, int64 alignment); + void Insert(string name, const LiteralSlice& literal, int64 alignment); // Find the constant with name `name` in this constant pool. If there isn't // such constant, return nullptr. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index f209a69e3cd..b560b7531c0 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -24,8 +24,25 @@ limitations under the License. namespace xla { namespace cpu { +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features) { + CHECK(ShapeUtil::IsArray(shape)); + CHECK(!LayoutUtil::HasLayout(shape) || LayoutUtil::IsDense(shape.layout())); + + // We don't require a layout to be set on `shape`. This only works on CPU + // because we don't pad our tensors or otherwise have complicated data tiling + // schemes. + + int64 allocation_size_bytes = + ShapeUtil::ElementsIn(shape) * + ShapeUtil::ByteSizeOfPrimitiveType(shape.element_type()); + return target_machine_features.minimum_alignment_for_allocation( + allocation_size_bytes); +} + bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution) { + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features) { // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. @@ -35,6 +52,18 @@ bool PotentiallyImplementedAsEigenConvolution( // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); const Shape& kernel_shape = convolution.operand(1)->shape(); + const Shape& output_shape = convolution.shape(); + + auto is_aligned = [&](const Shape& shape) { + return GetMinimumAlignmentForArray(shape, target_machine_features) >= + TargetMachineFeatures::kEigenExpectedTensorAlignment; + }; + + if (!is_aligned(input_shape) || !is_aligned(kernel_shape) || + !is_aligned(output_shape)) { + return false; + } + if (ShapeUtil::HasZeroElements(input_shape) || ShapeUtil::HasZeroElements(kernel_shape)) { return false; @@ -71,7 +100,6 @@ bool PotentiallyImplementedAsEigenConvolution( } } - const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && dnums.output_batch_dimension() == 0 && diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 34b20039169..68fbc7caaa9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -17,13 +17,20 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { namespace cpu { bool PotentiallyImplementedAsEigenConvolution( - const HloInstruction& convolution); + const HloInstruction& convolution, + const TargetMachineFeatures& target_machine_features); + +// Computes the minimum alignment guaranteed for a tensor of shape `shape` on +// the target machine. +int64 GetMinimumAlignmentForArray( + const Shape& shape, const TargetMachineFeatures& target_machine_features); // Dynamic loop bounds are specified as an array of dimension index // [start, limit) pairs of ir values (one for each partitioned outer dimension). diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index 215f48c4cc1..abb2471e6ae 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" @@ -39,7 +40,12 @@ ENTRY Conv { HloComputation* entry_computation = module->entry_computation(); HloInstruction* conv_instr = entry_computation->root_instruction(); - EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution(*conv_instr)); + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features( + [](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }); + EXPECT_FALSE(cpu::PotentiallyImplementedAsEigenConvolution( + *conv_instr, target_machine_features)); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 3405277d449..e4d11f27ad9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -83,7 +83,7 @@ IrEmitter::IrEmitter( llvm::Module* llvm_module, std::unordered_map instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine_features, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), @@ -93,10 +93,8 @@ IrEmitter::IrEmitter( computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), - parallel_cpu_backend_( - options::CpuParallelBackendRequested(hlo_module_config_)), is_top_level_computation_(false), - target_machine_features_(target_machine), + target_machine_features_(*target_machine_features), external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() @@ -162,10 +160,8 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -Status IrEmitter::HandleConstant(HloInstruction* constant) { - VLOG(2) << "HandleConstant: " << constant->ToString(); - const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; +llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::GlobalVariable* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant @@ -176,27 +172,42 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - global_for_const = new llvm::GlobalVariable( + result = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + result->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - global_for_const = new llvm::GlobalVariable( + result = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + result->setAlignment(MinimumAlignmentForShape(literal.shape())); + } + return result; +} + +Status IrEmitter::HandleConstant(HloInstruction* constant) { + VLOG(2) << "HandleConstant: " << constant->ToString(); + const Literal& literal = constant->literal(); + llvm::GlobalVariable* global_for_const; + + auto it = emitted_literals_.find(&literal); + if (it != emitted_literals_.end()) { + global_for_const = it->second; + } else { + global_for_const = EmitGlobalForLiteral(literal); + emitted_literals_[&literal] = global_for_const; } emitted_value_[constant] = global_for_const; VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); @@ -216,32 +227,6 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { } } -// Calculate the alignment of a buffer with a particular size. -int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { - // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on - // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than kMallocAlignmentThreshold bytes and at least - // alignment 16 for allocations greater than or equal to - // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound - // by explicitly allocating the memory with posix_memalign. This is - // complicated by our desire to allow parameter buffers created by clients to - // be consumed directly by the JIT. - if (buffer_size == 0) { - // No need to align empty buffers. - return 1; - } - - const int64 kMallocAlignmentThreshold = 512; - - int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= kMallocAlignmentThreshold - ? 2 * pointer_size - : pointer_size; - DCHECK_GT(buffer_alignment, 0); - - return buffer_alignment; -} - // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); @@ -266,7 +251,7 @@ int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); - return MinimumAlignmentForBufferSize(buffer_size); + return target_machine_features_.minimum_alignment_for_allocation(buffer_size); } void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, @@ -279,7 +264,8 @@ void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size) { - int alignment = MinimumAlignmentForBufferSize(buffer_size); + int alignment = + target_machine_features_.minimum_alignment_for_allocation(buffer_size); if (alignment > 1) { llvm_ir::SetAlignmentMetadataForLoad(load, alignment); } @@ -519,7 +505,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32, BF16})); + /*supported_types=*/{F32, BF16, S32})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -816,13 +802,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { "Dot with multiple contracting dimensions not implemented."); } - if (dnums.lhs_contracting_dimensions(0) != - std::min(lhs->shape().dimensions_size() - 1, 1) || - dnums.rhs_contracting_dimensions(0) != 0) { - return Unimplemented( - "Dot with non-standard contracting dimensions not implemented."); - } - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -839,8 +818,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, /*addend_array=*/nullptr, + *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, target_machine_features_); } @@ -856,7 +834,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); - if (PotentiallyImplementedAsEigenConvolution(*convolution)) { + // TODO(tonywy): Add PotentiallyImplementedAsMKLCovolution to support + // different data layouts. + if (PotentiallyImplementedAsEigenConvolution(*convolution, + target_machine_features_)) { const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); const Shape& convolution_shape = convolution->shape(); @@ -944,16 +925,26 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - bool multi_threaded_eigen = + bool multi_threaded = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + bool use_mkl_dnn = + hlo_module_config_.debug_options().xla_cpu_use_mkl_dnn(); + + // TODO(b/78639006) Singlethread MKL conv2d is not implemented due to the + // potential race condition by setting the omp_num_threads. const char* fn_name = primitive_type == F16 - ? (multi_threaded_eigen + ? (multi_threaded ? runtime::kEigenConvF16SymbolName : runtime::kEigenSingleThreadedConvF16SymbolName) - : (multi_threaded_eigen - ? runtime::kEigenConvF32SymbolName + : (multi_threaded + ? (use_mkl_dnn ? runtime::kMKLConvF32SymbolName + : runtime::kEigenConvF32SymbolName) : runtime::kEigenSingleThreadedConvF32SymbolName); + if (!multi_threaded && use_mkl_dnn) { + LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded " + "conv2d function."; + } llvm::Function* conv_func = llvm::cast( module_->getOrInsertFunction(fn_name, conv_type)); conv_func->setCallingConv(llvm::CallingConv::C); @@ -1012,12 +1003,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); + llvm::Type* lhs_llvm_type = + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), - "convolution_sum_address", &ir_builder_, + lhs_llvm_type, "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address); + llvm::Value* constant_zero = + llvm::Constant::getNullValue(lhs_llvm_type); + ir_builder_.CreateStore(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &ir_builder_); std::vector kernel_spatial(num_spatial_dims); @@ -1171,7 +1164,13 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); - const char* fn_name = runtime::kEigenFftSymbolName; + + bool multi_threaded_eigen = + hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); + const char* fn_name = multi_threaded_eigen + ? runtime::kEigenFftSymbolName + : runtime::kEigenSingleThreadedFftSymbolName; + llvm::Function* fft_func = llvm::cast( module_->getOrInsertFunction(fn_name, fft_type)); fft_func->setCallingConv(llvm::CallingConv::C); @@ -2063,44 +2062,7 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { Status IrEmitter::HandleFusion(HloInstruction* fusion) { auto* root = fusion->fused_expression_root(); - if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - DCHECK(root->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - fusion->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - fusion->operand(rhs_parameter->parameter_number()); - - TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*root, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F16, F32})); - - llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - - Shape target_shape = fusion->shape(); - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array = GetIrArrayFor(fusion); - VLOG(2) << "HandleFusion kTransposeDot: "; - VLOG(2) << " lhs operand: " - << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); - VLOG(2) << " rhs operand: " - << llvm_ir::DumpToString(*rhs_array.GetBasePointer()); - VLOG(2) << " target: " - << llvm_ir::DumpToString(*target_array.GetBasePointer()); - - // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *root, root->operand(0)->IsRank2Transpose(), - root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); - return Status::OK(); - } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, - assignment_)) { + if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); @@ -2143,9 +2105,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { GetIrArrayFor(fusion->operand(addend_param_number))); TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), - &ir_builder_, hlo_module_config_, target_machine_features_)); + *dot, target_array, lhs_array, rhs_array, &addend_array, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_)); return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); @@ -2163,8 +2125,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); - if (!computation->root_instruction()->outer_dimension_partitions().empty() && - !parallel_cpu_backend_) { + if (!computation->root_instruction()->outer_dimension_partitions().empty()) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. std::vector call_args = GetArrayFunctionCallArguments( @@ -2541,8 +2502,12 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // nothing to do since the result was already written directly into the output // buffer. VLOG(2) << "FinishVisit root: " << root->ToString(); - llvm::Value* root_value = GetEmittedValueFor(root); - VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); + if (root->opcode() == HloOpcode::kOutfeed) { + VLOG(2) << " outfeed with value: " + << llvm_ir::DumpToString(*GetEmittedValueFor(root->operand(0))); + } else { + VLOG(2) << " value: " << llvm_ir::DumpToString(*GetEmittedValueFor(root)); + } auto record_complete_computation = [&](llvm::Value* prof_counter) { if (prof_counter) { @@ -2550,22 +2515,6 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { } }; - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - record_complete_computation(GetProfileCounterFor(*instruction)); - return Status::OK(); - } - } - } - } - // For the entry computation this increment is cumulative of embedded // computations since it includes cycles spent in computations invoked by // While, Call etc. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 50944025149..f49cfc1dc37 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -76,7 +76,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { instruction_to_profile_idx, std::unordered_map computation_to_profile_idx, - llvm::TargetMachine* target_machine, + const TargetMachineFeatures* target_machine, ExternalConstantPool* external_constant_pool); ~IrEmitter() override; @@ -514,9 +514,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Calculate the alignment of a buffer allocated for a given primitive type. int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type); - // Calculate the alignment of a buffer with a particular size. - int MinimumAlignmentForBufferSize(int64 buffer_size); - // Returns the number of bytes within the shape. int64 ByteSizeOf(const Shape& shape) const; @@ -530,17 +527,31 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - const HloModuleConfig& hlo_module_config_; + llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal); - const bool parallel_cpu_backend_; + const HloModuleConfig& hlo_module_config_; bool is_top_level_computation_; - TargetMachineFeatures target_machine_features_; + const TargetMachineFeatures& target_machine_features_; int64 external_global_constant_counter_ = 0; ExternalConstantPool* external_constant_pool_; + struct LiteralPtrHashFunctor { + size_t operator()(const Literal* literal) const { return literal->Hash(); } + }; + + struct LiteralPtrEqualityFunctor { + bool operator()(const Literal* lhs, const Literal* rhs) const { + return *lhs == *rhs; + } + }; + + tensorflow::gtl::FlatMap + emitted_literals_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index 557aa4a6bfc..2e55181eed8 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -33,8 +33,8 @@ namespace cpu { // emitters for function and function argument access. // The llvm::Function is created with the standard function signature // used in the XLA CPU backend (see ir_function.cc for argument details). -// In addtion IrFunction saves the callers IR insert point during contruction, -// and restores it after desctruction. +// In addition IrFunction saves the callers IR insert point during construction, +// and restores it after destruction. // // Example usage: // diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc deleted file mode 100644 index 07a9f0efcb6..00000000000 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ /dev/null @@ -1,531 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace se = ::perftools::gputools; - -namespace xla { -namespace cpu { - -ParallelCpuExecutable::ParallelCpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr> function_names, - std::unordered_map> - aligned_constants, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), - std::move(hlo_profile_index_map)), - jit_(std::move(jit)), - assignment_(std::move(assignment)), - function_names_(std::move(function_names)), - aligned_constants_(std::move(aligned_constants)) {} - -// Type of the computation function we expect in the JIT. -using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - int64*, int64*); - -// Given a pointer to an output buffer (following the CPU JIT calling -// conventions), mark addresses that are "live". The initial pointer itself is -// trivially live. If the shape of the buffer is a tuple, this analysis looks -// into the tuple's elements and marks them live as well (since tuples keep -// pointers to buffers) and also works recursively. -// address is an in-memory buffer address that contains some runtime XLA object. -// shape is its shape. marked_addresses is the set of live addresses to -// populate. -static void MarkLiveAddressesInOutput( - const void* address, const Shape& shape, - std::unordered_set* marked_addresses) { - marked_addresses->insert(address); - const uintptr_t* address_buffer = static_cast(address); - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const uintptr_t* element_address = address_buffer + i; - const void* element = reinterpret_cast(*element_address); - MarkLiveAddressesInOutput( - element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); - } - } -} - -namespace { - -// Executor manages the concurrent execution of 'functions' for instructions -// in 'pending' on 'thread_pool' (storing resulting data in 'results'). -class Executor { - public: - Executor(const HloInstructionMap& functions, - const ServiceExecutableRunOptions* run_options, - std::list* pending, - HloInstructionMap* results, void** temps_array, - int64* profile_counters_array, const BufferAssignment* assignment) - : functions_(functions), - run_options_(run_options), - pending_(pending), - results_(results), - temps_array_(temps_array), - profile_counters_array_(profile_counters_array), - thread_pool_(CHECK_NOTNULL(run_options_->xla_intra_op_thread_pool())), - assignment_(assignment) {} - - // Executes pending list of instructions on thread pool. - // Returns OK status on success, error status otherwise. - Status Run(); - - private: - // Schedules a parallel invocation of compute function for 'instruction' on - // 'thread_pool_', storing result in 'result_buffer'. - // If 'partition_buffers' is non-null, parallel task will be invoked on - // per-dimension partition [start, limit) values stored in - // 'partition_buffers'. - void Schedule(HloInstruction* instruction, int64* partition_buffers, - void* result_buffer); - - // Returns true if 'instruction' has been assigned parallel tasks (returns - // false otherwise). - bool HasParallelTasks(HloInstruction* instruction); - - // Returns in 'partition_buffers' the partition [size, limit) for each - // dimension. - int64* GetPartitionBuffers( - const std::vector>& partition); - - // Returns array of result buffers for all operands in 'instruction'. - const void** GetOperandBuffers(HloInstruction* instruction); - - // Arguments passed into Executor. - const HloInstructionMap& functions_; - const ServiceExecutableRunOptions* run_options_; - std::list* pending_; - HloInstructionMap* results_; - void** temps_array_; - int64* profile_counters_array_; - tensorflow::thread::ThreadPool* thread_pool_; - const BufferAssignment* assignment_; - - // Members used to manage instruction execution. - tensorflow::mutex completion_queue_lock_; - tensorflow::condition_variable completion_queue_cv_; - std::deque completion_queue_; - int64 instructions_in_flight_ = 0; - std::unordered_map tasks_in_flight_; -}; - -Status Executor::Run() { - while (!pending_->empty() || instructions_in_flight_ > 0) { - auto pending_it = pending_->begin(); - while (pending_it != pending_->end()) { - HloInstruction* instruction = *pending_it; - // Skip pending instructions whose operands aren't ready. - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), - [&](HloInstruction* operand) { - return !ContainsKey(*results_, operand); - })) { - ++pending_it; - continue; - } - - // Get 'result_buffer' reference to result buffer for 'instruction'. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array_[result_slice.index()]) + - result_slice.offset(); - - if (HasParallelTasks(instruction)) { - // 'instruction' has been assigned parallel task partitions. - CHECK_EQ(HloOpcode::kCall, instruction->opcode()); - HloInstruction* root = instruction->to_apply()->root_instruction(); - - // Create ShapePartitionIterator to iterate through all outer dimension - // partitions of 'instruction'. - ShapePartitionIterator partition_iterator( - root->shape(), root->outer_dimension_partitions()); - - const int64 partition_count = - partition_iterator.GetTotalPartitionCount(); - - // Record total parallel task count for 'instruction' before dispatch. - { - tensorflow::mutex_lock l(completion_queue_lock_); - tasks_in_flight_.insert(std::make_pair(instruction, partition_count)); - VLOG(2) << "Schedule PARALLEL" - << " instruction: " << instruction->name() - << " instruction.callee: " - << instruction->to_apply()->root_instruction()->name() - << " partition_count: " << partition_count; - } - - for (int64 i = 0; i < partition_count; ++i) { - // Get partition [start, limit) for each dimension. - auto partition_buffers = - GetPartitionBuffers(partition_iterator.GetPartition(i)); - Schedule(instruction, partition_buffers, result_buffer); - } - - } else { - // Set tasks in-flight to '1' for sequential instruction execution. - { - tensorflow::mutex_lock l(completion_queue_lock_); - tasks_in_flight_.insert(std::make_pair(instruction, 1)); - VLOG(2) << "Schedule SEQUENTIAL" - << " instruction: " << instruction->name() - << " instruction.callee: " - << instruction->to_apply()->root_instruction()->name(); - } - Schedule(instruction, nullptr, result_buffer); - } - - ++instructions_in_flight_; - pending_it = pending_->erase(pending_it); - } - // Wait for a completed HLO instruction to be present in the queue. We will - // pop it out of the queue and make the result available to its users. - HloInstruction* instruction; - do { - tensorflow::mutex_lock l(completion_queue_lock_); - if (completion_queue_.empty()) { - completion_queue_cv_.wait(l); - } - if (!completion_queue_.empty()) { - instruction = completion_queue_.front(); - completion_queue_.pop_front(); - break; - } - } while (true); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array_[result_slice.index()]) + - result_slice.offset(); - InsertOrDie(results_, instruction, result_buffer); - --instructions_in_flight_; - } - return Status::OK(); -} - -void Executor::Schedule(HloInstruction* instruction, int64* partition_buffers, - void* result_buffer) { - // The thread pool entry takes ownership of |operand_buffers|. - auto operand_buffers = GetOperandBuffers(instruction); - - auto function = FindOrDie(functions_, instruction); - const auto* exec_run_options = &run_options_->run_options(); - thread_pool_->Schedule([this, instruction, result_buffer, operand_buffers, - partition_buffers, exec_run_options, function]() { - function(result_buffer, exec_run_options, operand_buffers, temps_array_, - partition_buffers, profile_counters_array_); - - delete[] operand_buffers; - delete[] partition_buffers; - // Push the completed HLO instruction on the queue, the main - // thread will pop it off and potentially launch more work which - // uses the result. - // TODO(b/27458679) Consider alternative task scheduling and synchronization - // schemes. For example, we could avoid the overhead associate with the - // condvar here if the thread just dequed the next instruction to execute - // on completion. - { - tensorflow::mutex_lock l(completion_queue_lock_); - // Decrement in-flight task count for this completion. - if (--FindOrDie(tasks_in_flight_, instruction) == 0) { - completion_queue_.push_back(instruction); - completion_queue_cv_.notify_all(); - tasks_in_flight_.erase(instruction); - } - } - }); -} - -int64* Executor::GetPartitionBuffers( - const std::vector>& partition) { - // Return in 'partition_buffers' partition [size, limit) for each dimension. - auto partition_buffers = new int64[partition.size() * 2]; - for (int i = 0; i < partition.size(); ++i) { - partition_buffers[2 * i + 0] = partition[i].first; - partition_buffers[2 * i + 1] = partition[i].first + partition[i].second; - } - return partition_buffers; -} - -bool Executor::HasParallelTasks(HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kCall && - !instruction->to_apply() - ->root_instruction() - ->outer_dimension_partitions() - .empty(); -} - -const void** Executor::GetOperandBuffers(HloInstruction* instruction) { - // We cannot use a move-only RAII type like std::unique_ptr because the - // list of operands is allocated on the main thread and transferred to the - // worker via the lambda passed to enqueue_function. In order for the - // lambda to take ownership, we would need to use generalized lambda - // capture which is a feature new to C++14. - // TODO(b/27458679) Avoid dynamic allocations in Executor. - auto operand_buffers = new const void*[instruction->operand_count()]; - std::transform(instruction->operands().begin(), instruction->operands().end(), - operand_buffers, [this](HloInstruction* operand) { - return FindOrDie(*results_, operand); - }); - return operand_buffers; -} - -} // namespace - -Status ParallelCpuExecutable::AllocateBuffers( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers) { - CHECK_EQ(buffers->size(), assignment_->Allocations().size()); - VLOG(3) << "Allocating " << assignment_->Allocations().size() - << " allocations for module " << module().name(); - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - auto& allocation = assignment_->GetAllocation(i); - - VLOG(3) << allocation.ToString(); - - if (allocation.is_entry_computation_parameter()) { - VLOG(3) << "allocation #" << i << " is a parameter"; - continue; - } - - if (allocation.is_thread_local()) { - VLOG(3) << "buffer #" << i << " is thread-local"; - continue; - } - - int64 buffer_size = allocation.size(); - if (!(*buffers)[i].is_null()) { - VLOG(3) << "buffer #" << i - << " is in the preallocated result ShapedBuffer"; - } else { - TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate( - device_ordinal, buffer_size)); - - VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes [" - << (*buffers)[i].opaque() << "]"; - } - - // Since the output buffer and all the temporary buffers were written into - // by the JITed code, msan has no way of knowing their memory was - // initialized. Mark them initialized so that msan doesn't flag loads from - // these buffers. - TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size); - } - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - VLOG(3) << "result index: " << result_slice.index(); - - return Status::OK(); -} - -Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { - // Allocate profiling counters for each hlo instruction that we would like to - // profile. - std::vector* profile_counters = nullptr; - if (hlo_execution_profile) { - profile_counters = hlo_execution_profile->mutable_profile_counters(); - } - - std::vector buffer_pointers; - buffer_pointers.reserve(buffers.size()); - for (auto device_allocation : buffers) { - buffer_pointers.push_back(device_allocation.opaque()); - } - - // Resolve functions for all the HLO instructions ahead of time. - HloInstructionMap functions; - for (auto& entry : *function_names_) { - tensorflow::mutex_lock lock(jit_mutex_); - HloInstruction* instruction = entry.first; - llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry.second); - TF_RET_CHECK(sym); - InsertOrDie( - &functions, instruction, - reinterpret_cast(cantFail(sym.getAddress()))); - } - - // Map containing pointers to result buffers for each instruction. - HloInstructionMap results; - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - - std::list pending; - - // Call the function for each HLO instruction in topological order. - const HloComputation& entry_computation = *module().entry_computation(); - for (auto* instruction : entry_computation.MakeInstructionPostOrder()) { - // Parameters and constants have no functions associated with them. Instead - // just copy the existing buffer into the map containing instruction - // results.. - if (instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie( - &results, instruction, - arguments[instruction->parameter_number()]->root_buffer().opaque()); - } else if (instruction->opcode() == HloOpcode::kConstant) { - unsigned char* aligned_data = - FindOrDie(aligned_constants_, instruction).get(); - InsertOrDie(&results, instruction, aligned_data); - } else { - TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall); - pending.push_back(instruction); - } - } - - // TODO(b/27458679) Manage scheduling based on in-flight concurrency limits. - // For example, if we expect a library conv/matmul call to run at max - // concurrency, we should not dispatch runnable instructions until the - // library call is finished (to avoid expensive cache invalidation). - Executor executor( - functions, run_options, &pending, &results, buffer_pointers.data(), - profile_counters ? profile_counters->data() : nullptr, assignment_.get()); - - TF_RETURN_IF_ERROR(executor.Run()); - - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - - { - tensorflow::mutex_lock lock(mutex_); - double nanoseconds = (end_micros - start_micros) * 1000.0; - execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - } - - return Status::OK(); -} - -StatusOr> ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); - - auto result_buffer = MakeUnique( - /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), - stream->parent()->platform(), stream->parent()->device_ordinal()); - - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, - hlo_execution_profile)); - - // Copy DeviceMemoryBase values which into the respective location in - // ShapedBuffer which is returned to the caller. - std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( - [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { - const auto& sources = this->GetRootPointsToSet().element(index); - - // The points to set is unambiguous so the set should be a singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer such as a - // tuple element. The source instruction should have a non-parameter - // buffer assigned. - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice(src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *device_memory = buffer; - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); - - // Free all buffers not in the result. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } - - return std::move(result_buffer); -} - -StatusOr> -ParallelCpuExecutable::ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { - // TODO(b/30671675): Implement asynchronous execution mode. - return Unimplemented( - "Asynchronous execution on stream is not yet supported on CPU."); -} - -const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { - return assignment_->points_to_analysis().GetPointsToSet( - module().entry_computation()->root_instruction()); -} - -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h deleted file mode 100644 index 87c0a3df458..00000000000 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/buffer_assignment.h" -#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" -#include "tensorflow/compiler/xla/service/device_memory_allocator.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { -namespace cpu { - -// CPU-targeting parallel implementation of the XLA Executable interface. -// -// Wraps a JIT-ed object that can be executed "on device". We JIT for the host -// architecture, so JIT-ed code and host code share the same ABI. -class ParallelCpuExecutable : public Executable { - public: - ParallelCpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr> function_names, - std::unordered_map> - aligned_constants, - std::unique_ptr hlo_profile_printer_data, - std::unique_ptr hlo_profile_index_map); - ~ParallelCpuExecutable() override {} - - StatusOr> ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; - - StatusOr> ExecuteAsyncOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) override; - - // This should be called after set_ir_module_string. - const string& ir_module_string() const { return ir_module_string_; } - - void set_ir_module_string(const string& ir_module_string) { - ir_module_string_ = ir_module_string; - } - - static int64 ShapeSizeBytes(const Shape& shape) { - // On the cpu, opaques are pointers. - if (ShapeUtil::IsOpaque(shape)) { - return sizeof(void*); - } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - } - - private: - // Allocate buffers required for execution and assign them to the elements of - // "buffers". "buffers" should be sized to the number of buffers in buffer - // assignment. Each vector element corresponds to a particular Index. If - // a vector element already contains a non-null DeviceMemoryBase, then no - // buffer is assigned for this element. - Status AllocateBuffers( - DeviceMemoryAllocator* memory_allocator, int device_ordinal, - std::vector* buffers); - - // Calls the generated functions in 'function_names_', performing the - // computation with the given arguments using the supplied buffers. - Status ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); - - // Returns the points-to set of the root instruction of the entry - // computation. Uses points-to analysis from buffer assignment. - const PointsToSet& GetRootPointsToSet() const; - - // The JIT containing compiled modules. - tensorflow::mutex jit_mutex_; - const std::unique_ptr jit_ GUARDED_BY(jit_mutex_); - - // Buffer assignment for the buffers we need to allocate. - const std::unique_ptr assignment_; - - // The LLVM IR, in string format, of the unoptimized module generated for this - // ParallelCpuExecutable. We save a string instead of an llvm::Module* because - // leaving llvm::Module* in a singleton can cause the heap checker to emit - // false positives. - string ir_module_string_; - - // Map containing the JITted function names for each HLO instruction. - const std::unique_ptr> function_names_; - - // Map from HLO Constant instructions to a pointer to their literal data. - // The data stored in the protocol buffer might be insufficiently aligned, - // we create a sufficiently aligned copy and store it in this map. - const std::unordered_map> - aligned_constants_; - - TF_DISALLOW_COPY_AND_ASSIGN(ParallelCpuExecutable); -}; - -} // namespace cpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index fb28280fade..63d0f7b95c7 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel { ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module, + const TargetMachineFeatures* target_machine_features) + : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -127,7 +129,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Emit custom loops (kSelectAndScatter). // *) Operations that are not thread safe (like infeed and rng). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. @@ -139,8 +141,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || + PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_)) || + PotentiallyImplementedAsEigenDot(*instruction, + target_machine_features_) || (opcode == HloOpcode::kFusion && instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || ShapeUtil::IsTuple(instruction->shape())) { @@ -231,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { ParallelTaskAssignment parallel_task_assignment(max_parallelism_, - shape_size_function_, module); + shape_size_function_, module, + &target_machine_features_); // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 7140dabe516..8becc8fa234 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -39,7 +40,8 @@ class ParallelTaskAssignment { // 'module': the containing HloModule. ParallelTaskAssignment(const int64 max_parallelism, const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + HloModule* module, + const TargetMachineFeatures* target_machine_features); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -47,6 +49,7 @@ class ParallelTaskAssignment { private: std::unique_ptr cost_model_; + const TargetMachineFeatures& target_machine_features_; }; // ParallelTaskAssigner computes target parallel task counts for all HLOs @@ -63,8 +66,11 @@ class ParallelTaskAssigner : public HloPassInterface { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size) - : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size, + const TargetMachineFeatures* target_machine_features) + : max_parallelism_(max_parallelism), + shape_size_function_(shape_size), + target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -94,6 +100,7 @@ class ParallelTaskAssigner : public HloPassInterface { int64 max_parallelism_; HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index 13eb75a5721..fc2efbaf9a2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -31,6 +32,19 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { // Use any value larger than 2 since we only test whether a module is // parallelized or not const int max_parallelism_ = 10; + + cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; + + ParallelTaskAssignmentTest() + : target_machine_features_([](int64 shape_size) { + return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; + }) {} + + StatusOr RunParallelTaskAssigner(HloModule* module) { + return cpu::ParallelTaskAssigner(max_parallelism_, shape_size_func_, + &target_machine_features_) + .Run(module); + } }; TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { @@ -45,9 +59,7 @@ TEST_F(ParallelTaskAssignmentTest, DotOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -74,9 +86,7 @@ TEST_F(ParallelTaskAssignmentTest, )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -92,9 +102,7 @@ TEST_F(ParallelTaskAssignmentTest, RngOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } @@ -108,9 +116,7 @@ TEST_F(ParallelTaskAssignmentTest, InfeedOutfeedOperationNotParallelized) { )"; ParseAndVerifyModule(hlo_string); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cpu::ParallelTaskAssigner( - max_parallelism_, shape_size_func_) - .Run(&module())); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(&module())); EXPECT_FALSE(changed); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc new file mode 100644 index 00000000000..c60580d6e76 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" +#include +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int64; + +#ifdef INTEL_MKL +#include +#include "mkldnn.hpp" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" + +namespace { + +// Downcast an int64 to int and check if value is in range. +int ToInt(int64 input) { + int output = static_cast(input); + if (static_cast(output) != input) { + std::cerr << "Error occurred in downcasting int64 to int32: Value " << input + << " is out-of-range for type int32. \n"; + exit(1); + } + return output; +} + +using mkldnn::convolution_direct; +using mkldnn::convolution_forward; +using mkldnn::engine; +using mkldnn::memory; +using mkldnn::padding_kind; +using mkldnn::primitive; +using mkldnn::prop_kind; +using mkldnn::reorder; +using mkldnn::stream; + +template +void MKLConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, int64 input_batch, int64 input_rows, + int64 input_cols, int64 input_channels, int64 kernel_rows, + int64 kernel_cols, int64 kernel_channels, int64 kernel_filters, + int64 output_rows, int64 output_cols, int64 row_stride, + int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, + int64 rhs_row_dilation, int64 rhs_col_dilation) { + auto cpu_engine = engine(engine::cpu, 0); + + // Create a vector primitive to hold the network. + std::vector net; + + // Since memory::dims takes int for each dimension, we downcast the int64 + // values to int using the ToInt function defined above. + memory::dims conv1_src_dim = {ToInt(input_batch), ToInt(input_channels), + ToInt(input_rows), ToInt(input_cols)}; + memory::dims conv1_weights_dim = {ToInt(kernel_filters), + ToInt(kernel_channels), ToInt(kernel_rows), + ToInt(kernel_cols)}; + memory::dims conv1_dst_dim = {ToInt(input_batch), ToInt(kernel_filters), + ToInt(output_rows), ToInt(output_cols)}; + memory::dims conv1_strides = {ToInt(row_stride), ToInt(col_stride)}; + // Note: In MKL_DNN dilation starts from 0. + memory::dims conv1_dilates = {ToInt(rhs_row_dilation - 1), + ToInt(rhs_col_dilation - 1)}; + memory::dims conv1_padding_l = {ToInt(padding_top), ToInt(padding_left)}; + memory::dims conv1_padding_r = {ToInt(padding_bottom), ToInt(padding_right)}; + + // Create memory for user data. Input and output data have format of NHWC and + // kernel data has format of HWIO. + // Note that as a convention in MKL-DNN, the dimensions of the data is always + // described in NCHW/IOHW, regardless of the actual layout of the data. + auto user_src_memory = + memory({{{conv1_src_dim}, memory::data_type::f32, memory::format::nhwc}, + cpu_engine}, + lhs); + auto user_weights_memory = memory( + {{{conv1_weights_dim}, memory::data_type::f32, memory::format::hwio}, + cpu_engine}, + rhs); + auto user_dst_memory = + memory({{{conv1_dst_dim}, memory::data_type::f32, memory::format::nhwc}, + cpu_engine}, + out); + + // Create memory descriptors for convolution data with no specified format for + // best performance. + auto conv1_src_mem_desc = memory::desc( + {conv1_src_dim}, memory::data_type::f32, memory::format::any); + auto conv1_weights_mem_desc = memory::desc( + {conv1_weights_dim}, memory::data_type::f32, memory::format::any); + auto conv1_dst_mem_desc = memory::desc( + {conv1_dst_dim}, memory::data_type::f32, memory::format::any); + + // Create a convolution. + auto conv1_desc = convolution_forward::desc( + prop_kind::forward_inference, convolution_direct, conv1_src_mem_desc, + conv1_weights_mem_desc, conv1_dst_mem_desc, conv1_strides, conv1_dilates, + conv1_padding_l, conv1_padding_r, padding_kind::zero); + auto conv1_prim_desc = + convolution_forward::primitive_desc(conv1_desc, cpu_engine); + + // Create reorders for data and weights if layout requested by convolution is + // different from NCHW/OIHW. + auto conv1_src_memory = user_src_memory; + if (memory::primitive_desc(conv1_prim_desc.src_primitive_desc()) != + user_src_memory.get_primitive_desc()) { + conv1_src_memory = memory(conv1_prim_desc.src_primitive_desc()); + net.push_back(reorder(user_src_memory, conv1_src_memory)); + } + + auto conv1_weights_memory = user_weights_memory; + if (memory::primitive_desc(conv1_prim_desc.weights_primitive_desc()) != + user_weights_memory.get_primitive_desc()) { + conv1_weights_memory = memory(conv1_prim_desc.weights_primitive_desc()); + net.push_back(reorder(user_weights_memory, conv1_weights_memory)); + } + + // Check if output need layout conversion. If yes, create memory for + // intermediate layer of conv1_dst_memory. + bool need_output_conversion = + (memory::primitive_desc(conv1_prim_desc.dst_primitive_desc()) != + user_dst_memory.get_primitive_desc()); + auto conv1_dst_memory = need_output_conversion + ? memory(conv1_prim_desc.dst_primitive_desc()) + : user_dst_memory; + + // Create convolution primitive and add it to net. + net.push_back(convolution_forward(conv1_prim_desc, conv1_src_memory, + conv1_weights_memory, conv1_dst_memory)); + if (need_output_conversion) { + net.push_back(reorder(conv1_dst_memory, user_dst_memory)); + } + stream(stream::kind::eager).submit(net).wait(); +} +} // namespace +#endif // INTEL_MKL + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_MKLConvF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, + int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, + int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, + int64 padding_left, int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { +#ifdef INTEL_MKL + // Since MKL_DNN cannot handle transposed convolution, this is handled by + // Eigen. + if (lhs_row_dilation > 1 || lhs_col_dilation > 1) { + __xla_cpu_runtime_EigenConvF32( + run_options_ptr, out, lhs, rhs, input_batch, input_rows, input_cols, + input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, col_stride, + padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); + } else { + MKLConvImpl(nullptr, out, lhs, rhs, input_batch, input_rows, input_cols, + input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, + col_stride, padding_top, padding_bottom, padding_left, + padding_right, lhs_row_dilation, lhs_col_dilation, + rhs_row_dilation, rhs_col_dilation); + } +#else + std::cerr << "Attempt to call MKL Conv2D runtime library without defining " + "INTEL_MKL. Add --config=mkl to build with MKL."; + exit(1); +#endif // INTEL_MKL +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h new file mode 100644 index 00000000000..b239e71d231 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ + +#include +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_MKLConvF32( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, + float* lhs, float* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_MKL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h index 984cb0616e0..0bf693edd0b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -21,8 +21,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/types.h" // 'tensorflow' namespace is used so that int64 and other types don't require @@ -71,11 +69,9 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = fft_shape[i]; out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -88,8 +84,8 @@ void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); // Compute the full FFT using a temporary tensor. - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(in_dims); + const Eigen::DSizes zero_start_indices; full_fft.device(device) = input.template fft(axes); @@ -112,11 +108,9 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, in_dims[0] = input_batch; Eigen::DSizes out_dims; out_dims[0] = input_batch; - TensorShape temp_shape{input_batch}; for (int i = 0; i < FFTRank; i++) { in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; out_dims[i + 1] = fft_shape[i]; - temp_shape.AddDim(fft_shape[i]); } const Eigen::TensorMap, Eigen::Aligned> @@ -129,8 +123,7 @@ void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, // region we will slice from input given fft_shape. We slice input to // fft_shape on its inner-most dimensions, except the last (which we // slice to fft_shape[-1] / 2 + 1). - Tensor temp(DataTypeToEnum::v(), temp_shape); - auto full_fft = temp.flat_inner_dims(); + Eigen::Tensor full_fft(out_dims); // Calculate the starting point and range of the source of // negative frequency part. @@ -179,7 +172,6 @@ template void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, int32 fft_type, int64 input_batch, int64 fft_length0, int64 fft_length1, int64 fft_length2) { - CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; switch (fft_type) { case ::xla::FftType::FFT: EigenFftC2C( @@ -204,7 +196,8 @@ void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, input_batch, fft_length0, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT type: " << fft_type; + // Unsupported FFT type + abort(); } } @@ -230,7 +223,8 @@ void EigenFftImpl(const EigenDevice& device, void* out, void* operand, fft_length1, fft_length2); break; default: - LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + // Unsupported FFT rank + abort(); } } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc new file mode 100644 index 00000000000..2613ddb1270 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.cc @@ -0,0 +1,32 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + tensorflow::xla::EigenFftImpl(Eigen::DefaultDevice(), out, operand, fft_type, + fft_rank, input_batch, fft_length0, fft_length1, + fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h new file mode 100644 index 00000000000..dcd133d012c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenSingleThreadedFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index b3f4609d465..167aa4adda9 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -48,13 +48,13 @@ int main(int argc, char** argv) { client->TransferToServer(*param1_literal).ConsumeValueOrDie(); // Build computation. - xla::ComputationBuilder builder(client, ""); + xla::XlaBuilder builder(""); auto p0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto add = builder.Add(p1, p0, {0}); - xla::StatusOr computation_status = builder.Build(); - xla::Computation computation = computation_status.ConsumeValueOrDie(); + xla::StatusOr computation_status = builder.Build(); + xla::XlaComputation computation = computation_status.ConsumeValueOrDie(); // Execute and transfer result of computation. xla::ExecutionProfile profile; diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h index 33d02b70e61..db2cda2936c 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.h +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -38,7 +38,7 @@ namespace cpu { // // [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8) // -// Note that the last partition has residule because the dimension size is +// Note that the last partition has residual because the dimension size is // not a multiple of the partition count. // // diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index b7ce5bbe474..c4c90515ac7 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -31,12 +31,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" @@ -72,23 +74,33 @@ llvm::StringRef GetHostCpuName() { } } // namespace +/*static*/ std::unique_ptr +SimpleOrcJIT::InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level) { + std::unique_ptr target_machine( + llvm::EngineBuilder() + .setTargetOptions(target_options) + .setOptLevel(opt_level) + .selectTarget( + /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", + /*MCPU=*/GetHostCpuName(), + /*MAttrs=*/DetectMachineAttributes())); + CHECK(target_machine != nullptr); + return target_machine; +} + SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook) - : target_machine_( - CHECK_NOTNULL(llvm::EngineBuilder() - .setTargetOptions(target_options) - .setOptLevel(opt_level) - .selectTarget( - /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/GetHostCpuName(), - /*MAttrs=*/DetectMachineAttributes()))), + : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), symbol_resolver_(llvm::orc::createLegacyLookupResolver( + execution_session_, [this](const std::string& name) -> llvm::JITSymbol { return this->ResolveRuntimeSymbol(name); }, @@ -178,6 +190,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); @@ -190,6 +203,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index f4260a95bc4..1851a3ee0bb 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -95,6 +95,12 @@ class SimpleOrcJIT { return &external_constant_pool_; } + // Creates an llvm::TargetMachine suitable for JITting code that will run on + // the current machine. + static std::unique_ptr InferTargetMachineForJIT( + const llvm::TargetOptions& target_options, + llvm::CodeGenOpt::Level opt_level); + private: llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index eeb049737dd..a0cd8ee2d2b 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -18,7 +18,7 @@ limitations under the License. namespace xla { namespace cpu { -llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( +llvm::TargetTransformInfo* LLVMTargetMachineFeatures::GetTargetTransformInfoFor( const llvm::Function& function) const { auto it = target_transform_info_cache_.find(&function); if (it == target_transform_info_cache_.end()) { @@ -31,5 +31,30 @@ llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( return &it->second; } +int64 LLVMTargetMachineFeatures::minimum_alignment_for_allocation( + int64 size_bytes) const { + // GLibc malloc returns a pointer with alignment 8 on 32-bit platforms and 16 + // on 64-bit platforms. TCMalloc returns a pointer with alignment 8 for + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is + // complicated by our desire to allow parameter buffers created by clients to + // be consumed directly by the JIT. + if (size_bytes == 0) { + // No need to align empty buffers. + return 1; + } + + const int64 kMallocAlignmentThreshold = 512; + + int pointer_size = target_machine_->getPointerSize(0); + int buffer_alignment = + size_bytes >= kMallocAlignmentThreshold ? 2 * pointer_size : pointer_size; + DCHECK_GT(buffer_alignment, 0); + + return buffer_alignment; +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 703942615e5..8b00ae9e47e 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -24,43 +24,68 @@ limitations under the License. namespace xla { namespace cpu { -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// LLVM IR code generation decisions. +// Abstract interface for classes providing information about the target we're +// compiling for. class TargetMachineFeatures { public: static constexpr int kX86AvxVectorByteSize = 32; - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} + // Input and output tensor buffers must be aligned to this many bytes if we + // want to call an Eigen backed GEMM or Convolution. + static constexpr int kEigenExpectedTensorAlignment = 16; // Return the vectorization factor, which is the number of bytes of data // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { + virtual int vectorization_factor_in_bytes() const = 0; + + // Return the size of the largest vector size in bytes. We need to pass in + // "function" since llvm functions can contain annotations for specializing + // them to specific micro-architectures (though currently XLA does not use + // this functionality). + virtual int vector_register_byte_size( + const llvm::Function& function) const = 0; + + // Return the number of elements of type `type` that can fit into the largest + // vector register available. We need to pass in "function" since llvm + // functions can contain annotations for specializing them to specific + // micro-architectures (though currently XLA does not use this functionality). + virtual int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const = 0; + + // Returns the minimum alignment for a buffer of size size_bytes. + virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; + + virtual ~TargetMachineFeatures() = default; +}; + +// Implements the TargetMachineFeatures interface using an llvm::TargetMachine. +class LLVMTargetMachineFeatures : public TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + LLVMTargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + int vectorization_factor_in_bytes() const override { // Ideally this should be a function of the cache line size (which we can // get from llvm::TargetTransformInfo::getCacheLineSize) of the target // machine. Guess a value of 128 bytes for now. return 128; } - // Return the size of the largest vector size in bytes. We need to pass in - // "function" since llvm functions can contain annotations for specializing - // them to specific micro-architectures (though currently XLA does not use - // this functionality). - int vector_register_byte_size(const llvm::Function& function) const { + int vector_register_byte_size(const llvm::Function& function) const override { llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); return tti->getRegisterBitWidth(/*Vector=*/true) / 8; } - // Return the number of elements of type `type` that can fit into the largest - // vector register available. We need to pass in "function" since llvm - // functions can contain annotations for specializing them to specific - // micro-architectures (though currently XLA does not use this functionality). int vector_register_num_elements(const llvm::Function& function, - PrimitiveType type) const { + PrimitiveType type) const override { return vector_register_byte_size(function) / (primitive_util::BitWidth(type) / 8); } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override; + private: llvm::TargetTransformInfo* GetTargetTransformInfoFor( const llvm::Function& function) const; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h new file mode 100644 index 00000000000..ffc6927cbe1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { +// Delegates calls to minimum_alignment_for_allocation to a user provided +// std::function, crashes on all other methods. +// +// Primarily useful for testing. +class TargetMachineFeaturesWithFakeAlignmentLogic + : public TargetMachineFeatures { + public: + explicit TargetMachineFeaturesWithFakeAlignmentLogic( + std::function fake_alignment_logic) + : fake_alignment_logic_(std::move(fake_alignment_logic)) {} + + int vectorization_factor_in_bytes() const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_byte_size(const llvm::Function& function) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + + int64 minimum_alignment_for_allocation(int64 size_bytes) const override { + return fake_alignment_logic_(size_bytes); + } + + private: + std::function fake_alignment_logic_; +}; +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_FAKE_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD new file mode 100644 index 00000000000..67f776e7b58 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -0,0 +1,176 @@ +# Description: +# Tests for LLVM-based CPU backend for XLA. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "cpu_codegen_test", + testonly = True, + hdrs = ["cpu_codegen_test.h"], + deps = [ + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_fusion_test", + srcs = ["cpu_fusion_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_instruction_fusion", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_bytesizeof_test", + srcs = ["cpu_bytesizeof_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_external_constants_test", + srcs = ["cpu_external_constants_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + ], +) + +tf_cc_test( + name = "cpu_noalias_test", + srcs = ["cpu_noalias_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@llvm//:core", + ], +) + +tf_cc_test( + name = "cpu_intrinsic_test", + srcs = ["cpu_intrinsic_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_eigen_dot_operation_test", + srcs = ["cpu_eigen_dot_operation_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_infeed_test", + srcs = ["cpu_infeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_literal_caching_test", + srcs = ["cpu_literal_caching_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "cpu_outfeed_test", + srcs = ["cpu_outfeed_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc new file mode 100644 index 00000000000..d5bbe7677ac --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_bytesizeof_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/test.h" + +class CpuByteSizeOfTest : public ::testing::Test {}; + +TEST_F(CpuByteSizeOfTest, ARM32) { + llvm::DataLayout data_layout( + "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64"); + auto tuple_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout), + data_layout.getPointerSize(0 /* default address space */)); +} + +TEST_F(CpuByteSizeOfTest, ARM64) { + llvm::DataLayout data_layout("e-m:e-i64:64-i128:128-n32:64-S128"); + auto tuple_shape = + xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {})}); + EXPECT_EQ(xla::llvm_ir::ByteSizeOf(tuple_shape, data_layout), + data_layout.getPointerSize(0 /* default address space */)); +} diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h new file mode 100644 index 00000000000..7c8d07a10ba --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h @@ -0,0 +1,30 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ + +#include "tensorflow/compiler/xla/tests/llvm_irgen_test_base.h" + +namespace xla { +namespace cpu { + +// Tests that verify IR emitted by the CPU backend is as expected. +class CpuCodegenTest : public LLVMIRGenTestBase {}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TESTS_CPU_CODEGEN_TEST_H_ diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc new file mode 100644 index 00000000000..6fcce42eaa4 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests that we call into Eigen for dot operations as needed. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +struct DotTestSpec { + PrimitiveType primitive_type; + string filecheck_lines; +}; + +string DotTestSpecToString(const ::testing::TestParamInfo& info) { + return PrimitiveType_Name(info.param.primitive_type); +} + +class CpuEigenDotOperationTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + protected: + void CompileAndCheck(std::unique_ptr entry_computation, + const string& filecheck_lines) { + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(entry_computation)); + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, + filecheck_lines, + /*match_optimized_ir=*/true); + } +}; + +TEST_P(CpuEigenDotOperationTest, SimpleDotOp) { + HloComputation::Builder builder(TestName()); + DotTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128}); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "input")); + + builder.AddInstruction( + HloInstruction::CreateCanonicalDot(param_shape, lhs, rhs)); + CompileAndCheck(builder.Build(), spec.filecheck_lines); +} + +TEST_P(CpuEigenDotOperationTest, DotTransposeOp) { + HloComputation::Builder builder(TestName()); + DotTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(spec.primitive_type, {128, 128}); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "input")); + HloInstruction* lhs_transposed = builder.AddInstruction( + HloInstruction::CreateTranspose(param_shape, lhs, {1, 0})); + + builder.AddInstruction( + HloInstruction::CreateCanonicalDot(param_shape, lhs_transposed, rhs)); + CompileAndCheck(builder.Build(), spec.filecheck_lines); +} + +std::vector GetDotTestCases() { + std::vector result; + result.push_back( + {F16, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF16)"}); + result.push_back( + {F32, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF32)"}); + result.push_back( + {F64, R"(CHECK: call void @__xla_cpu_runtime_EigenMatMulF64)"}); + return result; +} + +INSTANTIATE_TEST_CASE_P(CpuEigenDotOperationTestInstantiation, + CpuEigenDotOperationTest, + ::testing::ValuesIn(GetDotTestCases()), + DotTestSpecToString); + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc new file mode 100644 index 00000000000..ed8f375bd61 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class CpuExternalConstantsTest : public CpuCodegenTest { + public: + void TestWithArray(int64 rows, int64 cols, const char* filecheck_pattern) { + HloComputation::Builder builder(TestName()); + + Array2D backing_array(rows, cols); + backing_array.FillUnique(); + + auto shape = ShapeUtil::MakeShape(F32, {rows, cols}); + + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2FromArray2D(backing_array))); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, constant)); + + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CompileAndVerifyIr(std::move(module), filecheck_pattern, + /*match_optimized_ir=*/false); + } +}; + +TEST_F(CpuExternalConstantsTest, Basic) { + TestWithArray(/*rows=*/1024, /*cols=*/1024, R"( +CHECK: @constant_global_0 = external constant [1024 x [1024 x float]], align 16 +)"); +} + +TEST_F(CpuExternalConstantsTest, BasicNegative) { + // The constant array in this test case is small enough that there is no need + // to externalize it. + TestWithArray(/*rows=*/4, /*cols=*/4, R"( +CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8 +CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8 +)"); +} +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc new file mode 100644 index 00000000000..23e7a3de4d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +class CpuFusionTest : public HloTestBase { + protected: + CpuFusionTest() {} + + ErrorSpec error_spec_{0.0001, 1e-5}; +}; + +TEST_F(CpuFusionTest, FuseTwoElementwiseOps) { + auto builder = HloComputation::Builder(TestName()); + auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + Shape vshape = input_literal1->shape(); + + auto input1 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal1))); + auto input2 = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal2))); + + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kAdd, input1, input2)); + builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, add1)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + auto fusion_instruction = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + EXPECT_EQ(HloOpcode::kNegate, + fusion_instruction->fused_expression_root()->opcode()); + // There should be four fused instructions: 2 parameters, the add, and the + // negate. + EXPECT_EQ(4, fusion_instruction->fused_instruction_count()); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({1.0, 40.0, -5.0}, *result, error_spec_); +} + +TEST_F(CpuFusionTest, FuseElementwiseOpChain) { + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + Shape vshape = input_literal->shape(); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kExp, ceil)); + auto floor = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kFloor, exp)); + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, two, floor)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + auto fusion_instruction = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction->opcode()); + EXPECT_EQ(HloOpcode::kMultiply, + fusion_instruction->fused_expression_root()->opcode()); + // There should be 7 fused instructions: 2 parameters and the fused + // operations. + EXPECT_EQ(7, fusion_instruction->fused_instruction_count()); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0}, *result, + error_spec_); +} + +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { + // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the + // middle. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({-1.5, -2.5, -3.0}); + Shape vshape = input_literal->shape(); + + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, input)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + + auto cshape = ShapeUtil::MakeShape(F32, {6}); + auto concatenate = builder.AddInstruction( + HloInstruction::CreateConcatenate(cshape, {ceil, ceil}, /*dimension=*/0)); + + // Build an x+y computation to use in a reduce. + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + auto embedded_builder = HloComputation::Builder("f32+f32"); + embedded_builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kAdd, + embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "x")), + embedded_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "y")))); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); + + // This is a nop reduction. + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + cshape, + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1}), concatenate)), + /*init_value=*/ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{1}, add_f32)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(cshape, HloOpcode::kExp, reduce)); + auto floor = builder.AddInstruction( + HloInstruction::CreateUnary(cshape, HloOpcode::kFloor, exp)); + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction( + HloInstruction::CreateBinary(cshape, HloOpcode::kMultiply, two, floor)); + + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The computation root instruction was fused. Verify the fusion instruction + // is now the root. + auto computation = module->entry_computation(); + + auto fusion_instruction1 = computation->root_instruction(); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); + EXPECT_EQ(HloOpcode::kMultiply, + fusion_instruction1->fused_expression_root()->opcode()); + // There should be 5 fused instructions in the root fusion instruction: 2 + // parameters, multiply, floor, and exp. + EXPECT_EQ(5, fusion_instruction1->fused_instruction_count()) + << fusion_instruction1->fused_instructions_computation()->ToString(); + + auto fusion_instruction2 = reduce->operand(0); + EXPECT_EQ(HloOpcode::kFusion, fusion_instruction1->opcode()); + EXPECT_EQ(HloOpcode::kReshape, + fusion_instruction2->fused_expression_root()->opcode()); + // There should be 5 fused instructions in the second fusion instruction: 1 + // parameter, negate, ceil, concat, and reshape. + EXPECT_EQ(5, fusion_instruction2->fused_instruction_count()) + << fusion_instruction2->fused_instructions_computation()->ToString(); + + // Compile and execute the computation. + auto result = ExecuteAndTransfer(std::move(module), {}); + + // Check the output correctness. + LiteralTestUtil::ExpectR1Near({14.0, 40.0, 40.0, 14.0, 40.0, 40.0}, + *result, error_spec_); +} + +TEST_F(CpuFusionTest, TestOperandOrderToAvoidDuplication) { + // Test that the operands of an instruction to be fused are considered in the + // proper order to avoid duplication. Test input: + // + // constant = {...} + // negate = neg(constant) + // ceil = ceil(negate) + // add1 = add(negate, ceil) + // add2 = add(ceil, negate) + // + // In this example, the operands of both add1 and add2 should be fused in the + // order {ceil, negate} even though they have different orders in their + // operand vectors. Test for this problem by counting the number of nodes in + // each fusion instruction to ensure that negate is not duplicated. + auto builder = HloComputation::Builder(TestName()); + auto input_literal = Literal::CreateR1({1.0, 2.0, 3.0}); + Shape vshape = input_literal->shape(); + + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kNegate, constant)); + auto ceil = builder.AddInstruction( + HloInstruction::CreateUnary(vshape, HloOpcode::kCeil, negate)); + + auto add1 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, negate, ceil)); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(vshape, HloOpcode::kMultiply, ceil, negate)); + + // Tie together the two adds with a tuple to create a single root. + auto result = + builder.AddInstruction(HloInstruction::CreateTuple({add1, add2})); + + // Create computation and module. + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Run fusion. + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + auto fusion1 = result->operand(0); + auto fusion2 = result->operand(1); + EXPECT_EQ(HloOpcode::kFusion, fusion1->opcode()); + EXPECT_EQ(HloOpcode::kFusion, fusion2->opcode()); + + // Each fusion instruction should have 4 fused instruction inside: add, ceil, + // negate, and the fused parameter. + EXPECT_EQ(4, fusion1->fused_instruction_count()); + EXPECT_EQ(4, fusion2->fused_instruction_count()); + + // Each fusion instruction should have one parameter and the parameter should + // be the constant. + EXPECT_EQ(1, fusion1->operand_count()); + EXPECT_EQ(constant, fusion1->operand(0)); + EXPECT_EQ(1, fusion2->operand_count()); + EXPECT_EQ(constant, fusion2->operand(0)); +} + +TEST_F(CpuFusionTest, DoNotDuplicateExpensiveOps) { + // Verify that expensive operations will not be fused if the fusion results in + // duplication. Test code: + // + // constant = 42.0 + // exp1 = exp(constant) + // negate1 = negate(exp1) + // exp2 = exp(constant) + // negate2 = negate(exp2) + // tuple = tuple(negate1, negate2, exp2) + // + // exp1 should be fused down into negate1, but exp2 will not be fused into + // negate2 because this will result in duplication of the expensive exp + // computation. The duplication is caused by the other use of exp2 in the + // tuple. + auto builder = HloComputation::Builder(TestName()); + auto input_literal1 = Literal::CreateR1({1.0, 2.0, 3.0}); + auto input_literal2 = Literal::CreateR1({-2.0, -42.0, 2.0}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + Shape shape = constant->shape(); + + auto exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant)); + auto negate1 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp1)); + + auto exp2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, constant)); + auto negate2 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kNegate, exp2)); + + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({negate1, negate2, exp2})); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + CpuInstructionFusion fusion; + EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); + + // The only fusion instruction should be operand 0 of the tuple (formerly + // negate1). + EXPECT_EQ(HloOpcode::kFusion, tuple->operand(0)->opcode()); + EXPECT_EQ(HloOpcode::kNegate, tuple->operand(1)->opcode()); + EXPECT_EQ(HloOpcode::kExp, tuple->operand(2)->opcode()); + + auto fusion_inst = tuple->operand(0); + // There should be three fused instructions: negate2, exp2, and the fused + // parameter. + EXPECT_EQ(3, fusion_inst->fused_instruction_count()); + EXPECT_EQ(1, fusion_inst->operand_count()); + EXPECT_EQ(constant, fusion_inst->operand(0)); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc new file mode 100644 index 00000000000..dd63b998e9b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_infeed_test.cc @@ -0,0 +1,294 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class InfeedTest : public ClientLibraryTestBase { + protected: + // Transfers the given literal to the infeed interface of the device, and + // check if the returned data from Infeed HLO is same as the literal. + void TestInfeedRoundTrip(const Literal& literal) { + // TODO(b/31037751) Explicitly reset the Infeed state so that the + // test is not affected by the state from the previous tests by + // adding ClearInfeed if necessary when it is implemented. For now + // don't use ResetDevice since it is not implemented on CPU. + ASSERT_IS_OK(client_->TransferToInfeed(literal)); + XlaBuilder builder(TestName()); + builder.Infeed(literal.shape()); + if (ShapeUtil::IsTuple(literal.shape())) { + // TODO(b/30609564): Use ComputeAndCompareLiteral instead. + ComputeAndCompareTuple(&builder, literal, {}); + } else { + ComputeAndCompareLiteral(&builder, literal, {}); + } + } +}; + +TEST_F(InfeedTest, SingleInfeedR0Bool) { + TestInfeedRoundTrip(*Literal::CreateR0(true)); +} + +TEST_F(InfeedTest, SingleInfeedR1U32) { + TestInfeedRoundTrip(*Literal::CreateR1({1, 2, 3})); +} + +TEST_F(InfeedTest, SingleInfeedR2F32) { + TestInfeedRoundTrip(*Literal::CreateR2F32Linspace(0.0, 1.0, 128, 64)); +} + +TEST_F(InfeedTest, SingleInfeedR3F32) { + TestInfeedRoundTrip( + *Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); +} + +TEST_F(InfeedTest, SingleInfeedR3F32DifferentLayout) { + const Layout r3_dim0minor = LayoutUtil::MakeLayout({0, 1, 2}); + const Layout r3_dim0major = LayoutUtil::MakeLayout({2, 1, 0}); + + TestInfeedRoundTrip( + *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0minor)); + + TestInfeedRoundTrip( + *Literal::CreateR3WithLayout({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}, + r3_dim0major)); +} + +TEST_F(InfeedTest, SingleInfeedR4S32) { + TestInfeedRoundTrip(*Literal::CreateR4( + {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, + {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); +} + +TEST_F(InfeedTest, SingleInfeedTuple) { + TestInfeedRoundTrip( + *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR0(false).get()})); +} + +TEST_F(InfeedTest, SingleInfeedEmptyTuple) { + TestInfeedRoundTrip(*Literal::MakeTuple({})); +} + +// Tests Infeed operation used in a while loop, as in the code below. The +// computation is launched asynchronously, and then infeed data is transferred. +// +// float acc = 0.0f; +// while (acc < 40.0f) { +// acc += reduce_add(Infeed()); +// } +// return acc; +// TODO(b/30671675) enable this test once asynchronous execution is +// implemented for CPU. +TEST_F(InfeedTest, DISABLED_SingleInfeedInWhile) { + XlaBuilder builder(TestName()); + const auto infeed_shape = ShapeUtil::MakeShape(F32, {3}); + const auto result_shape = ShapeUtil::MakeShape(F32, {}); + + // Create a computation for the condition: repeat until (prev < 40.0f) holds. + XlaComputation condition; + { + XlaBuilder builder("condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(40.0f), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + // Create a computation for the body: add the reduced value of the Infeed + // data to the result variable. + XlaComputation body; + { + XlaBuilder builder("body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto infeed = builder.Infeed(infeed_shape); + auto addend = + builder.Reduce(infeed, builder.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + builder.Add(prev, addend); + body = builder.Build().ConsumeValueOrDie(); + } + // Create a While node with computations for the condition and the body. + auto init = builder.ConstantR0(0.0f); + builder.While(condition, body, init); + + // Build and asynchronously launch the computation. + auto computation = builder.Build().ConsumeValueOrDie(); + std::unique_ptr result; + tensorflow::Thread* computation_thread = + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions{}, "computation_thread", [&] { + result = client_->Execute(computation, {}, &execution_options_) + .ValueOrDie(); + }); + + // Send 5 Infeed data of shape F32[3]. + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({1, 2, 3}))); + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({4, 5, 6}))); + ASSERT_IS_OK(client_->TransferToInfeed(*Literal::CreateR1({7, 8, 9}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*Literal::CreateR1({10, 11, 12}))); + ASSERT_IS_OK( + client_->TransferToInfeed(*Literal::CreateR1({13, 14, 15}))); + + delete computation_thread; // Joins the thread. + auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); + + // Only the first 3 infeed data should be added. + LiteralTestUtil::ExpectR0Near(45.0f, *result_literal, ErrorSpec{1e-7}); +} + +// Tests two Infeed operations with a total order. The order is enforced by +// using the result of the first while loop as the initial value of the second +// while loop. The shapes of both Infeeds are Tuples, where the first tuple +// element (R1F32) is for the data to reduce and accumulate, and the second +// tuple element (PRED) to indicate whether the loop should continue. The +// computation is launched asynchronously, and then infeed data is transferred. +// +// float acc = 0.0f; +// continue = true; +// while (!continue) { +// (data, continue) = Infeed(shape1); +// acc += reduce_add(data) +// } +// continue = true; +// while(!continue) { +// (data, continue) = Infeed(shape2); +// acc += reduce_add(data) +// } +// return acc; +// TODO(b/30671675) enable this test once asynchronous execution is +// implemented for CPU. +TEST_F(InfeedTest, DISABLED_TwoInfeedsInTotalOrder) { + XlaBuilder builder(TestName()); + const auto infeed1_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(PRED, {})}); + const auto infeed2_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(PRED, {})}); + const auto result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(PRED, {})}); + + // Create a computation for the condition: repeat until the second tuple + // element is false. + XlaComputation condition; + { + XlaBuilder builder("condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.GetTupleElement(prev, 1); + condition = builder.Build().ConsumeValueOrDie(); + } + + // A lambda that builds the body computation of a while loop with the given + // infeed shape, and returns the computation with the ownership. + // + // The body adds the reduced value of the Infeed data (first tuple element) + // to the previous accumulator, and returns the accumulator and the continue + // flag (second tuple element) as a tuple. + const auto build_body = [this, &result_shape](const Shape& infeed_shape) { + XlaComputation body; + XlaBuilder builder("body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto infeed = builder.Infeed(infeed_shape); + auto addend = builder.Reduce( + builder.GetTupleElement(infeed, 0), builder.ConstantR0(0.0f), + CreateScalarAddComputation(F32, &builder), {0}); + auto result = builder.Add(builder.GetTupleElement(prev, 0), addend); + builder.Tuple({result, builder.GetTupleElement(infeed, 1)}); + return builder.Build().ConsumeValueOrDie(); + }; + + // Create the first while loop with infeed1_shape. + auto init = builder.Tuple( + {builder.ConstantR0(0.0f), builder.ConstantR0(true)}); + auto while1 = builder.While(condition, build_body(infeed1_shape), init); + auto result1 = builder.Tuple( + {builder.GetTupleElement(while1, 0), builder.ConstantR0(true)}); + + // Create the second while loop with infeed2_shape. Note that the result from + // the first while loop is used as the initial value. + auto while2 = builder.While(condition, build_body(infeed2_shape), result1); + builder.GetTupleElement(while2, 0); + + // Build the computation. + auto computation = builder.Build().ConsumeValueOrDie(); + + // Send the first 4 Infeed data of shape Tuple(F32[2], PRED). + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({3, 4}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({5, 6}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({7, 8}).get(), + Literal::CreateR0(false).get()}))); + + // Asynchronously launch the execution on the device. + std::unique_ptr result; + tensorflow::Thread* computation_thread = + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions{}, "computation_thread", [&] { + result = client_->Execute(computation, {}, &execution_options_) + .ValueOrDie(); + }); + + // Wait for a second to ensure testing that the execution is waiting on the + // Infeed data, and send the rest Infeed data of shape Tuple(F32[3], PRED). + sleep(1); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR0(true).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({7, 8, 9}).get(), + Literal::CreateR0(false).get()}))); + ASSERT_IS_OK(client_->TransferToInfeed( + *Literal::MakeTuple({Literal::CreateR1({4, 5, 6}).get(), + Literal::CreateR0(true).get()}))); + + // Wait for the execution to be done, and transfer the result. + delete computation_thread; // Joins the thread. + auto result_literal = client_->Transfer(*result).ConsumeValueOrDie(); + + // Only the first 6 infeed data should be added. + LiteralTestUtil::ExpectR0Near(66.0f, *result_literal, ErrorSpec{1e-7}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc new file mode 100644 index 00000000000..973aac8766f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { + +const char* const kTriple_x86_64 = "x86_64-pc-linux"; +const char* const kTriple_android_arm = "armv7-none-android"; + +struct IntrinsicTestSpec { + HloOpcode opcode; + tensorflow::StringPiece triple; + tensorflow::StringPiece features; + tensorflow::StringPiece check_lines; +}; + +// Tests that unary functions get lowered using intrinsic calls. +class CpuUnaryIntrinsicTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static string Name(const ::testing::TestParamInfo& info) { + auto spec = info.param; + + string opcode = HloOpcodeString(spec.opcode); + opcode[0] = toupper(opcode[0]); + + string triple{spec.triple.data(), spec.triple.size()}; + if (triple == kTriple_x86_64) { + triple = "x86_64"; + } else if (triple == kTriple_android_arm) { + triple = "android_arm"; + } else { + triple = "Unknown"; + } + + string features{spec.features.data(), spec.features.size()}; + if (!features.empty()) { + std::replace_if(features.begin(), features.end(), + [](char c) { return c != '_' && !isalnum(c); }, '_'); + } else { + features = ""; + } + + return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), + features.empty() ? "" : "_With", + features.c_str()); + } +}; + +// Creates a module with a call to the unary op, and tests if the +// compiler replaced it with a call to the intrinsic. +TEST_P(CpuUnaryIntrinsicTest, DoIt) { + HloComputation::Builder builder(TestName()); + IntrinsicTestSpec spec = GetParam(); + + auto param_shape = ShapeUtil::MakeShape(F32, {1024}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "input")); + builder.AddInstruction( + HloInstruction::CreateUnary(param_shape, spec.opcode, param)); + std::unique_ptr computation = builder.Build(); + + string triple{spec.triple.data(), spec.triple.size()}; + string features{spec.features.data(), spec.features.size()}; + + CpuAotCompilationOptions options{ + /*triple=*/triple, /*cpu_name=*/"", /*features=*/features, + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + string check_lines{spec.check_lines.data(), spec.check_lines.size()}; + + CompileAheadOfTimeAndVerifyIr(std::move(hlo_module), options, check_lines, + /*match_optimized_ir=*/true); +} + +IntrinsicTestSpec CpuUnaryIntrinsicTestCases[] = { + // The intrinsics are always inlined, so we match a line from it instead of + // a function call. + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_x86_64, "", + R"(CHECK: fmul fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_x86_64, "+avx", + R"(CHECK: fmul fast <8 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kExp, kTriple_android_arm, "+neon", + R"(CHECK: fmul fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_x86_64, "", + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_x86_64, "+avx", + R"(CHECK: fcmp fast uge <8 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kTanh, kTriple_android_arm, "", + R"(CHECK: fcmp fast uge <4 x float> %wide.load, )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_x86_64, "", + R"(CHECK: fadd fast <4 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_x86_64, "+avx", + R"(CHECK: fadd fast <8 x float> )"}, + + IntrinsicTestSpec{ + HloOpcode::kLog, kTriple_android_arm, "", + R"(CHECK: fadd fast <4 x float> )"}}; + +INSTANTIATE_TEST_CASE_P(CpuUnaryIntrinsicTestInstantiation, + CpuUnaryIntrinsicTest, + ::testing::ValuesIn(CpuUnaryIntrinsicTestCases), + CpuUnaryIntrinsicTest::Name); + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc new file mode 100644 index 00000000000..d6e0425c554 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuDuplicateConstantsTest : public CpuCodegenTest {}; + +TEST_F(CpuDuplicateConstantsTest, RepeatedArrayConstants) { + // We use a while loop here to force the two constant HloInstructions to be in + // different computations. Otherwise the HLO optimizer itself CSEs them. + const string hlo_text = R"( +HloModule RepeatedConstants + +while_body { + arg_body = f32[2,3,2] parameter(0) + ROOT const = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) +} + +while_cond { + arg_cond = f32[2,3,2] parameter(0) + ROOT unknown = pred[] infeed() +} + +ENTRY main { + param = f32[2,3,2] parameter(0) + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + const_b = f32[2,3,2] while(f32[2,3,2] const_a), condition=while_cond, body=while_body + + out0 = () outfeed(f32[2,3,2] const_a) + ROOT out1 = () outfeed(f32[2,3,2] const_b) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [2 x [3 x [2 x float]]] +CHECK-NOT: private constant [2 x [3 x [2 x float]]] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { + // We use a while loop here to force the two constant HloInstructions to be in + // different computations. Otherwise the HLO optimizer itself CSEs them. + const string hlo_text = R"( +HloModule RepeatedConstants + +while_body { + arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) +} + +while_cond { + arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) + ROOT unknown = pred[] infeed() +} + +ENTRY main { + param = f32[2,3,2] parameter(0) + const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body + + out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [2 x float] +CHECK: private constant [2 x [1 x float]] +CHECK-NOT: private constant [2 x float] +CHECK-NOT: private constant [2 x [1 x float]] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc new file mode 100644 index 00000000000..3b6b0ed7406 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/IR/Module.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { + +class CpuNoAliasTest : public CpuCodegenTest {}; + +// Creates a simple HLO ir_module (runs concat(concat(x, y), x)), and then +// inspects the aliasing information for loads to its buffers. +TEST_F(CpuNoAliasTest, Concat) { + HloComputation::Builder builder(TestName()); + + std::unique_ptr literal = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto param_shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* param_x = builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "x")); + HloInstruction* param_y = builder.AddInstruction( + HloInstruction::CreateParameter(1, param_shape, "y")); + HloInstruction* concat1 = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 4}), {param_x, param_y}, 1)); + HloInstruction* concat2 = + builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(F32, {2, 6}), {concat1, param_x}, 1)); + + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. + auto status_or_buffer_assn = BufferAssigner::Run( + hlo_module.get(), MakeUnique(hlo_module.get()), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; }); + ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); + + llvm::LLVMContext context; + llvm_ir::AliasAnalysis aa(*hlo_module, *status_or_buffer_assn.ValueOrDie(), + &context); + + // Construct an LLVM module containing loads that we annotate as being from + // the buffers in the HLO module. We'll inspect these loads to ensure that + // they have the expected alias information. + llvm::Module ir_module("test", context); + llvm::Function* func = llvm::cast( + ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))); + llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func); + llvm::IRBuilder<> ir_builder(bb); + auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0); + llvm_ir::IrArray::Index zero2D({zero, zero}); + + llvm::ArrayType* array2d_type = llvm::ArrayType::get( + llvm::ArrayType::get(llvm::Type::getFloatTy(context), 100), 100); + + { + llvm::Value* param_x_val = + ir_module.getOrInsertGlobal("param_x", array2d_type); + llvm_ir::IrArray param_x_array(param_x_val, param_shape); + aa.AddAliasingInformationToIrArray(*param_x, ¶m_x_array); + param_x_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_param_x_array"); + } + + { + llvm::Value* concat1_val = + ir_module.getOrInsertGlobal("concat1", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 4}); + llvm_ir::IrArray concat1_array(concat1_val, shape); + aa.AddAliasingInformationToIrArray(*concat1, &concat1_array); + concat1_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_concat1_array"); + } + + { + llvm::Value* concat2_val = + ir_module.getOrInsertGlobal("concat2", array2d_type); + auto shape = ShapeUtil::MakeShape(F32, {2, 6}); + llvm_ir::IrArray concat2_array(concat2_val, shape); + aa.AddAliasingInformationToIrArray(*concat2, &concat2_array); + concat2_array.EmitReadArrayElement(zero2D, &ir_builder) + ->setName("read_concat2_array"); + } + + // Check the AA info in the loads. + const char* filecheck_pattern = R"( + CHECK: %read_param_x_array = load {{.*}} !noalias [[param_x_noalias:![0-9]+]] + CHECK: %read_concat1_array = load {{.*}} !alias.scope [[concat1_scope:![0-9]+]], !noalias [[concat1_noalias:![0-9]+]] + CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] + CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 + CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} + CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} + )"; + + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_match, + RunFileCheck(llvm_ir::DumpModuleToString(ir_module), filecheck_pattern)); + EXPECT_TRUE(filecheck_match); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc new file mode 100644 index 00000000000..879372eb138 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuOutfeedTest : public CpuCodegenTest {}; + +TEST_F(CpuOutfeedTest, OutfeedRoot) { + const string hlo_text = R"( +HloModule Outfeed + +ENTRY main { + const_a = f32[2,3,2] constant( + f32[2,3,2] + {{{1, 2}, {1001, 1002}, {2001, 2002}}, + {{2, 1}, {2001, 3002}, {2001, 2002}}}) + + ROOT out = () outfeed(f32[2,3,2] const_a) +} +)"; + + string filecheck_pattern = R"( +CHECK: private constant [2 x [3 x [2 x float]]] +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/false); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index af48f4ab6e5..cc1695b7f86 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -25,7 +25,7 @@ namespace xla { // Creates an HloPassPipeline containing multiple HloPasses that can // despecialize an optimized HloModule. This is useful to run an HloModule -// optimized for one specfic platform on a different platform (undoing platform +// optimized for one specific platform on a different platform (undoing platform // specific passes) with matching numerics for comparison. // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 78e7aa48acc..e228bb56bce 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -24,45 +24,37 @@ limitations under the License. namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( - const perftools::gputools::Platform* platform, - tensorflow::gtl::ArraySlice - stream_executors) + const se::Platform* platform, + tensorflow::gtl::ArraySlice stream_executors) : DeviceMemoryAllocator(platform), stream_executors_(stream_executors.begin(), stream_executors.end()) {} -StatusOr -StreamExecutorMemoryAllocator::Allocate(int device_ordinal, uint64 size, - bool retry_on_failure) { - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, +StatusOr StreamExecutorMemoryAllocator::Allocate( + int device_ordinal, uint64 size, bool retry_on_failure) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - perftools::gputools::DeviceMemoryBase result = - stream_executor->AllocateArray(size); + se::DeviceMemoryBase result = stream_executor->AllocateArray(size); if (size > 0 && result == nullptr) { return ResourceExhausted( "Failed to allocate request for %s (%lluB) on device ordinal %d", tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, device_ordinal); } - return result; + return OwningDeviceMemory(result, device_ordinal, this); } -tensorflow::Status StreamExecutorMemoryAllocator::Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) { - if (!mem->is_null()) { - TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor, +Status StreamExecutorMemoryAllocator::Deallocate(int device_ordinal, + se::DeviceMemoryBase mem) { + if (!mem.is_null()) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * stream_executor, GetStreamExecutor(device_ordinal)); - // We make a local copy of 'mem' so the original is not zeroed out by the - // Deallocate() call below. This gives us a better chance of - // catching double-free bugs, since Deallocate silently succeeds for null - // values. - perftools::gputools::DeviceMemoryBase mem_copy(*mem); - stream_executor->Deallocate(&mem_copy); + stream_executor->Deallocate(&mem); } - return tensorflow::Status::OK(); + return Status::OK(); } -StatusOr -StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) { +StatusOr StreamExecutorMemoryAllocator::GetStreamExecutor( + int device_ordinal) { if (device_ordinal < 0) { return InvalidArgument("device ordinal value (%d) must be non-negative", device_ordinal); diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 39dfad84c1c..d87b86caf0d 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/owning_device_memory.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -33,30 +34,44 @@ class DeviceMemoryAllocator { public: // Parameter platform indicates which platform the allocator allocates memory // on. Must be non-null. - explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform) + explicit DeviceMemoryAllocator(const se::Platform* platform) : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} + // Allocates memory on the device. + // + // If size > 0 and the returned StatusOr is OK, the wrapped OwningDeviceMemory + // must not be null. If size == 0, must return a null OwningDeviceMemory. + // // 'retry_on_failure': If false, and the first attempt to allocate the memory - // fails, the allocation should return immediately without retrying. - // An example use case is optional scratch spaces where a failure - // has only performance impact. - // Allocate() should return a null pointer for a size-0 allocation. - // Deallocate() must be a no-op for null pointers. - virtual StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure = true) = 0; - virtual tensorflow::Status Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0; + // fails, the allocation should return immediately without retrying. An + // example use case is optional scratch spaces where a failure has only + // performance impact. + virtual StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) = 0; + + // Two-arg version of Allocate(), which sets retry-on-failure to true. + // + // (We don't simply use a default argument on the virtual Allocate function + // because default args on virtual functions are disallowed by the Google + // style guide.) + StatusOr Allocate(int device_ordinal, uint64 size) { + return Allocate(device_ordinal, size, /*retry_on_failure=*/true); + } + + // Must be a nop for null pointers. + virtual Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) = 0; // Return the platform that the allocator allocates memory on. - const perftools::gputools::Platform* platform() const { return platform_; } + const se::Platform* platform() const { return platform_; } // Can we call Deallocate() as soon as a computation has been scheduled on // a stream, or do we have to wait for the computation to complete first? virtual bool AllowsAsynchronousDeallocation() const = 0; protected: - const perftools::gputools::Platform* platform_; + friend class OwningDeviceMemory; + const se::Platform* platform_; }; // Default memory allocator for a platform which uses @@ -64,25 +79,26 @@ class DeviceMemoryAllocator { class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( - const perftools::gputools::Platform* platform, - tensorflow::gtl::ArraySlice - stream_executors); + const se::Platform* platform, + tensorflow::gtl::ArraySlice stream_executors); - StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure = true) override; - tensorflow::Status Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + + // Pull in two-arg overload that sets retry_on_failure to true. + using DeviceMemoryAllocator::Allocate; + + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; bool AllowsAsynchronousDeallocation() const override; private: - StatusOr GetStreamExecutor( - int device_ordinal); + StatusOr GetStreamExecutor(int device_ordinal); // A vector indexed by device ordinal of StreamExecutors for each device of // the allocator's platform type. If an element is nullptr, then the device // with the respective device ordinal is not supported by XLA. - std::vector stream_executors_; + std::vector stream_executors_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 3f7089d6ca1..b9d7ec9c2e1 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -138,6 +138,9 @@ class DfsHloVisitorBase { virtual Status HandleExp(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleExpm1(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleFloor(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -147,6 +150,12 @@ class DfsHloVisitorBase { virtual Status HandleLog(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleClz(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleLog1p(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleCos(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } @@ -199,7 +208,6 @@ class DfsHloVisitorBase { virtual Status HandleReduce(HloInstructionPtr hlo) = 0; virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; - virtual Status HandleBroadcastDimOne(HloInstructionPtr hlo) = 0; virtual Status HandleReshape(HloInstructionPtr hlo) = 0; virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; virtual Status HandleParameter(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index e6680ee9b87..240faebe62f 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -158,9 +158,6 @@ class DfsHloVisitorWithDefaultBase Status HandleBroadcast(HloInstructionPtr broadcast) override { return DefaultAction(broadcast); } - Status HandleBroadcastDimOne(HloInstructionPtr broadcastDimOne) override { - return DefaultAction(broadcastDimOne); - } Status HandlePad(HloInstructionPtr pad) override { return DefaultAction(pad); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index b6a0903b0ee..9a8bab353ef 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -52,6 +52,13 @@ using tensorflow::strings::StrCat; namespace { +int64 GlobalRandomValue() { + static auto* mu = new tensorflow::mutex(); + static std::mt19937_64 rng{42}; + tensorflow::mutex_lock l(*mu); + return rng(); +} + llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, int64 mantissa_bits, llvm::IRBuilder<>* ir_builder) { @@ -293,6 +300,12 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( return operand_value; } } + case HloOpcode::kClz: { + auto is_zero_undef = ir_builder_->getFalse(); + return llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::ctlz, {operand_value, is_zero_undef}, + {operand_value->getType()}, ir_builder_); + } case HloOpcode::kSign: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); @@ -405,8 +418,12 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } case HloOpcode::kExp: return EmitExp(op->shape().element_type(), operand_value); + case HloOpcode::kExpm1: + return EmitExpm1(op->shape().element_type(), operand_value); case HloOpcode::kLog: return EmitLog(op->shape().element_type(), operand_value); + case HloOpcode::kLog1p: + return EmitLog1p(op->shape().element_type(), operand_value); case HloOpcode::kCos: return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: @@ -480,6 +497,22 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex( op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } + case HloOpcode::kLog1p: { + // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto one = llvm::ConstantFP::get(llvm_ty, 1.0); + auto a_plus_one = ir_builder_->CreateFAdd(a, one); + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(a_plus_one, a_plus_one), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -510,6 +543,20 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } + case HloOpcode::kExpm1: { + // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); + auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); + auto real_result = + ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one); + auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b); + return EmitComposeComplex(op, real_result, imag_result); + } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) @@ -962,6 +1009,28 @@ StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto negative_half = llvm::ConstantFP::get(type, -0.5); + // When x is large, the naive evaluation of ln(x + 1) is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto for_large_x, + EmitLog(prim_type, ir_builder_->CreateFAdd(x, one))); + // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. + auto for_small_x = ir_builder_->CreateFMul( + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one), + x); + const auto kAntilogarithmIsSmallThreshold = 1e-4; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, llvm::Value* value) const { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, @@ -980,6 +1049,29 @@ StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, {value->getType()}, ir_builder_); } +StatusOr ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const { + auto x = value; + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); + auto one = llvm::ConstantFP::get(type, 1.0); + auto half = llvm::ConstantFP::get(type, 0.5); + // When the exponent is large, the naive evaluation of e^(x) - 1 is more + // accurate than the Taylor series. + TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); + auto for_large_x = ir_builder_->CreateFSub(exp_x, one); + // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. + // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. + auto x_squared = ir_builder_->CreateFAdd(x, x); + auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half); + auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two); + const auto kExponentIsSmallThreshold = 1e-5; + auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, + {type}, ir_builder_); + auto x_is_small = ir_builder_->CreateFCmpOLT( + abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x); +} + StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { @@ -1169,7 +1261,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::Value* increment = ir_builder_->getInt( llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); - auto random_value = [hlo]() { + auto random_value_from_hlo = [hlo]() { const HloModule* module = hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() : hlo->parent()->parent(); @@ -1191,10 +1283,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( /*Ty=*/ir_builder_->getInt64Ty(), /*isConstant=*/false, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/ir_builder_->getInt64(random_value()), + /*Initializer=*/ir_builder_->getInt64(random_value_from_hlo()), /*Name=*/"state_ptr0"); + + // When the module config seed is 0, the expected result of a prng is a random + // value. Instead of using the random_value_from_hlo, we need a global random + // value as the graph seed. This is because if we use random_value_from_hlo + // here, then for a newly built hlo graph, it always gives the same number. uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() - : random_value(); + : GlobalRandomValue(); llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( /*M=*/*module_, /*Ty=*/ir_builder_->getInt64Ty(), @@ -1326,6 +1423,482 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( }; } +StatusOr ElementalIrEmitter::EmitElementalSelect( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + return ir_builder_->CreateSelect( + ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), + on_true_value, on_false_value); +} + +StatusOr ElementalIrEmitter::EmitElementalClamp( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + TF_ASSIGN_OR_RETURN(llvm::Value * min_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, + operand_to_generator.at(hlo->operand(1))( + ElementwiseSourceIndex(index, *hlo, 1))); + TF_ASSIGN_OR_RETURN(llvm::Value * max_value, + operand_to_generator.at(hlo->operand(2))( + ElementwiseSourceIndex(index, *hlo, 2))); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } +} + +StatusOr ElementalIrEmitter::EmitElementalConcatenate( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const { + const int64 concat_dim = hlo->dimensions(0); + auto source_index = target_index; + + llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(), + AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); + llvm::PHINode* output = ir_builder_->CreatePHI( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); + auto prior_insert_point = ir_builder_->GetInsertPoint(); + + ir_builder_->SetInsertPoint(init_block); + + for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); + ++operand_idx) { + const HloInstruction* operand = hlo->operand(operand_idx); + auto true_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_from_operand", operand_idx), + ir_builder_); + auto false_block = llvm_ir::CreateBasicBlock( + exit_block, StrCat("concat_index_not_from_operand", operand_idx), + ir_builder_); + auto concat_dim_size = + llvm::ConstantInt::get(source_index[concat_dim]->getType(), + operand->shape().dimensions(concat_dim)); + ir_builder_->CreateCondBr( + ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size), + true_block, false_block); + + // Create the terminator of the true block before calling operand + // generators, because they require non-degenerate basic blocks. + ir_builder_->SetInsertPoint( + llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); + TF_ASSIGN_OR_RETURN(llvm::Value * value, + operand_to_generator.at(operand)(source_index)); + output->addIncoming(value, ir_builder_->GetInsertBlock()); + + // Subtract the size of the concat dimension of the current operand + // from the source index. + ir_builder_->SetInsertPoint(false_block); + source_index[concat_dim] = + ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); + } + + ir_builder_->CreateUnreachable(); + ir_builder_->SetInsertPoint(exit_block, prior_insert_point); + return output; +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + // Emit IR to read dynamic start indices from hlo->operand(1). + const HloInstruction* input_hlo = hlo->operand(0); + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(hlo->operand(1))(dim_index)); + + // Clamp the start index so that the sliced portion fits in the operand: + // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* operand_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index_value->getType(), hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(operand_dim_size, output_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + } + + llvm_ir::IrArray::Index input_index(rank); + for (int64 i = 0; i < rank; ++i) { + // Emit IR which computes: + // input_index = start_index + offset_index + input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]); + } + return operand_to_generator.at(input_hlo)(input_index); +} + +StatusOr ElementalIrEmitter::EmitElementalGather( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const Shape& operand_shape = hlo->operand(0)->shape(); + const Shape& indices_shape = hlo->operand(1)->shape(); + const Shape& output_shape = hlo->shape(); + + const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers(); + + const llvm_ir::ElementGenerator& operand_generator = + operand_to_generator.at(hlo->operand(0)); + const llvm_ir::ElementGenerator& indices_generator = + operand_to_generator.at(hlo->operand(1)); + + // This is the index into `operand` that holds the element we want to + // generate. This index "unsafe" as in the components in here may be + // out of bounds. + IrArray::Index unsafe_operand_index; + + // First copy in the window indices to unsafe_operand_index. + for (int64 i = 0, e = operand_shape.dimensions_size(), + unsafe_operand_index_dim = 0; + i < e; i++) { + if (c_binary_search(dim_numbers.elided_window_dims(), i)) { + unsafe_operand_index.push_back(ir_builder_->getInt64(0)); + } else { + unsafe_operand_index.push_back( + index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]); + } + } + + // This is the index of the index vector in the gather_indices tensor. + IrArray::Index gather_index_index; + { + std::vector gather_index_index_components; + for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + gather_index_index.push_back(index[i]); + } + } + + if (gather_index_index.size() != indices_shape.dimensions_size()) { + gather_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + } + } + + auto add_to_unsafe_operand_index = [&](llvm::Value* index_component, + int64 dim) { + llvm::Value* gather_dim_component_extended = ir_builder_->CreateSExtOrTrunc( + index_component, ir_builder_->getInt64Ty()); + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] = + ir_builder_->CreateAdd( + unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)], + gather_dim_component_extended); + }; + + if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, 0); + } else { + int64 index_vector_size = + indices_shape.dimensions(dim_numbers.index_vector_dim()); + for (int64 i = 0; i < index_vector_size; i++) { + gather_index_index[dim_numbers.index_vector_dim()] = + ir_builder_->getInt64(i); + TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component, + indices_generator(gather_index_index)); + add_to_unsafe_operand_index(gather_dim_component, i); + } + } + + IrArray::Index safe_operand_index; + for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) { + safe_operand_index.push_back(ir_builder_->CreateURem( + unsafe_operand_index[i], + ir_builder_->getInt64(operand_shape.dimensions(i)))); + } + + return operand_generator(safe_operand_index); +} + +StatusOr ElementalIrEmitter::EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const { + const HloInstruction* input_hlo = hlo->operand(0); + const HloInstruction* update_hlo = hlo->operand(1); + const HloInstruction* start_hlo = hlo->operand(2); + // Calculate slice start/end indices. + const int64 rank = ShapeUtil::Rank(input_hlo->shape()); + llvm_ir::IrArray::Index slice_start_index(rank); + llvm_ir::IrArray::Index slice_limit_index(rank); + // Slice intersection gathers (ANDs) conditions on all ranks for which + // 'input' is set to 'update' + llvm::Value* slice_intersection = ir_builder_->getTrue(); + + for (int64 i = 0; i < rank; ++i) { + llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); + TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, + operand_to_generator.at(start_hlo)(dim_index)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + start_index_value = ir_builder_->CreateSExtOrBitCast(start_index_value, + index[i]->getType()); + llvm::Value* input_dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + index[i]->getType(), update_hlo->shape().dimensions(i)); + + start_index_value = EmitIntegralMin( + ir_builder_->CreateSub(input_dim_size, update_dim_size), + EmitIntegralMax(llvm::ConstantInt::get(start_index_value->getType(), 0), + start_index_value, /*is_signed=*/true), + /*is_signed=*/true); + + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); + slice_start_index[i] = start_index_value; + slice_limit_index[i] = + ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); + + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = ir_builder_->CreateAnd( + slice_intersection, + ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); + } + + // Emit: + // if (slice_intersection) -> return data from 'update'. + // else -> return data from 'input'. + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "ret_value_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + slice_intersection, "slice_intersection", ir_builder_); + + // Handle true BB (return data from 'update') + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + // Compute update index for intersection case. + llvm_ir::IrArray::Index update_index(rank); + for (int64 i = 0; i < rank; ++i) { + update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]); + } + TF_ASSIGN_OR_RETURN(llvm::Value * true_value, + operand_to_generator.at(update_hlo)(update_index)); + ir_builder_->CreateStore(true_value, ret_value_addr); + + // Handle false BB (return data from 'input') + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * false_value, + operand_to_generator.at(input_hlo)(index)); + ir_builder_->CreateStore(false_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalPad( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const { + auto index = padded_index; + llvm::Value* in_bounds = ir_builder_->getTrue(); + for (size_t i = 0; i < index.size(); ++i) { + auto index_typed_const = [=](int64 n) { + return llvm::ConstantInt::get(index[i]->getType(), n); + }; + const auto& pad_dim = hlo->padding_config().dimensions(i); + index[i] = ir_builder_->CreateSub( + index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = ir_builder_->CreateAnd( + in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), + "in_bounds"); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpEQ( + index_typed_const(0), + ir_builder_->CreateURem( + index[i], index_typed_const(pad_dim.interior_padding() + 1))), + "in_bounds"); + index[i] = ir_builder_->CreateSDiv( + index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = ir_builder_->CreateAnd( + in_bounds, + ir_builder_->CreateICmpSLT( + index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); + } + + // if (in_bounds) { + // ret_value = operand0[index]; // source + // } else { + // ret_value = *operand1; // padding + // } + llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + "pad_result_addr", ir_builder_); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + SetToFirstInsertPoint(if_data.true_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))(index)); + ir_builder_->CreateStore(operand_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.false_block, ir_builder_); + TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, + operand_to_generator.at(hlo->operand(1))({})); + ir_builder_->CreateStore(padding_value, ret_value_addr); + + SetToFirstInsertPoint(if_data.after_block, ir_builder_); + // Don't create phi(operand_value, padding_value) here, because invoking + // operand_to_generator may create new basic blocks, making the parent + // of operand_value or padding_value no longer a predecessor of + // if_data.after_block. + return ir_builder_->CreateLoad(ret_value_addr); +} + +StatusOr ElementalIrEmitter::EmitElementalDot( + const HloInstruction* hlo, + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const { + auto lhs_generator = operand_to_generator.at(hlo->operand(0)); + auto rhs_generator = operand_to_generator.at(hlo->operand(1)); + + const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); + int64 lhs_contracting_dim = dim_numbers.lhs_contracting_dimensions(0); + int64 rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); + + int64 contracted_dim_size = + hlo->operand(0)->shape().dimensions(lhs_contracting_dim); + int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); + int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); + + std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( + IrName(hlo, "inner"), ir_builder_->getInt64(0), + ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), + ir_builder_); + + SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_); + PrimitiveType primitive_type = hlo->shape().element_type(); + llvm::Type* primitive_type_llvm = + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); + llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( + primitive_type_llvm, "dot_acc", ir_builder_); + ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), + accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); + + // This is the inner reduction loop for a dot operation that produces + // one element in the output. If the operands to the dot operation have + // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. + // Given an output index [a,b,c,d,e] in the result, we compute: + // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) + + IrArray::Index lhs_index, rhs_index; + + for (int64 i = 0; i < lhs_dims - 1; i++) { + lhs_index.push_back(dot_result_index[i]); + } + lhs_index.InsertAt(lhs_contracting_dim, inner_loop->GetIndVarValue()); + + for (int64 i = 0; i < rhs_dims - 1; i++) { + rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); + } + rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); + + llvm::Value* current_accumulator = + ir_builder_->CreateLoad(accumulator_alloca); + TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); + TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); + llvm::Value* next_accumulator; + if (primitive_util::IsComplexType(primitive_type)) { + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); + next_accumulator = ir_builder_->CreateInsertValue( + current_accumulator, + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), + {0}); + next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator, + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), + {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { + next_accumulator = ir_builder_->CreateFAdd( + current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); + } else { + next_accumulator = ir_builder_->CreateAdd( + current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value)); + } + ir_builder_->CreateStore(next_accumulator, accumulator_alloca); + + SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); + return ir_builder_->CreateLoad(accumulator_alloca); +} + llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) @@ -1334,15 +1907,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: case HloOpcode::kReal: @@ -1392,43 +1968,12 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSelect: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - return ir_builder_->CreateSelect( - ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), - on_true_value, on_false_value); + return EmitElementalSelect(hlo, operand_to_generator, index); }; case HloOpcode::kClamp: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - TF_ASSIGN_OR_RETURN(llvm::Value * min_value, - operand_to_generator.at(hlo->operand(0))( - ElementwiseSourceIndex(index, *hlo, 0))); - TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, - operand_to_generator.at(hlo->operand(1))( - ElementwiseSourceIndex(index, *hlo, 1))); - TF_ASSIGN_OR_RETURN(llvm::Value * max_value, - operand_to_generator.at(hlo->operand(2))( - ElementwiseSourceIndex(index, *hlo, 2))); - PrimitiveType prim_type = hlo->shape().element_type(); - if (primitive_util::IsFloatingPointType(prim_type)) { - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); - } else if (primitive_util::IsIntegralType(prim_type)) { - bool is_signed = primitive_util::IsSignedIntegralType(prim_type); - return EmitIntegralMin( - max_value, EmitIntegralMax(min_value, arg_value, is_signed), - is_signed); - } else { - return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); - } + return EmitElementalClamp(hlo, operand_to_generator, index); }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( @@ -1441,70 +1986,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { - const int64 concat_dim = hlo->dimensions(0); - auto source_index = target_index; - - llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); - - // A terminator should be present iff we're emitting code - // into the middle (as opposed to the end) of a basic block. - CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), - init_block->getTerminator() == nullptr); - - llvm::BasicBlock* exit_block; - if (ir_builder_->GetInsertPoint() == init_block->end()) { - exit_block = llvm_ir::CreateBasicBlock( - /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); - } else { - exit_block = init_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); - init_block->getTerminator()->eraseFromParent(); - } - - llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); - llvm::PHINode* output = - ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), module_), - hlo->operands().size()); - auto prior_insert_point = ir_builder_->GetInsertPoint(); - - ir_builder_->SetInsertPoint(init_block); - - for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); - ++operand_idx) { - const HloInstruction* operand = hlo->operand(operand_idx); - auto true_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_from_operand", operand_idx), - ir_builder_); - auto false_block = llvm_ir::CreateBasicBlock( - exit_block, StrCat("concat_index_not_from_operand", operand_idx), - ir_builder_); - auto concat_dim_size = - llvm::ConstantInt::get(source_index[concat_dim]->getType(), - operand->shape().dimensions(concat_dim)); - ir_builder_->CreateCondBr( - ir_builder_->CreateICmpULT(source_index[concat_dim], - concat_dim_size), - true_block, false_block); - - // Create the terminator of the true block before calling operand - // generators, because they require non-degenerate basic blocks. - ir_builder_->SetInsertPoint( - llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); - TF_ASSIGN_OR_RETURN(llvm::Value * value, - operand_to_generator.at(operand)(source_index)); - output->addIncoming(value, ir_builder_->GetInsertBlock()); - - // Subtract the size of the concat dimension of the current operand - // from the source index. - ir_builder_->SetInsertPoint(false_block); - source_index[concat_dim] = - ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); - } - - ir_builder_->CreateUnreachable(); - ir_builder_->SetInsertPoint(exit_block, prior_insert_point); - return output; + return EmitElementalConcatenate(hlo, operand_to_generator, + target_index); }; case HloOpcode::kReverse: return [this, hlo, &operand_to_generator]( @@ -1540,184 +2023,19 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kDynamicSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - // Emit IR to read dynamic start indices from hlo->operand(1). - const HloInstruction* input_hlo = hlo->operand(0); - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN( - llvm::Value * start_index_value, - operand_to_generator.at(hlo->operand(1))(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = start_index_value; - } + return EmitElementalDynamicSlice(hlo, operand_to_generator, index); + }; - llvm_ir::IrArray::Index input_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // input_index = (start_index + offset_index) % dim_size - // Security note: this is the code that keeps the indices in-bounds. - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( - slice_start_index[i], index[i]->getType()); - input_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateAdd(start_index, index[i]), dim_size); - } - return operand_to_generator.at(input_hlo)(input_index); + case HloOpcode::kGather: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + return EmitElementalGather(hlo, operand_to_generator, index); }; case HloOpcode::kDynamicUpdateSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { - const HloInstruction* input_hlo = hlo->operand(0); - const HloInstruction* update_hlo = hlo->operand(1); - const HloInstruction* start_hlo = hlo->operand(2); - // Calculate slice start/end indices. - const int64 rank = ShapeUtil::Rank(input_hlo->shape()); - llvm_ir::IrArray::Index slice_start_index(rank); - llvm_ir::IrArray::Index slice_limit_index(rank); - // Slice starts at update[index - slice_start_index_adjusted], - // where adjusted value = slice_start_index when in bounds, and - // adjusted value = slice_start_index - input_dim, when wrapping. - llvm_ir::IrArray::Index slice_start_index_adjusted(rank); - - // Slice intersection gathers (ANDs) conditions on all ranks for which - // 'input' is set to 'update' - llvm::Value* slice_intersection = ir_builder_->getTrue(); - - for (int64 i = 0; i < rank; ++i) { - // Emit IR to read dynamic start indices from 'start_hlo'. - llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); - TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, - operand_to_generator.at(start_hlo)(dim_index)); - start_index_value->setName( - AsStringRef(IrName(hlo, StrCat("start_idx", i)))); - slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( - start_index_value, index[i]->getType()); - - llvm::Value* input_dim_size = llvm::ConstantInt::get( - index[i]->getType(), input_hlo->shape().dimensions(i)); - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - - // Generate code to handle wrapping semantics: - // slice_start_index[i] = slice_start_index[i] % input_dim_size; - // slice_limit_index[i] = slice_start_index[i] + update_dim_size. - // slice_start_index[i] is updated in place and it will now be in - // range. slice_limit_index[i] may be out of range, and it's being - // URem-ed below if so. - slice_start_index[i] = - ir_builder_->CreateURem(slice_start_index[i], input_dim_size); - slice_limit_index[i] = - ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - - // Test if slice_limit_index[i] is in bounds - llvm::Value* in_bounds = - ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); - llvm_ir::LlvmIfData if_in_bounds = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - - // Handle true BB (slice_limit_index[i] <= input_dim_size). - SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] && - // index[i] < slice_limit_index[i] - llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection, - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection_in"); - slice_intersection_in_bounds = ir_builder_->CreateAnd( - slice_intersection_in_bounds, - ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection_in"); - - // Handle false BB (slice_limit_index[i] > input_dim_size). - SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); - // Check that index[i] >= slice_start_index[i] || - // index[i] < slice_limit_index[i]%input_dim_size. - llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( - index[i], - ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); - llvm::Value* slice_intersection_or = ir_builder_->CreateOr( - ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - index_wraps, "slice_intersection_out"); - llvm::Value* slice_intersection_out_of_bounds = - ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, - "slice_intersection_out"); - // Create value for slice_start_index_adjusted[i] when out of bounds. - // If within out-of-bounds if. - llvm_ir::LlvmIfData if_start_needs_adjustment = - llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); - SetToFirstInsertPoint(if_start_needs_adjustment.true_block, - ir_builder_); - llvm::Value* slice_start_index_adjusted_oob = - ir_builder_->CreateSub(slice_start_index[i], input_dim_size); - SetToFirstInsertPoint(if_start_needs_adjustment.after_block, - ir_builder_); - llvm::PHINode* slice_start_index_adjusted_phi = - ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), - 2); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index_adjusted_oob, - if_start_needs_adjustment.true_block); - slice_start_index_adjusted_phi->addIncoming( - slice_start_index[i], if_start_needs_adjustment.false_block); - // End of if within if. - - // After checking in/out of bounds. - SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); - llvm::PHINode* phi_slice_intersection = - ir_builder_->CreatePHI(slice_intersection->getType(), 2); - phi_slice_intersection->addIncoming(slice_intersection_in_bounds, - if_in_bounds.true_block); - phi_slice_intersection->addIncoming( - slice_intersection_out_of_bounds, - if_start_needs_adjustment.after_block); - slice_intersection = phi_slice_intersection; - - llvm::PHINode* phi_index = - ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); - phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); - phi_index->addIncoming(slice_start_index_adjusted_phi, - if_start_needs_adjustment.after_block); - slice_start_index_adjusted[i] = phi_index; - } - - // Emit: - // if (slice_intersection) -> return data from 'update'. - // else -> return data from 'input'. - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "ret_value_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - slice_intersection, "slice_intersection", ir_builder_); - - // Handle true BB (return data from 'update') - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - // Compute update index for intersection case. - llvm_ir::IrArray::Index update_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm::Value* update_dim_size = llvm::ConstantInt::get( - index[i]->getType(), update_hlo->shape().dimensions(i)); - // NOTE: Subtraction will be positive due to bounds checking above. - update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), - update_dim_size); - } - TF_ASSIGN_OR_RETURN(llvm::Value * true_value, - operand_to_generator.at(update_hlo)(update_index)); - ir_builder_->CreateStore(true_value, ret_value_addr); - - // Handle false BB (return data from 'input') - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * false_value, - operand_to_generator.at(input_hlo)(index)); - ir_builder_->CreateStore(false_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalDynamicUpdateSlice(hlo, operand_to_generator, + index); }; case HloOpcode::kBitcast: CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), @@ -1746,155 +2064,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRng: return MakeRngElementGenerator(hlo, operand_to_generator); case HloOpcode::kPad: - return [=, &operand_to_generator]( + return [this, hlo, &operand_to_generator]( const IrArray::Index& padded_index) -> StatusOr { - auto index = padded_index; - llvm::Value* in_bounds = ir_builder_->getTrue(); - for (size_t i = 0; i < index.size(); ++i) { - auto index_typed_const = [=](int64 n) { - return llvm::ConstantInt::get(index[i]->getType(), n); - }; - const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = ir_builder_->CreateSub( - index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpEQ( - index_typed_const(0), - ir_builder_->CreateURem( - index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = ir_builder_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = ir_builder_->CreateAnd( - in_bounds, - ir_builder_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), - "in_bounds"); - } - - // if (in_bounds) { - // ret_value = operand0[index]; // source - // } else { - // ret_value = *operand1; // padding - // } - llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_), - "pad_result_addr", ir_builder_); - llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, - operand_to_generator.at(hlo->operand(0))(index)); - ir_builder_->CreateStore(operand_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.false_block, ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, - operand_to_generator.at(hlo->operand(1))({})); - ir_builder_->CreateStore(padding_value, ret_value_addr); - - SetToFirstInsertPoint(if_data.after_block, ir_builder_); - // Don't create phi(operand_value, padding_value) here, because invoking - // operand_to_generator may create new basic blocks, making the parent - // of operand_value or padding_value no longer a predecessor of - // if_data.after_block. - return ir_builder_->CreateLoad(ret_value_addr); + return EmitElementalPad(hlo, operand_to_generator, padded_index); }; case HloOpcode::kDot: - return [=, &operand_to_generator](const IrArray::Index& dot_result_index) + return [this, hlo, + &operand_to_generator](const IrArray::Index& dot_result_index) -> StatusOr { - auto lhs_generator = operand_to_generator.at(hlo->operand(0)); - auto rhs_generator = operand_to_generator.at(hlo->operand(1)); - int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( - hlo->operand(0)->shape().dimensions_size() - 1); - int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); - int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); - - std::unique_ptr inner_loop = - llvm_ir::ForLoop::EmitForLoop( - IrName(hlo, "inner"), ir_builder_->getInt64(0), - ir_builder_->getInt64(contracted_dim_size), - ir_builder_->getInt64(1), ir_builder_); - - SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), - ir_builder_); - PrimitiveType primitive_type = hlo->shape().element_type(); - llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); - llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( - primitive_type_llvm, "dot_acc", ir_builder_); - ir_builder_->CreateStore( - llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); - - // This is the inner reduction loop for a dot operation that produces - // one element in the output. If the operands to the dot operation have - // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. - // Given an output index [a,b,c,d,e] in the result, we compute: - // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) - - IrArray::Index lhs_index, rhs_index; - - for (int64 i = 0; i < lhs_dims - 1; i++) { - lhs_index.push_back(dot_result_index[i]); - } - lhs_index.push_back(inner_loop->GetIndVarValue()); - - for (int64 i = 0; i < rhs_dims - 2; i++) { - rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); - } - rhs_index.push_back(inner_loop->GetIndVarValue()); - rhs_index.push_back(dot_result_index.back()); - - llvm::Value* current_accumulator = - ir_builder_->CreateLoad(accumulator_alloca); - TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); - TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); - llvm::Value* next_accumulator; - if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); - llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - ir_builder_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value))); - next_accumulator = ir_builder_->CreateInsertValue( - current_accumulator, - ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), - product_real), - {0}); - next_accumulator = ir_builder_->CreateInsertValue( - next_accumulator, - ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), - product_imag), - {1}); - } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = ir_builder_->CreateFAdd( - current_accumulator, - ir_builder_->CreateFMul(lhs_value, rhs_value)); - } else { - next_accumulator = ir_builder_->CreateAdd( - current_accumulator, - ir_builder_->CreateMul(lhs_value, rhs_value)); - } - ir_builder_->CreateStore(next_accumulator, accumulator_alloca); - - SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accumulator_alloca); + return EmitElementalDot(hlo, operand_to_generator, dot_result_index); }; default: return [this, hlo, &operand_to_generator](const IrArray::Index& index) { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index c516a826d9e..d199473374a 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -105,6 +105,9 @@ class ElementalIrEmitter { virtual StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const; @@ -114,6 +117,9 @@ class ElementalIrEmitter { virtual StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const; + virtual StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const; @@ -142,6 +148,46 @@ class ElementalIrEmitter { return ir_builder_->getIntN(128, 0); } + StatusOr EmitElementalSelect( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalClamp( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalConcatenate( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& target_index) const; + + StatusOr EmitElementalDynamicSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalGather( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalDynamicUpdateSlice( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& index) const; + + StatusOr EmitElementalPad( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& padded_index) const; + + StatusOr EmitElementalDot( + const HloInstruction* hlo, + const HloToElementGeneratorMap& operand_to_generator, + const llvm_ir::IrArray::Index& dot_result_index) const; + llvm::IRBuilder<>* const ir_builder_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc new file mode 100644 index 00000000000..b43dc0c65d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ElementalIrEmitterExecutionTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text, config)); + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { + const string hlo_text = R"( +HloModule FusedDot + +fused_computation { + arg0 = s32[1,2,1]{2,1,0} parameter(0) + reshape.lhs = s32[2,1]{1,0} reshape(arg0) + arg1 = s32[1,2,1]{2,1,0} parameter(1) + reshape.rhs = s32[2,1]{1,0} reshape(arg1) + ROOT dot = s32[1,1]{1,0} dot(reshape.lhs, reshape.rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} + +ENTRY main { + entry_arg0 = s32[1,2,1]{2,1,0} parameter(0) + entry_arg1 = s32[1,2,1]{2,1,0} parameter(1) + ROOT fusion = s32[1,1]{1,0} fusion(entry_arg0, entry_arg1), kind=kLoop, calls=fused_computation +} +)"; + + std::unique_ptr lhs = Literal::CreateR3({{{1}, {2}}}); + std::unique_ptr rhs = Literal::CreateR3({{{3}, {4}}}); + RunTest(hlo_text, {lhs.get(), rhs.get()}); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 471d2fd6ceb..8119478ce93 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -29,18 +29,19 @@ using tensorflow::gtl::ArraySlice; namespace xla { -StatusOr>> -Executable::ExecuteOnStreams( +StatusOr> Executable::ExecuteOnStreams( ArraySlice run_options, ArraySlice> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); - std::vector> return_values(run_options.size()); + std::vector return_values; + return_values.reserve(run_options.size()); if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(return_values[0], + TF_ASSIGN_OR_RETURN(auto rv, ExecuteOnStream(&run_options[0], arguments[0], /*hlo_execution_profile=*/nullptr)); + return_values.push_back(std::move(rv)); return std::move(return_values); } @@ -48,8 +49,9 @@ Executable::ExecuteOnStreams( // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched // executions may never complete if not all executions are running. - TF_ASSIGN_OR_RETURN(return_values[i], + TF_ASSIGN_OR_RETURN(auto rv, ExecuteAsyncOnStream(&run_options[i], arguments[i])); + return_values.push_back(std::move(rv)); } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); @@ -58,13 +60,13 @@ Executable::ExecuteOnStreams( return std::move(return_values); } -StatusOr> Executable::ExecuteOnStreamWrapper( +StatusOr Executable::ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, ArraySlice arguments) { - perftools::gputools::Stream* stream = run_options->stream(); - std::unique_ptr timer; + se::Stream* stream = run_options->stream(); + std::unique_ptr timer; if (profile != nullptr) { - timer.reset(new perftools::gputools::Timer(stream->parent())); + timer.reset(new se::Timer(stream->parent())); stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); } @@ -78,7 +80,7 @@ StatusOr> Executable::ExecuteOnStreamWrapper( &hlo_profile_index_map()) : nullptr; - StatusOr> return_value = + StatusOr return_value = ExecuteOnStream(run_options, arguments, profile_ptr.get()); TF_RETURN_IF_ERROR(return_value.status()); @@ -141,6 +143,19 @@ Status Executable::DumpSessionModule() { *session_module_); } +Status Executable::DumpHloSnapshot() { + TF_RET_CHECK(dumping_snapshot()); + TF_RET_CHECK(hlo_snapshot_->has_hlo() && + hlo_snapshot_->hlo().has_hlo_module()); + const string& directory_path = + module_config().debug_options().xla_dump_executions_to(); + const auto& module = hlo_snapshot_->hlo().hlo_module(); + string filename = tensorflow::strings::Printf( + "computation_%lld__%s__execution_%lld", module.id(), + module.entry_computation_name().c_str(), ++execution_count_); + return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); +} + /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const SessionModule& session_module) { @@ -161,4 +176,24 @@ Status Executable::DumpSessionModule() { result); } +/* static */ Status Executable::DumpToDirectory( + const string& directory_path, string filename, + const HloSnapshot& hlo_session) { + tensorflow::Env* env = tensorflow::Env::Default(); + if (!env->IsDirectory(directory_path).ok()) { + // NB! CreateDir does not work reliably with multiple XLA threads -- two + // threads can race to observe the absence of the dump directory and + // simultaneously try to create it, causing the "losing" thread to get a + // "directory already exists" error. + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); + } + filename = SanitizeFileName(std::move(filename)); + string file_path = tensorflow::io::JoinPath(directory_path, filename); + string result; + TF_RET_CHECK( + tensorflow::SerializeToStringDeterministic(hlo_session, &result)); + return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, + result); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index a157235f8af..4f0466c5447 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -62,14 +63,14 @@ class Executable { // enabled. // // Returns a shaped buffer containing the result of the computation. - virtual StatusOr> ExecuteOnStream( + virtual StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) = 0; // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr> ExecuteAsyncOnStream( + virtual StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) = 0; @@ -77,7 +78,7 @@ class Executable { // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr>> ExecuteOnStreams( + virtual StatusOr> ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< @@ -90,14 +91,14 @@ class Executable { // has completed. virtual Status PopulateExecutionProfile( HloExecutionProfile* hlo_execution_profile, - perftools::gputools::StreamExecutor* executor) { + se::StreamExecutor* executor) { return Status::OK(); } // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. - StatusOr> ExecuteOnStreamWrapper( + StatusOr ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, tensorflow::gtl::ArraySlice arguments); @@ -139,11 +140,11 @@ class Executable { // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. - const Shape& result_shape() const { - return hlo_module_->config().entry_computation_layout().result_shape(); + const Shape& host_result_shape() const { + return hlo_module_->config().host_entry_computation_layout().result_shape(); } - // Dumping helpers. + // TODO(b/74197823): Delete the session module dumping helpers. void set_session_module(std::unique_ptr session_module) { session_module_ = std::move(session_module); } @@ -151,10 +152,22 @@ class Executable { SessionModule* session_module() const { return session_module_.get(); } Status DumpSessionModule(); + // Dumping helpers. + void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { + hlo_snapshot_ = std::move(hlo_snapshot); + } + bool dumping_snapshot() const { return hlo_snapshot_ != nullptr; } + HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } + Status DumpHloSnapshot(); + // Dump session_module to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, const SessionModule& session_module); + // Dump hlo snapshot to directory_path/filename. + static Status DumpToDirectory(const string& directory_path, string filename, + const HloSnapshot& hlo_session); + protected: mutable tensorflow::mutex mutex_; @@ -169,6 +182,9 @@ class Executable { // SessionModule this was compiled from. Null if not dumping executions. std::unique_ptr session_module_; + // HloSnapshot this was compiled from. Null if not dumping executions. + std::unique_ptr hlo_snapshot_; + // Execution count, used to generate a unique filename for each dumped // execution. int64 execution_count_ = 0; diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 2f0b9ed2bd9..6794cfe297b 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -37,11 +37,11 @@ AsyncExecution::AsyncExecution(Backend* backend, } } -tensorflow::Status AsyncExecution::BlockUntilDone() const { +Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } - return tensorflow::Status::OK(); + return Status::OK(); } ExecutionTracker::ExecutionTracker() : next_handle_(1) {} @@ -61,7 +61,7 @@ ExecutionHandle ExecutionTracker::Register( return execution_handle; } -tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { +Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { @@ -69,7 +69,7 @@ tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { handle.handle()); } handle_to_execution_.erase(handle.handle()); - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr ExecutionTracker::Resolve( diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h index 5b6bddf9f16..4458152dd9a 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.h +++ b/tensorflow/compiler/xla/service/execution_tracker.h @@ -43,7 +43,7 @@ class AsyncExecution { AsyncExecution(Backend* backend, std::vector streams, const ExecutionProfile& profile, GlobalDataHandle result); - tensorflow::Status BlockUntilDone() const; + Status BlockUntilDone() const; const GlobalDataHandle& result() const { return result_; } @@ -77,7 +77,7 @@ class ExecutionTracker { GlobalDataHandle data); // Unregisters the execution for the given handle. - tensorflow::Status Unregister(const ExecutionHandle& handle); + Status Unregister(const ExecutionHandle& handle); // Resolves the given ExecutionHandle to an AsyncExecution. Returns an // error status if the given handle is not found, which means that the diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 221ff7900f3..2d3e4b1fcdf 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -28,9 +28,15 @@ using tensorflow::gtl::ArraySlice; static StatusOr TransposeIndexVectorDimToLast( HloInstruction* gather_indices, int64 index_vector_dim) { const Shape& gather_indices_shape = gather_indices->shape(); + + if (gather_indices_shape.dimensions_size() == index_vector_dim) { + return gather_indices; + } + if (index_vector_dim == (gather_indices_shape.dimensions_size() - 1)) { return gather_indices; } + std::vector permutation; permutation.reserve(gather_indices_shape.dimensions_size()); for (int64 i = 0, e = gather_indices_shape.dimensions_size(); i < e; i++) { @@ -42,55 +48,35 @@ static StatusOr TransposeIndexVectorDimToLast( return MakeTransposeHlo(gather_indices, permutation); } -// If the gather_indices holds scalar indices (i.e. gather_indices has rank N -// and index_vector_dim is N) then reshape it to have a trailing degenerate -// dimension. This makes the code for slicing out the index vector more -// uniform. -static StatusOr DeScalarizeGatherIndices( - HloInstruction* gather_indices, int64 index_vector_dim) { - const Shape& gather_indices_shape = gather_indices->shape(); - if (index_vector_dim != gather_indices_shape.dimensions_size()) { - return gather_indices; - } - - DCHECK_EQ(index_vector_dim, gather_indices_shape.dimensions_size()); - - std::vector result_shape_dims; - c_copy(gather_indices_shape.dimensions(), - std::back_inserter(result_shape_dims)); - result_shape_dims.push_back(1); - - return MakeReshapeHlo(result_shape_dims, gather_indices); -} - // Canonicalizes the gather_indices tensors so that we only have deal with some // specific cases in the while loop that does the heavy lifting. // // See the "High Level Algorithm" section for a broader picture. static StatusOr CanonicalizeGatherIndices( HloInstruction* gather_indices, int64 index_vector_dim) { - // If gather_indices holds scalar indices, normalize it to hold index vectors - // of size 1. - TF_ASSIGN_OR_RETURN( - HloInstruction * descalarized_gather_indices, - DeScalarizeGatherIndices(gather_indices, index_vector_dim)); - // Transpose the non-index-vector dimensions to the front. - TF_ASSIGN_OR_RETURN(HloInstruction * transposed_gather_indices, - TransposeIndexVectorDimToLast(descalarized_gather_indices, - index_vector_dim)); + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_gather_indices, + TransposeIndexVectorDimToLast(gather_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == gather_indices->shape().dimensions_size(); + + // The number of dimensions in gather_indices that are index dimensions. + const int64 index_dims_in_gather_indices = indices_are_scalar ? 0 : 1; // If there is only one index (i.e. gather_indices has rank 1 and this gather // is really just a dynamic slice) add a leading degenerate dimension for // uniformity. Otherwise create a "collapsed" leading dimension that subsumes // all of the non-index-vector dimensions. const Shape& shape = transposed_gather_indices->shape(); - if (shape.dimensions_size() == 1) { - return ExpandFirstDimIntoNDims(transposed_gather_indices, - {1, shape.dimensions(0)}); + if (shape.dimensions_size() == index_dims_in_gather_indices) { + return PrependDegenerateDims(transposed_gather_indices, 1); } else { - return CollapseFirstNDims(transposed_gather_indices, - shape.dimensions_size() - 1); + // Collapse all but the dimensions (0 or 1) in gather_indices containing the + // index vectors. + return CollapseFirstNDims( + transposed_gather_indices, + shape.dimensions_size() - index_dims_in_gather_indices); } } @@ -112,11 +98,7 @@ static StatusOr AdjustGatherDimsInAccumulator( // dynamic-slice. In that case, there is a leading degenerate gather // dimension that we added to make this special case play well with the // general while loop which we need to remove now. - CHECK_EQ(accumulator->shape().dimensions(0), 1); - ArraySlice reshaped_dim_sizes = - AsInt64Slice(accumulator->shape().dimensions()); - reshaped_dim_sizes.remove_prefix(1); - return MakeReshapeHlo(reshaped_dim_sizes, accumulator); + return ElideDegenerateDims(accumulator, {0}); } return ExpandFirstDimIntoNDims(accumulator, output_gather_dim_bounds); @@ -161,50 +143,73 @@ static StatusOr ExpandIndexVectorIntoOperandSpace( static StatusOr> GatherLoopBody( const HloInstruction& gather, HloInstruction* induction_var, const std::vector& incoming_loop_state) { + const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); CHECK_EQ(incoming_loop_state.size(), 3); HloInstruction* const operand = incoming_loop_state[0]; HloInstruction* const gather_indices = incoming_loop_state[1]; HloInstruction* const output_accumulator = incoming_loop_state[2]; - int64 index_vector_size = gather_indices->shape().dimensions(1); + bool has_scalar_indices = gather_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + gather.operand(1)->shape().dimensions_size()); TF_ASSIGN_OR_RETURN( HloInstruction * induction_var_as_vector, MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/{1})); - TF_ASSIGN_OR_RETURN( - HloInstruction * index_into_gather_indices, - PadVectorWithZeros(induction_var_as_vector, - /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + HloInstruction* index_vector; + + if (has_scalar_indices) { + // In this case gather_indices has rank 1 and induction_var_as_vector (of + // shape {1}) is an index into this rank 1 tensor. + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(gather_indices, induction_var_as_vector, {1})); + } else { + // In this case gather_indices has rank 2 and induction_var_as_vector (of + // shape {1}) is an index into just the first dimension of this rank 2 + // tensor. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_gather_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + + int64 index_vector_size = gather_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, + {1, index_vector_size})); + + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } TF_ASSIGN_OR_RETURN( - HloInstruction * index_vector_2d, - MakeDynamicSliceHlo(gather_indices, index_into_gather_indices, - {1, index_vector_size})); - - TF_ASSIGN_OR_RETURN(HloInstruction * index_vector, - ElideDegenerateDims(index_vector_2d, {0})); - - TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice_start, - ExpandIndexVectorIntoOperandSpace( - index_vector, gather.gather_dimension_numbers(), - operand->shape().dimensions_size())); + HloInstruction * gathered_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, MakeDynamicSliceHlo(operand, gathered_slice_start, gather.gather_window_bounds())); + TF_ASSIGN_OR_RETURN( + HloInstruction * gathered_slice_with_dims_elided, + ElideDegenerateDims(gathered_slice, + AsInt64Slice(dim_numbers.elided_window_dims()))); + TF_ASSIGN_OR_RETURN( HloInstruction * gathered_slice_for_update, - ExpandFirstDimIntoNDims(gathered_slice, - {1, gathered_slice->shape().dimensions(0)})); + PrependDegenerateDims(gathered_slice_with_dims_elided, 1)); TF_ASSIGN_OR_RETURN( HloInstruction * index_vector_into_accumulator, PadVectorWithZeros( induction_var_as_vector, /*zeros_to_prepend=*/0, - /*zeros_to_append=*/gathered_slice->shape().dimensions_size())); + /*zeros_to_append=*/ + gathered_slice_with_dims_elided->shape().dimensions_size())); TF_ASSIGN_OR_RETURN( HloInstruction * updated_accumulator, @@ -220,26 +225,20 @@ static StatusOr> GatherLoopBody( static StatusOr CreateGatherLoopAccumulatorInitValue( HloComputation* computation, PrimitiveType element_type, - ArraySlice window_bounds, int64 gather_loop_trip_count) { + ArraySlice window_bounds, int64 gather_loop_trip_count, + const GatherDimensionNumbers& dim_numbers) { std::vector accumulator_state_shape_dims; accumulator_state_shape_dims.reserve(1 + window_bounds.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); - c_copy(window_bounds, std::back_inserter(accumulator_state_shape_dims)); + for (int64 i = 0; i < window_bounds.size(); i++) { + if (!c_binary_search(dim_numbers.elided_window_dims(), i)) { + accumulator_state_shape_dims.push_back(window_bounds[i]); + } + } return BroadcastZeros(computation, element_type, accumulator_state_shape_dims); } -static StatusOr ElideWindowDimsFromAccumulator( - HloInstruction* accumulator, const GatherDimensionNumbers& dim_numbers) { - std::vector dims_to_elide; - dims_to_elide.reserve(dim_numbers.elided_window_dims_size()); - for (int64 elided_window_dim : dim_numbers.elided_window_dims()) { - dims_to_elide.push_back(elided_window_dim + 1); - } - - return ElideDegenerateDims(accumulator, dims_to_elide); -} - // `accumulator` is almost the tensor the gather operation would have produced, // except that it has the dimensions in the wrong order -- the gather dimensions // are the major dimensions and the window dimensions are the minor dimensions. @@ -338,7 +337,8 @@ StatusOr GatherExpander::ExpandGather( HloInstruction * accumulator_init, CreateGatherLoopAccumulatorInitValue( computation, output_shape.element_type(), - gather_instr->gather_window_bounds(), gather_loop_trip_count)); + gather_instr->gather_window_bounds(), gather_loop_trip_count, + gather_instr->gather_dimension_numbers())); StatusOr> gather_loop_result_or_error = WhileUtil::MakeCountedLoop( @@ -353,14 +353,10 @@ StatusOr GatherExpander::ExpandGather( gather_loop_result_or_error); HloInstruction* accumulator_result = gather_loop_result.back(); - TF_ASSIGN_OR_RETURN( - HloInstruction * accumulator_with_window_dims_elided, - ElideWindowDimsFromAccumulator(accumulator_result, dim_numbers)); TF_ASSIGN_OR_RETURN( HloInstruction * accumulator_with_output_gather_dims_decanonicalized, - AdjustGatherDimsInAccumulator(gather_indices->shape(), - accumulator_with_window_dims_elided, + AdjustGatherDimsInAccumulator(gather_indices->shape(), accumulator_result, dim_numbers.index_vector_dim())); return PermuteGatherAndWindowDims( diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index ba41ee8428c..1c72ca06650 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -47,5 +47,62 @@ ENTRY main { "indices are not supported.")); } +TEST(GatherExpanderTest, AvoidDegenerateDims) { + const string hlo_text = R"( +HloModule TensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + ASSERT_TRUE(changed); + + HloInstruction* while_instr = nullptr; + for (auto* instr : module->entry_computation()->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(while_instr, nullptr) + << "Expected exactly one while instruction in the entry computation " + "after gather expansion"; + while_instr = instr; + } + } + + ASSERT_NE(while_instr, nullptr) + << "Expected exactly one while instruction in the entry computation " + "after gather expansion"; + + // We want to avoid create while loop with shapes that have degenerate + // dimensions for TF gather. In this case we expect the loop state to be of + // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}). The leading + // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't + // check it here (though in theory the form of the while loop state is itself + // an implementation detail from WhileUtil::MakeCountedLoop). + + const Shape& while_shape = while_instr->shape(); + ASSERT_TRUE(ShapeUtil::IsTuple(while_shape)); + ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4); + + EXPECT_TRUE(ShapeUtil::SameDimensions( + ShapeUtil::MakeShape(S32, {3, 3}), + ShapeUtil::GetTupleElementShape(while_shape, 1))); + + EXPECT_TRUE(ShapeUtil::SameDimensions( + ShapeUtil::MakeShape(S32, {2}), + ShapeUtil::GetTupleElementShape(while_shape, 2))); + + EXPECT_TRUE(ShapeUtil::SameDimensions( + ShapeUtil::MakeShape(S32, {2, 3}), + ShapeUtil::GetTupleElementShape(while_shape, 3))); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index a99e2b7794a..5ee67ccb4ae 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -32,8 +32,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, @@ -45,9 +43,9 @@ se::Platform::Id GenericTransferManager::PlatformId() const { } Status GenericTransferManager::WriteSingleTupleIndexTable( - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { + const Shape& shape, se::DeviceMemoryBase* region) { TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); std::vector element_pointers; @@ -91,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( } Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, + se::StreamExecutor* executor, const LiteralSlice& literal, const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " @@ -117,7 +115,7 @@ Status GenericTransferManager::TransferLiteralToDevice( TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const auto subliteral = LiteralView::Create(literal, index); + const auto subliteral = LiteralSlice(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), @@ -139,25 +137,24 @@ Status GenericTransferManager::TransferLiteralToDevice( } Status GenericTransferManager::TransferLiteralToInfeed( - se::StreamExecutor* executor, const Literal& literal) { + se::StreamExecutor* executor, const LiteralSlice& literal) { return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) { + se::StreamExecutor* executor, int64 size, const void* source) { return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + se::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) { return Unimplemented( "Outfeed is not supported on this platform (b/30467474)"); } Status GenericTransferManager::ResetDevices( - tensorflow::gtl::ArraySlice + tensorflow::gtl::ArraySlice /*executors*/) { return Unimplemented( "Device reset is not yet supported on this platform (b/30481585)"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 63a7c820cf4..3da9570ef7e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -36,46 +36,41 @@ namespace xla { // infeed. class GenericTransferManager : public TransferManager { public: - GenericTransferManager(perftools::gputools::Platform::Id platform_id, - size_t pointer_size); + GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size); ~GenericTransferManager() override {} - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; StatusOr> TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const ShapedBuffer& device_buffer) override; + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; - Status TransferLiteralToDevice(perftools::gputools::StreamExecutor* executor, - const Literal& literal, + Status TransferLiteralToDevice(se::StreamExecutor* executor, + const LiteralSlice& literal, const ShapedBuffer& device_buffer) override; - Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, - const Literal& literal) override; - Status TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) override; + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) override; Status ResetDevices( - tensorflow::gtl::ArraySlice - executors) override; + tensorflow::gtl::ArraySlice executors) override; int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source) override; Status WriteSingleTupleIndexTable( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, - perftools::gputools::DeviceMemoryBase* region) override; + se::StreamExecutor* executor, + tensorflow::gtl::ArraySlice elements, + const Shape& shape, se::DeviceMemoryBase* region) override; private: // The platform this transfer manager targets. - const perftools::gputools::Platform::Id platform_id_; + const se::Platform::Id platform_id_; // The size in bytes of pointers on this platform. const size_t pointer_size_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index f1707442fe3..4012f87f2bf 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -291,6 +291,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", @@ -388,8 +389,10 @@ cc_library( deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", + "//tensorflow/compiler/xla/service:pattern_matcher", ], ) @@ -620,6 +623,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 2029c303d47..ab5149dcdb0 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -28,8 +28,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -39,11 +37,11 @@ void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index, } StatusOr> BufferAllocations::Builder::Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { - const int64 num_buffers = buffer_assignment.Allocations().size(); - auto buffer_allocations = WrapUnique( - new BufferAllocations(num_buffers, device_ordinal, memory_allocator)); + const int64 num_buffers = buffer_assignment->Allocations().size(); + auto buffer_allocations = WrapUnique(new BufferAllocations( + num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { // If buffer #i's address is already registered (e.g. external arguments or @@ -64,28 +62,28 @@ StatusOr> BufferAllocations::Builder::Build( // Allocate each allocation that might escape, or is the temp buffer. bool seen_temp_buffer = false; - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment->GetAllocation(i); if (allocation.maybe_live_out() || allocation.IsPreallocatedTempBuffer()) { const int64 buffer_size = allocation.size(); se::DeviceMemoryBase buffer_address; if (buffer_size > 0) { - TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate( - device_ordinal, buffer_size)); - if (buffer_address == nullptr) { - return ResourceExhausted( - "Out of memory when allocating %s for buffer %lld.", - tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), - i); - } - if (reinterpret_cast(buffer_address.opaque()) % + OwningDeviceMemory buffer; + TF_ASSIGN_OR_RETURN( + buffer, memory_allocator->Allocate(device_ordinal, buffer_size)); + if (reinterpret_cast(buffer.opaque()) % kCudaMallocAlignBytes != 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " "multiple of %llx, but was %p", - kCudaMallocAlignBytes, buffer_address.opaque()); + kCudaMallocAlignBytes, buffer.opaque()); } + // We do manual memory management within BufferAllocations. Be sure not + // to do a TF_RETURN_IF_ERROR between this line and the + // buffer_allocations->SetBuffer(buffer_address) call below! + buffer_address = buffer.Forget(); } + buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { if (seen_temp_buffer) { @@ -105,28 +103,42 @@ StatusOr> BufferAllocations::Builder::Build( << "B)"; } } - return std::move(buffer_allocations); } -tensorflow::Status BufferAllocations::TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment) { - // Deallocate temporary buffers. - const int64 num_buffers = buffer_assignment.Allocations().size(); +BufferAllocations::~BufferAllocations() { + if (!torn_down_) { + // Presumably if we're executing this branch, the caller is in an error + // state, otherwise it would have explicitly called TearDown so it could + // save some set of live addresses. So ignoring any errors in TearDown is + // sensible. + TearDown(/*live_addresses=*/{}).IgnoreError(); + } +} + +Status BufferAllocations::TearDown( + const std::set& live_addresses) { + // Deallocate temporary buffers, taking care to try to deallocate all of them + // even if one of the deallocations fails. + Status status; + const int64 num_buffers = buffer_assignment_->Allocations().size(); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { - const BufferAllocation& allocation = buffer_assignment.GetAllocation(i); + const BufferAllocation& allocation = buffer_assignment_->GetAllocation(i); se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index()); // Deallocate buffers marked "maybe_live_out" but aren't actually live out, // and temp buffers. if ((allocation.maybe_live_out() && !live_addresses.count(buffer_address)) || allocation.IsPreallocatedTempBuffer()) { - TF_RETURN_IF_ERROR( - memory_allocator_->Deallocate(device_ordinal_, &buffer_address)); + auto dealloc_result = + memory_allocator_->Deallocate(device_ordinal_, buffer_address); + if (!dealloc_result.ok() && status.ok()) { + status = dealloc_result; + } } } - return tensorflow::Status::OK(); + torn_down_ = true; + return status; } se::DeviceMemoryBase BufferAllocations::GetDeviceAddress( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h index ea7f0eb3745..63662350259 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h @@ -41,21 +41,22 @@ class BufferAllocations { // user-specified result buffers) to the given buffer index. The builder // will skip allocating buffers for registered buffer indices. void RegisterBuffer(BufferAllocation::Index index, - perftools::gputools::DeviceMemoryBase address); + se::DeviceMemoryBase address); // Builds a BufferAllocations object from the given buffer assignment. // `memory_allocator` is what this function uses to allocate device memory. // `device_ordinal` is the number of the device this function allocates // memory on. StatusOr> Build( - const BufferAssignment& buffer_assignment, int device_ordinal, + const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator); private: - std::map - registered_buffers_; + std::map registered_buffers_; }; + ~BufferAllocations(); + BufferAllocations(const BufferAllocations&) = delete; BufferAllocations& operator=(const BufferAllocations&) = delete; @@ -65,46 +66,45 @@ class BufferAllocations { // Returns the device address of buffer `buffer_index`. `buffer_index` must be // a valid index, i.e., in [0, buffer_count). This function returns null if // `buffer_index` is not assigned to a buffer address. - perftools::gputools::DeviceMemoryBase GetDeviceAddress( + se::DeviceMemoryBase GetDeviceAddress( BufferAllocation::Index buffer_index) const; // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. - perftools::gputools::DeviceMemoryBase GetDeviceAddress( + se::DeviceMemoryBase GetDeviceAddress( const BufferAllocation::Slice& buffer_slice) const; - perftools::gputools::DeviceMemoryBase GetTempBufferBase() const { - return temp_buffer_base_; - } + se::DeviceMemoryBase GetTempBufferBase() const { return temp_buffer_base_; } // Tears down all buffers allocated by this object that are not in // `live_addresses`. - tensorflow::Status TearDown( - const std::set& live_addresses, - const BufferAssignment& buffer_assignment); + Status TearDown(const std::set& live_addresses); private: BufferAllocations(BufferAllocation::Index buffer_count, int device_ordinal, - DeviceMemoryAllocator* memory_allocator) + DeviceMemoryAllocator* memory_allocator, + const BufferAssignment* buffer_assignment) : buffers_(buffer_count), device_ordinal_(device_ordinal), - memory_allocator_(memory_allocator) {} + memory_allocator_(memory_allocator), + buffer_assignment_(buffer_assignment) {} // Sets the device address of buffer `buffer_index`. void SetBuffer(BufferAllocation::Index buffer_index, - perftools::gputools::DeviceMemoryBase buffer); + se::DeviceMemoryBase buffer); // An array of device pointers that stores the address of each buffer // indexed by Index. Each element can point to a temporary buffer, an // input buffer, or nullptr if no buffer is needed for that Index. - std::vector buffers_; + std::vector buffers_; // The base address of the memory block that contains all temporary buffers. - perftools::gputools::DeviceMemoryBase temp_buffer_base_; + se::DeviceMemoryBase temp_buffer_base_; int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; + const BufferAssignment* buffer_assignment_; + bool torn_down_ = false; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 790ca535b11..77a48965e03 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -35,18 +35,18 @@ ConditionalThunk::ConditionalThunk( true_thunk_(std::move(true_thunk_sequence), hlo), false_thunk_(std::move(false_thunk_sequence), hlo) {} -Status ConditionalThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); - TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); +Status ConditionalThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable, executor)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable, executor)); return Status::OK(); } Status ConditionalThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { + const BufferAllocations& buffer_allocations, se::Stream* stream) { // Copy the predicate value from device. bool predicate; - perftools::gputools::DeviceMemoryBase predicate_address = + se::DeviceMemoryBase predicate_address = buffer_allocations.GetDeviceAddress(predicate_buffer_index_); stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool)); diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index 7725c46a3b4..ee03865d174 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -47,9 +47,10 @@ class ConditionalThunk : public Thunk { ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: BufferAllocation::Slice predicate_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 461747b699b..f0881124128 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -25,17 +25,10 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { using se::dnn::AlgorithmDesc; -using se::dnn::BatchDescriptor; -using se::dnn::ConvolutionDescriptor; -using se::dnn::DataLayout; -using se::dnn::FilterDescriptor; -using se::dnn::FilterLayout; ConvolutionThunk::ConvolutionThunk( CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 900d9cb6243..6d845025b1a 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -66,23 +66,21 @@ class ConvolutionThunk : public Thunk { // Does the convolution for the thunk on "stream". Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: class ScratchAllocator; - Status Convolve( - const perftools::gputools::dnn::BatchDescriptor& input_descriptor, - perftools::gputools::DeviceMemory input_data, - const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, - perftools::gputools::DeviceMemory filter_data, - const perftools::gputools::dnn::BatchDescriptor& output_descriptor, - perftools::gputools::DeviceMemory output_data, - const perftools::gputools::dnn::ConvolutionDescriptor& - convolution_descriptor, - const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, - perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator, - perftools::gputools::dnn::ProfileResult* profile_result); + Status Convolve(const se::dnn::BatchDescriptor& input_descriptor, + se::DeviceMemory input_data, + const se::dnn::FilterDescriptor& filter_descriptor, + se::DeviceMemory filter_data, + const se::dnn::BatchDescriptor& output_descriptor, + se::DeviceMemory output_data, + const se::dnn::ConvolutionDescriptor& convolution_descriptor, + const se::dnn::AlgorithmConfig& algorithm_config, + se::Stream* stream, ScratchAllocator* scratch_allocator, + se::dnn::ProfileResult* profile_result); const CudnnConvKind convolution_kind_; diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index f4498663b1c..ee38c0318a8 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -29,13 +29,12 @@ HostToDeviceCopyThunk::HostToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - perftools::gputools::DeviceMemoryBase destination_data = +Status HostToDeviceCopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); stream->ThenMemcpy(&destination_data, source_address_, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( @@ -47,15 +46,14 @@ DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - perftools::gputools::DeviceMemoryBase destination_data = +Status DeviceToDeviceCopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + se::DeviceMemoryBase destination_data = buffer_allocations.GetDeviceAddress(destination_buffer_); - perftools::gputools::DeviceMemoryBase source_data = + se::DeviceMemoryBase source_data = buffer_allocations.GetDeviceAddress(source_buffer_); stream->ThenMemcpy(&destination_data, source_data, mem_size_); - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index e2783fd2552..8b128386f61 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -39,9 +39,8 @@ class HostToDeviceCopyThunk : public Thunk { HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const void* source_address_; @@ -63,9 +62,8 @@ class DeviceToDeviceCopyThunk : public Thunk { DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const BufferAllocation::Slice source_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 58d9c8caff3..68099fd6384 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -28,7 +28,6 @@ limitations under the License. namespace xla { namespace gpu { -namespace se = ::perftools::gputools; namespace dnn = se::dnn; static std::pair> AllocateBytes( - se::Stream* stream, int64 byte_size) override; + StatusOr> AllocateBytes(se::Stream* stream, + int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; -ScratchAllocator::~ScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -se::port::StatusOr> ScratchAllocator::AllocateBytes( +StatusOr> ScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -76,19 +61,14 @@ se::port::StatusOr> ScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } // Determines whether we can safely perform a winograd non-fused convolution for @@ -199,22 +179,42 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // We don't put any data in these buffers, because (in theory, anyway) the // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - se::port::StatusOr input_buf = + StatusOr maybe_input_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(input_shape)); - se::port::StatusOr filter_buf = + StatusOr maybe_filter_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(filter_shape)); - se::port::StatusOr output_buf = + StatusOr maybe_output_buf = input_output_allocator.AllocateBytes(&stream, ShapeUtil::ByteSizeOf(output_shape)); - if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || + !maybe_output_buf.ok()) { LOG(WARNING) << "Couldn't allocate space for input/filter/output of convolution " << instr->ToString() << ". Falling back to default algorithm."; return nullopt; } + DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); + DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); + DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + + // Although we don't have evidence this matters, zero out the buffers before + // autotuning. It's conceivable that using uninitialized memory as the inputs + // might affect performance if e.g. the inputs contain denormals, and this is + // easy enough. + if (!stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone() + .ok()) { + LOG(WARNING) + << "Couldn't zero out input/filter/output buffer for convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); se::dnn::ProfileResult best_result; @@ -227,12 +227,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = RunCudnnConvolution( - kind, input_shape, filter_shape, output_shape, - input_buf.ValueOrDie(), filter_buf.ValueOrDie(), - output_buf.ValueOrDie(), &scratch_allocator, window, - dnums, AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 516210ec2e5..bc5d1ce94af 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -33,9 +33,8 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - CudnnConvolutionAlgorithmPicker( - perftools::gputools::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) + CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) : stream_exec_(stream_exec), allocator_(allocator) {} tensorflow::StringPiece name() const override { @@ -52,7 +51,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); - perftools::gputools::StreamExecutor* stream_exec_; // never null + se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index e4ae839e1dd..10b4c3de899 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -22,8 +22,6 @@ namespace xla { namespace gpu { namespace { -namespace se = ::perftools::gputools; - using se::DeviceMemory; using se::DeviceMemoryBase; using se::Stream; @@ -215,14 +213,12 @@ string CudnnConvKindToString(CudnnConvKind kind) { Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::DeviceMemoryBase scratch_buf, const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result) { + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, input_buf, filter_buf, output_buf, @@ -232,14 +228,12 @@ Status RunCudnnConvolution( Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::ScratchAllocator* scratch_allocator, - const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result) { + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result) { PrimitiveType output_primitive_type = output_shape.element_type(); CHECK(output_primitive_type == F32 || output_primitive_type == F16) << ShapeUtil::HumanString(output_shape); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 3dbfa2730da..944e4ac686d 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -72,25 +72,21 @@ string CudnnConvKindToString(CudnnConvKind kind); // that size, if you like. Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::DeviceMemoryBase scratch_buf, const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, - perftools::gputools::DeviceMemoryBase filter_buf, - perftools::gputools::DeviceMemoryBase output_buf, - perftools::gputools::ScratchAllocator* scratch_allocator, - const Window& window, const ConvolutionDimensionNumbers& dnums, - perftools::gputools::dnn::AlgorithmConfig algorithm, - perftools::gputools::Stream* stream, - perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + const Shape& output_shape, se::DeviceMemoryBase input_buf, + se::DeviceMemoryBase filter_buf, se::DeviceMemoryBase output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, + se::dnn::AlgorithmConfig algorithm, se::Stream* stream, + se::dnn::ProfileResult* profile_result = nullptr); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 5af7a77ea85..e5e2a0478a0 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -227,6 +227,11 @@ StatusOr GpuElementalIrEmitter::EmitLog( return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog1p( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitSin( PrimitiveType prim_type, llvm::Value* value) const { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); @@ -242,6 +247,11 @@ StatusOr GpuElementalIrEmitter::EmitExp( return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitExpm1( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); +} + StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 77d4569b1e8..91f4d960aa6 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -64,6 +64,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitLog(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitSin(PrimitiveType prim_type, llvm::Value* value) const override; @@ -73,6 +76,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitExp(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) const override; + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const override; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 66931bdc8b1..e14ee6918bf 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -24,8 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -33,23 +31,12 @@ FftScratchAllocator::FftScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} -FftScratchAllocator::~FftScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. return kFftScratchSize; } -se::port::StatusOr> FftScratchAllocator::AllocateBytes( +StatusOr> FftScratchAllocator::AllocateBytes( se::Stream* stream, int64 byte_size) { CHECK_GE(byte_size, 0) << "byte_size must be positive."; if (byte_size > GetMemoryLimitInBytes(stream)) { @@ -60,18 +47,14 @@ se::port::StatusOr> FftScratchAllocator::AllocateBytes( byte_size, GetMemoryLimitInBytes(stream))); } - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return tensorflow::errors::ResourceExhausted( - "Failed to allocate %lld bytes on device %d.", byte_size, - device_ordinal_); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); + TF_ASSIGN_OR_RETURN(OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); + + se::DeviceMemoryBase buffer_addr = allocated_buffer.AsDeviceMemoryBase(); + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); } namespace { @@ -123,8 +106,8 @@ FftThunk::FftThunk(FftType fft_type, input_shape_(input_shape), output_shape_(output_shape) {} -tensorflow::Status FftThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); VLOG(3) << "Output shape: " @@ -224,7 +207,7 @@ tensorflow::Status FftThunk::ExecuteOnStream( LOG(FATAL) << "unsupported fft type"; } if (launch_ok) { - return tensorflow::Status::OK(); + return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, FftTypeToString(fft_type_).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 52fb8c376d7..b0a22564f3a 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -34,24 +34,22 @@ namespace gpu { // released on destruction. // // Not thread-safe in that AllocateBytes, destructor are not locked. -class FftScratchAllocator : public perftools::gputools::ScratchAllocator { +class FftScratchAllocator : public se::ScratchAllocator { public: FftScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator); - ~FftScratchAllocator() override; - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; + int64 GetMemoryLimitInBytes(se::Stream* stream) override; int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - perftools::gputools::port::StatusOr> - AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override; private: const int device_ordinal_; DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; + std::vector allocated_buffers_; int64 total_allocated_bytes_ = 0; }; @@ -73,17 +71,16 @@ class FftThunk : public Thunk { FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ // Does the FFT for the thunk on "stream". - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: - const perftools::gputools::fft::Type fft_type_; + const se::fft::Type fft_type_; const std::vector fft_length_; float scale_factor_; - std::unique_ptr fft_plan_; + std::unique_ptr fft_plan_; const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice output_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 283d21ca222..b36539e0cb8 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -30,20 +30,20 @@ ForThunk::ForThunk(const int64 loop_limit, body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); +Status ForThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); + return Status::OK(); } -tensorflow::Status ForThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status ForThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { for (int64 i = 0; i < loop_limit_; ++i) { // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 832494d17e9..41ddfe0ceb1 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,10 +36,10 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const int64 loop_limit_; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 38668ff455a..79fca43d022 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -22,8 +22,6 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -217,14 +215,32 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) { } } +DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) { + if (hlo_instruction.opcode() == HloOpcode::kDot) { + return hlo_instruction.dot_dimension_numbers(); + } + CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion); + CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput); + CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(), + HloOpcode::kMultiply); + // Try to find the dot inside the output fusion node. + const HloInstruction* dot = + hlo_instruction.fused_expression_root()->operand(0); + if (dot->opcode() != HloOpcode::kDot) { + dot = hlo_instruction.fused_expression_root()->operand(1); + } + CHECK_EQ(dot->opcode(), HloOpcode::kDot); + + return dot->dot_dimension_numbers(); +} + } // namespace GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, - bool transpose_rhs, double alpha, + const Shape& output_shape, double alpha, const HloInstruction* hlo_instruction) : Thunk(Kind::kGemm, hlo_instruction), lhs_buffer_(lhs_buffer), @@ -233,12 +249,10 @@ GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer, lhs_shape_(lhs_shape), rhs_shape_(rhs_shape), output_shape_(output_shape), - transpose_lhs_(transpose_lhs), - transpose_rhs_(transpose_rhs), alpha_(alpha) {} -tensorflow::Status GemmThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { VLOG(2) << "Executing a GemmThunk"; se::DeviceMemoryBase lhs_data = @@ -286,10 +300,12 @@ tensorflow::Status GemmThunk::ExecuteOnStream( shape.dimensions(!is_row_major)); }; - const MatrixDescriptor lhs_descriptor = - make_descriptor(lhs_data, lhs_shape_, transpose_lhs_); - const MatrixDescriptor rhs_descriptor = - make_descriptor(rhs_data, rhs_shape_, transpose_rhs_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + + const MatrixDescriptor lhs_descriptor = make_descriptor( + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + const MatrixDescriptor rhs_descriptor = make_descriptor( + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. @@ -352,7 +368,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( if (!launch_ok) { return InternalError("Unable to launch cuBLAS gemm on stream %p", stream); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index df3edcefef8..7a4830d64e7 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -35,29 +35,25 @@ namespace gpu { class GemmThunk : public Thunk { public: // Constructs a thunk that computes "output = (lhs rhs) * alpha" using - // BLAS gemm. transpose_lhs and transpose_rhs indicate whether gemm should - // transpose the lhs and rhs operand. hlo_instruction is as in Thunk. alpha is - // a constant. + // BLAS gemm. hlo_instruction is as in Thunk. alpha is a constant. GemmThunk(const BufferAllocation::Slice& lhs_buffer, const BufferAllocation::Slice& rhs_buffer, const BufferAllocation::Slice& output_buffer, const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape, bool transpose_lhs, bool transpose_rhs, - double alpha, const HloInstruction* hlo_instruction); + const Shape& output_shape, double alpha, + const HloInstruction* hlo_instruction); GemmThunk(const GemmThunk&) = delete; GemmThunk& operator=(const GemmThunk&) = delete; // Does the gemm operation for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; // Returns true if we'll perform autotuning if run on the given stream. If // so, we want the GPU to be quiescent during autotuning, so as not to // introduce noise in our results. - bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream* stream) override { + bool ShouldHaltAllActivityBeforeRunning(se::Stream* stream) override { return autotune_results_.count( stream->parent()->GetDeviceDescription().name()) != 0; } @@ -71,16 +67,13 @@ class GemmThunk : public Thunk { const Shape rhs_shape_; const Shape output_shape_; - const bool transpose_lhs_; - const bool transpose_rhs_; const double alpha_; // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune // results. The map's value is the best algorithm we've found for this thunk // on this device, or an error if none of the algorithms worked and we should // use the regular gemm without an algorithm. - std::unordered_map> + std::unordered_map> autotune_results_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 07be2a0cf90..d50153d8a31 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -91,8 +91,6 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -102,7 +100,7 @@ namespace gpu { namespace { -using tensorflow::port::Tracing; +namespace tracing = tensorflow::tracing; // Returns the directory containing nvvm libdevice files. config_cuda_data_dir // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the @@ -130,9 +128,8 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule(HloModule* hlo_module, - se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { +Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker(); @@ -250,7 +247,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, { HloPassPipeline pipeline("layout_assignment"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); + hlo_module->mutable_device_entry_computation_layout()); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -285,12 +282,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } - return tensorflow::Status::OK(); + return Status::OK(); } // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { +Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output @@ -412,7 +409,7 @@ void WarnIfBadDriverJITVersion() { // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, int cc_minor) { - Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); + tracing::ScopedActivity activity("Compile PTX", /*is_expensive=*/true); const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); VLOG(2) << "Using ptxas at " << ptxas_path; @@ -483,8 +480,8 @@ StatusOr> GpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); - Tracing::TraceMe annotation("HLO Transforms", module->name(), - /*is_expensive=*/true); + tracing::ScopedActivity activity("HLO Transforms", module->name(), + /*is_expensive=*/true); TF_RETURN_IF_ERROR( OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); @@ -694,7 +691,7 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, int cc_major, int cc_minor) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); - Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); + tracing::ScopedActivity activity("PTX->CUBIN", /*is_expensive=*/true); bool inserted; decltype(compilation_cache_.begin()) iter; // Pointers into compilation_cache_ where the ptx and (optional) cubin are @@ -779,9 +776,9 @@ se::Platform::Id GpuCompiler::PlatformId() const { } // namespace xla static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(se::cuda::kCudaPlatformId, []() { - return xla::MakeUnique(); - }); + xla::Compiler::RegisterCompilerFactory( + stream_executor::cuda::kCudaPlatformId, + []() { return xla::MakeUnique(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index c352d4d8462..f3b02ae5d88 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -45,25 +45,23 @@ class GpuCompiler : public LLVMCompiler { // Bring in // StatusOr>> Compile( // std::vector> modules, - // std::vector> + // std::vector> // stream_execs) using LLVMCompiler::Compile; StatusOr> RunHloPasses( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( - std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> module, AotCompilationOptions const& options) override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { // Capture just the pointer size, not the entire GpuCompiler object. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 28f93447953..25d8f720ea4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -32,14 +32,15 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { namespace { +using tensorflow::tracing::ScopedAnnotation; + // A helper class for profiling HLO in the course of GPU program execution. // All of the profiling is guarded internally, to avoid the caller needing to // have lots of conditionals sprinkled around. @@ -136,9 +137,10 @@ Status GpuExecutable::ExecuteThunks( const BufferAllocations& buffer_allocations, bool block_host_until_done, HloExecutionProfile* hlo_execution_profile) { se::Stream* main_stream = run_options->stream(); + se::StreamExecutor* executor = main_stream->parent(); std::pair stream_compute_compatibility; - main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + executor->GetDeviceDescription().cuda_compute_capability( &stream_compute_compatibility.first, &stream_compute_compatibility.second); TF_RET_CHECK(stream_compute_compatibility == compute_capability_) @@ -157,21 +159,39 @@ Status GpuExecutable::ExecuteThunks( sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); - TF_ASSIGN_OR_RETURN( - sub_streams.back(), - run_options->BorrowStream(main_stream->parent()->device_ordinal())); + TF_ASSIGN_OR_RETURN(sub_streams.back(), + run_options->BorrowStream(executor->device_ordinal())); } HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, sub_streams, hlo_module_->entry_computation()); uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // The next event enqueued on stream N must not run until the thunk at - // last_blocking_thunk_for_stream[N] completes. - std::map last_blocking_thunk_for_stream; + // This top-level trace serves two purposes: + // 1) It marks the scope of the whole XLA module. + // 2) It tells us whether tracing is enabled. We use this to avoid the + // expensive HloInstruction::ToString() calls inside the loop below if + // tracing is disabled. + ScopedAnnotation top_level_annotation(hlo_module_->name(), "XLA GPU module"); + std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - TF_RETURN_IF_ERROR(thunk->Initialize(*this)); + // Annotate execution of this op if tracing was enabled when we started + // running this module. If tracing is enabled *while* we're running the + // module, we won't get any data, but that's probably an OK trade-off. + // + // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), + // since we expect it to be an expensive call? + tensorflow::gtl::optional op_annotation; + if (top_level_annotation.IsEnabled()) { + op_annotation.emplace( + thunk->hlo_instruction() != nullptr + ? thunk->hlo_instruction()->ToString(HloPrintOptions::Canonical()) + : "", + "XLA op"); + } + + TF_RETURN_IF_ERROR(thunk->Initialize(*this, executor)); int32 stream_no = thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction()); se::Stream* stream = @@ -181,18 +201,10 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - if (last_blocking_thunk_for_stream.count(stream_no)) { - stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, - last_blocking_thunk_for_stream[stream_no]) - .get()); - last_blocking_thunk_for_stream.erase(stream_no); - } - // If this thunk requests it, wait for all currently-executing thunks to // finish. This is useful e.g. if the thunk is about to perform autotuning. if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); - last_blocking_thunk_for_stream.clear(); } profiler.StartOperation(); @@ -200,22 +212,11 @@ Status GpuExecutable::ExecuteThunks( << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { + if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); - - if (thunk->ShouldBlockFutureThunks()) { - // Set last_blocking_thunk_for_stream on all streams other than this one - // so that all other streams will wait for this thunk to complete before - // executing any events that occur later in the total order. - for (int32 i = 0; i < sub_streams.size() + 1; ++i) { - if (i != stream_no) { - last_blocking_thunk_for_stream[i] = thunk; - } - } - } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -252,7 +253,7 @@ Status GpuExecutable::ExecuteThunks( return Status::OK(); } -StatusOr> GpuExecutable::ExecuteOnStream( +StatusOr GpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -288,8 +289,8 @@ StatusOr> GpuExecutable::ExecuteOnStream( se::StreamExecutor* executor = run_options->stream()->parent(); TF_ASSIGN_OR_RETURN( auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); + buffer_allocations_builder.Build( + assignment_.get(), executor->device_ordinal(), memory_allocator)); bool block_host_until_done = !memory_allocator->AllowsAsynchronousDeallocation(); @@ -299,13 +300,13 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); - auto shaped_buffer = MakeUnique( - root->shape(), root->shape(), executor->platform(), device_ordinal); + ScopedShapedBuffer shaped_buffer(root->shape(), root->shape(), + memory_allocator, device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. std::set buffers_in_result; - TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachMutableElementWithStatus( [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { const auto& sources = this->GetRootPointsToSet().element(index); @@ -324,20 +325,19 @@ StatusOr> GpuExecutable::ExecuteOnStream( this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); CHECK(!slice.allocation()->is_entry_computation_parameter()); - perftools::gputools::DeviceMemoryBase src_base = + se::DeviceMemoryBase src_base = buffer_allocations->GetDeviceAddress(slice.index()); CHECK(!src_base.is_null() || src_base.size() == 0); *device_memory = src_base; buffers_in_result.insert(src_base); return Status::OK(); })); - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(buffers_in_result, *assignment_)); + TF_RETURN_IF_ERROR(buffer_allocations->TearDown(buffers_in_result)); return std::move(shaped_buffer); } -StatusOr> GpuExecutable::ExecuteAsyncOnStream( +StatusOr GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index dcb3991f41a..80ec38c3ac1 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -74,12 +74,12 @@ class GpuExecutable : public Executable { // ExecuteOnStream will fail if the compute capability of the stream doesn't // match the compute capability passed to this object's constructor. - StatusOr> ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index af9897769fd..7bb8df6581b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -33,8 +33,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { // TODO(b/30467474) Once GPU infeed implementation settles, consider @@ -46,8 +44,8 @@ GpuTransferManager::GpuTransferManager() /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) .getPointerSize(0 /* default address space */)) {} -Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, - const Literal& literal) { +Status GpuTransferManager::TransferLiteralToInfeed( + se::StreamExecutor* executor, const LiteralSlice& literal) { const Shape& shape = literal.shape(); VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); @@ -153,8 +151,8 @@ static std::unique_ptr CreateGpuTransferManager() { } static bool InitModule() { - xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId, - &CreateGpuTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 9aa369c6683..09f8227f508 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -36,21 +36,20 @@ class GpuTransferManager : public GenericTransferManager { GpuTransferManager(); ~GpuTransferManager() override {} - Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, - const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; + Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) override; + Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size, + const void* source) override; private: // Initiates the infeed data transfers. InfeedBuffer->Done() must be // called to clean up the memory allocated for InfeedBuffer. StatusOr TransferBufferToInfeedInternal( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source); + se::StreamExecutor* executor, int64 size, const void* source); // Enqueues infeed data buffers with the infeed manager after their // transfer completes. - Status EnqueueBuffersToInfeed(perftools::gputools::StreamExecutor* executor, + Status EnqueueBuffersToInfeed(se::StreamExecutor* executor, std::vector buffers); TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 42c1539e86c..f766f968826 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -199,7 +200,7 @@ StatusOr> HloSchedule::Build( TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, CreateMemoryMinimizingSequence( - *entry_computation, [pointer_size](const LogicalBuffer& buffer) { + *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); } else { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index ece9fa04dce..e230d538cc2 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -42,6 +42,15 @@ class HloScheduleTest : public HloTestBase { .ConsumeValueOrDie(); } + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", VersionedComputationHandle(), + config); + } + HloVec RemoveHlo(const HloVec& input, const std::unordered_set& remove) { HloVec result(input); @@ -65,9 +74,9 @@ TEST_F(HloScheduleTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -193,11 +202,11 @@ TEST_F(HloScheduleTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, dot2)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(add)); @@ -259,24 +268,24 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index ee5b447c9cd..3ddc1c0789d 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -19,8 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index 73d5a5ce354..d5f2216d460 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -46,7 +46,7 @@ namespace gpu { // the client. The client manages the memory of the buffer. class InfeedBuffer { public: - InfeedBuffer(perftools::gputools::StreamExecutor* executor, int64 length) + InfeedBuffer(se::StreamExecutor* executor, int64 length) : executor_(executor), length_(length) { device_memory_ = executor_->AllocateArray(length); CHECK(!device_memory_.is_null()); @@ -60,14 +60,12 @@ class InfeedBuffer { // client to manage memory for the infeed buffers. void Done() { delete this; } - perftools::gputools::DeviceMemoryBase* device_memory() { - return &device_memory_; - } + se::DeviceMemoryBase* device_memory() { return &device_memory_; } private: - perftools::gputools::StreamExecutor* executor_; // Not owned. + se::StreamExecutor* executor_; // Not owned. const int64 length_; - perftools::gputools::DeviceMemoryBase device_memory_; + se::DeviceMemoryBase device_memory_; }; // Client-side class used to enqueue infeed buffers. @@ -100,8 +98,7 @@ class InfeedManager { // new stream on the first invocation. On subsequent invocations, if // the cached executor is not the same as the requested executor, // returns null. - perftools::gputools::Stream* GetStream( - perftools::gputools::StreamExecutor* executor); + se::Stream* GetStream(se::StreamExecutor* executor); private: // TODO(b/30467474): Revisit if this mutex becomes a point of @@ -121,10 +118,10 @@ class InfeedManager { tensorflow::gtl::FlatSet dequeued_buffer_; // Cached host to device stream for queuing infeed data. - std::unique_ptr host_to_device_stream_; + std::unique_ptr host_to_device_stream_; // Executor that the host_to_device_stream belongs to. Not owned. - perftools::gputools::StreamExecutor* host_to_device_executor_; + se::StreamExecutor* host_to_device_executor_; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 2ac95ceb692..ea34d5b30c9 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -31,10 +31,10 @@ InfeedThunk::InfeedThunk( destination_buffer_(destination_buffer) {} Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { + se::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - perftools::gputools::DeviceMemoryBase destination_address = + se::DeviceMemoryBase destination_address = buffer_allocations.GetDeviceAddress(destination_buffer_); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); @@ -45,7 +45,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, std::vector tuple_element_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { - perftools::gputools::DeviceMemoryBase tuple_element_address = + se::DeviceMemoryBase tuple_element_address = buffer_allocations.GetDeviceAddress(tuple_element_buffer); InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 86918705fa0..93713cb12de 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -44,7 +44,7 @@ class InfeedThunk : public Thunk { InfeedThunk& operator=(const InfeedThunk&) = delete; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 85ecbe8fdb3..5d5bef6b57b 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -46,41 +48,100 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { + if (constant->opcode() != HloOpcode::kConstant || + !ShapeUtil::IsScalar(constant->shape())) { + return false; + } + auto type = constant->shape().element_type(); + return type == F16 || type == F32 || type == F64; +} + } // namespace +/*static*/ bool GpuInstructionFusion::IsExpensive( + const HloInstruction& instruction) { + switch (instruction.opcode()) { + // We say that floating-point division is cheap on the GPU. + case HloOpcode::kDivide: + return !ShapeUtil::ElementIsFloating(instruction.shape()) && + InstructionFusion::IsExpensive(instruction); + + default: + return InstructionFusion::IsExpensive(instruction); + } +} + bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (producer->opcode() == HloOpcode::kDot) { - if (consumer->opcode() == HloOpcode::kMultiply) { - CHECK_EQ(consumer->operand_count(), 2); - int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); - if (alpha->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalar(alpha->shape())) { + if (consumer->operand_count() == 2 && + (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + int64 other_operand_index = 1 - operand_index; + const HloInstruction* alpha = consumer->operand(other_operand_index); + HloInstruction* op1 = nullptr; + HloInstruction* op2 = nullptr; + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && + Match(consumer->fused_expression_root(), + match::Op() + .WithOpcode(HloOpcode::kMultiply) + .WithOperand(0, match::Op(&op1)) + .WithOperand(1, match::Op(&op2)))) { + CHECK(op1 != nullptr && op2 != nullptr); + // If 'consumer' is a fusion node, it should consist of a broadcast of a + // scalar constant fused into a multiply, but nothing more. So one operand + // should be a parameter, and the other should be a broadcast. + if (op1->opcode() != HloOpcode::kParameter) { + std::swap(op1, op2); + } + if (op1->opcode() != HloOpcode::kParameter || + op2->opcode() != HloOpcode::kBroadcast) { + return false; + } + if (IsIEEEFloatingPointScalarConstant(alpha)) { + return true; + } + } else if (consumer->opcode() == HloOpcode::kMultiply) { + // Fuse if 'alpha' is a broadcast of a scalar constant. + if (alpha->opcode() == HloOpcode::kBroadcast && + alpha->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(alpha->operand(0))) { return true; } } } - // Only allow to fuse transpose into an output fusion. + // Only allow fusing transpose or broadcast into an output fusion that is + // implemented as a Gemm call. if (consumer->opcode() == HloOpcode::kFusion && - consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) { - if (producer->opcode() != HloOpcode::kTranspose) { - return false; - } - // Check that the transpose is the operand of a dot. + consumer->fusion_kind() == HloInstruction::FusionKind::kOutput && + ImplementedAsGemm(*consumer)) { auto producer_operand_index = consumer->operand_index(producer); auto fused_parameter = consumer->fused_parameter(producer_operand_index); const std::vector& fused_parameter_users = fused_parameter->users(); - return (fused_parameter_users.size() == 1 && - fused_parameter_users[0]->opcode() == HloOpcode::kDot); + if (fused_parameter_users.size() != 1) { + return false; + } + if (producer->opcode() == HloOpcode::kTranspose) { + // Check that the transpose is an operand of a dot. + return fused_parameter_users[0]->opcode() == HloOpcode::kDot; + } + if (producer->opcode() == HloOpcode::kBroadcast) { + // Check that the broadcast is a broadcast of a scalar constant into a + // multiply. + return producer->dimensions().empty() && + IsIEEEFloatingPointScalarConstant(producer->operand(0)) && + fused_parameter_users[0]->opcode() == HloOpcode::kMultiply; + } } - // Output fusion is not currently supported on GPUs. + // Other output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { return false; } @@ -121,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( if (IsReductionToVector(*consumer)) { return HloInstruction::FusionKind::kInput; } - if (producer->opcode() == HloOpcode::kDot) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { return HloInstruction::FusionKind::kOutput; } if (HloOpcode::kFusion == consumer->opcode()) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index bb2990e6dfc..9fb06b0a244 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -27,6 +27,8 @@ class GpuInstructionFusion : public InstructionFusion { explicit GpuInstructionFusion(bool may_duplicate) : InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {} + static bool IsExpensive(const HloInstruction& instruction); + bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; HloInstruction::FusionKind ChooseKind( diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 4b231c449f8..760e0e90f58 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1)); @@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(S32, {1, 1}), "0")); - auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0)); + auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot( + ShapeUtil::MakeShape(S32, {1, 1}), param0, param0)); auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1})); @@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { auto module = tools::Parse(R"( HloModule test_module ENTRY OutputFusion { - constant = f32[] constant(3) + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} p0 = f32[4,3]{1,0} parameter(0) p1 = f32[4,3]{1,0} parameter(1) transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0} - dot = f32[4,4]{1,0} dot(p0, transpose) - ROOT mul = f32[4,4] multiply(constant, dot) + dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT mul = f32[4,4] multiply(dot, broadcast) })") .ValueOrDie(); @@ -247,10 +248,93 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput); EXPECT_THAT( root->fused_expression_root(), - op::Multiply(op::Parameter(), - op::Dot(op::Parameter(), op::Transpose(op::Parameter())))); + op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), + op::Broadcast(op::Parameter()))); +} + +// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is +// duplicated and fused into both reduces. +TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = f32[] constant(0) + one = f32[] constant(1) + p0 = f32[100] parameter(0) + recip = f32[100] divide(one, p0) + sum1 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = f32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT root = (f32[], f32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); +} + +// Compute sum(100/p0), where p0 has type s32, twice. Check that the division +// is *not* duplicated and fused into both reduces, because we say that integer +// division is not cheap. +TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { + auto module = tools::Parse(R"( + HloModule test_module + Add { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(lhs, rhs) + } + ENTRY TestComputation { + zero = s32[] constant(0) + one_hundred = s32[] constant(100) + p0 = s32[100] parameter(0) + recip = s32[100] divide(one_hundred, p0) + sum1 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + sum2 = s32[] reduce(recip, zero), dimensions={0}, to_apply=Add + ROOT mul = (s32[], s32[]) tuple(sum1, sum2) + })") + .ValueOrDie(); + + EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); +} + +TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY NoOutputFusion { + alpha = f32[] constant(3) + broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={} + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[3,4]{1,0} parameter(1) + dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + d = f32[4,4]{1,0} multiply(dot, dot) + ROOT mul = f32[4,4] multiply(d, broadcast) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_THAT(root->fused_expression_root(), + op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), + op::Broadcast(op::Parameter()))); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 532d436ee82..22e71509952 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, !ShapeUtil::HasZeroElements(lhs_shape) && !ShapeUtil::HasZeroElements(rhs_shape); } + +bool DotImplementedAsGemm(const HloInstruction& dot) { + CHECK_EQ(dot.opcode(), HloOpcode::kDot); + const Shape& lhs_shape = dot.operand(0)->shape(); + const Shape& rhs_shape = dot.operand(1)->shape(); + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), + rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); + return true; + } + return false; +} } // namespace bool ImplementedAsGemm(const HloInstruction& hlo) { @@ -69,24 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { // For certain types of Dot, we can call pre-canned BLAS gemm. if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - return true; + return DotImplementedAsGemm(hlo); } if (hlo.opcode() == HloOpcode::kFusion && @@ -98,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { dot = hlo.fused_expression_root()->operand(1); } if (dot->opcode() == HloOpcode::kDot) { - return ImplementedAsGemm(*dot); + return DotImplementedAsGemm(*dot); } } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 3790ed313b9..a78b4ff8307 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -32,7 +32,7 @@ class IrEmitterContext { public: IrEmitterContext(const HloModule* hlo_module, const BufferAssignment* buffer_assignment, - const perftools::gputools::DeviceDescription* device_desc, + const se::DeviceDescription* device_desc, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), @@ -47,7 +47,7 @@ class IrEmitterContext { const BufferAssignment& buffer_assignment() const { return *buffer_assignment_; } - const perftools::gputools::DeviceDescription& device_description() const { + const se::DeviceDescription& device_description() const { return *device_desc_; } llvm::Module* llvm_module() { return llvm_module_; } @@ -56,7 +56,7 @@ class IrEmitterContext { private: const HloModule* hlo_module_; const BufferAssignment* buffer_assignment_; - const perftools::gputools::DeviceDescription* device_desc_; + const se::DeviceDescription* device_desc_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 71aada080ae..bb47a428054 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/core/lib/core/status.h" @@ -116,6 +117,26 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { + // For MOF we give the loop emitter an array for every output it should + // generate. + if (hlo.IsMultiOutputFusion()) { + std::vector target_arrays; + for (int64 i = 0, e = ShapeUtil::TupleElementCount(hlo.shape()); i != e; + ++i) { + target_arrays.push_back(GetIrArray(hlo, hlo, {i})); + } + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_arrays, &ir_builder_) + .EmitLoop()); + + std::vector tuple_operand_ptrs; + for (const llvm_ir::IrArray& array : target_arrays) { + tuple_operand_ptrs.push_back(array.GetBasePointer()); + } + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, + module_); + return Status::OK(); + } return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &ir_builder_) .EmitLoop(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 26e497762f2..957733fe8aa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -257,8 +257,39 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( return kernel; } +namespace { +// Computes the maximum valid unroll factor for a given instruction. +int ComputeMaxUnrollFactor(const HloInstruction* hlo) { + int max_unroll_factor = hlo->GetModule() + ->config() + .debug_options() + .xla_gpu_max_kernel_unroll_factor(); + + // Find the largest possible power of two to unroll by. + // TODO(kramerb): Make this smarter. + const Shape& element_shape = hlo->IsMultiOutputFusion() + ? ShapeUtil::GetSubshape(hlo->shape(), {0}) + : hlo->shape(); + int64 num_elements = ShapeUtil::ElementsIn(element_shape); + for (int i = max_unroll_factor; i > 1; i /= 2) { + if (num_elements % i == 0) { + return i; + } + } + + // Cannot unroll. + return 1; +} +} // namespace + Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { - thunk_sequence_->emplace_back(BuildKernelThunk(hlo)); + int unroll_factor = 1; + // Unfused elementwise operations are usually memory bound, unroll them. + if (hlo->IsElementwise()) { + unroll_factor = ComputeMaxUnrollFactor(hlo); + } + + thunk_sequence_->emplace_back(BuildKernelThunk(hlo, unroll_factor)); return IrEmitter::DefaultAction(hlo); } @@ -537,24 +568,8 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { return Status::OK(); } - int max_unroll_factor = fusion->GetModule() - ->config() - .debug_options() - .xla_gpu_max_kernel_unroll_factor(); - - // Find the largest possible power of two to unroll by. - // TODO(kramerb): Make this smarter. - int unroll_factor = 1; - if (!fusion->IsMultiOutputFusion()) { - CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); - int64 num_elements = ShapeUtil::ElementsIn(fusion->shape()); - for (int i = max_unroll_factor; i > 1; i /= 2) { - if (num_elements % i == 0) { - unroll_factor = i; - break; - } - } - } + CHECK(fusion->fusion_kind() == HloInstruction::FusionKind::kLoop); + int unroll_factor = ComputeMaxUnrollFactor(fusion); thunk_sequence_->emplace_back(BuildKernelThunk(fusion, unroll_factor)); return IrEmitter::HandleFusion(fusion); @@ -2178,6 +2193,21 @@ std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( /*destination_buffer=*/GetAllocationSlice(*inst), inst); } +namespace { +double GetScalarConstantAsDouble(const Literal& literal) { + switch (literal.shape().element_type()) { + case F16: + return static_cast(literal.Get({})); + case F32: + return literal.Get({}); + case F64: + return literal.Get({}); + default: + LOG(FATAL) << "Unsupported type."; + } +} +} // namespace + std::unique_ptr IrEmitterUnnested::BuildGemmThunk( const HloInstruction* inst) { if (inst->opcode() == HloOpcode::kDot) { @@ -2190,65 +2220,48 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( lhs->shape(), // The shape of LHS. rhs->shape(), // The shape of RHS. inst->shape(), // The shape of the output. - false, // Do not transpose LHS. - false, // Do not transpose RHS. 1.0, // alpha. inst); } if (inst->opcode() == HloOpcode::kFusion) { - if (inst->fusion_kind() == HloInstruction::FusionKind::kOutput) { - const HloInstruction* mul = inst->fused_expression_root(); - const HloInstruction* dot = mul->operand(0); - const HloInstruction* alpha = mul->operand(1); - if (dot->opcode() != HloOpcode::kDot) { - std::swap(dot, alpha); - } - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*mul), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - alpha->literal().Get({0}), // alpha. - inst); - } else { - const HloInstruction* dot = inst->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); - DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && - rhs_parameter->opcode() == HloOpcode::kParameter); - const HloInstruction* lhs = - inst->operand(lhs_parameter->parameter_number()); - const HloInstruction* rhs = - inst->operand(rhs_parameter->parameter_number()); - - return MakeUnique( - GetAllocationSlice(*lhs), // The buffer assigned to LHS. - GetAllocationSlice(*rhs), // The buffer assigned to RHS. - GetAllocationSlice(*inst), // The output buffer. - lhs->shape(), // The shape of LHS. - rhs->shape(), // The shape of RHS. - inst->shape(), // The shape of the output. - dot->operand(0)->IsRank2Transpose(), // Transpose LHS. - dot->operand(1)->IsRank2Transpose(), // Transpose RHS. - 1.0, // Alpha. - inst); + CHECK_EQ(inst->fusion_kind(), HloInstruction::FusionKind::kOutput); + const HloInstruction* mul = inst->fused_expression_root(); + const HloInstruction* dot = mul->operand(0); + const HloInstruction* alpha = mul->operand(1); + if (dot->opcode() != HloOpcode::kDot) { + std::swap(dot, alpha); } + if (alpha->opcode() == HloOpcode::kBroadcast) { + alpha = alpha->operand(0); + } + alpha = inst->operand(alpha->parameter_number()); + // TODO(b/74185543): Remove the following if block once we support fusion + // with a non-constant as well. Then we will just always use the constant + // on the device. + if (alpha->opcode() == HloOpcode::kCopy) { + alpha = alpha->operand(0); + } + + DCHECK(dot->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && + rhs_parameter->opcode() == HloOpcode::kParameter); + const HloInstruction* lhs = + inst->operand(lhs_parameter->parameter_number()); + const HloInstruction* rhs = + inst->operand(rhs_parameter->parameter_number()); + + return MakeUnique( + GetAllocationSlice(*lhs), // The buffer assigned to LHS. + GetAllocationSlice(*rhs), // The buffer assigned to RHS. + GetAllocationSlice(*inst), // The output buffer. + lhs->shape(), // The shape of LHS. + rhs->shape(), // The shape of RHS. + inst->shape(), // The shape of the output. + GetScalarConstantAsDouble(alpha->literal()), // alpha. + inst); } LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); @@ -2524,16 +2537,14 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( .EmitLoop(IrName(&hlo)); } - CHECK_EQ(unroll_factor, 1) - << "multi-output fusion does not support unrolling"; - // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, - launch_dimensions, &ir_builder_) + launch_dimensions, &ir_builder_, + unroll_factor) .EmitLoop(IrName(&hlo))); std::vector tuple_operand_ptrs; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b842f480c62..b41ab2162ab 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -38,7 +38,7 @@ namespace gpu { // // Examples of things that are not unnested computations: // -// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The reducer of a kReduce HLO. This is emitted using IrEmitterNested. // - The body of a fusion node. IrEmitterUnenested emits the relevant code // within a kernel function using FusedIrEmitter. (FusedIrEmitter is not // really an IrEmitter, but is more an "IR generator generator".) diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index c24dc1457f8..f56c1ce69f1 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -23,8 +23,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { @@ -37,26 +35,38 @@ KernelThunk::KernelThunk( kernel_name_(kernel_name), unroll_factor_(unroll_factor) {} -tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { +Status KernelThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { tensorflow::mutex_lock lock(mutex_); - if (loader_spec_) { - // Already initialized by another thread. - return tensorflow::Status::OK(); + if (!loader_spec_) { + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); + tensorflow::StringPiece ptx = executable.ptx(); + // Convert tensorflow::StringPiece to se::port::StringPiece because + // StreamExecutor uses the latter. + loader_spec_->AddCudaPtxInMemory( + se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), + kernel_name_); + } } - loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because - // StreamExecutor uses the latter. - loader_spec_->AddCudaPtxInMemory( - se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); - - if (!executable.cubin().empty()) { - loader_spec_->AddCudaCubinInMemory( - reinterpret_cast(executable.cubin().data()), kernel_name_); + // Load the kernel into the device if necessary. + // + // We could alternatively do this within ExecuteOnStream, but doing it here + // lets the time spent loading the kernel not count towards our execution + // profiles. + auto it = kernel_cache_.find(executor); + if (kernel_cache_.end() == it) { + it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; + if (!executor->GetKernel(*loader_spec_, &it->second)) { + return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + } } - return tensorflow::Status::OK(); + return Status::OK(); } void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { @@ -64,21 +74,18 @@ void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) { launch_dimensions_ = launch_dims; } -tensorflow::Status KernelThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { // Load the kernel. se::StreamExecutor* executor = stream->parent(); LaunchDimensions launch_dimensions; const se::KernelBase* kernel = nullptr; + { tensorflow::mutex_lock lock(mutex_); auto it = kernel_cache_.find(executor); - if (kernel_cache_.end() == it) { - it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; - if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); - } - } + CHECK(it != kernel_cache_.end()) + << "Initialize() not called for StreamExecutor " << executor; launch_dimensions = launch_dimensions_; kernel = &it->second; } @@ -99,7 +106,7 @@ tensorflow::Status KernelThunk::ExecuteOnStream( *kernel_args)) { return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index df8971b083f..7def27e189b 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -57,12 +57,12 @@ class KernelThunk : public Thunk { int unroll_factor() const { return unroll_factor_; } void SetLaunchDimensions(const LaunchDimensions& launch_dims); - tensorflow::Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; // Executes the kernel for the thunk on "stream", which must be non-null. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // Buffers passed to the kernel as arguments. @@ -82,13 +82,12 @@ class KernelThunk : public Thunk { // Describes how to load this kernel. ExecuteOnStream reuses this loader // specification for all executions. mutable tensorflow::mutex mutex_; - std::unique_ptr loader_spec_ - GUARDED_BY(mutex_); + std::unique_ptr loader_spec_ GUARDED_BY(mutex_); - // Loaded kernels for each `StreamExecutor` - std::unordered_map - kernel_cache_ GUARDED_BY(mutex_); + // Loaded kernels for each `StreamExecutor`. Requires pointer stability of + // values. + std::unordered_map kernel_cache_ + GUARDED_BY(mutex_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 86c4ac18b05..7de8f9e1ee9 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -47,7 +47,6 @@ cc_library( "@llvm//:scalar", "@llvm//:support", "@llvm//:target", - "@llvm//:transform_utils", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index df9d9be889c..917c5768234 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -77,8 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = - tensorflow::Env::Default()->GetMatchingPaths( + const Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); if (status.ok() && unified_libdevice_files.size() == 1) { @@ -311,11 +310,11 @@ bool CouldNeedLibdevice(const llvm::Module& module) { } // Links libdevice into the given module if the module needs libdevice. -tensorflow::Status LinkLibdeviceIfNecessary( - llvm::Module* module, std::pair compute_capability, - const string& libdevice_dir_path) { +Status LinkLibdeviceIfNecessary(llvm::Module* module, + std::pair compute_capability, + const string& libdevice_dir_path) { if (!CouldNeedLibdevice(*module)) { - return tensorflow::Status::OK(); + return Status::OK(); } llvm::Linker linker(*module); @@ -336,7 +335,7 @@ tensorflow::Status LinkLibdeviceIfNecessary( return tensorflow::errors::Internal(tensorflow::strings::StrCat( "Error linking libdevice from ", libdevice_path)); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr CompileModuleToPtx(llvm::Module* module, @@ -491,7 +490,7 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - tensorflow::port::Tracing::TraceMe annotation( + tensorflow::tracing::ScopedActivity activity( "Compiling IR", llvm_ir::AsString(module->getName()), /*is_expensive=*/true); XLA_SCOPED_LOGGING_TIMER("Compile module " + diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc index 18e673542c5..d4100a898b5 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.cc @@ -19,8 +19,6 @@ limitations under the License. namespace xla { namespace gpu { -namespace se = ::perftools::gputools; - Status MemzeroThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { se::DeviceMemoryBase dest_data = buffer_allocations.GetDeviceAddress(dest_); diff --git a/tensorflow/compiler/xla/service/gpu/memset_thunk.h b/tensorflow/compiler/xla/service/gpu/memset_thunk.h index b4bb74d1dd6..51c332d287d 100644 --- a/tensorflow/compiler/xla/service/gpu/memset_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/memset_thunk.h @@ -36,7 +36,7 @@ class MemzeroThunk : public Thunk { : Thunk(Kind::kMemzero, hlo), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: const BufferAllocation::Slice dest_; @@ -52,7 +52,7 @@ class Memset32BitValueThunk : public Thunk { : Thunk(Kind::kMemset32BitValue, hlo), value_(value), dest_(dest) {} Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: uint32 value_; diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 7bda4e2fcd4..c8f0d4185c6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -370,23 +370,35 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( return true; } +StatusOr PadInsertion::RunOnComputation(HloComputation* computation) { + bool changed = false; + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + for (HloInstruction* instruction : convs) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); + } + } + return changed; +} + StatusOr PadInsertion::Run(HloModule* module) { bool changed = false; - for (HloInstruction* instruction : - module->entry_computation()->MakeInstructionPostOrder()) { - if (IsCustomCallToDnnConvolution(*instruction)) { - const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - changed |= CanonicalizeBackwardFilterConvolution(instruction); - } else if (target == kCudnnConvBackwardInputCallTarget) { - changed |= CanonicalizeBackwardInputConvolution(instruction); - } else { - LOG(FATAL) << "Unknown custom call target for cudnn conv: " - << instruction->ToString(); - } - } + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 5e1c68701da..67e51509e4c 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -31,6 +31,7 @@ class PadInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; private: + StatusOr RunOnComputation(HloComputation* computation); // Returns if any changes are made to the parent computation. bool CanonicalizeForwardConvolution(HloInstruction* conv); bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv); diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index 5283d51cd10..d3fd0544fb6 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -29,8 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 42d2d2af2e3..c125474edb1 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -57,8 +57,7 @@ std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc, + const Shape& shape, const se::DeviceDescription& device_desc, int unroll_factor = 1); } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index d8a43091d40..b50f5b5a903 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -24,21 +24,20 @@ SequentialThunk::SequentialThunk(std::vector>&& thunks, const HloInstruction* hlo) : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {} -tensorflow::Status SequentialThunk::Initialize( - const GpuExecutable& executable) { +Status SequentialThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { for (auto& thunk : thunks_) { - TF_RETURN_IF_ERROR(thunk->Initialize(executable)); + TF_RETURN_IF_ERROR(thunk->Initialize(executable, executor)); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status SequentialThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status SequentialThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { for (const auto& thunk : thunks_) { TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 32c5b748aba..3537110bb5c 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -38,10 +38,10 @@ class SequentialThunk : public Thunk { const std::vector>& thunks() const { return thunks_; } - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: // The list of sub-thunks. diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 8c98956f1a9..696fa7e0194 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -28,6 +28,15 @@ namespace gpu { class StreamAssignmentTest : public HloTestBase { protected: + std::unique_ptr CreateNewModule() { + HloModuleConfig config; + auto debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_disable_multi_streaming(false); + config.set_debug_options(debug_options); + return MakeUnique("test_module", VersionedComputationHandle(), + config); + } + // Pre-canned shapes. Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2}); }; @@ -41,9 +50,9 @@ TEST_F(StreamAssignmentTest, SequentialMatMul) { HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/2, f32_2x2_, /*name=*/"z")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z)); + HloInstruction::CreateCanonicalDot(f32_2x2_, dot1, z)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(dot2)); @@ -60,9 +69,9 @@ TEST_F(StreamAssignmentTest, ConcurrentMatMul) { HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/1, f32_2x2_, /*name=*/"y")); HloInstruction* dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y)); + HloInstruction::CreateCanonicalDot(f32_2x2_, x, y)); HloInstruction* dot2 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x)); + HloInstruction::CreateCanonicalDot(f32_2x2_, y, x)); HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2)); @@ -91,24 +100,24 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); } - HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary( - f32_2x2_, HloOpcode::kDot, params[2], params[3])); + HloInstruction* d00 = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); HloInstruction* d10 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[1], d00)); HloInstruction* d11 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d00, params[4])); HloInstruction* d20 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10)); + HloInstruction::CreateCanonicalDot(f32_2x2_, params[0], d10)); HloInstruction* d21 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d10, d11)); HloInstruction* d22 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5])); + HloInstruction::CreateCanonicalDot(f32_2x2_, d11, params[5])); HloInstruction* d30 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d20, d21)); HloInstruction* d31 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d21, d22)); HloInstruction* d40 = builder.AddInstruction( - HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31)); + HloInstruction::CreateCanonicalDot(f32_2x2_, d30, d31)); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build(d40)); diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 9eea958d121..931c0bffab8 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,11 +70,14 @@ class Thunk { Kind kind() const { return kind_; } const HloInstruction* hlo_instruction() const { return hlo_instruction_; } - // Prepares for executing the thunk. This method is called only once over - // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a - // kernel, which is the same in every execution. - virtual tensorflow::Status Initialize(const GpuExecutable& executable) { - return tensorflow::Status::OK(); + // Prepares the thunk for execution on the given StreamExecutor. + // + // This may be called multiple times. Its main purpose is to give us a chance + // to do initialization outside of ExecuteOnStream() so that the + // time spent initializing doesn't count towards our execution profile. + virtual Status Initialize(const GpuExecutable& /*executable*/, + se::StreamExecutor* /*executor*/) { + return Status::OK(); } // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) @@ -85,27 +88,17 @@ class Thunk { // This value is not required to be constant for a given Thunk. For example, // a Thunk that performs autotuning may return true for its first run and // false thereafter. - virtual bool ShouldHaltAllActivityBeforeRunning( - perftools::gputools::Stream* /*stream*/) { + virtual bool ShouldHaltAllActivityBeforeRunning(se::Stream* /*stream*/) { return false; } - // Indicates whether thunks scheduled after this one should wait for this one - // to complete before running. For example, a convolution thunk creates a - // scratch allocator, then kicks off a convolution in cudnn via the stream - // executor. When the stream executor call returns, the scratch allocator goes - // out of scope, and the scratch memory is deallocated. In this case, the - // convolution thunk needs to return true so that future thunks wait for the - // convolution thunk to avoid reusing the deallocated memory until the - // convolution thunk is done with it. - virtual bool ShouldBlockFutureThunks() { return false; } - // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. - virtual tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) = 0; + // + // Precondition: Initialize(stream->parent()) has been called. + virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) = 0; private: Kind kind_; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index bd65e72393a..97cb04c38fb 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -17,13 +17,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" -namespace se = ::perftools::gputools; - namespace xla { namespace gpu { -tensorflow::Status TupleThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, se::Stream* stream) { +Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) { std::vector tuple_element_buffer_addresses; for (BufferAllocation::Slice tuple_element_buffer : tuple_element_buffers_) { tuple_element_buffer_addresses.push_back( @@ -42,7 +40,7 @@ tensorflow::Status TupleThunk::ExecuteOnStream( tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(), sizeof(void*) * tuple_element_buffer_addresses.size()); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h index 3b1a4963285..951f809b519 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h @@ -45,9 +45,8 @@ class TupleThunk : public Thunk { TupleThunk(const TupleThunk&) = delete; TupleThunk& operator=(const TupleThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + se::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index c21559af6d2..30b9640c4c7 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,15 +34,17 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -Status WhileThunk::Initialize(const GpuExecutable& executable) { - TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); - TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); +Status WhileThunk::Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) { + TF_RETURN_IF_ERROR( + condition_thunk_sequence_->Initialize(executable, executor)); + TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); return Status::OK(); } Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - perftools::gputools::DeviceMemoryBase condition_result_data = + se::Stream* stream) { + se::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); while (true) { diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 4c9f45de9e4..22176685a92 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,9 +45,10 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - Status Initialize(const GpuExecutable& executable) override; + Status Initialize(const GpuExecutable& executable, + se::StreamExecutor* executor) override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + se::Stream* stream) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index e6caec8625f..ad55728c455 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -144,7 +144,7 @@ class ExprTree { TF_RETURN_IF_ERROR(pair.second->Match(instruction->operand(pair.first), tagged_instructions)); } - return tensorflow::Status::OK(); + return Status::OK(); } private: @@ -169,7 +169,7 @@ class MatcherBase { // Attempts to match each ExprTree in 'expr_trees_'. // Returns OK on the first successful match, error status otherwise. - virtual tensorflow::Status Run() { + virtual Status Run() { Status status; for (const ExprTree& expr_tree : expr_trees_) { status = MatchExprTree(expr_tree); @@ -201,7 +201,7 @@ class MatcherBase { } else if (type == S64) { *const_value = literal.GetFirstElement(); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr GetTaggedInstruction( @@ -315,7 +315,7 @@ class WhileConditionComputationMatcher : public MatcherBase { gte_fusion_param0->name().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; @@ -379,7 +379,7 @@ class WhileInitOperandMatcher : public MatcherBase { GetTaggedInstruction("loop_start", tagged_instructions)); TF_RETURN_IF_ERROR(ParseConstInteger(const_hlo, &loop_start_)); - return tensorflow::Status::OK(); + return Status::OK(); } const HloInstruction* while_hlo_; @@ -477,7 +477,7 @@ class WhileBodyComputationMatcher : public MatcherBase { } } } - return tensorflow::Status::OK(); + return Status::OK(); } const HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 05017008e2d..acf66114869 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -82,7 +82,8 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // instructions. Sets the computation as the entry to an HLO module and returns // the module. std::unique_ptr MakeBigGraph() { - auto module = MakeUnique("BigGraph"); + HloModuleConfig config; + auto module = MakeUnique("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 3dd4c4a0794..06a5e0351b6 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -32,7 +31,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = @@ -47,7 +46,7 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, @@ -73,11 +72,11 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap> live_buffers; - FlatMap> used_buffers; + FlatMap> live_buffers; + FlatMap> used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, - const LogicalBuffer* buffer) { + const BufferValue* buffer) { if (!IgnoreBuffer(buffer)) { VLOG(4) << " Adding user " << user->name() << " to buffer " << buffer->ToString(); @@ -96,7 +95,7 @@ Status HeapSimulator::RunComputation( const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); for (const HloInstruction* user : instruction->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { - for (const LogicalBuffer* buffer : buffer_set) { + for (const BufferValue* buffer : buffer_set) { add_user_to_buffer(user, buffer); } } else { @@ -104,12 +103,12 @@ Status HeapSimulator::RunComputation( // alive. It only needs the buffers that relate to the element its // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. - for (const LogicalBuffer* buffer : points_to.element({})) { + for (const BufferValue* buffer : points_to.element({})) { add_user_to_buffer(user, buffer); } const PointsToSet& gte_points_to = points_to_analysis.GetPointsToSet(user); - for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + for (const BufferValue* buffer : gte_points_to.CreateFlattenedSet()) { add_user_to_buffer(user, buffer); } } @@ -117,24 +116,25 @@ Status HeapSimulator::RunComputation( } const HloInstruction* root = computation.root_instruction(); - auto output_source_buffers = - points_to_analysis.GetPointsToSet(root).CreateFlattenedSet(); + BufferValueCompactPointerSet output_source_buffers = + ToBufferValueCompactPointerSet( + points_to_analysis.GetPointsToSet(root).CreateFlattenedSet()); - std::vector dead_buffers_to_free; - std::vector operand_buffers_to_free; + std::vector dead_buffers_to_free; + std::vector operand_buffers_to_free; for (const HloInstruction* instruction : instruction_sequence) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); VLOG(3) << "Instruction: " << instruction->ToString(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { VLOG(4) << " Defines: " << buffer->ToString() << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); } dead_buffers_to_free.clear(); - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -161,7 +161,7 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { + for (const BufferValue* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } @@ -177,7 +177,7 @@ Status HeapSimulator::RunComputation( } // Sort to get a deterministic iteration order. std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); @@ -188,7 +188,7 @@ Status HeapSimulator::RunComputation( // // INVARIANT: Either Alloc or ShareBuffer will be called for each buffer // that we should assign. - for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } @@ -199,12 +199,12 @@ Status HeapSimulator::RunComputation( // we must be the last user of the buffer. bool shared = false; if (options_.may_reuse_operand_buffers) { - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + for (const BufferValue* operand_buffer : operand_buffers_to_free) { if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( + points_to_analysis.CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index())) { VLOG(3) << " Sharing: " << buffer->ToString() << " with " << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); @@ -248,11 +248,11 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. - for (const LogicalBuffer* buffer : dead_buffers_to_free) { + for (const BufferValue* buffer : dead_buffers_to_free) { VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } - for (const LogicalBuffer* buffer : operand_buffers_to_free) { + for (const BufferValue* buffer : operand_buffers_to_free) { VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } @@ -261,10 +261,10 @@ Status HeapSimulator::RunComputation( // Any remaining live buffers must be entry parameters or output source // buffers, which had a nullptr sentry added. Free them now, in a // deterministic order. - std::vector to_free; + std::vector to_free; to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { - const LogicalBuffer* buffer = buffer_pending.first; + const BufferValue* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; @@ -272,10 +272,10 @@ Status HeapSimulator::RunComputation( } std::sort(to_free.begin(), to_free.end(), - [](const LogicalBuffer* x, const LogicalBuffer* y) { + [](const BufferValue* x, const BufferValue* y) { return x->id() < y->id(); }); - for (const LogicalBuffer* buffer : to_free) { + for (const BufferValue* buffer : to_free) { VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } @@ -285,7 +285,7 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), @@ -297,7 +297,7 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} -bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { +bool HeapSimulator::IgnoreBuffer(const BufferValue* buffer) const { // Buffers for constants are ignored unless the alloc_constants option is // set. Also ignore buffers that we're not meant to assign. // @@ -311,7 +311,7 @@ bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { } // Alloc always calls the underlying heap algorithm. -void HeapSimulator::Alloc(const LogicalBuffer* buffer, +void HeapSimulator::Alloc(const BufferValue* buffer, const HloInstruction* instruction) { CHECK(allocated_buffers_.count(buffer) == 0) << "Alloc called on allocated buffer: " << *buffer; @@ -331,7 +331,7 @@ void HeapSimulator::Alloc(const LogicalBuffer* buffer, // buffers whose group liveness has expired. Shared group liveness is tracked // by maintaining a refcount; the Free call on the last buffer in the group // causes Free to be called on the underlying algorithm. -void HeapSimulator::Free(const LogicalBuffer* buffer, +void HeapSimulator::Free(const BufferValue* buffer, const HloInstruction* instruction) { auto shared_it = shared_buffers_.find(buffer); if (shared_it != shared_buffers_.end()) { @@ -362,8 +362,8 @@ void HeapSimulator::Free(const LogicalBuffer* buffer, // The 'buffer' must be a non-allocated, non-freed buffer, just like in calls to // Alloc. The 'shared' buffer must be a previously allocated or shared buffer. // Both 'buffer' and 'shared' will be associated with the same SharedGroup. -void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, - const LogicalBuffer* shared, +void HeapSimulator::ShareBuffer(const BufferValue* buffer, + const BufferValue* shared, const HloInstruction* instruction) { CHECK_LE(size_fn_(*buffer), size_fn_(*shared)) << "ShareBuffer oversized buffer" << *buffer << " shared: " << *shared; @@ -374,7 +374,7 @@ void HeapSimulator::ShareBuffer(const LogicalBuffer* buffer, CHECK(freed_buffers_.count(shared) == 0) << "ShareBuffer called on freed shared buffer: " << *shared; - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; auto shared_it = shared_buffers_.find(shared); if (shared_it != shared_buffers_.end()) { // The 'shared' buffer already has a group; it might be the canonical, but @@ -408,7 +408,7 @@ HeapSimulator::Result HeapSimulator::Finish() { // collecting statistics, e.g. NoFragmentationStatsHeap. if (!result.chunk_map.empty()) { for (const auto& share_pair : shared_buffers_) { - const LogicalBuffer* buffer = share_pair.first; + const BufferValue* buffer = share_pair.first; std::shared_ptr group = share_pair.second; if (buffer != group->canonical) { // The canonical must already exist in the chunk_map, since we called @@ -437,9 +437,9 @@ HeapSimulator::Result HeapSimulator::Finish() { } void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* share_with_canonical) { + const BufferValue* share_with_canonical) { HeapSimulatorTrace::Event* event = debug_trace_.add_events(); event->set_kind(kind); event->set_buffer_id(buffer->id()); @@ -453,14 +453,14 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void NoFragmentationStatsHeap::Free(const BufferValue* buffer, int64 size) { current_heap_size_ -= size; } @@ -472,12 +472,12 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -void DecreasingSizeRunsHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Alloc(const BufferValue* buffer, int64 size) { SetMode(kAlloc); run_.emplace_back(Op{buffer, size}); } -void DecreasingSizeRunsHeap::Free(const LogicalBuffer* buffer, int64 size) { +void DecreasingSizeRunsHeap::Free(const BufferValue* buffer, int64 size) { CHECK(mode_ != kInit) << "Free called on empty heap: " << *buffer; SetMode(kFree); run_.emplace_back(Op{buffer, size}); @@ -518,7 +518,7 @@ void DecreasingSizeRunsHeap::CallAndDrainRun() { run_.clear(); } -void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Alloc(const BufferValue* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -586,7 +586,7 @@ void LazyBestFitHeap::Alloc(const LogicalBuffer* buffer, int64 size) { result_.chunk_map.emplace(buffer, Chunk{kLazyAllocOffset, size}); } -void LazyBestFitHeap::Free(const LogicalBuffer* buffer, int64 size) { +void LazyBestFitHeap::Free(const BufferValue* buffer, int64 size) { auto alloc_it = result_.chunk_map.find(buffer); CHECK(alloc_it != result_.chunk_map.end()) << "Free called on non-allocated buffer: " << *buffer; diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 636f19dd39f..8b2b43a37a5 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,11 +21,12 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" +#include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -43,7 +44,7 @@ class HeapAlgorithm; // don't need to return the assignment of buffer offsets until the very end. class HeapSimulator { public: - // Chunk represents a contiguous piece of memory. Each LogicalBuffer will be + // Chunk represents a contiguous piece of memory. Each BufferValue will be // associated with a chunk in the assignment result. struct Chunk { int64 offset; @@ -55,7 +56,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap chunk_map; + tensorflow::gtl::FlatMap chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -81,7 +82,7 @@ class HeapSimulator { bool alloc_constants; // If 'buffers_to_assign' is provided, only those buffers are assigned // offsets, otherwise all buffers defined by the instructions are assigned. - const tensorflow::gtl::FlatSet* buffers_to_assign; + const BufferValueFlatSet* buffers_to_assign; }; // Run the heap simulation with the given algorithm, assuming the given @@ -97,7 +98,7 @@ class HeapSimulator { std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' @@ -109,7 +110,7 @@ class HeapSimulator { const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, + const BufferValue::SizeFunction& size_fn, const Options& options = Options()); private: @@ -118,7 +119,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, const Options& options, + const BufferValue::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -127,21 +128,21 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis); - bool IgnoreBuffer(const LogicalBuffer* buffer) const; - void Alloc(const LogicalBuffer* buffer, const HloInstruction* instruction); - void Free(const LogicalBuffer* buffer, const HloInstruction* instruction); - void ShareBuffer(const LogicalBuffer* buffer, const LogicalBuffer* shared, + bool IgnoreBuffer(const BufferValue* buffer) const; + void Alloc(const BufferValue* buffer, const HloInstruction* instruction); + void Free(const BufferValue* buffer, const HloInstruction* instruction); + void ShareBuffer(const BufferValue* buffer, const BufferValue* shared, const HloInstruction* instruction); Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, - const LogicalBuffer* buffer, + const BufferValue* buffer, const HloInstruction* instruction, - const LogicalBuffer* shared_with_canonical); + const BufferValue* shared_with_canonical); const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; - const LogicalBuffer::SizeFunction size_fn_; + const BufferValue::SizeFunction size_fn_; const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; @@ -160,15 +161,15 @@ class HeapSimulator { // The shared_buffers_ map associates each shared buffer (including the // canonical) to its SharedGroup control block. struct SharedGroup { - const LogicalBuffer* canonical = nullptr; + const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap> + tensorflow::gtl::FlatMap> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet allocated_buffers_; - tensorflow::gtl::FlatSet freed_buffers_; + tensorflow::gtl::FlatSet allocated_buffers_; + tensorflow::gtl::FlatSet freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -186,10 +187,10 @@ class HeapAlgorithm { virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Alloc(const BufferValue* buffer, int64 size) = 0; // Free de-allocates a previously allocated buffer. - virtual void Free(const LogicalBuffer* buffer, int64 size) = 0; + virtual void Free(const BufferValue* buffer, int64 size) = 0; // Finish collects the buffer offset assignment results. Free may only be // called once, after the Alloc and Free calls. @@ -205,8 +206,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: @@ -223,14 +224,14 @@ class DecreasingSizeRunsHeap : public HeapAlgorithm { : algorithm_(std::move(algorithm)) {} ~DecreasingSizeRunsHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: // A single Alloc or Free operation that we've buffered in run_. struct Op { - const LogicalBuffer* buffer; + const BufferValue* buffer; int64 size; }; @@ -266,8 +267,8 @@ class LazyBestFitHeap : public HeapAlgorithm { LazyBestFitHeap(int64 alignment) : alignment_(alignment) {} ~LazyBestFitHeap() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override; - void Free(const LogicalBuffer* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; Result Finish() override; private: diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 688a271712a..6271652412c 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -20,11 +20,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -38,7 +39,7 @@ const char kFree[] = "Free"; const char kFinish[] = "Finish"; // CallSequence records a sequence of Alloc/Free/Finish calls. -using CallSequence = std::vector>; +using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. class HeapCallRecorder : public HeapAlgorithm { @@ -46,7 +47,7 @@ class HeapCallRecorder : public HeapAlgorithm { explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} - void Alloc(const LogicalBuffer* buffer, int64 size) override { + void Alloc(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kAlloc, buffer); // Instead of assigning a real offset, we set the cardinality of the Alloc // call. This isn't a valid assignment, but allows us to easily test for @@ -54,7 +55,7 @@ class HeapCallRecorder : public HeapAlgorithm { const int64 offset = result_.chunk_map.size(); result_.chunk_map.emplace(buffer, Chunk{offset, size}); } - void Free(const LogicalBuffer* buffer, int64 size) override { + void Free(const BufferValue* buffer, int64 size) override { calls_->emplace_back(kFree, buffer); } Result Finish() override { @@ -76,7 +77,8 @@ class HeapSimulatorTracker { HeapSimulatorTracker( const string& name, std::unique_ptr computation, const std::vector& instruction_sequence) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -84,7 +86,7 @@ class HeapSimulatorTracker { // size of the buffers doesn't matter, so we always return 0. We rely on // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. - auto zero_size = [](const LogicalBuffer& buffer) { return 0; }; + auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = MakeUnique( MakeUnique(&actual_calls_)); result_ = HeapSimulator::Run( @@ -94,7 +96,8 @@ class HeapSimulatorTracker { } explicit HeapSimulatorTracker(const string& name) { - module_ = MakeUnique(name); + HloModuleConfig config; + module_ = MakeUnique(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -115,9 +118,9 @@ class HeapSimulatorTracker { // Hack the size_fn so that it returns a decreasing value as we step through // the sequence. This lets us ensure the Alloc calls are in the sequence - // order. The Free calls are sorted by LogicalBuffer.id, which is at least + // order. The Free calls are sorted by BufferValue.id, which is at least // deterministic. - auto size_fn = [&reverse_position](const LogicalBuffer& buffer) { + auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; auto algorithm = MakeUnique( @@ -130,8 +133,8 @@ class HeapSimulatorTracker { HloModule* module() { return module_.get(); } // Returns the buffer defined at the given instruction and index. - const LogicalBuffer* BufferAt(const HloInstruction* instruction, - const ShapeIndex& index) const { + const BufferValue* BufferAt(const HloInstruction* instruction, + const ShapeIndex& index) const { return points_to_analysis_->GetBufferDefinedAt(instruction, index) .ConsumeValueOrDie(); } @@ -147,8 +150,8 @@ class HeapSimulatorTracker { const ShapeIndex& index_a, const HloInstruction* instruction_b, const ShapeIndex& index_b) { - const LogicalBuffer* a = BufferAt(instruction_a, index_a); - const LogicalBuffer* b = BufferAt(instruction_b, index_b); + const BufferValue* a = BufferAt(instruction_a, index_a); + const BufferValue* b = BufferAt(instruction_b, index_b); EXPECT_EQ(result_.chunk_map[a].offset, result_.chunk_map[b].offset) << *a << ", " << *b; } @@ -522,7 +525,7 @@ TEST_F(HeapSimulatorTest, WholeModule) { // Now the final cond less-than buffer is allocated. {kAlloc, tracker.BufferAt(cond_lt, {})}, - // The order of the remaining Free calls is based on the LogicalBuffer.id, + // The order of the remaining Free calls is based on the BufferValue.id, // which is deterministic, but not obvious. {kFree, tracker.BufferAt(param, {})}, {kFree, tracker.BufferAt(param, {0})}, @@ -544,40 +547,40 @@ TEST_F(HeapSimulatorTest, WholeModule) { class HeapAlgorithmTestBase : public ::testing::Test { protected: HeapAlgorithmTestBase() : builder_("heap_simulator_test") { - buffer_a_ = DummyLogicalBuffer(); - buffer_b_ = DummyLogicalBuffer(); - buffer_c_ = DummyLogicalBuffer(); - buffer_d_ = DummyLogicalBuffer(); - buffer_e_ = DummyLogicalBuffer(); - buffer_f_ = DummyLogicalBuffer(); - buffer_g_ = DummyLogicalBuffer(); - buffer_h_ = DummyLogicalBuffer(); - buffer_i_ = DummyLogicalBuffer(); + buffer_a_ = DummyBufferValue(); + buffer_b_ = DummyBufferValue(); + buffer_c_ = DummyBufferValue(); + buffer_d_ = DummyBufferValue(); + buffer_e_ = DummyBufferValue(); + buffer_f_ = DummyBufferValue(); + buffer_g_ = DummyBufferValue(); + buffer_h_ = DummyBufferValue(); + buffer_i_ = DummyBufferValue(); } ~HeapAlgorithmTestBase() override {} - const LogicalBuffer* buffer_a_; - const LogicalBuffer* buffer_b_; - const LogicalBuffer* buffer_c_; - const LogicalBuffer* buffer_d_; - const LogicalBuffer* buffer_e_; - const LogicalBuffer* buffer_f_; - const LogicalBuffer* buffer_g_; - const LogicalBuffer* buffer_h_; - const LogicalBuffer* buffer_i_; + const BufferValue* buffer_a_; + const BufferValue* buffer_b_; + const BufferValue* buffer_c_; + const BufferValue* buffer_d_; + const BufferValue* buffer_e_; + const BufferValue* buffer_f_; + const BufferValue* buffer_g_; + const BufferValue* buffer_h_; + const BufferValue* buffer_i_; private: - // Create a dummy LogicalBuffer to pass to the heap algorithm. - const LogicalBuffer* DummyLogicalBuffer() { - const LogicalBuffer::Id id = buffers_.size(); + // Create a dummy BufferValue to pass to the heap algorithm. + const BufferValue* DummyBufferValue() { + const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - buffers_.emplace_back(MakeUnique(const0, ShapeIndex{}, id)); + buffers_.emplace_back(MakeUnique(id, const0, ShapeIndex{})); return buffers_.back().get(); } HloComputation::Builder builder_; - std::vector> buffers_; + std::vector> buffers_; }; class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 8fd7f8945c7..1f7c1cffd32 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -147,6 +147,9 @@ message HloInstructionProto { repeated int64 called_computation_ids = 38; xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + string backend_config = 43; } // Serialization of HloComputation. @@ -296,3 +299,20 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } + +// Encapsulates HloProto together with the arguments, result, and +// execution_platform. This message is used for purposes such as +// analysis/replay/file-storage. +message HloSnapshot { + // The hlo graph. + HloProto hlo = 1; + + // The arguments passed to the graph. + repeated LiteralProto arguments = 2; + + // The result of the graph. + LiteralProto result = 3; + + // The name of the platform used to run the graph. + string execution_platform = 4; +} diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 594413e88fb..63c3dc4a593 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -347,6 +347,11 @@ std::list HloComputation::MakeEmbeddedComputationsList() // To avoid special handling of this computation, cast away const of // 'this'. 'this' is immediately removed from the post order after // construction. + // + // TODO(b/78350259): This violates const-correctness, since while the original + // computation is not returned, we still retrieve non-const computations from + // a const one. Consider also avoiding const for HloComputation, or review XLA + // for const-correctness of non-HloInstruction* types like this. ComputeComputationPostOrder(const_cast(this), &visited, &post_order); @@ -360,25 +365,38 @@ std::list HloComputation::MakeEmbeddedComputationsList() string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } - if (options.print_percent()) { - s << "%"; - } - s << name(); - if (options.print_program_shape()) { - s << " " << ShapeUtil::HumanString(ComputeProgramShape()); - } - s << " {\n"; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + + if (!options.is_in_nested_computation()) { + if (options.print_percent()) { + s << "%"; } - s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString(options) << "\n"; + s << name() << " "; } + + if (options.print_program_shape()) { + s << ShapeUtil::HumanString(ComputeProgramShape()) << " "; + } + s << "{\n"; + { + // Print the instructions in this computation. + HloPrintOptions new_options = options; + new_options.set_indent_amount(options.indent_amount() + 1) + .set_is_in_nested_computation(true); + CanonicalNameMap name_map; + for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (int i = 0; i < new_options.indent_amount(); i++) { + s << " "; + } + s << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToStringWithCanonicalNameMap(new_options, &name_map) + << "\n"; + } + } + for (int i = 0; i < options.indent_amount(); i++) { - s << " "; + s << " "; } s << "}"; return s.str(); @@ -402,27 +420,37 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map) { - std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { TF_ASSIGN_OR_RETURN( std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + HloInstruction::CreateFromProto(instruction_proto, instruction_map, + computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id())); instruction_map[instruction_proto.id()] = instruction.get(); + to_proto_id[instruction.get()] = instruction_proto.id(); instructions.push_back(std::move(instruction)); } TF_RET_CHECK(proto.root_id() != -1); TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id())); HloInstruction* root = instruction_map.at(proto.root_id()); + + // Sort the instructions in the proto id's order. + std::sort(instructions.begin(), instructions.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + return WrapUnique(new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); @@ -723,18 +751,25 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix, - HloModule* module) { +std::unique_ptr HloComputation::Clone( + const string& suffix, HloModule* module, + HloInstruction::CloneMap* clone_map) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, suffix); + module, clone_map, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, const string& suffix) { + HloModule* module, HloInstruction::CloneMap* clone_map, + const string& suffix) { + HloInstruction::CloneMap local_clone_map; + if (clone_map == nullptr) { + clone_map = &local_clone_map; + } + // Look up instr in the replacements map, and return either the replacement, // or instr, if the replacement isn't present. // @@ -756,24 +791,19 @@ std::unique_ptr HloComputation::CloneWithReplacements( } } - std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); - // If replaced_operand is null, that means 'replacements' asked us not to - // include operand in the new computation. But we can't do that, because - // operand is used by instr. CHECK_NE(replaced_operand, nullptr) - << "replacements map tried to eliminate a used instruction " - << operand->ToString() << ", used by " << instr->ToString(); - new_operands.push_back(FindOrDie(clone_map, replaced_operand)); + << "Replacements map specifies to leave out " << operand->ToString() + << ", but it is used by " << instr->ToString() << "."; + new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); } - new_instr = - instr->CloneWithNewOperands(instr->shape(), new_operands, module); - InsertOrDie(&clone_map, instr, new_instr.get()); + new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, + module, clone_map); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -781,27 +811,24 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); + /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(clone_map, instr); + HloInstruction* new_instr = FindOrDie(*clone_map, instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - - // successor may not be in clone_map, because it might have been - // removed by the replacements map. - if (replaced_successor == nullptr) { - continue; - } + CHECK_NE(replaced_successor, nullptr) + << "Replacements map specifies to leave out " << successor->ToString() + << ", but it is control-depended-on by " << instr->ToString() << "."; TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(clone_map, replaced_successor))); + FindOrDie(*clone_map, replaced_successor))); } } // We cloned the elements of 'replacements', so they're all going to be - // destroyed. HloInstructions need to be detached from their operands before + // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists // and cause use-after-frees. for (auto& kv : replacements) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 9d3f6e9a2c2..8bc97df0365 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -49,9 +49,20 @@ class HloModule; // Describes a computation at the HLO level. // -// An HloComputation contains a directed acyclic graph of HLO instructions. The -// computation has a single root instruction which produces the output of the -// computation. +// You can think of an HloComputation like a function. It has some inputs +// (parameters) and returns exactly one value (the value of its root node). If +// you want to return multiple values, you can return a tuple. +// +// The instructions inside of a computation do not have an explicit total order. +// Instead, they have a partial order determined by their data and control +// dependencies. +// +// An HloModule contains one "entry computation" -- this is like main() in a C +// program. Every other computation inside of a module is attached to one or +// more HloInstructions, as a "nested computation". For example, the kMap +// instruction has a nested computation and "applies" it to every element of its +// input, elementwise. (That is, the input [x, y, z] is transformed to [f(x), +// f(y), f(z)].) class HloComputation { public: // Builder class for HloComputation. @@ -157,14 +168,12 @@ class HloComputation { // Creates a computation from the given proto. Arguments: // - // module: the module which will contain the computation. The newly created - // computation is *not* added to the module, however. // proto: the proto to convert from. // computation_map: a map from computation id to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloComputationProto& proto, + const HloComputationProto& proto, const tensorflow::gtl::FlatMap& computation_map); // Gets the instructions in this computation. @@ -291,11 +300,17 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr); + // + // If the module pointer is not nullptr, then the cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the computation is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. + std::unique_ptr Clone( + const string& suffix = "clone", HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -305,7 +320,9 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, const string& suffix = "clone"); + HloModule* module = nullptr, + HloInstruction::CloneMap* clone_map = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 7b7588f4ba9..25469a54c48 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -550,6 +550,108 @@ TEST_F(HloComputationTest, Reachability) { EXPECT_FALSE(reachability->IsReachable(constant2, copy)); } +TEST_F(HloComputationTest, Stringification) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloComputationTest, StringificationIndent) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = + HloPrintOptions().set_print_metadata(false).set_indent_amount(2); + EXPECT_EQ(computation->ToString(options), + R"( %TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })"); +} + +TEST_F(HloComputationTest, StringificationCanonical) { + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto options = HloPrintOptions().set_print_metadata(false); + EXPECT_EQ(computation->ToString(options), + R"(%TransposeDot (x: f32[5,10], y: f32[20,10]) -> f32[5,20] { + %x = f32[5,10]{1,0} parameter(0) + %y = f32[20,10]{1,0} parameter(1) + %transpose = f32[10,20]{1,0} transpose(f32[20,10]{1,0} %y), dimensions={1,0} + ROOT %dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} %transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); + + options = HloPrintOptions().Canonical(); + EXPECT_EQ(computation->ToString(options), R"(TransposeDot { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 7aa38c6b79e..35ecd4428d0 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -69,8 +69,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { // Broadcasts dramatically increase the size of constants, which is often // detrimental to performance and memory capacity, so do not fold // broadcasts. - if (instruction->opcode() == HloOpcode::kBroadcast || - instruction->opcode() == HloOpcode::kBroadcastDimOne) { + if (instruction->opcode() == HloOpcode::kBroadcast) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 7b552ee5b17..5d05ccfc0b2 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) { const int64 slice_limits[] = {10, 8, 6, 5, 9}; const int64 slice_strides[] = {1, 1, 1, 1, 1}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index ea4dd62fdb5..94c9c7eabcc 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -142,19 +142,25 @@ Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { } Status HloCostAnalysis::HandleParameter(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleConstant(const HloInstruction*) { + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. + current_should_compute_bottleneck_time_ = false; current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -329,6 +335,7 @@ Status HloCostAnalysis::HandleSelectAndScatter( Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; + current_properties_[kOptimalSecondsKey] = 0; return Status::OK(); } @@ -336,11 +343,6 @@ Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleBroadcastDimOne( - const HloInstruction* broadcastDimOne) { - return Status::OK(); -} - Status HloCostAnalysis::HandlePad(const HloInstruction*) { return Status::OK(); } @@ -560,11 +562,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - // We can't do anything sane with CustomCalls, since we don't know what they - // do, and returning an error status will stop iteration over this - // computation, which is probably also not what we want. So just punt and - // return OK. This will cause all of the properties to be reported as 0, - // which is fine. + // Mark applicable fields as "unknown", since we don't know what CustomCall + // does. This is better than returning an error, which would stop iteration, + // and therefore would prevent us from getting *any* stats for a computation + // which contains a CustomCall. + current_properties_[kOptimalSecondsKey] = -1; + current_properties_[kBytesAccessedKey] = -1; + current_properties_[kFlopsKey] = -1; current_should_compute_bottleneck_time_ = false; return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index a9f6845747a..d17678d20f2 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -95,7 +95,6 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleSelectAndScatter(const HloInstruction* instruction) override; Status HandleBitcast(const HloInstruction* bitcast) override; Status HandleBroadcast(const HloInstruction* broadcast) override; - Status HandleBroadcastDimOne(const HloInstruction* broadcastDimOne) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; Status HandleTranspose(const HloInstruction* transpose) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 3d055b327ee..16fdda8a8b9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -20,16 +20,13 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/logging.h" @@ -58,11 +55,10 @@ class HloCostAnalysisTest : public ::testing::Test { // whitebox accesses to the user computation built from the client, // as shown in the BuildHloGraph functions below. service_(static_cast(ClientLibrary::GetXlaService( - static_cast(client_)->platform()))), - computation_tracker_(service_->computation_tracker()) { + static_cast(client_)->platform()))) { // Create a computation for a unary user function: x => exp(x + 0.5) { - ComputationBuilder builder(client_, "add_and_exp"); + XlaBuilder builder("add_and_exp"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto half = builder.ConstantR0(0.5); builder.Exp(builder.Add(x, half)); @@ -73,7 +69,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary user function: (x, y) => x + y { - ComputationBuilder builder(client_, "add"); + XlaBuilder builder("add"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Add(x, y); @@ -84,7 +80,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x)) { - ComputationBuilder builder(client_, "sigmoid"); + XlaBuilder builder("sigmoid"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto one = builder.ConstantR0(1.0); builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x)))); @@ -95,7 +91,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary max function: (x, y) => max (x, y) { - ComputationBuilder builder(client_, "max"); + XlaBuilder builder("max"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Max(x, y); @@ -106,7 +102,7 @@ class HloCostAnalysisTest : public ::testing::Test { // Create a computation for a binary GT function: (x, y) => x > y { - ComputationBuilder builder(client_, "gt"); + XlaBuilder builder("gt"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); builder.Gt(x, y); @@ -117,35 +113,30 @@ class HloCostAnalysisTest : public ::testing::Test { } // Build HLO graph from the given builder and return the HLO module. - std::unique_ptr BuildHloGraph(ComputationBuilder* builder) { + std::unique_ptr BuildHloGraph(XlaBuilder* builder) { auto computation_status = builder->Build(); TF_CHECK_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto user_computation_status = - computation_tracker_.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - return std::move( - computation_tracker_.BuildHloModule(versioned_handle, HloModuleConfig()) - .ValueOrDie()); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); + return HloModule::CreateFromProto(computation.proto(), config) + .ConsumeValueOrDie(); } Client* client_; Service* service_; - const ComputationTracker& computation_tracker_; // User computations used for higher order operations (e.g., Map, Reduce). - Computation add_; - Computation add_and_exp_; - Computation sigmoid_; - Computation max_; - Computation gt_; + XlaComputation add_; + XlaComputation add_and_exp_; + XlaComputation sigmoid_; + XlaComputation max_; + XlaComputation gt_; }; TEST_F(HloCostAnalysisTest, MatrixMultiply) { - ComputationBuilder builder(client_, "matrix_multiply"); + XlaBuilder builder("matrix_multiply"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs"); auto result = builder.Dot(lhs, rhs); @@ -167,7 +158,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { } TEST_F(HloCostAnalysisTest, Map) { - ComputationBuilder builder(client_, "map"); + XlaBuilder builder("map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); auto result = builder.Map({input}, add_and_exp_, {0}); @@ -184,7 +175,7 @@ TEST_F(HloCostAnalysisTest, Map) { } TEST_F(HloCostAnalysisTest, Convolution) { - ComputationBuilder builder(client_, "convolution"); + XlaBuilder builder("convolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -213,7 +204,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { } TEST_F(HloCostAnalysisTest, Reduce) { - ComputationBuilder builder(client_, "reduce"); + XlaBuilder builder("reduce"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = @@ -231,7 +222,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { } TEST_F(HloCostAnalysisTest, ReduceWindow) { - ComputationBuilder builder(client_, "reduce_window"); + XlaBuilder builder("reduce_window"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto result = builder.ReduceWindow(input, builder.ConstantR0(0), add_, @@ -248,7 +239,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { } TEST_F(HloCostAnalysisTest, SelectAndScatter) { - ComputationBuilder builder(client_, "select_and_scatter"); + XlaBuilder builder("select_and_scatter"); auto operand = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input"); auto source = @@ -269,7 +260,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { } TEST_F(HloCostAnalysisTest, Broadcast) { - ComputationBuilder b(client_, "broadcast"); + XlaBuilder b("broadcast"); b.Broadcast(b.ConstantR0(42), {10, 7}); auto hlo_module = BuildHloGraph(&b); HloCostAnalysis analysis(ShapeSize); @@ -280,7 +271,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { // Calculates the computation cost of a graph with more than one HLO node. TEST_F(HloCostAnalysisTest, FullyConnectedForward) { - ComputationBuilder builder(client_, "fully_connected_forward"); + XlaBuilder builder("fully_connected_forward"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input"); auto weight = @@ -305,7 +296,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis conv_analysis(ShapeSize); { - ComputationBuilder builder(client_, "conv_looking_matmul"); + XlaBuilder builder("conv_looking_matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), "input"); auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -318,7 +309,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { HloCostAnalysis matmul_analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input"); auto rhs = @@ -370,8 +361,8 @@ TEST_F(FusionCostAnalysis, LoopFusion) { HloInstruction::CreateBinary(r2f32, HloOpcode::kSubtract, mul, clamp)); auto tuple = HloInstruction::CreateTuple({sub, sub, mul, c1}); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -412,8 +403,8 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( shape_with_layout, HloOpcode::kAdd, c1, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); @@ -427,7 +418,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { TEST_F(HloCostAnalysisTest, TupleCost) { HloCostAnalysis analysis(ShapeSize); { - ComputationBuilder builder(client_, "matmul"); + XlaBuilder builder("matmul"); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {123}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {42}), "y"); auto tuple = builder.Tuple({x, y}); @@ -443,7 +434,7 @@ TEST_F(HloCostAnalysisTest, TupleCost) { } TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { - ComputationBuilder builder(client_, "BaseDilatedConvolution"); + XlaBuilder builder("BaseDilatedConvolution"); auto input = builder.Parameter( 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10, @@ -458,7 +449,7 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { auto result = builder.ConvGeneralDilated( input, kernel, /*window_strides=*/{1, 1}, /*padding=*/{{1, 1}, {1, 1}}, /*lhs_dilation=*/{3, 5}, /*rhs_dilation=*/{7, 11}, - ComputationBuilder::CreateDefaultConvDimensionNumbers(2)); + XlaBuilder::CreateDefaultConvDimensionNumbers(2)); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b186767ce79..0fb65c845a6 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -162,7 +162,20 @@ StatusOr MakeConcatHlo(ArraySlice operands, HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers) { + HloComputation* computation = lhs->parent(); + CHECK_EQ(computation, rhs->parent()); + TF_ASSIGN_OR_RETURN( + Shape dot_shape, + ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); + return computation->AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); +} + StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { + CHECK_GT(n, 0); + const Shape& operand_shape = operand->shape(); CHECK_GE(operand_shape.dimensions_size(), n); int64 new_shape_leading_bound = 1; @@ -184,6 +197,17 @@ StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n) { return MakeReshapeHlo(output_shape, operand); } +StatusOr PrependDegenerateDims(HloInstruction* operand, + int64 n) { + CHECK_GT(n, 0); + std::vector new_shape_dims; + const Shape& operand_shape = operand->shape(); + new_shape_dims.reserve(n + operand_shape.dimensions_size()); + new_shape_dims.insert(new_shape_dims.begin(), n, 1); + c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + return MakeReshapeHlo(new_shape_dims, operand); +} + StatusOr ExpandFirstDimIntoNDims( HloInstruction* operand, ArraySlice expanded_dims) { CHECK_GT(operand->shape().dimensions_size(), 0); @@ -256,7 +280,7 @@ StatusOr BroadcastZeros( StatusOr> CreateComputationWithSignature( ArraySlice domain, const Shape& range, tensorflow::StringPiece name) { - HloComputation::Builder b(name.ToString()); + HloComputation::Builder b{std::string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index d99e32a737e..49b1402d689 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -97,18 +97,33 @@ StatusOr MakeGetTupleElementHlo(HloInstruction* operand, StatusOr MakeConcatHlo( tensorflow::gtl::ArraySlice operands, int64 dimension); +// Creates a Dot HLO instruction and adds it to the computation containing `lhs` +// and `rhs` (both must be in the same computation). +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing // their operand(s). // Collapses (via reshape) the first N (logical) dimensions of `operand` into a -// single leading dimension. `operand` must have rank > n. +// single leading dimension. `operand` must have rank > `n` and `n` must not be +// 0. // // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is // the `operand` reshaped to [56,9]. StatusOr CollapseFirstNDims(HloInstruction* operand, int64 n); +// Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand` +// using a reshape. +// +// For instance if operand has shape f32[3,4,5] then this returns the operand +// reshaped to f32[1,3,4,5]. If the operand is a f32 scalar (i.e. has shape +// f32[]) then this returns the operand reshaped to f32[1]. +StatusOr PrependDegenerateDims(HloInstruction* operand, + int64 n); + // Expands (via reshape) the first (logical) dimension of `operand` into a // sequence of `expanded_dims` dimensions. `operand` must at least be of rank 1 // and the number of elements in its first dimension must be equal to the diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc new file mode 100644 index 00000000000..7e7c4f95fed --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +using tensorflow::gtl::ArraySlice; + +class HloCreationUtilsTest : public HloTestBase { + protected: + static std::unique_ptr CreateModuleWithProgramShape( + PrimitiveType primitive_type, ArraySlice input_shape_dims, + ArraySlice output_shape_dims, HloInstruction** param, + HloComputation** entry_computation) { + Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims); + Shape output_shape = + ShapeUtil::MakeShape(primitive_type, output_shape_dims); + auto module = CreateNewModule("test"); + *entry_computation = module->AddEntryComputation( + CreateComputationWithSignature({&input_shape}, output_shape, "entry") + .ValueOrDie()); + *param = (*entry_computation)->parameter_instruction(0); + return module; + } +}; + +TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed, + CollapseFirstNDims(param, 1)); + entry_computation->set_root_instruction(first_1_dims_collapsed); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *Literal::CreateR1({3, 4})); +} + +TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed, + CollapseFirstNDims(param, 2)); + entry_computation->set_root_instruction(first_2_dims_collapsed); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, + {Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{-1, -2}, {-3, -4}, {-5, -6}}})})); + CHECK_EQ(*result_literal, + *Literal::CreateR2( + {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}})); +} + +TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended, + PrependDegenerateDims(param, 1)); + entry_computation->set_root_instruction(with_1_degenerate_dim_prepended); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *Literal::CreateR2({{9, 10}})); +} + +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{1, 1, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, + PrependDegenerateDims(param, 2)); + entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({9, 10})})); + CHECK_EQ(*result_literal, *Literal::CreateR3({{{9, 10}}})); +} + +TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{}, /*output_shape_dims=*/{1, 1}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended, + PrependDegenerateDims(param, 2)); + entry_computation->set_root_instruction(with_2_degenerate_dims_prepended); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR0(9)})); + CHECK_EQ(*result_literal, *Literal::CreateR2({{9}})); +} + +TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{6}, /*output_shape_dims=*/{3, 1, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded, + ExpandFirstDimIntoNDims(param, {3, 1, 2})); + entry_computation->set_root_instruction(first_dim_expanded); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({1, 2, 3, 4, 5, 6})})); + CHECK_EQ(*result_literal, + *Literal::CreateR3({{{1, 2}}, {{3, 4}}, {{5, 6}}})); +} + +TEST_F(HloCreationUtilsTest, PadVectorWithZeros) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{2}, /*output_shape_dims=*/{6}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * zero_padded_param, + PadVectorWithZeros(param, /*zeros_to_prepend=*/3, /*zeros_to_append=*/1)); + entry_computation->set_root_instruction(zero_padded_param); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR1({3, 4})})); + CHECK_EQ(*result_literal, *Literal::CreateR1({0, 0, 0, 3, 4, 0})); +} + +TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + S32, + /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * zeros, + BroadcastZeros(module->entry_computation(), S32, {2, 2})); + entry_computation->set_root_instruction(zeros); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR0(0)})); + CHECK_EQ(*result_literal, *Literal::CreateR2({{0, 0}, {0, 0}})); +} + +TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) { + HloInstruction* param; + HloComputation* entry_computation; + + std::unique_ptr module = CreateModuleWithProgramShape( + F32, + /*input_shape_dims=*/{}, /*output_shape_dims=*/{2, 2}, ¶m, + &entry_computation); + + TF_ASSERT_OK_AND_ASSIGN( + HloInstruction * zeros, + BroadcastZeros(module->entry_computation(), F32, {2, 2})); + entry_computation->set_root_instruction(zeros); + + HloEvaluator evaluator; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, + evaluator.Evaluate>( + *module, {Literal::CreateR0(0.0f)})); + CHECK_EQ(*result_literal, + *Literal::CreateR2({{0.0f, 0.0f}, {0.0f, 0.0f}})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index cd7cbbdd717..c17c26c5a43 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -88,6 +89,20 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { return changed; } +// An instruction is considered to be equivalent to another only if they +// share the exact same set of operands. +int64 CseHash(const HloInstruction* instruction) { + int64 hash = std::hash()(static_cast(instruction->opcode())); + hash = tensorflow::Hash64Combine( + hash, instruction->opcode() == HloOpcode::kGetTupleElement + ? instruction->tuple_index() + : -1); + for (auto operand : instruction->operands()) { + hash = tensorflow::Hash64Combine(hash, operand->unique_id()); + } + return hash; +} + } // namespace StatusOr HloCSE::Run(HloModule* module) { @@ -95,17 +110,32 @@ StatusOr HloCSE::Run(HloModule* module) { const std::function eq_instructions = std::equal_to(); const std::function - eq_computations = std::equal_to(); + eq_computations = [](const HloComputation* lhs, + const HloComputation* rhs) { return *lhs == *rhs; }; + + auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->Identical(*rhs, eq_instructions, eq_computations, + is_layout_sensitive_); + }; + for (auto* computation : module->computations()) { + if (only_fusion_computations_ && !computation->IsFusionComputation()) { + continue; + } + changed |= CombineConstants(computation, is_layout_sensitive_); - std::list post_order = - computation->MakeInstructionPostOrder(); - std::set removed_instructions; - for (auto instruction : post_order) { - // If the instruction has already been removed by CSE skip over it. - if (removed_instructions.count(instruction) > 0 || - instruction->operand_count() == 0) { + // HLO instructions are grouped into equivalency classes by using the + // cse_equal predicate defined above. This set holds a representative + // instruction for each class. + tensorflow::gtl::FlatSet + representatives(/*N=*/1024, &CseHash, cse_equal); + + for (auto instruction : computation->MakeInstructionPostOrder()) { + // If the instruction has zero operands (constants, parameters, etc.) skip + // over it. + if (instruction->operand_count() == 0) { continue; } @@ -114,31 +144,16 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } - // An instruction is considered to be equivalent to another only if they - // share the exact same set of operands. So to find equivalent - // instructions, we just search among instructions which share operand(0) - // of this instruction. - const HloInstruction* operand = instruction->operand(0); - - tensorflow::gtl::InlinedVector - equivalent_instructions; - for (HloInstruction* user : operand->users()) { - if (user != instruction && !user->HasSideEffect() && - user->Identical(*instruction, eq_instructions, eq_computations, - is_layout_sensitive_)) { - equivalent_instructions.push_back(user); - } - } - - // Replace all equivalent instructions with this instruction. - for (HloInstruction* equivalent_instruction : equivalent_instructions) { + auto it = representatives.find(instruction); + if (it != representatives.end()) { + HloInstruction* equivalent_instruction = *it; TF_RETURN_IF_ERROR( - equivalent_instruction->ReplaceAllUsesWith(instruction)); - TF_RETURN_IF_ERROR( - computation->RemoveInstruction(equivalent_instruction)); - removed_instructions.insert(equivalent_instruction); + instruction->ReplaceAllUsesWith(equivalent_instruction)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); changed = true; + continue; } + representatives.insert(instruction); } } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 70096e07a24..5e2b348bdda 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -29,9 +29,11 @@ class HloCSE : public HloPassInterface { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. - explicit HloCSE(bool is_layout_sensitive) - : is_layout_sensitive_(is_layout_sensitive) {} - ~HloCSE() override {} + explicit HloCSE(bool is_layout_sensitive, + bool only_fusion_computations = false) + : is_layout_sensitive_(is_layout_sensitive), + only_fusion_computations_(only_fusion_computations) {} + ~HloCSE() override = default; tensorflow::StringPiece name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common @@ -39,7 +41,8 @@ class HloCSE : public HloPassInterface { StatusOr Run(HloModule* module) override; private: - bool is_layout_sensitive_; + const bool is_layout_sensitive_; + const bool only_fusion_computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index df8853f34f6..9735764b692 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -72,7 +73,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR0(84.0); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { @@ -104,7 +105,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { @@ -134,7 +135,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { auto result = ExecuteAndTransfer(std::move(module), {}); auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4))); } TEST_F(HloCseTest, ConstantsSameValueDifferentType) { @@ -469,5 +470,36 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { EXPECT_THAT(root, op::Add(op::Map(op::Constant()), op::Map(op::Constant()))); } +TEST_F(HloCseTest, CompareComputations) { + auto module = tools::Parse(R"( + HloModule m + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + add_computation2 { + add_lhs2 = f32[] parameter(0) + add_rhs2 = f32[] parameter(1) + ROOT add_root2 = f32[] add(add_lhs2, add_rhs2) + } + + ENTRY entry { + p = f32[10]{0} parameter(0) + c = f32[] constant(0) + r1 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation + r2 = f32[] reduce(p, c), dimensions={0}, to_apply=add_computation2 + ROOT f2 = (f32[],f32[]) tuple(r1, r2) + })") + .ValueOrDie(); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->operand(0), root->operand(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 0c37a8d75f3..b06e6c9f3e6 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -878,4 +878,128 @@ Status HloDataflowAnalysis::Verify() const { return Status::OK(); } +bool HloDataflowAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // Iterate through all users of all uses of the fusion parameter value. + // Return false if any uses are detected, returns true otherwise. + const HloValue& value = GetValueDefinedAt(fusion_param, index); + return value.uses().empty(); + } else { + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + return false; + } + } + } + } + + return true; +} + +bool HloDataflowAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + + if (user->opcode() == HloOpcode::kFusion) { + // Get the parameter associated with 'operand'; + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + + const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); + if (value.uses().size() != 1) { + return false; + } + const HloUse& use = value.uses()[0]; + + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 7b8a74b096f..9868746b611 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -118,6 +118,23 @@ class HloDataflowAnalysis { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 07f69b8e133..5798326dcbf 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1873,5 +1873,346 @@ INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); +class HloDataflowAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr dataflow_analysis_; +}; + +class DoesNotUseOperandBufferTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(add_operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE( + dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index c782d1b0add..d236f83aeb9 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -178,24 +178,37 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { if (hlo->shape().element_type() == eliminate_type_) { Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + new_hlo = ToElementType(new_hlo, eliminate_type_); } else if (ShapeUtil::IsTuple(hlo->shape())) { Shape old_shape = hlo->shape(); Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( new_shape, new_operands, hlo->GetModule())); + TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); + // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( hlo->shape(), new_operands, hlo->GetModule())); + TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); + TF_RETURN_IF_ERROR(hlo->DropAllControlDeps()); + + // NB! We want to replace and remove side effecting instructions like Rng + // as well so we can't rely HloComputation::ReplaceInstruction to reliably + // remove the replaced instruction. + TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo)); changed = true; } } diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc index cb94d9f19b8..5c5a059e0fd 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -22,6 +22,12 @@ namespace { namespace op = xla::testing::opcode_matchers; +using ::testing::Contains; +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::Not; +using ::testing::ResultOf; + class HloElementTypeConverterTest : public HloTestBase { public: std::unique_ptr CreateModuleFromHloString( @@ -117,5 +123,65 @@ TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { op::Convert(op::GetTupleElement(batch_norm, 2)))); } +TEST_F(HloElementTypeConverterTest, RngIsRemoved) { + const string& hlo_string = R"( +HloModule RngIsRemoved + +ENTRY main { + constant.3 = bf16[] constant(0) + constant.4 = bf16[] constant(1) + ROOT rng = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform +} + )"; + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + + std::function is_bf16_rng = + [](const HloInstruction* inst) { + return inst->shape().element_type() == BF16 && + inst->opcode() == HloOpcode::kRng; + }; + + EXPECT_THAT(module->entry_computation()->instructions(), + Not(Contains(ResultOf(is_bf16_rng, Eq(true))))); +} + +TEST_F(HloElementTypeConverterTest, RngCtrlDep) { + const string& hlo_string = R"( +HloModule RngIsRemoved + +ENTRY main { + constant.3 = bf16[] constant(0) + constant.4 = bf16[] constant(1) + rng0 = bf16[1,2000,20]{2,1,0} rng(constant.3, constant.4), distribution=rng_uniform + ROOT rng1 = bf16[1,1000,20]{2,1,0} rng(constant.3, constant.4), control-predecessors={%rng0}, distribution=rng_uniform +} + )"; + auto module = CreateModuleFromHloString(hlo_string); + + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + + HloInstruction *rng0, *rng1; + for (auto* inst : module->entry_computation()->instructions()) { + if (inst->opcode() == HloOpcode::kRng) { + const Shape& shape = inst->shape(); + ASSERT_EQ(shape.dimensions_size(), 3); + ASSERT_TRUE(shape.dimensions(1) == 2000 || shape.dimensions(1) == 1000); + if (shape.dimensions(1) == 2000) { + rng0 = inst; + } else { + rng1 = inst; + } + } + } + + EXPECT_THAT(rng0->control_successors(), ElementsAre(rng1)); + EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index b4f9a9db9cb..fa59a5fb203 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -52,25 +52,11 @@ namespace xla { namespace { using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::FlatSet; -using tensorflow::gtl::optional; - -template -struct is_complex_t : public std::false_type {}; - -template <> -struct is_complex_t : public std::true_type {}; - -template -struct is_complex64_t : public std::false_type {}; - -template <> -struct is_complex64_t : public std::true_type {}; template StatusOr> Compare(const Shape& shape, HloOpcode opcode, - const Literal& lhs_literal, - const Literal& rhs_literal) { + LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -108,7 +94,7 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -119,8 +105,8 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, template <> StatusOr> Compare( - const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, - const Literal& rhs_literal) { + const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, + LiteralSlice rhs_literal) { std::function compare_op; switch (opcode) { case HloOpcode::kEq: @@ -138,7 +124,7 @@ StatusOr> Compare( << HloOpcodeString(opcode); } - auto result = Literal::CreateFromShape(shape); + auto result = MakeUnique(shape); TF_RETURN_IF_ERROR(result->Populate([&](ArraySlice multi_index) { return compare_op(lhs_literal.Get(multi_index), rhs_literal.Get(multi_index)); @@ -147,2064 +133,48 @@ StatusOr> Compare( return std::move(result); } -template -StatusOr> ElementWiseUnaryOpImpl( - HloInstruction* instruction, - const std::function& unary_op, - const Literal& operand_literal) { - const auto shape = instruction->shape(); - const auto* operand = instruction->operand(0); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. - if (!ShapeUtil::SameDimensions(shape, operand->shape())) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); - } - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return unary_op(operand_literal.Get(multi_index)); - })); - return std::move(result); -} - -// For one particular placement of a window in a base shape (the placement is -// represented as `window_count_index`), iterates inside the window. Translates -// the window index into base index. If the base index is within bound, call `f` -// with the base index. -void IterateThroughWindow( - const Shape& window_shape, const Window& window, const Shape& base_shape, - const ArraySlice& window_count_index, - const std::function&)>& f) { - const int64 rank = ShapeUtil::Rank(base_shape); - DimensionVector window_index(rank); - std::fill(window_index.begin(), window_index.end(), 0); - do { - std::vector base_index(rank); - bool out_of_bound = false; - for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - f(base_index); - } - } while (IndexUtil::BumpIndices(window_shape, &window_index)); -} - -// Creates a vector of multipliers which can be used to create a linear index -// into shape. -// -// Given the multidimensional index {i1, ..., iN} and -// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply -// -// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. -// -// This lets you calculate LI given the multidimensional indices in any order. -DimensionVector MakeDimMultipliers(const Shape& shape) { - DimensionVector v(ShapeUtil::Rank(shape)); - int64 scale = 1; - for (auto dim : LayoutUtil::MinorToMajor(shape)) { - v[dim] = scale; - scale *= shape.dimensions(dim); - } - return v; -} - } // namespace -template -class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { - public: - explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} - - // The following higher-order functions convert a function with ElementwiseT - // to a function with ReturnT. - std::function ConvertUnaryFunction( - const std::function& unary_op) { - return [&unary_op](ReturnT arg) { - return static_cast(unary_op(static_cast(arg))); - }; - } - std::function ConvertBinaryFunction( - const std::function& - binary_op) { - return [&binary_op](ReturnT arg1, ReturnT arg2) { - return static_cast(binary_op(static_cast(arg1), - static_cast(arg2))); - }; - } - std::function ConvertTernaryFunction( - const std::function& ternary_op) { - return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { - return static_cast(ternary_op(static_cast(arg1), - static_cast(arg2), - static_cast(arg3))); - }; - } - - Status DefaultAction(HloInstruction* hlo_instruction) override { - return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); - } - - // TODO(b/35950897): many of the stl functions used in the handlers are not - // overloaded for every XLA primitive types. - - template ::value>::type* = - nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { - return std::abs(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(abs->operand(0)); - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[abs], - (ElementWiseUnaryOpImpl( - abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, - operand_literal))); - - return Status::OK(); - } - - Status HandleAbs(HloInstruction* abs) override { - // If the operand is of C64 type, the return type of abs will be F32. - // However, ElementwiseT would still be the return type, F32, and thus - // specifying the ElementwiseT explicitly as C64 is needed below. - if (abs->operand(0)->shape().element_type() == C64) { - return HandleAbs(abs); - } - return HandleAbs(abs); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { - return std::round(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRound(HloInstruction* round) { - return InvalidArgument("Unsupported type for Round"); - } - - Status HandleRound(HloInstruction* round) override { - return HandleRound(round); - } - - Status HandleBroadcast(HloInstruction* broadcast) override { - parent_->evaluated_[broadcast] = - Literal::CreateFromShape(broadcast->shape()); - auto output = parent_->evaluated_[broadcast].get(); - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - return output->Populate([&](ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - }); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { - return std::ceil(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleCeil(HloInstruction* ceil) { - return InvalidArgument("Unsupported type for Ceil"); - } - - Status HandleCeil(HloInstruction* ceil) override { - return HandleCeil(ceil); - } - - Status HandleConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).Convert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleBitcastConvert(HloInstruction* convert) override { - const HloInstruction* operand = convert->operand(0); - TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( - convert->shape().element_type())); - - if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { - parent_->evaluated_[convert] = std::move(result); - } else { - parent_->evaluated_[convert] = - result->Relayout(convert->shape().layout()); - } - return Status::OK(); - } - - Status HandleExp(HloInstruction* exp) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { - return std::exp(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { - return std::floor(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleFloor(HloInstruction* floor) { - return InvalidArgument("Unsupported type for Floor"); - } - - Status HandleFloor(HloInstruction* floor) override { - return HandleFloor(floor); - } - - Status HandleLog(HloInstruction* log) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { - return std::log(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return ~elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleNot(HloInstruction* not_) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { - return !elem_operand; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleNot(HloInstruction* not_) { - return InvalidArgument("Unsupported type for Not"); - } - - Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { - return NativeT(-type(elem_operand)); - })); - return Status::OK(); - } - - template ::value || - std::is_floating_point::value>::type* = nullptr> - Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[negate], - ElementWiseUnaryOp( - negate, [](ElementwiseT elem_operand) { return -elem_operand; })); - return Status::OK(); - } - - Status HandleNegate(HloInstruction* negate) override { - return HandleNegate(negate); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - return (ElementwiseT(0) < elem_operand) - - (elem_operand < ElementwiseT(0)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleSign(HloInstruction* sign) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { - auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ElementwiseT(0) - : elem_operand / abs_val; - })); - return Status::OK(); - } - - Status HandleSign(HloInstruction* sign) override { - return HandleSign(sign); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], - ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return std::atan2(lhs_elem, rhs_elem); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAtan2(HloInstruction* atan2) { - return InvalidArgument("Unsupported type for Atan2"); - } - - Status HandleAtan2(HloInstruction* atan2) override { - return HandleAtan2(atan2); - } - - Status HandleTanh(HloInstruction* tanh) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { - return std::tanh(elem_operand); - })); - return Status::OK(); - } - - template ::value && - !std::is_floating_point::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - std::is_floating_point::value || - is_complex_t::value>::type* = nullptr> - Status HandleMultiply(HloInstruction* multiply) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem * rhs_elem; - })); - return Status::OK(); - } - - Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); - } - - Status HandleSubtract(HloInstruction* subtract) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, - [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { - return lhs_elem - rhs_elem; - })); - return Status::OK(); - } - - Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem + rhs_elem; - })); - return Status::OK(); - } - - Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, - ElementwiseT rhs_elem) { - return lhs_elem / rhs_elem; - })); - return Status::OK(); - } - - template ::value>::type* = - nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return std::max(lhs, rhs); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { - return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMaximum(HloInstruction* maximum) { - return InvalidArgument("Unsupported type for Maximum"); - } - - Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); - } - - template ::value>::type* = - nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::min(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleMinimum(HloInstruction* minimum) { - return InvalidArgument("Unsupported type for Minimum"); - } - - Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); - } - - Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, - ElementwiseT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleRemainder(HloInstruction* remainder) { - return InvalidArgument("Unsupported type for Remainder"); - } - - Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); - } - - template ::value>::type* = - nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el & rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el && rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAnd(HloInstruction* and_) { - return InvalidArgument("Unsupported type for And"); - } - - Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); - } - - template ::value>::type* = - nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el | rhs_el; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { - return lhs_el || rhs_el; - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleOr(HloInstruction* or_) { - return InvalidArgument("Unsupported type for Or"); - } - - Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftLeft(HloInstruction* shl) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shl], - ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return IsShiftOutOfBounds(rhs_elem) ? 0 - : (lhs_elem << rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftLeft(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftLeft"); - } - - Status HandleShiftLeft(HloInstruction* shl) override { - return HandleShiftLeft(shl); - } - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightArithmetic(HloInstruction* shr) { - typedef typename std::make_signed::type SignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - SignedT lhs_signed = static_cast(lhs_elem); - if (IsShiftOutOfBounds(rhs_elem)) { - return lhs_signed < 0 ? static_cast(-1) : 0; - } else { - return lhs_signed >> rhs_elem; - } - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightArithmetic(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightArithmetic"); - } - - Status HandleShiftRightArithmetic(HloInstruction* shra) override { - return HandleShiftRightArithmetic(shra); - } - - template ::value && - !std::is_same::value>::type* = nullptr> - Status HandleShiftRightLogical(HloInstruction* shr) { - typedef typename std::make_unsigned::type UnsignedT; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[shr], - ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - // If shift amount is greater than the number of bits, then return 0. - if (IsShiftOutOfBounds(rhs_elem)) { - return static_cast(0); - } - return static_cast(static_cast(lhs_elem) >> - rhs_elem); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = - nullptr> - Status HandleShiftRightLogical(HloInstruction*) { - return InvalidArgument("Unsupported type for ShiftRightLogical"); - } - - Status HandleShiftRightLogical(HloInstruction* shrl) override { - return HandleShiftRightLogical(shrl); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction* clamp) { - std::function - clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { - return std::fmin(high, std::fmax(value, low)); - }; - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[clamp], - ElementwiseTernaryOp(clamp, - std::move(ConvertTernaryFunction(clamp_op)))); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleClamp(HloInstruction*) { - return InvalidArgument("Unsupported type for Clamp"); - } - - Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); - } - - Status HandleSelect(HloInstruction* select) override { - CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); - CHECK(!ShapeUtil::IsTuple(select->shape())); - std::function select_op = - [](bool pred, ReturnT on_true, ReturnT on_false) { - if (pred) { - return on_true; - } - return on_false; - }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementwiseTernaryOp(select, std::move(select_op))); - return Status::OK(); - } - - Status HandleReverse(HloInstruction* reverse) override { - const auto result_shape = reverse->shape(); - const auto reverse_dimensions = reverse->dimensions(); - - auto operand = reverse->operand(0); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReverseShape(operand->shape(), - reverse_dimensions)); - - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = Literal::CreateFromShape(result_shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice out_index) { - std::vector from_index(out_index.begin(), out_index.end()); - for (const int64 dim : reverse_dimensions) { - from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; - } - return operand_literal.Get(from_index); - })); - - parent_->evaluated_[reverse] = std::move(result); - return Status::OK(); - } - - Status HandleConvolution(HloInstruction* conv) override { - auto lhs = conv->operand(0); - auto rhs = conv->operand(1); - const auto& window = conv->window(); - const Shape& result_shape = conv->shape(); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); - - TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); - TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); - CHECK(ShapeUtil::IsArray(lhs_shape)); - CHECK(ShapeUtil::IsArray(rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); - CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); - - const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); - CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); - CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 0); - CHECK_EQ(window.dimensions_size(), num_spatial_dims); - - const auto lhs_rank = ShapeUtil::Rank(lhs_shape); - const auto rhs_rank = ShapeUtil::Rank(rhs_shape); - - CHECK_EQ(num_spatial_dims + 2, lhs_rank); - CHECK_EQ(num_spatial_dims + 2, rhs_rank); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); - CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - std::vector window_dimension_sizes; - for (auto i : dnums.kernel_spatial_dimensions()) { - window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); - } - - const Shape& window_shape = - ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); - - DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); - DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); - - auto lhs_literal_data = lhs_literal.data(); - auto rhs_literal_data = rhs_literal.data(); - - auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, - &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](ArraySlice out_index) { - // Dimension number applicable for input (lhs). - const int64 input_batch_dim = dnums.input_batch_dimension(); - const int64 input_z_dim = dnums.input_feature_dimension(); - // Dimension number applicable for kernel (rhs). - const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); - const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); - // Dimension number applicable for output. - const int64 output_batch_dim = dnums.output_batch_dimension(); - const int64 output_z_dim = dnums.output_feature_dimension(); - - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); - - ElementwiseT result_val = static_cast(0); - DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), - 0); - - // Convolve input feature with kernel. - do { - for (int64 iz = 0; iz < z_size; ++iz) { - int64 lhs_linear_index = 0; - lhs_linear_index += out_index[output_batch_dim] * - lhs_dim_multipliers[input_batch_dim]; - lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; - - int64 rhs_linear_index = 0; - rhs_linear_index += out_index[output_z_dim] * - rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; - - // Find corresponding spatial dimension index for input (lhs). - for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { - // Spatial dimension number for input (lhs) and output. - const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); - const int64 output_spatial_dim = - dnums.output_spatial_dimensions(ki); - - // Calculate lhs (input) index without taking base dilation into - // account. - const auto& window_dim = window.dimensions(ki); - const int64 undilated_index = - out_index[output_spatial_dim] * window_dim.stride() - - window_dim.padding_low() + - rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. As an - // optimization, skip this mod if there's no dilation. - if (window_dim.base_dilation() > 1 && - undilated_index % window_dim.base_dilation() != 0) { - goto cnt; - } - - // Calculate the actual lhs (input) index after dilation. As an - // optimization, skip this integer divide if there's no dilation. - int64 lhs_spatial_index; - if (window_dim.base_dilation() > 1) { - lhs_spatial_index = undilated_index / window_dim.base_dilation(); - } else { - lhs_spatial_index = undilated_index; - } - lhs_linear_index += - lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; - - // Skip if input index is not in bounds. - if (!(lhs_spatial_index >= 0 && - lhs_spatial_index < - lhs_shape.dimensions(input_spatial_dim))) { - goto cnt; - } - - rhs_linear_index += - (window_dim.window_reversal() - ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) - : rhs_spatial_index[ki]) * - rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; - } - - result_val += - static_cast(lhs_literal_data[lhs_linear_index]) * - static_cast(rhs_literal_data[rhs_linear_index]); - } - cnt : {} - } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - - return static_cast(result_val); - }; - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR(result->PopulateParallel(func)); - - parent_->evaluated_[conv] = std::move(result); - return Status::OK(); - } - - Status HandleDot(HloInstruction* dot) override { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(ShapeUtil::IsArray(dot->shape())); - CHECK(ShapeUtil::IsArray(lhs->shape())); - CHECK(ShapeUtil::IsArray(rhs->shape())); - - const auto& dnums = dot->dot_dimension_numbers(); - - const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); - const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - - // There must be 1 and only 1 Contracting dimension for lhs and rhs. - CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); - CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); - const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); - const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); - // Contracted dimension sizes must be the same. - CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), - rhs->shape().dimensions(rhs_contracting_dimension)) - << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracting_dimension) - << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracting_dimension); - const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracting_dimension); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(dot->shape()); - - CHECK_EQ(dnums.lhs_batch_dimensions_size(), - dnums.rhs_batch_dimensions_size()); - - std::vector lhs_non_contracting_dims; - for (int64 i = 0; i < lhs_rank; i++) { - if (i != lhs_contracting_dimension) { - lhs_non_contracting_dims.push_back(i); - } - } - - std::vector rhs_non_batch_non_contracting_dims; - FlatSet batch_dims_set(dnums.rhs_batch_dimensions().begin(), - dnums.rhs_batch_dimensions().end()); - for (int64 i = 0; i < rhs_rank; i++) { - if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { - rhs_non_batch_non_contracting_dims.push_back(i); - } - } - - const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); - const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); - - DimensionVector lhs_index(lhs_rank); - DimensionVector rhs_index(rhs_rank); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice result_index) { - ElementwiseT result_val = static_cast(0); - - // Find the corresponding non-contracting indices for lhs and rhs. - // - // For `result_index`, its batch dimension, if exists, will be at the - // same dimension as the batch dimension of lhs and rhs. More - // specifically: - // - For lhs, the non-contracting dimensions, including the batch - // dimension have the same index as the `result_index`. - // - For rhs, the batch dimension is set seperately from other - // non-contracting dimensions, since these other non-contracting - // dimensions in rhs follow the non-contracting dimensions of lhs in - // the resulting index. - // - // As an example, for a resulting index: - // result_index [result_batch, result_x, result_y] - // the effecting lhs and rhs indices are: - // lhs [result_batch, lhs_non_contracting_dim, contracting_dim - // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] - // `result_x` is only affected by the lhs_non_contracting_dim and - // likewise `result_y` only depends on rhs_non_contracting_dim. - // - // so we can look up the lhs and rhs indices by: - // - // lhs: - // batch index is the same as `result_batch`. - // non-contracting dimension is the same as - // result_index[lhs_non_contracting_dim] - // rhs: - // batch index: the same as `result_batch`. - // non-contracting dimension index: *not* the same as - // result_index[rhs_non_contractng_dim], since the - // non-contracting dimensions of lhs are included in the - // result_index first. Instead, the non_contracting_dim of rhs must - // be calculated as following: - // lhs_non_contracting_dimensions_size + - // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 - // - // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is - // the index offset to the result_index that only depends on - // the non_batch and non-contracting dimensions of rhs. -1 at the - // end translates size to index. - for (auto i : lhs_non_contracting_dims) { - lhs_index[i] = result_index[i]; - } - for (auto i : dnums.rhs_batch_dimensions()) { - rhs_index[i] = result_index[i]; - } - for (auto i : rhs_non_batch_non_contracting_dims) { - const int64 rhs_non_batch_non_contracting_dim = - lhs_non_contracting_size + (i - batch_dim_size) - 1; - rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; - } - - // Accumulates resulting product along the contracted dimension. - for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracting_dimension] = i; - rhs_index[rhs_contracting_dimension] = i; - - result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); - } - - return static_cast(result_val); - })); - - parent_->evaluated_[dot] = std::move(result); - return Status::OK(); - } - - Status HandlePad(HloInstruction* pad) override { - CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); - // Padding value must be scalar. - CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); - CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), - pad->padding_config().dimensions_size()); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferPadShape( - /*operand_shape=*/pad->operand(0)->shape(), - /*padding_value_shape=*/pad->operand(1)->shape(), - /*padding_config=*/pad->padding_config())); - CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - // Create new HLO of padded shape with padding value. - ReturnT scalar = - parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); - auto result = Literal::CreateFromShape(pad->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&scalar](ArraySlice multi_index) { return scalar; })); - - const Literal& evaluated_operand = - parent_->GetEvaluatedLiteralFor(pad->operand(0)); - - std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), - 0); - std::vector target_index(ShapeUtil::Rank(result->shape()), 0); - - // Loop through each element of the operand, assign them to the - // corresponding index of the resulting padded literal. - const PaddingConfig& pad_config = pad->padding_config(); - - auto func = [&](ArraySlice input_index) { - for (auto i = 0; i < input_index.size(); ++i) { - // Interior padding occurs logically before edge padding, so in the case - // of negative edge padding elements are removed from the - // interior-padded operand. - target_index[i] = - pad_config.dimensions(i).edge_padding_low() + - input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); - - // Account for negative low and high padding: skip assignment if the - // any target index is out of range. - if (!(target_index[i] >= 0 && - target_index[i] < pad->shape().dimensions(i))) { - return true; - } - } - result->Set(target_index, - evaluated_operand.Get(input_index)); - return true; - }; - - std::vector zero_base(evaluated_operand.shape().dimensions_size(), - 0); - std::vector step(evaluated_operand.shape().dimensions_size(), 1); - - ShapeUtil::ForEachIndex( - evaluated_operand.shape(), zero_base, - AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); - - parent_->evaluated_[pad] = std::move(result); - return Status::OK(); - } - - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - auto operand = dynamic_slice->operand(0); - auto start_indices = dynamic_slice->operand(1); - auto result_shape = dynamic_slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferDynamicSliceShape( - operand->shape(), start_indices->shape(), - dynamic_slice->dynamic_slice_sizes())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_slice], - DynamicSlice(operand_literal, start_indices_literal, - result_shape)); - } break; - default: - LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - auto operand = dynamic_update_slice->operand(0); - auto update = dynamic_update_slice->operand(1); - auto start_indices = dynamic_update_slice->operand(2); - auto result_shape = dynamic_update_slice->shape(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->shape(), update->shape(), start_indices->shape())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(result_shape) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - TF_RET_CHECK( - primitive_util::IsIntegralType(start_indices->shape().element_type())); - TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); - const Literal& start_indices_literal = - parent_->GetEvaluatedLiteralFor(start_indices); - - switch (start_indices->shape().element_type()) { - case S32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case S64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U32: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - case U64: { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[dynamic_update_slice], - DynamicUpdateSlice(operand_literal, update_literal, - start_indices_literal)); - } break; - default: - LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " - "start_indices: " - << PrimitiveType_Name(start_indices->shape().element_type()); - } - - return Status::OK(); - } - - template - StatusOr> MapImpl(HloInstruction* map) { - auto operands = map->operands(); - HloComputation* computation = map->to_apply(); - - auto result = Literal::CreateFromShape(map->shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - std::vector> arg_literals; - arg_literals.reserve(operands.size()); - - // Construct scalar literal parameters to be passed to the map - // computation. - for (auto operand : operands) { - const Literal& arg_literal = - parent_->GetEvaluatedLiteralFor(operand); - - auto curr_val = arg_literal.Get(multi_index); - auto curr_val_literal = Literal::CreateR0(curr_val); - - arg_literals.push_back(std::move(curr_val_literal)); - } - - std::unique_ptr computed_result = - embedded_evaluator - .Evaluate>(*computation, - arg_literals) - .ConsumeValueOrDie(); - // Clear visit states so that the we can use the evaluate again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - - return computed_result->Get({}); - })); - return std::move(result); - } - - Status HandleMap(HloInstruction* map) override { - switch (map->operand(0)->shape().element_type()) { - case PRED: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case U64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S8: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case S64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F16: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], - MapImpl(map)); - break; - } - case F32: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case F64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - case C64: { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); - break; - } - default: - LOG(FATAL) << "HandleMap: unhandled primitive type for " - "input operand: " - << PrimitiveType_Name( - map->operand(0)->shape().element_type()); - } - - return Status::OK(); - } - - Status HandleReduce(HloInstruction* reduce) override { - auto arg = reduce->operand(0); - auto init_value = reduce->operand(1); - ArraySlice dimensions(reduce->dimensions()); - HloComputation* function = reduce->to_apply(); - TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == - ShapeUtil::Rank(arg->shape()) - dimensions.size()); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferReduceShape( - /*arg=*/arg->shape(), - /*init_value=*/init_value->shape(), - /*dimensions_to_reduce=*/dimensions, - /*to_apply=*/function->ComputeProgramShape())); - TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) - << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); - VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); - const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); - VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce->shape()); - - const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); - std::vector arg_dim_steps(arg_dimensions.size()); - std::vector arg_dim_counts(arg_dimensions.size()); - for (const int64 dim : dimensions) { - arg_dim_steps[dim] = 1; - arg_dim_counts[dim] = arg_dimensions[dim]; - } - - // Map each dimension in the result to a dimension in arg that isn't - // being reduced. - std::vector result_to_arg_index; - for (int64 i = 0; i < arg_dimensions.size(); ++i) { - if (arg_dim_steps[i] == 0) { - result_to_arg_index.push_back(i); - } - } - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - ReturnT result_val = init_scalar; - - std::vector base(arg_dimensions.size()); - for (int64 i = 0; i < multi_index.size(); ++i) { - base[result_to_arg_index[i]] = multi_index[i]; - } - - // When the reduction is addition of floats, accumulate in a double - // for better precision. Also, avoid creating Literals for the - // intermediate results; it's much faster. - if (ShapeUtil::ElementIsFloating(init_literal.shape()) && - IsScalarAdd(function)) { - double computed_result = 0; - auto func = [&](ArraySlice input_index) { - computed_result += arg_literal.Get(input_index); - return true; - }; - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return static_cast(computed_result); - } - auto func = [&](ArraySlice input_index) { - auto curr_val = arg_literal.Get(input_index); - - // Evaluate computation with specified literal operands. - auto curr_val_literal = Literal::CreateR0(curr_val); - auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {curr_val_literal.get(), - result_val_literal.get()}; - - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - // Clear visit states so that we can use the evaluator again on - // the same computation. - embedded_evaluator.ResetVisitStates(); - // Assign computed result to result_val. - result_val = computed_result->Get({}); - return true; - }; - // Computes one element of the result, reducing all dimensions that - // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); - return result_val; - })); - - parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); - } - - bool IsScalarAdd(HloComputation* computation) { - HloInstruction* instruction = computation->root_instruction(); - if (instruction->opcode() == HloOpcode::kAdd && - computation->num_parameters() == 2) { - const HloInstruction* lhs = instruction->operand(0); - const HloInstruction* rhs = instruction->operand(1); - return lhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(lhs->shape()) && - rhs->opcode() == HloOpcode::kParameter && - ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; - } - return false; - } - - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { - auto operand = select_and_scatter->operand(0); - auto source = select_and_scatter->operand(1); - const Window& window = select_and_scatter->window(); - - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(select_and_scatter->shape()); - - // Initialize result array with the init value. - TF_RETURN_IF_ERROR(result->Populate( - [&](ArraySlice output_index) { return init_scalar; })); - - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - HloComputation* select = select_and_scatter->select(); - HloComputation* scatter = select_and_scatter->scatter(); - - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); - - int64 rank = ShapeUtil::Rank(operand_literal.shape()); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); - do { - // For each element in `source`, we place a window in `operand`. For each - // window placement, we iterate inside the window twice: - // - // 1. Find the selected index by applying `select` function to all - // elements. E.g., If the `select` function is GreaterEqual, the first - // iteration through the window finds the biggest value and returns its - // index. - // - // 2. Using the selected index, scatter value from `source` to result. We - // do this by iterating through the window, and compare each index with - // the selected index. - optional selected_val; - optional> selected_index; - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - if (!selected_val) { - selected_val = curr_val; - selected_index = operand_index; - } - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto selected_val_literal = - Literal::CreateR0(*selected_val); - - const std::vector args = { - selected_val_literal.get(), curr_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*select, args) - .ConsumeValueOrDie(); - bool selected = !computed_result->Get({}); - if (selected) { - selected_val = curr_val; - selected_index = operand_index; - } - embedded_evaluator.ResetVisitStates(); - }); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), source_index, - [&](const std::vector& operand_index) { - if (std::equal(operand_index.begin(), operand_index.end(), - selected_index->begin())) { - auto source = source_literal.Get(source_index); - auto scattered = result->Get(operand_index); - const auto source_literal = Literal::CreateR0(source); - const auto scattered_literal = - Literal::CreateR0(scattered); - - const std::vector args = { - source_literal.get(), scattered_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*scatter, args) - .ConsumeValueOrDie(); - result->Set(operand_index, computed_result->Get({})); - // Clear visit states so that the we can use the evaluator again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - } - }); - } while (IndexUtil::BumpIndices(source->shape(), &source_index)); - - parent_->evaluated_[select_and_scatter] = std::move(result); - return Status::OK(); - } - - Status HandleReduceWindow(HloInstruction* reduce_window) override { - auto operand = reduce_window->operand(0); - const Window& window = reduce_window->window(); - HloComputation* function = reduce_window->to_apply(); - TF_ASSIGN_OR_RETURN( - auto inferred_return_shape, - ShapeInference::InferReduceWindowShape( - /*operand_shape=*/reduce_window->operand(0)->shape(), - /*init_value=*/reduce_window->operand(1)->shape(), window, - /*to_apply_shape=*/function->ComputeProgramShape())); - TF_RET_CHECK( - ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) - << "return shape is set to: " - << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) - << "but is inferred to be: " - << ShapeUtil::HumanStringWithLayout(inferred_return_shape); - - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); - VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); - const Literal& init_literal = - parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); - VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); - TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); - auto init_scalar = init_literal.Get({}); - - auto result = Literal::CreateFromShape(reduce_window->shape()); - - // Creates a Shape object from window, for iteration below. - std::vector window_dimension_sizes; - for (const auto& window_dimension : window.dimensions()) { - window_dimension_sizes.push_back(window_dimension.size()); - } - const Shape window_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), window_dimension_sizes); - - DimensionVector window_index(window.dimensions_size()); - DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); - - HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - // For each resulting dimension, calculate and assign computed value. - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice output_index) { - ReturnT result_val = init_scalar; - - std::fill(window_index.begin(), window_index.end(), 0); - std::fill(operand_index.begin(), operand_index.end(), 0); - - IterateThroughWindow( - window_shape, window, operand_literal.shape(), output_index, - [&](const std::vector& operand_index) { - auto curr_val = operand_literal.Get(operand_index); - - // Evaluate computation with specified literal operands. - const auto curr_val_literal = - Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = { - curr_val_literal.get(), result_val_literal.get()}; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - // Clear visit states so that the we can use the evaluate again - // on the same computation. - embedded_evaluator.ResetVisitStates(); - - result_val = computed_result->Get({}); - }); - - return result_val; - })); - - parent_->evaluated_[reduce_window] = std::move(result); - return Status::OK(); - } - - Status HandleSlice(HloInstruction* slice) override { - auto operand = slice->operand(0); - const Shape& shape = slice->shape(); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferSliceShape( - operand->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const int64 rank = ShapeUtil::Rank(operand->shape()); - const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto func = [&](ArraySlice out_index) { - DimensionVector operand_index(rank); - for (int64 i = 0; i < rank; ++i) { - operand_index[i] = - slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); - } - return operand_literal.Get(operand_index); - }; - - auto result = Literal::CreateFromDimensions( - shape.element_type(), AsInt64Slice(shape.dimensions())); - TF_RETURN_IF_ERROR(result->Populate(func)); - parent_->evaluated_[slice] = std::move(result); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { - return std::sin(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleSin(HloInstruction* sin) { - return InvalidArgument("Unsupported type for Sin"); - } - - Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); - } - - template ::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { - return std::cos(elem_operand); - })); - return Status::OK(); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleCos(HloInstruction* cos) { - return InvalidArgument("Unsupported type for Cos"); - } - - Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[reduce_precision], - ElementWiseUnaryOp(reduce_precision, [reduce_precision]( - ElementwiseT elem) { - uint32_t value_as_int = tensorflow::bit_cast(elem); - const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); - const uint32_t exponent_bits = reduce_precision->exponent_bits(); - - // Code is based on the CPU/GPU implementation in LLVM-emitting code. - // - // Bits in float type: - // mantissa : bits [0:22] - // exponent : bits [23:30] - // sign : bits [31] - if (mantissa_bits < 23) { - const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); - - // Compute rounding bias for round-to-nearest with ties to even. - // This is equal to a base value of 0111... plus one bit if the last - // remaining mantissa bit is 1. - const uint32_t base_rounding_bias = - (last_mantissa_bit_mask >> 1) - 1; - const uint32_t x_last_mantissa_bit = - (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); - const uint32_t x_rounding_bias = - x_last_mantissa_bit + base_rounding_bias; - - // Add rounding bias, and mask out truncated bits. Note that the - // case where adding the rounding bias overflows into the exponent - // bits is correct; the non-masked mantissa bits will all be zero, - // and the exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - value_as_int = value_as_int + x_rounding_bias; - value_as_int = value_as_int & truncation_mask; - } - if (exponent_bits < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the - // most- significant bit -- is equal to 1.0f for all exponent sizes. - // Adding 2^(n-1)-1 to this gives us the highest non-infinite - // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from - // this gives us the lowest' exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n - // is (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (exponent_bits - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; - const bool x_overflows = x_exponent > (reduced_max_exponent << 23); - const bool x_underflows = - x_exponent <= (reduced_min_exponent << 23); - - // Compute appropriately-signed values of zero and infinity. - const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; - const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; - - // Force to zero or infinity if overflow or underflow. (Note that - // this truncates all denormal values to zero, rather than rounding - // them.) - value_as_int = x_overflows ? x_signed_inf : value_as_int; - value_as_int = x_underflows ? x_signed_zero : value_as_int; - } - - float reduced_result = tensorflow::bit_cast(value_as_int); - if (std::isnan(elem)) { - reduced_result = mantissa_bits > 0 - ? elem - : std::numeric_limits::infinity(); - } - return reduced_result; - })); - return Status::OK(); - } - - template ::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Double not supported for reduce precision"); - } - - template < - typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> - Status HandleReducePrecision(HloInstruction* reduce_precision) { - return InvalidArgument("Unsupported type for reduce precision"); - } - - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return HandleReducePrecision(reduce_precision); - } - - private: - template - StatusOr> DynamicSlice( - const Literal& operand_literal, const Literal& start_indices_literal, - const Shape& result_shape) { - auto start_indices_typed = start_indices_literal.data(); - std::vector start(start_indices_typed.begin(), - start_indices_typed.end()); - - std::vector operand_indices(start.size()); - - auto result = Literal::CreateFromShape(result_shape); - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - for (int64 i = 0; i < operand_indices.size(); ++i) { - CHECK_GE(multi_index[i] + start[i], 0); - // Mod is only used here to be consistent with the existing - // backends' behavior. - operand_indices[i] = (multi_index[i] + start[i]) % - operand_literal.shape().dimensions(i); - } - - auto result = operand_literal.Get(operand_indices); - return result; - })); - - return std::move(result); - } - - template - StatusOr> DynamicUpdateSlice( - const Literal& operand_literal, const Literal& update_literal, - const Literal& start_indices_literal) { - auto result = operand_literal.CloneToUnique(); - auto start_indices_typed = start_indices_literal.data(); - const auto rank = ShapeUtil::Rank(result->shape()); - std::vector start(rank, 0); - for (int64 i = 0; i < rank; ++i) { - // All other implementations currently wrap-around the index, so this - // should do so as well. - start[i] = (start_indices_typed[i] % result->shape().dimensions(i)); - start[i] += (start[i] < 0) * result->shape().dimensions(i); - } - std::vector result_index(rank, 0); - - auto func = [&](ArraySlice update_index) { - std::transform(update_index.begin(), update_index.end(), start.begin(), - result_index.begin(), std::plus()); - // Same as above, wrap-around only to match other implementations' - // semantics. - std::transform(result_index.begin(), result_index.end(), - result->shape().dimensions().begin(), result_index.begin(), - std::modulus()); - result->Set(result_index, - update_literal.Get(update_index)); - return true; - }; - - std::vector base(update_literal.shape().dimensions_size(), 0); - std::vector step(update_literal.shape().dimensions_size(), 1); - ShapeUtil::ForEachIndex(update_literal.shape(), base, - AsInt64Slice(update_literal.shape().dimensions()), - step, func); - - return std::move(result); - } - - StatusOr> ElementWiseUnaryOp( - HloInstruction* instruction, - const std::function& unary_op) { - const Literal& operand_literal = - parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - TF_ASSIGN_OR_RETURN( - auto result_literal, - (ElementWiseUnaryOpImpl( - instruction, ConvertUnaryFunction(unary_op), operand_literal))); - - return std::move(result_literal); - } - - StatusOr> ElementWiseBinaryOp( - HloInstruction* instruction, - const std::function& - binary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast - // is removed. - if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return ConvertBinaryFunction(binary_op)( - lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); - })); - return std::move(result); - } - - template - StatusOr> ElementwiseTernaryOp( - HloInstruction* instruction, - const std::function& ternary_op) { - const auto shape = instruction->shape(); - const auto* lhs = instruction->operand(0); - const auto* rhs = instruction->operand(1); - const auto* ehs = instruction->operand(2); - - // TODO(b/35950897, b/27796129): add DCHECK back once implicit - // broadcast is removed. - if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && - ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && - ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { - return Unimplemented( - "Implicit broadcasting is currently unsupported in HLO evaluator " - "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); - } - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - - auto result = Literal::CreateFromShape(shape); - - TF_RETURN_IF_ERROR( - result->Populate([&](ArraySlice multi_index) { - return ternary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index), - ehs_literal.Get(multi_index)); - })); - - return std::move(result); - } - - template - static bool IsShiftOutOfBounds(NativeT rhs) { - typedef typename std::make_unsigned::type UnsignedT; - UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; - UnsignedT rhs_unsigned = static_cast(rhs); - return rhs_unsigned >= lhs_size_unsigned; - } - - HloEvaluator* parent_; -}; // class HloEvaluator::TypedVisitor HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique>(this); - typed_visitors_[U8] = MakeUnique>(this); + typed_visitors_[PRED] = MakeUnique>(this); + typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: U16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); }); - typed_visitors_[U32] = MakeUnique>(this); - typed_visitors_[U64] = MakeUnique>(this); - typed_visitors_[S8] = MakeUnique>(this); + typed_visitors_[U32] = MakeUnique>(this); + typed_visitors_[U64] = MakeUnique>(this); + typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: S16."); + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); }); - typed_visitors_[S32] = MakeUnique>(this); - typed_visitors_[S64] = MakeUnique>(this); - typed_visitors_[F16] = MakeUnique>(this); - typed_visitors_[F32] = MakeUnique>(this); - typed_visitors_[F64] = MakeUnique>(this); - typed_visitors_[C64] = MakeUnique>(this); + typed_visitors_[S32] = MakeUnique>(this); + typed_visitors_[S64] = MakeUnique>(this); + typed_visitors_[F16] = + MakeUnique>(this); + typed_visitors_[F32] = MakeUnique>(this); + typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. - typed_visitors_[BF16] = MakeUnique>(this); + typed_visitors_[BF16] = + MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVistor: unhandled primitive type: TUPLE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { return Unimplemented( - "HloEvaluator::TypedVisitor: unhandled primitive type: OPAQUE."); + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); }); } @@ -2508,6 +478,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { } break; case F16: return Unimplemented("unhandled primitive type: F16."); + case BF16: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; case F32: { TF_ASSIGN_OR_RETURN( evaluated_[compare], @@ -2935,9 +910,10 @@ Status HloEvaluator::HandleCall(HloInstruction* call) { } Status HloEvaluator::HandleFusion(HloInstruction* fusion) { + HloModuleConfig config; // Attach cloned computation to an empty HLO module so the existing ones are // not modified. - HloModule empty_hlo_module("EmptyModuleForFusion"); + HloModule empty_hlo_module("EmptyModuleForFusion", config); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( /*suffix=*/"clone_with_layout", &empty_hlo_module); @@ -2975,8 +951,8 @@ Status HloEvaluator::HandleConditional(HloInstruction* conditional) { auto* true_computation = conditional->true_computation(); auto* false_computation = conditional->false_computation(); - auto result = Literal::CreateFromShape(conditional->shape()); HloEvaluator embedded_evaluator; + std::unique_ptr result; if (pred.Get({})) { result = embedded_evaluator .Evaluate(*true_computation, @@ -3000,7 +976,7 @@ Status HloEvaluator::HandleSelect(HloInstruction* select) { // If predicate is of scalar type, no element-wise selection would be needed. // This would also handle output array of tuple types as the DefaultAction - // would go through the TypedVisitor which doesn't handle tuples. + // would go through the HloEvaluatorTypedVisitor which doesn't handle tuples. if (ShapeUtil::IsScalar(pred.shape())) { if (pred.Get({})) { evaluated_[select] = on_true.CloneToUnique(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c0dcee0c3e3..566d53a4142 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -109,19 +110,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { substitutions); protected: - // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting - // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this rule, notably: - // - HandleCompare and HandleIsFinite: where the resulting literal type is - // always boolean. - // These operations are handled outside of the parent HloEvaluator handlers - // instead of from within TypedVisitor. + // Make HloEvaluatorTypedVisitor a friend because it is logically part of this + // class. // - // Type params: - // - ReturnT: The type of input and output of each operation. - // - ElementwiseT: The type in which internal computation are done. - template - class TypedVisitor; + // A straightforward implementation would be to make it a nested class + // declared and defined in hlo_evaluator.cc. Instead HloEvaluatorTypedVisitor + // lives as a separate class with its own header because its template gets + // instantiated many times and we want to use extern templates to shard out + // the compilation of those instantiations across multiple cc files. + template + friend class HloEvaluatorTypedVisitor; // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. @@ -168,7 +166,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; - private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. @@ -183,14 +180,6 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return *(it->second); } - // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap, - std::hash> - typed_visitors_; - // Tracks the HLO instruction and its evaluated literal result. // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in @@ -199,6 +188,41 @@ class HloEvaluator : public DfsHloVisitorWithDefault { tensorflow::gtl::FlatMap> evaluated_; + private: + template + static StatusOr> ElementWiseUnaryOpImpl( + HloInstruction* instruction, + const std::function& unary_op, + const Literal& operand_literal) { + const auto shape = instruction->shape(); + const auto* operand = instruction->operand(0); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is + // removed. + if (!ShapeUtil::SameDimensions(shape, operand->shape())) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(operand->shape()).c_str()); + } + + auto result = MakeUnique(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); + })); + return std::move(result); + } + + // Map from a primitive type to its associated (templated) DfsHloVisitor. + // Note: the hash function here is only needed because current gcc std::hash + // does not specialize for enum types. This should however be fixed in the + // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 + tensorflow::gtl::FlatMap, + std::hash> + typed_visitors_; + // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of // each invocation to the Evaluate* method. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index dd14dd38537..ae5b5e0412e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { ErrorSpec error(aabs); - LiteralTestUtil::ExpectNear(*expected, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error)); } else { - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } } @@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface, std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } bool use_bfloat16_; @@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) { auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { @@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select @@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) { auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs @@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies Reshape operation is correctly evaluated. @@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, - LiteralTestUtil::CreateRandomLiteral( + Literal::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = @@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { @@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { std::unique_ptr result = Evaluate({}); - LiteralTestUtil::ExpectEqual(*result, *output_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal)); } TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { @@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { @@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { @@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { @@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { std::unique_ptr result = Evaluate(); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } PaddingConfig CreatePaddingConfig( @@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { @@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { auto expected = Literal::CreateR4FromArray4D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, NegativePadding2D) { @@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank1) { @@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // clang-format on auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank1AndRank2) { @@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { auto expected = Literal::CreateR1({22.f, 28.f}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DotRank2AndRank2) { @@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { }); auto expected = Literal::CreateR2FromArray2D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SimpleConv1D) { @@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { @@ -827,7 +827,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { // clang-format on auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { @@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { @@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { auto expected = Literal::CreateR4FromArray4D( use_bfloat16_ ? expected_array_bf16 : expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { @@ -1046,7 +1046,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { @@ -1109,7 +1109,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, @@ -1180,7 +1180,7 @@ TEST_P(HloEvaluatorTest, *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(2); + XlaBuilder::CreateDefaultConvDimensionNumbers(2); const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( @@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest, })); auto expected = Literal::CreateR4FromArray4D(expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; @@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { auto expected = Literal::CreateR1({6, 18}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowMax) { @@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd) { @@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { @@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); - LiteralTestUtil::ExpectEqual(*result_literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result)); } TEST_P(HloEvaluatorTest, StridedSlice) { @@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { {19}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSlice) { @@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } // Verifies that the HloEvaluator's implementation goes along with existing @@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { {6, 7, 8}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { @@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { {5, -6, -7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetTuples) { @@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { {5, 6, 7}, }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { @@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { result_inner_literal.get(), }); - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, Reverse) { @@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) { }); // clang-format on - LiteralTestUtil::ExpectEqual(*expected, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { @@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } // Check that EvaluateWithSubstitutions works if one of the operands to the op @@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { auto result = evaluator.EvaluateWithSubstitutions( add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); TF_ASSERT_OK(result.status()); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), - *result.ValueOrDie()); + EXPECT_TRUE(LiteralTestUtil::Equal( + *Literal::CreateR1({11, 22, 33, 44}), *result.ValueOrDie())); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) { @@ -1823,9 +1823,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) { @@ -1847,9 +1847,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 3}, {4, 6}, {7, 9}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) { @@ -1872,10 +1872,10 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 2}, {2, 1}}); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3( {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}), - *Evaluate({operand.get(), gather_indices.get()})); + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) { @@ -1900,9 +1900,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-1, 1}, {-4, 4}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, 1}, {-4, 4}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, @@ -1928,9 +1928,9 @@ ENTRY main { {{-7, 7}, {-8, 8}, {-9, 9}}}); std::unique_ptr gather_indices = Literal::CreateR2({{0, 0}, {1, 0}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{-2, 2}, {-1, 1}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-2, 2}, {-1, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) { @@ -1952,9 +1952,9 @@ ENTRY main { std::unique_ptr operand = Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{5}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{5}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) { @@ -1977,9 +1977,9 @@ ENTRY main { Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); std::unique_ptr gather_indices = Literal::CreateR2({{2, 1}, {1, 1}}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR3({{{8}}, {{5}}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{8}}, {{5}}}), + *Evaluate({operand.get(), gather_indices.get()}))); } TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) { @@ -2000,9 +2000,50 @@ ENTRY main { ParseAndVerifyModule(hlo_text); std::unique_ptr operand = Literal::CreateR2({{}, {}, {}}); std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); - LiteralTestUtil::ExpectEqual( - *Literal::CreateR2({{}, {}}), - *Evaluate({operand.get(), gather_indices.get()})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{}, {}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) { + const string hlo_text = R"( +HloModule GatherXd + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + ROOT gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1} +} +)"; + ParseAndVerifyModule(hlo_text); + + std::unique_ptr operand = Literal::CreateR1({0, 1, 2}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0}, {1}}, {{2}, {1}}}); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{0, 1}, {2, 1}}), + *Evaluate({operand.get(), gather_indices.get()}))); +} + +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise comparison with 2 bfloat16 operands. +TEST_P(HloEvaluatorTest, DoesCompareBF16) { + // lhs >= rhs + auto lhs = Literal::CreateR2( + {{bfloat16(0.25), bfloat16(0.35), bfloat16(0.125)}, + {bfloat16(-0.25), bfloat16(-0.35), bfloat16(-0.125)}}); + auto rhs = Literal::CreateR2( + {{bfloat16(0.5), bfloat16(0.125), bfloat16(0.125)}, + {bfloat16(0.25), bfloat16(-0.375), bfloat16(-0.127)}}); + auto expected = + Literal::CreateR2({{false, true, true}, {false, true, true}}); + TestBinaryOp(HloOpcode::kGe, std::move(expected), std::move(lhs), + std::move(rhs)); } INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h new file mode 100644 index 00000000000..024e8751f79 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -0,0 +1,2170 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// TODO(b/79274244): We'd like these type traits to live inside of +// HloEvaluatorTypedVisitor so they don't pollute namespace xla, but that +// crashes clang in the frontend. +// +// Anyway this is relatively safe as-is because hlo_evaluator_typed_visitor.h is +// a "private" header that's not exposed outside of hlo_evaluator.cc. +template +using is_complex_t = std::is_same; +template +using is_complex64_t = std::is_same; + +// Templated DfsHloVisitor for use by HloEvaluator. +// +// Typically ReturnT here indicates the resulting literal type of each evaluated +// Handle* method of a TypedVisitor. There are however a few notable exceptions +// to this rule, notably: +// - HandleCompare and HandleIsFinite: where the resulting literal type is +// always boolean. +// These operations are handled outside of the parent HloEvaluator handlers +// instead of from within TypedVisitor. +// +// Type params: +// - ReturnT: The type of input and output of each operation. +// - ElementwiseT: The type in which internal computation are done. +// +// This a logically a private part of HloEvaluator. It lives in this header +// file rather than in hlo_evaluator.cc because we use extern templates and a +// bunch of independent cc files to speed up compiling the many instantiations +// of this class. +template +class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { + public: + explicit HloEvaluatorTypedVisitor(HloEvaluator* p) : parent_(p) {} + + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", + HloOpcodeString(hlo_instruction->opcode()).c_str()); + } + + // TODO(b/35950897): many of the stl functions used in the handlers are not + // overloaded for every XLA primitive type. + + template ::value>::type* = + nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + return std::abs(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (HloEvaluator::ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } + return HandleAbs(abs); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound(round); + } + + Status HandleBroadcast(HloInstruction* broadcast) override { + const Literal& operand_to_broadcast = + parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); + std::vector broadcast_indices( + ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand_to_broadcast.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand_to_broadcast.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand_to_broadcast.shape().dimensions(i)); + } + + auto output = MakeUnique(broadcast->shape()); + TF_RETURN_IF_ERROR(output->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + })); + parent_->evaluated_[broadcast] = std::move(output); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { + return std::ceil(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil) override { + return HandleCeil(ceil); + } + + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleBitcastConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).BitcastConvert( + convert->shape().element_type())); + + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); + } + return Status::OK(); + } + + Status HandleExp(HloInstruction* exp) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { + return std::exp(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::expm1(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleExpm1(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Expm1"); + } + + Status HandleExpm1(HloInstruction* floor) override { + return HandleExpm1(floor); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor) override { + return HandleFloor(floor); + } + + Status HandleLog(HloInstruction* log) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { + return std::log(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* expm1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[expm1], + ElementWiseUnaryOp(expm1, [](ElementwiseT elem_operand) { + return std::log1p(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleLog1p(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Log1p"); + } + + Status HandleLog1p(HloInstruction* floor) override { + return HandleLog1p(floor); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } + + Status HandleNot(HloInstruction* not_) override { + return HandleNot(not_); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); + return Status::OK(); + } + + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ElementwiseT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign) override { + return HandleSign(sign); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2(atan2); + } + + Status HandleTanh(HloInstruction* tanh) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { + return std::tanh(elem_operand); + })); + return Status::OK(); + } + + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); + return Status::OK(); + } + + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + + Status HandleSubtract(HloInstruction* subtract) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[subtract], + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); + return Status::OK(); + } + + Status HandleAdd(HloInstruction* add) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return std::max(lhs, rhs); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[maximum], + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { + return ((lhs >= rhs) || std::isnan(lhs)) ? lhs : rhs; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum(maximum); + } + + template ::value>::type* = + nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::min(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return ((lhs_el <= rhs_el) || std::isnan(lhs_el)) ? lhs_el : rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum(minimum); + } + + Status HandlePower(HloInstruction* power) override { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder) override { + return HandleRemainder(remainder); + } + + template ::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el && rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_) override { + return HandleAnd(and_); + } + + template ::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { + return lhs_el || rhs_el; + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_) override { + return HandleOr(or_); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return IsShiftOutOfBounds(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl) override { + return HandleShiftLeft(shl); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + SignedT lhs_signed = static_cast(lhs_elem); + if (IsShiftOutOfBounds(rhs_elem)) { + return lhs_signed < 0 ? static_cast(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra) override { + return HandleShiftRightArithmetic(shra); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (IsShiftOutOfBounds(rhs_elem)) { + return static_cast(0); + } + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl) override { + return HandleShiftRightLogical(shrl); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { + return std::fmin(high, std::fmax(value, low)); + }; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction*) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp) override { + return HandleClamp(clamp); + } + + Status HandleSelect(HloInstruction* select) override { + CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape())); + CHECK(!ShapeUtil::IsTuple(select->shape())); + std::function select_op = + [](bool pred, ReturnT on_true, ReturnT on_false) { + if (pred) { + return on_true; + } + return on_false; + }; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], + ElementwiseTernaryOp(select, std::move(select_op))); + return Status::OK(); + } + + Status HandleReverse(HloInstruction* reverse) override { + const auto result_shape = reverse->shape(); + const auto reverse_dimensions = reverse->dimensions(); + + auto operand = reverse->operand(0); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReverseShape(operand->shape(), + reverse_dimensions)); + + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto result = MakeUnique(result_shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice out_index) { + std::vector from_index(out_index.begin(), out_index.end()); + for (const int64 dim : reverse_dimensions) { + from_index[dim] = result_shape.dimensions(dim) - 1 - out_index[dim]; + } + return operand_literal.Get(from_index); + })); + + parent_->evaluated_[reverse] = std::move(result); + return Status::OK(); + } + + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(ShapeUtil::IsArray(lhs_shape)); + CHECK(ShapeUtil::IsArray(rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, rhs_shape)); + CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 0); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + const auto lhs_rank = ShapeUtil::Rank(lhs_shape); + const auto rhs_rank = ShapeUtil::Rank(rhs_shape); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, + window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + std::vector window_dimension_sizes; + for (auto i : dnums.kernel_spatial_dimensions()) { + window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); + } + + const Shape& window_shape = + ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes); + + DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape); + DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape); + + auto lhs_literal_data = lhs_literal.data(); + auto rhs_literal_data = rhs_literal.data(); + + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, + &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, + rhs_literal_data]( + tensorflow::gtl::ArraySlice out_index) { + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); + // Dimension number applicable for kernel (rhs). + const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); + const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); + + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + + ElementwiseT result_val = static_cast(0); + DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), + 0); + + // Convolve input feature with kernel. + do { + for (int64 iz = 0; iz < z_size; ++iz) { + int64 lhs_linear_index = 0; + lhs_linear_index += out_index[output_batch_dim] * + lhs_dim_multipliers[input_batch_dim]; + lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim]; + + int64 rhs_linear_index = 0; + rhs_linear_index += out_index[output_z_dim] * + rhs_dim_multipliers[kernel_output_z_dim]; + rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + + // Find corresponding spatial dimension index for input (lhs). + for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { + // Spatial dimension number for input (lhs) and output. + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); + + // Calculate lhs (input) index without taking base dilation into + // account. + const auto& window_dim = window.dimensions(ki); + const int64 undilated_index = + out_index[output_spatial_dim] * window_dim.stride() - + window_dim.padding_low() + + rhs_spatial_index[ki] * window_dim.window_dilation(); + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { + goto cnt; + } + + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + int64 lhs_spatial_index; + if (window_dim.base_dilation() > 1) { + lhs_spatial_index = undilated_index / window_dim.base_dilation(); + } else { + lhs_spatial_index = undilated_index; + } + lhs_linear_index += + lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim]; + + // Skip if input index is not in bounds. + if (!(lhs_spatial_index >= 0 && + lhs_spatial_index < + lhs_shape.dimensions(input_spatial_dim))) { + goto cnt; + } + + rhs_linear_index += + (window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]) * + rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)]; + } + + result_val += + static_cast(lhs_literal_data[lhs_linear_index]) * + static_cast(rhs_literal_data[rhs_linear_index]); + } + cnt : {} + } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); + + return static_cast(result_val); + }; + + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->PopulateParallel(func)); + + parent_->evaluated_[conv] = std::move(result); + return Status::OK(); + } + + Status HandleDot(HloInstruction* dot) override { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(ShapeUtil::IsArray(dot->shape())); + CHECK(ShapeUtil::IsArray(lhs->shape())); + CHECK(ShapeUtil::IsArray(rhs->shape())); + + const auto& dnums = dot->dot_dimension_numbers(); + + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); + const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); + + CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); + CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); + + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) + << "lhs contracted dimension: " + << lhs->shape().dimensions(lhs_contracting_dimension) + << " rhs contracted dimension: " + << rhs->shape().dimensions(rhs_contracting_dimension); + const int64 contracted_dimension_size = + lhs->shape().dimensions(lhs_contracting_dimension); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + auto result = MakeUnique(dot->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice result_index) { + ElementwiseT result_val = static_cast(0); + + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; + } + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; + } + + // Accumulates resulting product along the contracted dimension. + for (int64 i = 0; i < contracted_dimension_size; ++i) { + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; + + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); + } + + return static_cast(result_val); + })); + + parent_->evaluated_[dot] = std::move(result); + return Status::OK(); + } + + Status HandlePad(HloInstruction* pad) override { + CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); + // Padding value must be scalar. + CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape())); + CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()), + pad->padding_config().dimensions_size()); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferPadShape( + /*operand_shape=*/pad->operand(0)->shape(), + /*padding_value_shape=*/pad->operand(1)->shape(), + /*padding_config=*/pad->padding_config())); + CHECK(ShapeUtil::Compatible(pad->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(pad->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + // Create new HLO of padded shape with padding value. + ReturnT scalar = + parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get({}); + auto result = MakeUnique(pad->shape()); + TF_RETURN_IF_ERROR(result->Populate( + [&scalar](tensorflow::gtl::ArraySlice multi_index) { + return scalar; + })); + + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); + + std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), + 0); + std::vector target_index(ShapeUtil::Rank(result->shape()), 0); + + // Loop through each element of the operand, assign them to the + // corresponding index of the resulting padded literal. + const PaddingConfig& pad_config = pad->padding_config(); + + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + for (auto i = 0; i < input_index.size(); ++i) { + // Interior padding occurs logically before edge padding, so in the case + // of negative edge padding elements are removed from the + // interior-padded operand. + target_index[i] = + pad_config.dimensions(i).edge_padding_low() + + input_index[i] * (pad_config.dimensions(i).interior_padding() + 1); + + // Account for negative low and high padding: skip assignment if the + // any target index is out of range. + if (!(target_index[i] >= 0 && + target_index[i] < pad->shape().dimensions(i))) { + return true; + } + } + result->Set(target_index, + evaluated_operand.Get(input_index)); + return true; + }; + + std::vector zero_base(evaluated_operand.shape().dimensions_size(), + 0); + std::vector step(evaluated_operand.shape().dimensions_size(), 1); + + ShapeUtil::ForEachIndex( + evaluated_operand.shape(), zero_base, + AsInt64Slice(evaluated_operand.shape().dimensions()), step, func); + + parent_->evaluated_[pad] = std::move(result); + return Status::OK(); + } + + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + auto operand = dynamic_slice->operand(0); + auto start_indices = dynamic_slice->operand(1); + auto result_shape = dynamic_slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferDynamicSliceShape( + operand->shape(), start_indices->shape(), + dynamic_slice->dynamic_slice_sizes())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_slice], + DynamicSlice(operand_literal, start_indices_literal, + result_shape)); + } break; + default: + LOG(FATAL) << "HandleDynamicSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + auto operand = dynamic_update_slice->operand(0); + auto update = dynamic_update_slice->operand(1); + auto start_indices = dynamic_update_slice->operand(2); + auto result_shape = dynamic_update_slice->shape(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferDynamicUpdateSliceShape( + operand->shape(), update->shape(), start_indices->shape())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(result_shape) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + TF_RET_CHECK( + primitive_util::IsIntegralType(start_indices->shape().element_type())); + TF_RET_CHECK(ShapeUtil::Compatible(result_shape, operand->shape())); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& update_literal = parent_->GetEvaluatedLiteralFor(update); + const Literal& start_indices_literal = + parent_->GetEvaluatedLiteralFor(start_indices); + + switch (start_indices->shape().element_type()) { + case S32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case S64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U32: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + case U64: { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[dynamic_update_slice], + DynamicUpdateSlice(operand_literal, update_literal, + start_indices_literal)); + } break; + default: + LOG(FATAL) << "HandleDynamicUpdateSlice: unhandled primitive type for " + "start_indices: " + << PrimitiveType_Name(start_indices->shape().element_type()); + } + + return Status::OK(); + } + + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = MakeUnique(map->shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); + TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == + ShapeUtil::Rank(arg->shape()) - dimensions.size()); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferReduceShape( + /*arg=*/arg->shape(), + /*init_value=*/init_value->shape(), + /*dimensions_to_reduce=*/dimensions, + /*to_apply=*/function->ComputeProgramShape())); + TF_RET_CHECK(ShapeUtil::Compatible(reduce->shape(), inferred_return_shape)) + << "return shape is set to: " << ShapeUtil::HumanString(reduce->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& arg_literal = parent_->GetEvaluatedLiteralFor(arg); + VLOG(3) << "HandleReduce arg_literal: " << arg_literal.ToString(); + const Literal& init_literal = parent_->GetEvaluatedLiteralFor(init_value); + VLOG(3) << "HandleReduce init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + const auto arg_dimensions = AsInt64Slice(arg_literal.shape().dimensions()); + std::vector arg_dim_steps(arg_dimensions.size()); + std::vector arg_dim_counts(arg_dimensions.size()); + for (const int64 dim : dimensions) { + arg_dim_steps[dim] = 1; + arg_dim_counts[dim] = arg_dimensions[dim]; + } + + // Map each dimension in the result to a dimension in arg that isn't + // being reduced. + std::vector result_to_arg_index; + for (int64 i = 0; i < arg_dimensions.size(); ++i) { + if (arg_dim_steps[i] == 0) { + result_to_arg_index.push_back(i); + } + } + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + ReturnT result_val = init_scalar; + + std::vector base(arg_dimensions.size()); + for (int64 i = 0; i < multi_index.size(); ++i) { + base[result_to_arg_index[i]] = multi_index[i]; + } + + // When the reduction is addition of floats, accumulate in a double + // for better precision. Also, avoid creating Literals for the + // intermediate results; it's much faster. + if (ShapeUtil::ElementIsFloating(init_literal.shape()) && + IsScalarAdd(function)) { + double computed_result = 0; + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + computed_result += arg_literal.Get(input_index); + return true; + }; + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return static_cast(computed_result); + } + auto func = [&](tensorflow::gtl::ArraySlice input_index) { + auto curr_val = arg_literal.Get(input_index); + + // Evaluate computation with specified literal operands. + auto curr_val_literal = Literal::CreateR0(curr_val); + auto result_val_literal = Literal::CreateR0(result_val); + std::vector args = {result_val_literal.get(), + curr_val_literal.get()}; + + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + // Clear visit states so that we can use the evaluator again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + // Assign computed result to result_val. + result_val = computed_result->Get({}); + return true; + }; + // Computes one element of the result, reducing all dimensions that + // contribute to that element. + ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, + arg_dim_steps, func); + return result_val; + })); + + parent_->evaluated_[reduce] = std::move(result); + return Status::OK(); + } + + bool IsScalarAdd(HloComputation* computation) { + HloInstruction* instruction = computation->root_instruction(); + if (instruction->opcode() == HloOpcode::kAdd && + computation->num_parameters() == 2) { + const HloInstruction* lhs = instruction->operand(0); + const HloInstruction* rhs = instruction->operand(1); + return lhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(lhs->shape()) && + rhs->opcode() == HloOpcode::kParameter && + ShapeUtil::IsScalar(rhs->shape()) && lhs != rhs; + } + return false; + } + + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = MakeUnique(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + return init_scalar; + })); + + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + DimensionVector source_index(rank, 0); + + // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid + // dynamic memory allocations. + auto curr_val_literal = Literal::CreateR0(ReturnT()); + auto selected_val_literal = Literal::CreateR0(ReturnT()); + auto source_literal_scatter = Literal::CreateR0(ReturnT()); + auto scattered_literal = Literal::CreateR0(ReturnT()); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional selected_val; + tensorflow::gtl::optional> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + curr_val_literal->Set({}, curr_val); + selected_val_literal->Set({}, *selected_val); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate( + *select, + {selected_val_literal.get(), curr_val_literal.get()}) + .ConsumeValueOrDie(); + bool selected = !computed_result->Get({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get(source_index); + auto scattered = result->Get(operand_index); + source_literal_scatter->Set({}, source); + scattered_literal->Set({}, scattered); + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate(*scatter, + {source_literal_scatter.get(), + scattered_literal.get()}) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + + Status HandleReduceWindow(HloInstruction* reduce_window) override { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferReduceWindowShape( + /*operand_shape=*/reduce_window->operand(0)->shape(), + /*init_value=*/reduce_window->operand(1)->shape(), window, + /*to_apply_shape=*/function->ComputeProgramShape())); + TF_RET_CHECK( + ShapeUtil::Compatible(reduce_window->shape(), inferred_return_shape)) + << "return shape is set to: " + << ShapeUtil::HumanStringWithLayout(reduce_window->shape()) + << "but is inferred to be: " + << ShapeUtil::HumanStringWithLayout(inferred_return_shape); + + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(0)); + VLOG(3) << "HandleReduceWindow arg_literal: " << operand_literal.ToString(); + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(reduce_window->operand(1)); + VLOG(3) << "HandleReduceWindow init_literal: " << init_literal.ToString(); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + // Creates a Shape object from window, for iteration below. + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + DimensionVector window_index(window.dimensions_size()); + DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + + HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); + auto result = MakeUnique(reduce_window->shape()); + // For each resulting dimension, calculate and assign computed value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + ReturnT result_val = init_scalar; + + std::fill(window_index.begin(), window_index.end(), 0); + std::fill(operand_index.begin(), operand_index.end(), 0); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + const std::vector args = { + result_val_literal.get(), curr_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); + }); + + return result_val; + })); + + parent_->evaluated_[reduce_window] = std::move(result); + return Status::OK(); + } + + Status HandleSlice(HloInstruction* slice) override { + auto operand = slice->operand(0); + const Shape& shape = slice->shape(); + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferSliceShape( + operand->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); + TF_RET_CHECK(ShapeUtil::Compatible(shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const int64 rank = ShapeUtil::Rank(operand->shape()); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + auto func = [&](tensorflow::gtl::ArraySlice out_index) { + DimensionVector operand_index(rank); + for (int64 i = 0; i < rank; ++i) { + operand_index[i] = + slice->slice_starts(i) + out_index[i] * slice->slice_strides(i); + } + return operand_literal.Get(operand_index); + }; + + auto result = Literal::CreateFromDimensions( + shape.element_type(), AsInt64Slice(shape.dimensions())); + TF_RETURN_IF_ERROR(result->Populate(func)); + parent_->evaluated_[slice] = std::move(result); + return Status::OK(); + } + + // Enable CLZ only for int32, uint32, int64 and uint64. + template < + typename NativeT, + typename std::enable_if< + (std::is_floating_point::value || + std::is_integral::value || is_complex_t::value) && + !(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + return InvalidArgument("Unsupported type for Clz"); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 31 - tensorflow::Log2Floor(elem_operand); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = nullptr> + Status HandleClz(HloInstruction* clz) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], + ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { + return 63 - tensorflow::Log2Floor64(elem_operand); + })); + return Status::OK(); + } + + Status HandleClz(HloInstruction* clz) override { + return HandleClz(clz); + } + + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision(reduce_precision); + } + + private: + // Creates a vector of multipliers which can be used to create a linear index + // into shape. + // + // Given the multidimensional index {i1, ..., iN} and + // M = MakeDimMultipliers(shape), the corresponding linear index LI is simply + // + // LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N]. + // + // This lets you calculate LI given the multidimensional indices in any order. + static DimensionVector MakeDimMultipliers(const Shape& shape) { + DimensionVector v(ShapeUtil::Rank(shape)); + int64 scale = 1; + for (auto dim : LayoutUtil::MinorToMajor(shape)) { + v[dim] = scale; + scale *= shape.dimensions(dim); + } + return v; + } + + // For one particular placement of a window in a base shape (the placement is + // represented as `window_count_index`), iterates inside the window. + // Translates the window index into base index. If the base index is within + // bound, call `f` with the base index. + static void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice& window_count_index, + const std::function&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); + } + + template + StatusOr> DynamicSlice( + const Literal& operand_literal, const Literal& start_indices_literal, + const Shape& result_shape) { + auto start_indices_typed = start_indices_literal.data(); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + + // Clamp the start indices so the slice is in-bounds w.r.t the operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + for (int64 i = 0; i < start.size(); ++i) { + start[i] = std::min( + std::max(0LL, start[i]), + operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); + } + + std::vector operand_indices(start.size()); + auto result = MakeUnique(result_shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < operand_indices.size(); ++i) { + CHECK_GE(multi_index[i] + start[i], 0); + operand_indices[i] = multi_index[i] + start[i]; + } + + auto result = operand_literal.Get(operand_indices); + return result; + })); + + return std::move(result); + } + + template + StatusOr> DynamicUpdateSlice( + const Literal& operand_literal, const Literal& update_literal, + const Literal& start_indices_literal) { + auto result = operand_literal.CloneToUnique(); + auto start_indices_typed = start_indices_literal.data(); + const auto rank = ShapeUtil::Rank(result->shape()); + std::vector start(start_indices_typed.begin(), + start_indices_typed.end()); + // Clamp the update start indices so the slice is in-bounds w.r.t the + // operand. + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + for (int64 i = 0; i < rank; ++i) { + start[i] = std::min( + std::max(0, start[i]), + result->shape().dimensions(i) - update_literal.shape().dimensions(i)); + } + std::vector result_index(rank, 0); + + auto func = [&](tensorflow::gtl::ArraySlice update_index) { + std::transform(update_index.begin(), update_index.end(), start.begin(), + result_index.begin(), std::plus()); + result->Set(result_index, + update_literal.Get(update_index)); + return true; + }; + + std::vector base(update_literal.shape().dimensions_size(), 0); + std::vector step(update_literal.shape().dimensions_size(), 1); + ShapeUtil::ForEachIndex(update_literal.shape(), base, + AsInt64Slice(update_literal.shape().dimensions()), + step, func); + + return std::move(result); + } + + StatusOr> ElementWiseUnaryOp( + HloInstruction* instruction, + const std::function& unary_op) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(instruction->operand(0)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (HloEvaluator::ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); + } + + StatusOr> ElementWiseBinaryOp( + HloInstruction* instruction, + const std::function& + binary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. + if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + return std::move(result); + } + + template + StatusOr> ElementwiseTernaryOp( + HloInstruction* instruction, + const std::function& ternary_op) { + const auto shape = instruction->shape(); + const auto* lhs = instruction->operand(0); + const auto* rhs = instruction->operand(1); + const auto* ehs = instruction->operand(2); + + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. + if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && + ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && + ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { + return Unimplemented( + "Implicit broadcasting is currently unsupported in HLO evaluator " + "Shape Mismatch: %s vs %s vs %s vs %s: ", + ShapeUtil::HumanString(shape).c_str(), + ShapeUtil::HumanString(lhs->shape()).c_str(), + ShapeUtil::HumanString(rhs->shape()).c_str(), + ShapeUtil::HumanString(ehs->shape()).c_str()); + } + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); + + auto result = MakeUnique(shape); + + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); + })); + + return std::move(result); + } + + template + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + + HloEvaluator* parent_; +}; + +// These extern templates prevent users of this class from implicitly +// instantiating it. We explicitly instantiate this class in the various +// hlo_evaluator_typed_visitor*.cc files. +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc new file mode 100644 index 00000000000..39c352dfb96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bfloat16.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc new file mode 100644 index 00000000000..289b40fa06d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_bool.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc new file mode 100644 index 00000000000..9cb4eb921fd --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_complex64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc new file mode 100644 index 00000000000..5e6252fbf8c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_double.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc new file mode 100644 index 00000000000..ee793ae77b1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_float.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc new file mode 100644 index 00000000000..038d9d39e4a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_half.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc new file mode 100644 index 00000000000..b1952ca6193 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc new file mode 100644 index 00000000000..0cbaffb40b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc new file mode 100644 index 00000000000..6f4bf2a392b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_int8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc new file mode 100644 index 00000000000..10235447e0d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint32.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc new file mode 100644 index 00000000000..8abeaa6ffca --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint64.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc new file mode 100644 index 00000000000..6dabd1c176e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor_uint8.cc @@ -0,0 +1,22 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" + +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" + +namespace xla { +template class HloEvaluatorTypedVisitor; +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 6fb91b9bef9..be989846ef5 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -88,7 +88,7 @@ std::unique_ptr CreateHloProfilePrinterData( // down how much time each HLO took. class HloExecutionProfile { public: - using DeviceDescription = perftools::gputools::DeviceDescription; + using DeviceDescription = se::DeviceDescription; HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index a0cb28246d3..4900c813fdf 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -16,52 +16,32 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -class HloExecutionProfileTest : public HloTestBase { - protected: - static constexpr int64 kInstructionCyclesIndex = 0; - static constexpr int64 kInstructionNameIndex = 19; -}; +using tensorflow::strings::StrCat; +using ::testing::AllOf; +using ::testing::ContainsRegex; -// Splits `lines` into a sequence of lines delimited by newlines and then split -// each of those lines into a sequence of words delimited by spaces. Filter out -// empty words. -std::vector> SplitIntoLinesAndWords( - tensorflow::StringPiece lines) { - std::vector> result; - for (const string& line : tensorflow::str_util::Split(lines, '\n')) { - std::vector words; - for (const string& word : tensorflow::str_util::Split(line, ' ')) { - if (!word.empty()) { - words.push_back(word); - } - } - result.push_back(std::move(words)); - } - - return result; -} +class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - std::unique_ptr hlo_module = CreateNewModule(); - - HloComputation::Builder builder(TestName()); + auto hlo_module = tools::Parse(R"( + HloModule test_module + ENTRY entry_computation { + lhs = f32[30,30]{1,0} parameter(0) + rhs = f32[30,30]{1,0} parameter(1) + add = f32[30,30]{1,0} add(lhs, rhs) + ROOT dot = f32[30,30]{1,0} dot(lhs, add), lhs_contracting_dims={1}, rhs_contracting_dims={0} + })") + .ValueOrDie(); + const HloInstruction* dot_instruction = + hlo_module->entry_computation()->root_instruction(); + const HloInstruction* add_instruction = dot_instruction->operand(1); Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); - HloInstruction* param_lhs = - builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); - HloInstruction* param_rhs = - builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); - HloInstruction* add_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloInstruction* dot_instruction = - builder.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, param_lhs, add_instruction)); - - hlo_module->AddEntryComputation(builder.Build()); auto shape_size_function = [&](const Shape& shape) { const int64 pointer_size = 8; @@ -84,20 +64,12 @@ TEST_F(HloExecutionProfileTest, Basic) { execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); - string rendered_profile = execution_profile.ToString( - backend().default_stream_executor()->GetDeviceDescription()); - std::vector> lines_and_words = - SplitIntoLinesAndWords(rendered_profile); - ASSERT_EQ(lines_and_words.size(), 8); - - const std::vector& line_2 = lines_and_words[2]; - const std::vector& line_3 = lines_and_words[3]; - - EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); - EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); - - EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); - EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); + EXPECT_THAT(execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()), + AllOf(ContainsRegex(StrCat(dot_cycles, R"(\b.*%)", + dot_instruction->name())), + ContainsRegex(StrCat(add_cycles, R"(\b.*%)", + add_instruction->name())))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index c35783c456c..17e3c405f1e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -322,11 +322,13 @@ class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, const DebugOptions& debug_options, bool show_metadata, - const HloExecutionProfile* profile, NodeFilter filter) + bool show_backend_config, const HloExecutionProfile* profile, + NodeFilter filter) : computation_(computation), - label_(label.ToString()), + label_(std::string(label)), debug_options_(debug_options), show_metadata_(show_metadata), + show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -365,6 +367,7 @@ class HloDotDumper { string GetInstructionNodeShape(const HloInstruction* instr); string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); + string GetInstructionNodeBackendConfig(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); @@ -393,6 +396,7 @@ class HloDotDumper { const string label_; // overall name for the graph const DebugOptions& debug_options_; const bool show_metadata_; + const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -611,6 +615,10 @@ tooltip = " "; if (!extra_info.empty()) { StrAppend(&subcomp_label, "
", extra_info); } + string node_backend_config = GetInstructionNodeBackendConfig(parent_instr); + if (!node_backend_config.empty()) { + StrAppend(&subcomp_label, "
", node_backend_config); + } bool highlight = filter_.Highlight(parent_instr); const char* fillcolor; @@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_shape = GetInstructionNodeShape(instr); string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); + string node_backend_config = GetInstructionNodeBackendConfig(instr); string extra_info = GetInstructionNodeExtraInfo(instr); string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); @@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : - {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_metadata, + node_backend_config, extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } @@ -804,7 +813,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. - if (ShapeUtil::HasZeroElements(shape)) { + if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) { return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); } @@ -816,7 +825,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8) { + if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -909,12 +918,14 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kClz: case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -922,6 +933,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -956,7 +968,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: - case HloOpcode::kBroadcastDimOne: // De-emphasize nodes which broadcast a scalar within a fusion node -- // these are essentially free. if (instr->IsFused() && @@ -1078,13 +1089,23 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { return Join(lines, "
"); } +string HloDotDumper::GetInstructionNodeBackendConfig( + const HloInstruction* instr) { + if (!show_backend_config_ || instr->backend_config().empty()) { + return ""; + } + + return StrCat("backend_config=\"", instr->backend_config(), "\""); +} + string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { std::vector lines; // Get the instruction's extra attributes excluding the names of its // subcomputations, since those are drawn explicitly in the graph. for (const auto& line : instr->ExtraAttributesToString( - HloPrintOptions().set_print_subcomputation_references(false))) { + HloPrintOptions().set_print_subcomputation_mode( + HloPrintOptions::PrintSubcomputationMode::kOff))) { lines.push_back(HtmlLikeStringSanitize(line)); } @@ -1404,7 +1425,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1414,9 +1435,10 @@ string DumpGraph(const HloComputation& computation, const string& label, &graph)); graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = HloDotDumper(&computation, label, debug_options, show_metadata, - hlo_execution_profile, NodeFilter()) - .Dump(); + graph = + HloDotDumper(&computation, label, debug_options, show_metadata, + show_backend_config, hlo_execution_profile, NodeFilter()) + .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1427,15 +1449,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata) { + bool show_metadata, bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = - HloDotDumper(node.parent(), label, debug_options, show_metadata, - /*profile=*/nullptr, filter) - .Dump(); + string graph = HloDotDumper(node.parent(), label, debug_options, + show_metadata, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 2704aae1e3b..fc8e1468aca 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false); + bool show_metadata = false, bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false); + bool show_metadata = false, + bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path // suffixed with the provided label. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1f00aa41dc7..8e52d926d85 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -64,8 +64,8 @@ TEST(HloGraphDumperTest, NestedFusion) { sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, sums[i], params[i + 2]))); } - - HloModule m(TestName()); + HloModuleConfig config; + HloModule m(TestName(), config); m.AddEntryComputation(b.Build()); HloComputation* root_computation = m.entry_computation(); @@ -122,7 +122,8 @@ TEST(HloGraphDumperTest, Constant) { auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); instruction->set_name("i_am_a_constant_root_instruction"); - HloModule m(TestName()); + HloModuleConfig config; + HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); string graph = hlo_graph_dumper::DumpGraph( *root_computation, /*label=*/"an_empty_graph", DebugOptions()); @@ -130,5 +131,23 @@ TEST(HloGraphDumperTest, Constant) { EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); } +TEST(HloGraphDumperTest, TupleConstant) { + Shape tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(S32, {4, 5})}); + HloComputation::Builder b("b"); + auto constant = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(tuple_shape))); + auto gte = b.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {3, 2}), constant, 0)); + + HloModuleConfig config; + HloModule m(TestName(), config); + HloComputation* root_computation = m.AddEntryComputation(b.Build(gte)); + string graph = hlo_graph_dumper::DumpGraph( + *root_computation, /*label=*/"tuple_constant", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("tuple_constant")); + EXPECT_THAT(graph, HasSubstr("constant (f32[3,2], s32[4,5])")); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 56cb241087c..db1c33e2f0d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -51,7 +51,7 @@ using ::tensorflow::strings::StrCat; /* static */ StatusOr> HloInstruction::CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); @@ -109,6 +109,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->name_ = proto.name(); instruction->metadata_ = proto.metadata(); + instruction->set_backend_config(proto.backend_config()); if (proto.has_literal()) { TF_ASSIGN_OR_RETURN(instruction->literal_, Literal::CreateFromProto(proto.literal())); @@ -254,11 +255,14 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCeil: case HloOpcode::kCopy: case HloOpcode::kCos: + case HloOpcode::kClz: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -436,7 +440,7 @@ HloInstruction::CreateCrossReplicaSum( << "Outfeed shape " << shape << " must be compatible with operand shape " << operand->shape(); instruction->AppendOperand(operand); - instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_config_ = std::string(outfeed_config); instruction->outfeed_shape_ = shape; return instruction; } @@ -700,15 +704,6 @@ HloInstruction::CreateSelectAndScatter( return instruction; } -/* static */ std::unique_ptr -HloInstruction::CreateBroadcastDimOne(const Shape& shape, - HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcastDimOne, shape)); - instruction->AppendOperand(operand); - return instruction; -} - /* static */ std::unique_ptr HloInstruction::CreateBroadcastSequence( const Shape& output_shape, HloInstruction* operand, @@ -800,23 +795,11 @@ HloInstruction::CreateBroadcastSequence( return instruction; } -// We put the fusion kind into the instruction's name for transpose-dot fusions, -// since those fusions are really just describing a type of dot rather than -// generating a novel computation. -static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { - switch (fusion_kind) { - case HloInstruction::FusionKind::kTransposeDot: - return "dot_fusion"; - default: - return "fusion"; - } -} - /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); @@ -832,7 +815,7 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { instruction->AppendOperand(operand); } instruction->fusion_kind_ = fusion_kind; - instruction->name_ = FusionNodeName(fusion_kind); + instruction->name_ = "fusion"; instruction->called_computations_.push_back(fusion_computation); fusion_computation->SetFusionInstruction(instruction.get()); return instruction; @@ -1131,7 +1114,7 @@ RandomDistribution HloInstruction::random_distribution() const { return distribution_; } -bool HloInstruction::HasSideEffect() const { +bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: case HloOpcode::kSendDone: @@ -1143,16 +1126,22 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; - default: { - // Check if any of the called computations has a side effect. - for (const auto& computation : called_computations()) { - if (computation->HasSideEffect()) { - return true; - } - } + default: return false; + } +} + +bool HloInstruction::HasSideEffect() const { + if (HasSideEffectNoRecurse()) { + return true; + } + // Check if any of the called computations has a side effect. + for (const auto& computation : called_computations()) { + if (computation->HasSideEffect()) { + return true; } } + return false; } /* static */ std::unique_ptr HloInstruction::CreateCall( @@ -1175,7 +1164,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->custom_call_target_ = custom_call_target.ToString(); + instruction->custom_call_target_ = std::string(custom_call_target); return instruction; } @@ -1187,7 +1176,7 @@ bool HloInstruction::HasSideEffect() const { for (auto operand : operands) { instruction->AppendOperand(operand); } - instruction->channel_name_ = channel_name.ToString(); + instruction->channel_name_ = std::string(channel_name); instruction->cost_estimate_ns_ = cost_estimate_ns; return instruction; } @@ -1239,12 +1228,15 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module) const { + HloModule* module, CloneMap* clone_map) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } + if (module == nullptr) { + module = GetModule(); + } std::unique_ptr clone; @@ -1257,13 +1249,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -1311,10 +1306,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBroadcast(shape, new_operands[0], dimensions_); break; - case HloOpcode::kBroadcastDimOne: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcastDimOne(shape, new_operands[0]); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; @@ -1353,7 +1344,8 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); + break; case HloOpcode::kCrossReplicaSum: clone = CreateCrossReplicaSum(shape, new_operands); break; @@ -1426,9 +1418,15 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConstant: clone = CreateConstant(literal_->CloneToUnique()); break; - case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands, module); + case HloOpcode::kFusion: { + CHECK_NE(module, nullptr); + auto new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", module, clone_map)); + clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), + /*operands=*/new_operands, + /*fusion_computation=*/new_fused_computation); break; + } case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, name_); break; @@ -1492,15 +1490,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); + clone->set_backend_config(backend_config()); + if (clone_map != nullptr) { + InsertOrDie(clone_map, this, clone.get()); + } return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix, - HloModule* module) const { +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloModule* module, CloneMap* clone_map) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module); + CloneWithNewOperands(shape_, operands_, module, clone_map); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1537,71 +1539,6 @@ std::unique_ptr HloInstruction::Clone(const string& suffix, return clone; } -std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(parent() != nullptr); - - auto new_instruction = - WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - // Add the operands to our new fusion instruction. - for (HloInstruction* new_operand : operands) { - new_instruction->AppendOperand(new_operand); - } - // Clone all the fused instructions for the new fusion instruction. - HloInstructionMap old_to_new; - std::list> new_fused_instructions; - // Create the list of fused parameters by mapping through the cloned, - // fused instructions. - for (HloInstruction* old_fused_parameter : - fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back( - old_fused_parameter->Clone("clone", module)); - HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); - InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); - } - for (auto old_fused_instruction : - fused_instructions_computation()->MakeInstructionPostOrder()) { - if (old_fused_instruction->opcode() == HloOpcode::kParameter) { - FindOrDie(old_to_new, old_fused_instruction); - continue; - } - std::vector new_operands; - for (int64 operand_idx = 0; - operand_idx < old_fused_instruction->operand_count(); ++operand_idx) { - HloInstruction* old_operand = - old_fused_instruction->mutable_operand(operand_idx); - new_operands.push_back(FindOrDie(old_to_new, old_operand)); - } - new_fused_instructions.push_back( - old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands, module)); - HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent_); - InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); - } - new_instruction->fusion_kind_ = fusion_kind_; - auto computation_builder = HloComputation::Builder( - fused_instructions_computation()->name() + ".clone", - new_instruction.get()); - // We iterated the fusion instructions in reverse post order which means - // that we must reverse our new list of fusion instructions. - for (auto new_fused_instruction_iter = new_fused_instructions.rbegin(); - new_fused_instruction_iter != new_fused_instructions.rend(); - ++new_fused_instruction_iter) { - computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); - } - if (module == nullptr) { - module = GetModule(); - } - auto fused_root_ = fused_expression_root(); - new_instruction->called_computations_.push_back( - CHECK_NOTNULL(module)->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - return new_instruction; -} - std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -1630,6 +1567,8 @@ const Literal& HloInstruction::literal() const { return *literal_; } +bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } + bool HloInstruction::CanHaveDimensionsField() const { return (opcode() == HloOpcode::kReverse || opcode() == HloOpcode::kConcatenate || @@ -1689,14 +1628,35 @@ Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { } Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) { - auto succ_it = std::find(control_successors_.begin(), - control_successors_.end(), instruction); - TF_RET_CHECK(succ_it != control_successors_.end()); - control_successors_.erase(succ_it); - auto pred_it = std::find(instruction->control_predecessors_.begin(), - instruction->control_predecessors_.end(), this); - TF_RET_CHECK(pred_it != instruction->control_predecessors_.end()); - instruction->control_predecessors_.erase(pred_it); + TF_RET_CHECK(instruction->parent() == parent()); + TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction)); + TF_RETURN_IF_ERROR( + EraseElementFromVector(&instruction->control_predecessors_, this)); + return Status::OK(); +} + +Status HloInstruction::DropAllControlDeps() { + for (auto* ctrl_succ : control_successors_) { + TF_RETURN_IF_ERROR( + EraseElementFromVector(&ctrl_succ->control_predecessors_, this)); + } + for (auto* ctrl_pred : control_predecessors_) { + TF_RETURN_IF_ERROR( + EraseElementFromVector(&ctrl_pred->control_successors_, this)); + } + control_successors_.clear(); + control_predecessors_.clear(); + return Status::OK(); +} + +Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) { + for (auto* ctrl_pred : inst->control_predecessors()) { + TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this)); + } + + for (auto* ctrl_succ : inst->control_successors()) { + TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ)); + } return Status::OK(); } @@ -1729,25 +1689,30 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const { + eq_computations) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: case HloOpcode::kAtan2: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kClz: case HloOpcode::kComplex: + case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: @@ -1755,6 +1720,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kAnd: case HloOpcode::kNot: case HloOpcode::kOr: @@ -1767,6 +1733,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kReshape: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: @@ -1778,6 +1746,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; + // Broadcast, Concatenate, and Transpose need the same dimensions field. + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kTranspose: + return dimensions() == other.dimensions(); + case HloOpcode::kFusion: return fusion_kind() == other.fusion_kind() && eq_computations(fused_instructions_computation(), @@ -1790,10 +1764,7 @@ bool HloInstruction::IdenticalSlowPath( return false; case HloOpcode::kParameter: - return parameter_number() == other.parameter_number() && - // Check the shape too because `this` and `other` may be in - // different HloComputations. - eq_shapes(shape(), other.shape()); + return parameter_number() == other.parameter_number(); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1805,12 +1776,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConstant: return literal() == other.literal(); - // A convert result is determined by the primitive type that the operand is - // converted into. - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - return shape().element_type() == other.shape().element_type(); - // A reduce-precision operation is determined by the bit sizes. case HloOpcode::kReducePrecision: return exponent_bits() == other.exponent_bits() && @@ -1853,24 +1818,8 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - case HloOpcode::kReshape: - return eq_shapes(shape(), other.shape()); - - // Transpose result is determined by the final shape and the permutation. - case HloOpcode::kTranspose: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); // Remaining instructions with special values. - case HloOpcode::kBitcast: - case HloOpcode::kBroadcastDimOne: - case HloOpcode::kDynamicUpdateSlice: - return eq_shapes(shape(), other.shape()); - case HloOpcode::kBroadcast: - return eq_shapes(shape(), other.shape()) && - dimensions() == other.dimensions(); - case HloOpcode::kConcatenate: - return dimensions() == other.dimensions(); case HloOpcode::kGetTupleElement: return tuple_index() == other.tuple_index(); case HloOpcode::kPad: @@ -1880,9 +1829,6 @@ bool HloInstruction::IdenticalSlowPath( return slice_starts_ == other.slice_starts_ && slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; - case HloOpcode::kDynamicSlice: - return eq_shapes(shape(), other.shape()) && - dynamic_slice_sizes_ == other.dynamic_slice_sizes_; case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); @@ -2149,28 +2095,68 @@ string PrintName(const string& name, const HloPrintOptions& options) { } // namespace string HloInstruction::ToString(const HloPrintOptions& options) const { - string result = - StrCat(PrintName(name(), options), " = ", - ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + CanonicalNameMap new_map; + return ToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string result = ""; + + // Logic to print the instruction name (e.g. "%foo = "). + if (options.canonicalize_instruction_names()) { + if (options.is_in_nested_computation()) { + // If we are canonicalizing instruction names and this is a top-level + // HloInstruction::ToString() call, don't print an instruction name. + StrAppend(&result, + PrintName(canonical_name_map->LookupOrInsert(name()), options), + " = "); + } + } else { + StrAppend(&result, PrintName(name(), options), " = "); + } + + // Print opcode, operand(s) and shape. + StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", + OperandsToStringWithCanonicalNameMap(options, canonical_name_map), + ")"); + + // Print additional attributes. If an instruction contains a subcomputation, + // the subcomputation is also printed here. for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } + if (options.print_backend_config() && !backend_config().empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + } return result; } string HloInstruction::OperandsToString(const HloPrintOptions& options) const { + CanonicalNameMap new_map; + return OperandsToStringWithCanonicalNameMap(options, &new_map); +} + +string HloInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants()) { + // + // In HloInstruction, sometimes a constant literal is not constructed due + // to its size. Skip the printing in this case. + if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) && + ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants())) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -2204,7 +2190,14 @@ string HloInstruction::OperandsToString(const HloPrintOptions& options) const { if (options.print_operand_shape()) { str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - if (!options.compact_operands()) { + + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, Join(str, " ")); @@ -2273,7 +2266,8 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); } - if (options.print_subcomputation_references()) { + if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { extra.push_back( StrCat("condition=", PrintName(while_condition()->name(), options))); @@ -2301,8 +2295,45 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(computation->name(), options)); }))); } + } else if (options.print_subcomputation_mode() == + HloPrintOptions::PrintSubcomputationMode::kFullBodies) { + HloPrintOptions new_options = options; + new_options.set_is_in_nested_computation(true); + switch (opcode()) { + case HloOpcode::kWhile: + extra.push_back( + StrCat("condition=\n", while_condition()->ToString(new_options))); + extra.push_back(StrCat("body=\n", while_body()->ToString(new_options))); + break; + case HloOpcode::kSelectAndScatter: + extra.push_back(StrCat("select=\n", select()->ToString(new_options))); + extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options))); + break; + case HloOpcode::kConditional: + extra.push_back(StrCat("true_computation=\n", + true_computation()->ToString(new_options))); + extra.push_back(StrCat("false_computation=\n", + false_computation()->ToString(new_options))); + break; + case HloOpcode::kCall: + case HloOpcode::kMap: + case HloOpcode::kReduceWindow: + case HloOpcode::kReduce: + extra.push_back( + StrCat("to_apply=\n", to_apply()->ToString(new_options))); + break; + default: + if (!called_computations().empty()) { + extra.push_back( + StrCat("calls=\n", + Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); + } + break; + } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { extra.push_back(StrCat("channel_id=", channel_id_)); @@ -2340,12 +2371,13 @@ std::vector HloInstruction::ExtraAttributesToString( } // By contract, we print the custom call target even if - // !options.print_subcomputation_references(), because the call target is not + // options.print_subcomputation_mode() == kOff, because the call target is not // an HloComputation. if (opcode() == HloOpcode::kCustomCall) { extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); } + return extra; } @@ -2375,6 +2407,7 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; + proto.set_backend_config(backend_config()); if (literal_ != nullptr) { *proto.mutable_literal() = literal_->ToProto(); } @@ -2440,12 +2473,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_fft_length(fft_len); } - if (gather_dimension_numbers_ != nullptr) { - *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; - } - for (int64 bound : gather_window_bounds_) { - proto.add_gather_window_bounds(bound); + if (has_sharding()) { + *proto.mutable_sharding() = sharding().ToProto(); } + proto.set_channel_name(channel_name_); proto.set_cost_estimate_ns(cost_estimate_ns_); @@ -2482,8 +2513,6 @@ string HloInstruction::ToCategory() const { return "input fusion"; case FusionKind::kOutput: return "output fusion"; - case FusionKind::kTransposeDot: - return "dot"; case FusionKind::kCustom: return "custom fusion"; } @@ -2668,12 +2697,18 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleNegate(this); case HloOpcode::kExp: return visitor->HandleExp(this); + case HloOpcode::kExpm1: + return visitor->HandleExpm1(this); case HloOpcode::kFloor: return visitor->HandleFloor(this); case HloOpcode::kCeil: return visitor->HandleCeil(this); + case HloOpcode::kClz: + return visitor->HandleClz(this); case HloOpcode::kLog: return visitor->HandleLog(this); + case HloOpcode::kLog1p: + return visitor->HandleLog1p(this); case HloOpcode::kTanh: return visitor->HandleTanh(this); case HloOpcode::kCos: @@ -2692,8 +2727,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: return visitor->HandleBroadcast(this); - case HloOpcode::kBroadcastDimOne: - return visitor->HandleBroadcastDimOne(this); case HloOpcode::kPad: return visitor->HandlePad(this); case HloOpcode::kReshape: @@ -2966,6 +2999,7 @@ Status HloInstruction::AcceptOrdered( continue; } + // TODO(b/78350259): Eliminate const laundering. HloInstruction* instruction = const_cast(const_instruction); @@ -3015,15 +3049,18 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kConvert: case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFloor: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -3088,7 +3125,7 @@ bool HloInstruction::IsElementwise() const { bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { CHECK(IsElementwise()); - return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape()); + return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } namespace { @@ -3264,8 +3301,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kInput"; case HloInstruction::FusionKind::kOutput: return "kOutput"; - case HloInstruction::FusionKind::kTransposeDot: - return "kTransposeDot"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -3282,9 +3317,6 @@ StatusOr StringToFusionKind( if (kind_name == "kOutput") { return HloInstruction::FusionKind::kOutput; } - if (kind_name == "kTransposeDot") { - return HloInstruction::FusionKind::kTransposeDot; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 49aa0750299..234dbc8399d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -60,51 +60,75 @@ class HloModule; // A bunch of switches that control how the hlo text should be printed. class HloPrintOptions { public: + enum class PrintSubcomputationMode { + kOff, // Do not print anything about subcomputations. + kNameOnly, // Only print the name of subcomputations. + kFullBodies, // Print the full bodies of subcomputations. + }; + // Constructs the default print options: don't print large constants, don't // compact operands, no indentation. HloPrintOptions() : print_large_constants_(false), - print_subcomputation_references_(true), + print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly), print_metadata_(true), + print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), print_program_shape_(true), print_percent_(true), - indent_amount_(0) {} + canonicalize_instruction_names_(false), + indent_amount_(0), + is_in_nested_computation_(false) {} static HloPrintOptions ShortParsable() { return HloPrintOptions() .set_print_large_constants(true) - .set_print_subcomputation_references(true) + .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly) .set_print_metadata(false) + .set_print_backend_config(false) .set_print_operand_shape(false) .set_print_program_shape(false) .set_print_percent(false); } + // Options to produce the canonical string representing an isomorphic + // computation graph. + static HloPrintOptions Canonical() { + return HloPrintOptions() + .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) + .set_print_metadata(false) + .set_compact_operands(true) + .set_print_operand_shape(true) + .set_print_program_shape(false) + .set_print_percent(false) + .set_canonicalize_instruction_names(true); + } + // If true, large constants will be printed out. HloPrintOptions& set_print_large_constants(bool value) { print_large_constants_ = value; return *this; } - // If true, the names of subcomputations (e.g. a fusion node's fused - // computation) won't be printed. This makes the resulting text not parsable. - // - // A CustomCall's call target is printed even if - // print_subcomputation_references is false, because the call target isn't an - // HloComputation. - HloPrintOptions& set_print_subcomputation_references(bool value) { - print_subcomputation_references_ = value; + HloPrintOptions& set_print_subcomputation_mode( + PrintSubcomputationMode value) { + print_subcomputation_mode_ = value; return *this; } - // If true, metatdata will be printed. + // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; return *this; } + // If true, backend_config will be printed. + HloPrintOptions& set_print_backend_config(bool value) { + print_backend_config_ = value; + return *this; + } + // If true, operands' shapes will be printed. HloPrintOptions& set_print_operand_shape(bool value) { print_operand_shape_ = value; @@ -130,54 +154,175 @@ class HloPrintOptions { return *this; } + // If true, canonicalizes instructions' name. Instead of using "%foo.1" as + // the name of an instruction, we use "%tmp_1", "%tmp_2" etc. + HloPrintOptions& set_canonicalize_instruction_names(bool value) { + canonicalize_instruction_names_ = value; + return *this; + } + // The indent of the hlo text block. HloPrintOptions& set_indent_amount(int value) { indent_amount_ = value; return *this; } + // If true, indicates the instruction being printed is inside a nested + // computation. + HloPrintOptions& set_is_in_nested_computation(bool value) { + is_in_nested_computation_ = value; + return *this; + } + bool print_large_constants() const { return print_large_constants_; } - bool print_subcomputation_references() const { - return print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode() const { + return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } + bool print_backend_config() const { return print_metadata_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } + bool canonicalize_instruction_names() const { + return canonicalize_instruction_names_; + } int indent_amount() const { return indent_amount_; } + int is_in_nested_computation() const { return is_in_nested_computation_; } private: bool print_large_constants_; - bool print_subcomputation_references_; + PrintSubcomputationMode print_subcomputation_mode_; bool print_metadata_; + bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; bool print_program_shape_; bool print_percent_; + bool canonicalize_instruction_names_; int indent_amount_; + bool is_in_nested_computation_; }; -// HLO instructions are the IR used by the high-level compiler. +// For canonical string output, we need to have a canonical way to rename +// each instruction and its operands. Each operand is renamed as "tmp_", +// where is an index starting from 0. +class CanonicalNameMap { + public: + CanonicalNameMap() : index(0) {} + + string LookupOrInsert(const string& old_name) { + auto iter = canonical_name_map.find(old_name); + if (iter != canonical_name_map.end()) { + return iter->second; + } + + string new_name = tensorflow::strings::StrCat("tmp_", index++); + canonical_name_map[old_name] = new_name; + return new_name; + } + void Clear() { + canonical_name_map.clear(); + index = 0; + } + + private: + int64 index; + tensorflow::gtl::FlatMap canonical_name_map; +}; + +// HLO instructions are the atomic unit of the high-level compiler's IR. +// +// HloInstructions live inside of an HloComputation, which is analogous to a +// function in other programming languages. Nodes have no total order within +// their computation. Instead, they have a partial ordering determined by their +// data and control dependencies. +// +// HLO does not have basic blocks or explicit "branch" instructions. Instead, +// certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode +// control flow. For example, the kConditional HLO executes one of two possible +// computations, depending on the runtime value of a predicate. +// +// HLO is pure (mostly). It has no concept of mutable state. Instead, data +// values are produced by one HLO and flow into consumers across dependency +// edges. class HloInstruction { public: + // A fusion node computes the same value a call to its fusion computation + // would compute. However, the choice of fusion kind dictates codegen + // strategy for the backend. + // + // To generate code for a kFusion HloInstruction, most backends do something + // like the following: + // + // 1) Identify the "primary" HloInstruction of the fused computation. + // 2) Emit code that does the work of the primary node, creating its inputs + // and transforming its outputs as specified by the fused computation. + // + // In step (2), the code emitted is usually similar to the code that would be + // emitted for an *unfused* version of the primary node, except that + // + // - when the primary node reads an element of one of its operands, instead + // of loading the value from memory, it *computes* the value based on the + // contents of the fused computation. + // - when the primary node outputs a value, instead of storing it to memory, + // it forwards the value to its users, which then perform additional + // computations before the value is finally stored to memory at the root of + // the fusion node. + // + // An HloInstruction's FusionKind helps us find the kFusion instruction's + // primary node, and can also affect how we generate code in step (2). + // + // - kInput: The primary node is the root of the fused instruction. + // + // - kOutput: The primary node is not the root of the fused instruction. + // This fusion kind requires that one operand buffer of the fusion + // instruction be able to alias the output buffer. This constraint is + // usually enough to let backends find the primary node unambiguously. + // + // - kLoop: The primary node is the root of the fused computation, but, + // unlike in input fusion, we prescribe a specific implementation for + // codegen. Rather than generating code that looks like the code we'd emit + // for an unfused version of the primary/root node, we emit code that + // generates one element of the root at a time. + // + // - kCustom: Custom category for backend-specific fusions that don't fit + // into the above patterns. + // + // Not all backends support all fusion kinds, and given a particular fused + // computation, it's not in general safe to change its fusion kind. Creation + // of fusion nodes is always backend-specific. + // + // For elementwise ops (e.g. kAdd), most backends would emit a + // one-element-at-a-time implementation for the unfused version, so loop + // fusion and input fusion are probably equivalent if the root node is + // elementwise. They're not necessarily equivalent e.g. for kReduce, where an + // implementation might emit something more sophisticated for an unfused or + // input-fusion reduce, but will emit the naive code that reduces one element + // at a time for loop fusion with a reduce as the root. + // + // Another way to think of loop fusion is that it's equivalent to input + // fusion, but where the root node is an implicit identity node, whose + // unfused implementation is "read one element, write one element". + // + // TODO(b/79869434): This categorization scheme is not great. For one thing, + // input and loop fusion are basically the same thing: There is no reason for + // the HLO to encode backend-specific decisions about how e.g. a reduce that's + // the root of a fusion should be lowered. In addition, this scheme as + // written doesn't work for multi-output fusion, where the primary node is + // never actually the root (which is a kTuple instruction that gathers the + // multiple outputs of the fusion). enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, + kInput, + kOutput, + kCustom, }; ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // - // module: the module which will contain the instruction. The newly created - // instruction is *not* added to the module or any computation, however. // proto: the proto to convert from. // instruction_map: a map from instruction id to HloInstruction*. This map // must contain all operands of the newly constructed instruction. @@ -185,7 +330,7 @@ class HloInstruction { // must contain all computations which the newly constructed instruction // calls. static StatusOr> CreateFromProto( - HloModule* module, const HloInstructionProto& proto, + const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, const tensorflow::gtl::FlatMap& computation_map); @@ -401,10 +546,6 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions); - // Creates a broadcast-size-one-dimensions instruction. - static std::unique_ptr CreateBroadcastDimOne( - const Shape& shape, HloInstruction* operand); - // Creates a sequence of instructions that performs an explicit broadcast of // the operand to the target shape. // @@ -507,6 +648,10 @@ class HloInstruction { // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } + // Returns true if this instruction has a side effect, irrespective of whether + // any called computations may contain an instruction with side effects. + bool HasSideEffectNoRecurse() const; + // Returns true if this instruction has a side effect. An instruction has a // side effect if it uses certain opcodes or calls a computation with a side // effect. @@ -561,6 +706,18 @@ class HloInstruction { // 'instruction'. Status RemoveControlDependencyTo(HloInstruction* instruction); + // Drops all control predecessors and successors from this HLO instruction. + Status DropAllControlDeps(); + + // Copies the control predecessors and successors on this HLO instruction to + // `inst`. Does not do a deep copy so this makes sense only if `inst` and + // this HLO are in the same module. + // + // Depending on the use cases we see in practice, in the future we may + // consider folding the logic here into Clone, CloneWithNewOperands and + // ReplaceAllUsesWith by treating control dependencies like data dependencies. + Status CopyAllControlDepsFrom(const HloInstruction* inst); + // Returns the set of control predecessors (successors) of this // instruction. Control predecessors (successors) must execute before (after) // the current instruction. @@ -589,10 +746,8 @@ class HloInstruction { if (opcode() != other.opcode()) { return false; } - using EqShapeFuncType = bool (*)(const Shape&, const Shape&); - EqShapeFuncType eq_shapes = - layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; - if (!eq_shapes(shape(), other.shape())) { + if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape()) + : ShapeUtil::Compatible(shape(), other.shape()))) { return false; } if (operands().size() != other.operands().size()) { @@ -607,7 +762,7 @@ class HloInstruction { } } - return IdenticalSlowPath(other, eq_computations, eq_shapes); + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. @@ -635,6 +790,8 @@ class HloInstruction { // Detaches an instruction from its operands. That is, remove the instruction // from each operand's user set. This should only be called prior to // deallocating the instruction. + // + // TODO(b/78305363): Make this automatic when deleting an instruction. void DetachFromOperands(); // Performs a postorder DFS visit using this node as the root. If @@ -687,6 +844,9 @@ class HloInstruction { // Note: only constant and parameter opcodes have an associated literal. const Literal& literal() const; + // Returns whether there is literal associated with this instruction. + bool HasLiteral() const; + // Returns the parameter number associated with this instruction. // // Note: only parameter opcodes have an associated parameter number. @@ -948,6 +1108,14 @@ class HloInstruction { void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Checks whether the instruction has compatible sharding with the other + // instruction. + bool has_compatible_sharding(const HloInstruction* other) const { + if (!has_sharding()) { + return !other->has_sharding(); + } + return other->has_sharding() ? sharding() == other->sharding() : false; + } // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain @@ -1149,23 +1317,30 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kRng RandomDistribution random_distribution() const; + // See documentation for Clone(). + using CloneMap = std::unordered_map; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). + // the instruction to form the name of the cloned instruction. Ignores the + // control predecessors and successors of this HLO instruction. + // + // If the module pointer is not nullptr, then any cloned computations will be + // added to this module in order to support deep cloning. Otherwise the module + // of the instruction is used. + // + // If clone_map is not nullptr, then each original instruction that is cloned + // will be inserted and map to its clone. clone_map should not already contain + // any of the instructions to clone. std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr) const; + HloModule* module = nullptr, + CloneMap* clone_map = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. - // If the module pointer is not nullptr, it will be the module where - // the cloned computations will be added to (in order to support deep - // cloning). std::unique_ptr CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; + HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1237,7 +1412,7 @@ class HloInstruction { // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } + void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1254,6 +1429,19 @@ class HloInstruction { // if no id has been assigned yet). int unique_id() const { return unique_id_; } + // Returns the backend-specific configuration for how a backend should compile + // this HLO. The meaning of the field is backend specific. Not for use before + // or during general HLO optimization, since HLO optimizations do not preserve + // this field and they cannot interpret it due to its meaning being backend + // specific. + // + // TODO(b/78194644): Introduce structured configuration format as per + // go/xla-heuristics. + const string& backend_config() const { return backend_config_; } + void set_backend_config(string backend_config) { + backend_config_ = std::move(backend_config); + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1275,6 +1463,7 @@ class HloInstruction { // Get/Set the number of partitions per outer dimension (in order, starting // with outer-most dimension first). Currently used by the parallel cpu // backend to partition HLOs into parallel tasks. + // // TODO(b/62783254) Replace these methods with a more general way to // annotate HLOs with backend-specific information. const std::vector& outer_dimension_partitions() const { @@ -1290,20 +1479,34 @@ class HloInstruction { const ShapeIndex& shape_index = {}); private: + // Prints an instruction to a string. + // + // The canonical string representation needs to name operands and instruction + // names in a consistent way. This is implemented through the + // canonical_name_map. + string ToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Prints an operand to a string. + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const; + + // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and + // OperandsToStringWithCanonicalNameMap() functions. + friend class HloComputation; + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; // Helper class for computing OperandElementUse for kFusion. class FusionReusesParamElements; // See comments on Identical(). - // eq_shapes() is used to check shapes for equality, and would normally be - // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on - // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, const std::function& - eq_computations, - const std::function& eq_shapes) const; + eq_computations) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( @@ -1502,6 +1705,10 @@ class HloInstruction { // The string representation of the infeed configuration. string infeed_config_; + // The backend-specific configuration for how a backend should compile this + // HLO. See the documentation on backend_config(). + string backend_config_; + // String identifier for instruction. string name_; @@ -1532,13 +1739,20 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // an HloInstruction* or a const HloInstruction*. // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of -// the hlo. +// the hlo. Exception: null pointer values compare less than non-null. // // Note that this cannot be used for HLO instructions across multiple modules // since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } return lhs->unique_id() < rhs->unique_id(); } }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index f2980d309d0..a61c472c728 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -149,8 +149,8 @@ TEST_F(HloInstructionTest, UserWithTwoOperands) { builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar)); EXPECT_THAT(foo->users(), UnorderedElementsAre(add)); @@ -186,8 +186,8 @@ TEST_F(HloInstructionTest, MultipleUsers) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -219,8 +219,8 @@ TEST_F(HloInstructionTest, RepeatedUser) { builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(1, foo->user_count()); @@ -254,8 +254,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1)); auto addtotal = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(addtotal->Accept(&visitor)); @@ -303,8 +303,8 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright)); auto neg2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(neg2->Accept(&visitor)); @@ -325,7 +325,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Shape r0f32 = ShapeUtil::MakeShape(F32, {}); Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10}); - HloModule module(TestName()); + auto module = CreateNewModule(); // Builds an x+1.0 computation to use in a Map. auto embedded_builder = HloComputation::Builder("f32+1"); @@ -335,7 +335,7 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); @@ -343,7 +343,7 @@ TEST_F(HloInstructionTest, TrivialMap) { HloInstruction::CreateParameter(0, f32a100x10, "")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(map->Accept(&visitor)); @@ -373,8 +373,8 @@ TEST_F(HloInstructionTest, TrivialReduce) { HloInstruction::CreateParameter(1, r0f32, "y")); embedded_builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy)); - HloModule module(TestName()); - auto add_f32 = module.AddEmbeddedComputation(embedded_builder.Build()); + auto module = CreateNewModule(); + auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build()); // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); @@ -387,7 +387,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { auto reduce = builder.AddInstruction( HloInstruction::CreateReduce(f32v100, param0, const0, /*dimensions_to_reduce=*/{1}, add_f32)); - module.AddEntryComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); OpAndUserCollectingVisitor visitor; ASSERT_IS_OK(reduce->Accept(&visitor)); @@ -414,8 +414,8 @@ TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -449,8 +449,8 @@ TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) { builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo})); auto add_foobar = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar)); @@ -477,8 +477,8 @@ TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto log = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log)); @@ -514,8 +514,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) { HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, add_foobar, add_foofoo)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(2, foo->user_count()); EXPECT_EQ(1, bar->user_count()); @@ -544,8 +544,8 @@ TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) { auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(3, foo->user_count()); EXPECT_EQ(2, bar->user_count()); @@ -609,8 +609,8 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); NodeCollectorAndPostProcessor visitor; ASSERT_IS_OK(add->Accept(&visitor)); @@ -627,8 +627,8 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); auto exp = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); @@ -645,8 +645,8 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { HloInstruction::CreateConstant(Literal::CreateR0(42.1f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {add}, HloInstruction::FusionKind::kLoop); @@ -667,8 +667,8 @@ TEST_F(HloInstructionTest, ChainFusionOp) { auto exp3 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -690,8 +690,8 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { exp1->set_metadata(metadata); exp2->set_metadata(metadata); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp2, exp1}, HloInstruction::FusionKind::kLoop); @@ -746,13 +746,13 @@ TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - HloModule module(TestName()); + auto module = CreateNewModule(); auto make_map_computation = [&]() { auto builder = HloComputation::Builder("FusionMap"); builder.AddInstruction( HloInstruction::CreateParameter(0, scalar_shape, "param")); - return module.AddEmbeddedComputation(builder.Build()); + return module->AddEmbeddedComputation(builder.Build()); }; HloComputation* computation_x = make_map_computation(); @@ -767,7 +767,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { scalar_shape, {map_1_x}, computation_x, /*static_operands=*/{})); auto map_3_y = builder.AddInstruction(HloInstruction::CreateMap( scalar_shape, {map_2_x}, computation_y, /*static_operands=*/{})); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {map_3_y}, HloInstruction::FusionKind::kLoop); @@ -814,8 +814,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1})); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); @@ -940,8 +940,8 @@ TEST_F(HloInstructionTest, FunctionVisitor) { HloInstruction::CreateUnary(f32, HloOpcode::kExp, param)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); int visit_num = 0; std::unordered_map visit_order; @@ -969,8 +969,8 @@ TEST_F(HloInstructionTest, FullyElementwise) { builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y)); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_TRUE(add->IsElementwise()); for (int i = 0; i < add->operand_count(); ++i) { @@ -1013,8 +1013,8 @@ TEST_F(HloInstructionTest, PartiallyElementwise) { HloInstruction* max = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1056,8 +1056,8 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kSubtract, min, broadcast)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {sub, broadcast, min}, HloInstruction::FusionKind::kLoop); EXPECT_FALSE(fusion->IsElementwise()); @@ -1099,10 +1099,10 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); + {dot, reshape}, HloInstruction::FusionKind::kLoop); auto fusion2 = fusion->Clone(); const HloInstruction* root = fusion->fused_expression_root(); @@ -1118,7 +1118,7 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { } TEST_F(HloInstructionTest, FusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create two fusion instructions containing a single unary operation. @@ -1128,7 +1128,7 @@ TEST_F(HloInstructionTest, FusionEquality) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); auto neg = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); - auto* computation = module.AddEntryComputation(builder.Build()); + auto* computation = module->AddEntryComputation(builder.Build()); auto* fusion = computation->CreateFusionInstruction( {exp}, HloInstruction::FusionKind::kLoop); auto* fusion2 = computation->CreateFusionInstruction( @@ -1140,7 +1140,7 @@ TEST_F(HloInstructionTest, FusionEquality) { } TEST_F(HloInstructionTest, NestedFusionEquality) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Build a nested fusion computation. @@ -1166,10 +1166,10 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { data_shape, HloOpcode::kSubtract, dot, add_operand)); builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); auto nested_fusion = computation->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + {dot, b_t}, HloInstruction::FusionKind::kLoop); auto fusion = computation->CreateFusionInstruction( {add, nested_fusion}, HloInstruction::FusionKind::kOutput); @@ -1244,15 +1244,8 @@ TEST_F(HloInstructionTest, Stringification) { "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(builder.Build()); - HloInstruction* fusion = computation->CreateFusionInstruction( - {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - - EXPECT_EQ( - fusion->ToString(options), - "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); @@ -1295,8 +1288,8 @@ TEST_F(HloInstructionTest, StringifyGather_0) { /*index_vector_dim=*/4), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " @@ -1331,8 +1324,8 @@ TEST_F(HloInstructionTest, StringifyGather_1) { /*index_vector_dim=*/2), /*window_bounds=*/{30, 29, 28, 27, 26})); - HloModule module(TestName()); - module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); EXPECT_EQ(gather_instruction->ToString(), "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " @@ -1343,5 +1336,163 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().Canonical(); + + EXPECT_EQ(dot->ToString(options), + "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), " + "lhs_contracting_dims={1}, rhs_contracting_dims={0}"); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_EQ( + fusion->ToString(options), + R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + HloInstruction* loop = builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ(loop->ToString(options), + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + +TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { + // Tests stringification of a simple op, fusion, while, and conditional. + const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); + const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); + const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); + const Shape sout = ShapeUtil::MakeShape(F32, {5, 20}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({dot, reshape}, + HloInstruction::FusionKind::kLoop); + + builder.AddInstruction( + HloInstruction::CreateWhile(sout, computation, computation, x)); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + auto options = HloPrintOptions().Canonical(); + EXPECT_EQ( + conditional->ToString(options), + R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, false_computation= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc new file mode 100644 index 00000000000..43c41ece6ef --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -0,0 +1,306 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +using Worklist = std::deque; +using Workset = std::unordered_set; + +namespace { + +void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, + Workset* workset) { + if (workset->count(instruction) == 0) { + worklist->push_back(instruction); + workset->insert(instruction); + VLOG(3) << "ADD instruction: " << instruction->name(); + } +} + +using VisitorFunction = std::function; + +void ForEachLiveIndex(const ShapeTree& index_tree, + const VisitorFunction& func) { + index_tree.ForEachElement([&](const ShapeIndex& shape_index, bool live) { + if (live) { + func(shape_index); + } + }); +} + +// Marks 'instruction' output live at 'shape_index'. +// Adds to 'worklist' iff: +// *) 'instruction' is not already on worklist. +// *) 'shape_index' has not yet been visited. +void MarkLiveAtIndex(const HloInstruction* instruction, + const ShapeIndex& shape_index, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + auto it_added = live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + if (it->second.element(shape_index) == false) { + AddToWorklist(instruction, worklist, workset); + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } +} + +// Marks 'instruction' live at all shape indices in its output. +void MarkLiveAtAllIndices(const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, + Worklist* worklist, Workset* workset) { + bool add_to_worklist = false; + auto it = live_index_map->find(instruction); + if (it == live_index_map->end()) { + live_index_map->emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/true)); + add_to_worklist = true; + } else { + ShapeUtil::ForEachSubshape( + instruction->shape(), + [&](const Shape& sub_shape, const ShapeIndex& shape_index) { + if (it->second.element(shape_index) == false) { + add_to_worklist = true; + *it->second.mutable_element(shape_index) = true; + VLOG(3) << "MARK instruction: " << instruction->name() + << " shape_index: " << shape_index.ToString(); + } + }); + } + if (add_to_worklist) { + AddToWorklist(instruction, worklist, workset); + } +} + +// Propagates liveness through Tuple instructions. +// *) For each tuple operand: +// *) For tuple output shape index associated with operand: +// *) Propgate live shape indices to tuple operand at the associated +// shape index in the operands output, and add to worklist. +void PropagateLivenessThroughTuple( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kTuple); + for (int64 operand_index = 0; operand_index < instruction->operand_count(); + ++operand_index) { + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + if (shape_index.empty() || shape_index[0] != operand_index) { + return; + } + // Mark top-level index of operand at 'operand_index'. + MarkLiveAtIndex(instruction->operand(operand_index), {}, live_index_map, + worklist, workset); + // Mark sub-shape index of operand at 'operand_index'. + ShapeIndex operand_shape_index; + for (int i = 1; i < shape_index.size(); ++i) { + operand_shape_index.push_back(shape_index[i]); + } + MarkLiveAtIndex(instruction->operand(operand_index), operand_shape_index, + live_index_map, worklist, workset); + }); + } +} + +// Propagates liveness through GetTupleElement instructions. +// *) For each live index in GetTupleElement output, mark output of GTE operand +// at associated shape index in its output, and add to worklist. +void PropagateLivenessThroughGTE( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // Mark operand top-level index. + MarkLiveAtIndex(instruction->operand(0), {}, live_index_map, worklist, + workset); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + // Propagate live shape indices along GTE -> Tuple edge. + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + ShapeIndex operand_shape_index(shape_index); + operand_shape_index.push_front(instruction->tuple_index()); + MarkLiveAtIndex(instruction->operand(0), operand_shape_index, + live_index_map, worklist, workset); + }); +} + +// Propagates liveness through While instructions. +// *) For each live index in While output, mark shape index of while.body.root +// and while.operand (adding each to worklist). +// *) Mark while.cond.root and add to worklist. +void PropagateLivenessThroughWhile( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset) { + CHECK_EQ(instruction->opcode(), HloOpcode::kWhile); + const ShapeTree& index_tree = FindOrDie(*live_index_map, instruction); + + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while body computation root instruction. + MarkLiveAtIndex(instruction->while_body()->root_instruction(), shape_index, + live_index_map, worklist, workset); + // Propagate liveness to tuple-shaped operand. + MarkLiveAtIndex(instruction->operand(0), shape_index, live_index_map, + worklist, workset); + }); + + // Propagate liveness to while condition computation root instruction. + MarkLiveAtIndex(instruction->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); +} + +// Propagates liveness out of Parameter instructions to callers and aliasing +// positions. This can occur if liveness propagates to a parameter in the +// while.condition computation, requiring liveness to propagate out to caller +// callsite while (and while.body.root). +void PropagateLivenessToParameterCallers( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + CHECK_EQ(instruction->opcode(), HloOpcode::kParameter); + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + auto* xla_while = callsite.instruction(); + const ShapeTree& index_tree = + FindOrDie(*live_index_map, instruction); + ForEachLiveIndex(index_tree, [&](const ShapeIndex& shape_index) { + // Propagate liveness to while result{shape_index} + MarkLiveAtIndex(xla_while, shape_index, live_index_map, worklist, + workset); + // Propagate liveness to while body root{shape_index}. + MarkLiveAtIndex(xla_while->while_body()->root_instruction(), + shape_index, live_index_map, worklist, workset); + // Propagate liveness to operand(0){shape_index}. + MarkLiveAtIndex(xla_while->operand(0), shape_index, live_index_map, + worklist, workset); + }); + } + } + } +} + +} // namespace + +HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) + : module_(module), call_graph_(CallGraph::Build(&module)) {} + +// Runs liveness analysis on 'module_'. +// Initializes worklist with entry root instruction (and any instruction with +// side-effects), marking all of their output shape indices live. +// Visits elements on worklist, propagating liveness from an instructions +// live output shape indices to its called computations and operands. +void HloLivenessAnalysis::RunAnalysis() { + Worklist worklist; + Workset workset; + // Add entry compuation root instruction. + MarkLiveAtAllIndices(module_.entry_computation()->root_instruction(), + &live_index_map_, &worklist, &workset); + for (auto* computation : module_.computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->HasSideEffectNoRecurse()) { + // Add instructions with side effects. + MarkLiveAtAllIndices(instruction, &live_index_map_, &worklist, + &workset); + } + } + } + + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop_front(); + workset.erase(workset.find(instruction)); + VLOG(1) << "VISIT instruction: " << instruction->name(); + + if (instruction->opcode() == HloOpcode::kTuple) { + PropagateLivenessThroughTuple(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { + PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kWhile && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, + &workset); + } else if (instruction->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instruction->shape())) { + PropagateLivenessToParameterCallers(instruction, &live_index_map_, + &worklist, &workset, + call_graph_.get()); + } else { + // Propagate liveness to called computations. + for (auto* called_computation : instruction->called_computations()) { + MarkLiveAtAllIndices(called_computation->root_instruction(), + &live_index_map_, &worklist, &workset); + } + // Propagate liveness to operands. + for (HloInstruction* operand : instruction->operands()) { + MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); + } + } + } +} + +bool HloLivenessAnalysis::IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const { + if (ContainsKey(live_index_map_, instruction)) { + return FindOrDie(live_index_map_, instruction).element(shape_index); + } + return false; +} + +/* static */ +StatusOr> HloLivenessAnalysis::Run( + const HloModule& module) { + VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); + + auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + + liveness_analysis->RunAnalysis(); + + return std::move(liveness_analysis); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.h b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h new file mode 100644 index 00000000000..fe55a8070a4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_value.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = + std::unordered_map>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc new file mode 100644 index 00000000000..8e2e2c7627b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -0,0 +1,402 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class HloLivenessAnalysisTest : public HloTestBase { + protected: + HloLivenessAnalysisTest() {} + + // Run liveness analysis on the member module. For convenience returns a + // reference to the generated analysis stored in analysis_. + const HloLivenessAnalysis& RunLiveness(HloModule* module) { + liveness_ = HloLivenessAnalysis::Run(*module).ConsumeValueOrDie(); + return *liveness_; + } + + HloInstruction* GetInstruction(HloModule* module, const string& name) { + HloInstruction* to_return = nullptr; + for (auto* comp : module->computations()) { + for (auto* inst : comp->instructions()) { + if (inst->name() == name) { + to_return = inst; + break; + } + } + } + return CHECK_NOTNULL(to_return); + } + + std::unique_ptr liveness_; +}; + +// Test that add instruction at entry root is live at all output shape indices. +TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT add = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Test that a dead add instruction is marked as dead by analysis. +TEST_F(HloLivenessAnalysisTest, DeadAdd) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + add.1 = s32[] add(constant.1, constant.2) + ROOT add.2 = s32[] add(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "add.1"), {})); +} + +// Test that all output shape indices of entry root tuple (and defining +// instruction in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + ROOT tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that all outputs of nested tuple and entry root (and defining +// instruction values appearing in its output) are marked live. +TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(1) + constant.2 = s32[] constant(2) + constant.3 = s32[] constant(3) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + ROOT tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE at entry root of Tuple instruction only propgates liveness +// to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfTuple) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + tuple.1 = (s32[], s32[]) tuple(constant.1, constant.2) + ROOT get-tuple-element.1 = s32[] get-tuple-element(tuple.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); +} + +// Tests that GTE at entry root of nested Tuple instruction only propgates +// liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + ROOT get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Tests that GTE of GTE (at entry root) of nested Tuple instruction only +// propgates liveness to the live elements in tuple. +TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { + auto module = tools::Parse(R"( + HloModule SimpleModule + ENTRY SimpleComputation { + constant.1 = s32[] constant(0) + constant.2 = s32[] constant(1) + constant.3 = s32[] constant(2) + tuple.1 = (s32[], s32[]) tuple(constant.2, constant.3) + tuple.2 = (s32[], s32[]) tuple(constant.1, tuple.1) + get-tuple-element.1 = (s32[], s32[]) get-tuple-element(tuple.2), index=1 + ROOT get-tuple-element.2 = s32[] get-tuple-element(get-tuple-element.1), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.2"), {})); + + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.1"), {})); + EXPECT_TRUE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {0})); + EXPECT_FALSE(liveness.IsLive( + GetInstruction(module.get(), "get-tuple-element.1"), {1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 0})); + EXPECT_FALSE( + liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1, 1})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.2"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +// Test that live/dead while tuple elements are marked live/dead correctly. +TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.4"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_FALSE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a tuple element live in while.cond computation, propagates +// liveness to while.body.root/while.result/while.operand (where it is unused). +TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add.0 = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply.0 = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.0 = (s32[], s32[3]{0}) tuple(add.0, multiply.0) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=1 + add.1 = s32[] add(get-tuple-element.3, get-tuple-element.4) + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(add.1, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while.0 = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.0), index=0 + })") + .ValueOrDie(); + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.0"), {1})); + + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.4"), {})); + + // While body. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.0"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.0"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "multiply.0"), {})); +} + +// Tests that a use of while.result{0} propagates liveness to +// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. +TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + add.1 = s32[] add(get-tuple-element.1, get-tuple-element.2) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.1), index=2 + multiply.1 = s32[] multiply(get-tuple-element.3, get-tuple-element.3) + ROOT tuple.1 = (s32[], s32[], s32[]) tuple(add.1, get-tuple-element.3, multiply.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[], s32[]) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.2), index=0 + constant.1 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.4, constant.1) + } + ENTRY SimpleLoop { + constant.2 = s32[] constant(0) + constant.3 = s32[] constant(1) + constant.4 = s32[] constant(2) + tuple.2 = (s32[], s32[], s32[]) tuple(constant.2, constant.3, constant.4) + while.1 = (s32[], s32[], s32[]) while(tuple.2), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.5 = s32[] get-tuple-element(while.1), index=0 + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE( + liveness.IsLive(GetInstruction(module.get(), "get-tuple-element.5"), {})); + + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "while.1"), {2})); + // While operand. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.2"), {2})); + // While body root. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "tuple.1"), {2})); + // While body param. + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {0})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {1})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index bc74c4bc10c..7e4b8834357 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -17,10 +17,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { +using ::tensorflow::str_util::Join; + bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -132,6 +135,104 @@ bool HloCustomCallMatcher::MatchAndExplain( return result; } +bool HloShapeMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (ShapeUtil::Compatible(instruction->shape(), shape_)) { + return true; + } + *listener << instruction->ToString() << " has incorrect shape (expected: " + << ShapeUtil::HumanString(shape_) << ")"; + return false; +} + +void HloShapeMatcher::DescribeTo(std::ostream* os) const { + *os << ShapeUtil::HumanString(shape_); +} + +bool HloShapeAndLayoutMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (ShapeUtil::Equal(instruction->shape(), shape_)) { + return true; + } + *listener << instruction->ToString() << " has incorrect shape (expected: " + << ShapeUtil::HumanStringWithLayout(shape_) << ")"; + return false; +} + +void HloShapeAndLayoutMatcher::DescribeTo(std::ostream* os) const { + *os << ShapeUtil::HumanStringWithLayout(shape_); +} + +bool HloShardingMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!sharding_.has_value()) { + if (!instruction->has_sharding()) { + return true; + } + *listener << instruction->ToString() << " expected to have no sharding."; + return false; + } + if (instruction->has_sharding()) { + if (instruction->sharding() == sharding_.value()) { + return true; + } + *listener << instruction->ToString() + << " has incorrect sharding (expected: " << sharding_->ToString() + << ")"; + return false; + } else { + *listener << instruction->ToString() + << " has no sharding (expected: " << sharding_->ToString() << ")"; + return false; + } +} + +void HloShardingMatcher::DescribeTo(std::ostream* os) const { + if (sharding_.has_value()) { + *os << sharding_->ToString(); + } else { + *os << ""; + } +} + +bool HloDotWithContractingDimsMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + + const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); + if (dim_nums.lhs_contracting_dimensions_size() != 1 || + dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong lhs_contracting_dimensions (got {" + << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" + << lhs_contracting_dim_ << "})"; + return false; + } + + if (dim_nums.rhs_contracting_dimensions_size() != 1 || + dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { + *listener << instruction->ToString() + << " has wrong rhs_contracting_dimensions (got {" + << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" + << rhs_contracting_dim_ << "})"; + return false; + } + + return true; +} + +void HloDotWithContractingDimsMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with lhs_contracting_dims={" << lhs_contracting_dim_ + << "} and rhs_contracting_dims={" << rhs_contracting_dim_ << "}"; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 103f04a2cb7..c33bdadf1c7 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace xla { namespace testing { @@ -86,6 +87,71 @@ class HloCustomCallMatcher : public HloMatcher { ::testing::Matcher call_target_matcher_; }; +class HloShapeMatcher + : public ::testing::MatcherInterface { + public: + explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + Shape shape_; +}; + +class HloShapeAndLayoutMatcher + : public ::testing::MatcherInterface { + public: + explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + Shape shape_; +}; + +// Verify the sharding of an instruction against the provided HloSharding. If a +// nullopt is provided for the expected sharding then it checks that no sharding +// is present for an instruction. +class HloShardingMatcher + : public ::testing::MatcherInterface { + public: + explicit HloShardingMatcher( + const tensorflow::gtl::optional& sharding) + : sharding_(sharding) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + tensorflow::gtl::optional sharding_; +}; + +// Matches a Dot HLO instruction with specific LHS and RHS contracting +// dimensions. +class HloDotWithContractingDimsMatcher : public HloMatcher { + public: + explicit HloDotWithContractingDimsMatcher( + ::testing::Matcher lhs, + ::testing::Matcher rhs, int64 lhs_contracting_dim, + int64 rhs_contracting_dim) + : HloMatcher(HloOpcode::kDot, /*operands=*/{lhs, rhs}), + lhs_contracting_dim_(lhs_contracting_dim), + rhs_contracting_dim_(rhs_contracting_dim) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + int64 lhs_contracting_dim_; + int64 rhs_contracting_dim_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -113,7 +179,6 @@ HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); HLO_MATCHER(Divide); -HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); HLO_MATCHER(DynamicUpdateSlice); HLO_MATCHER(Eq); @@ -231,6 +296,64 @@ inline ::testing::Matcher CustomCall() { new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {})); } +// Verifies the shape or the shape and the layout of an HLO instruction against +// the provided shape object. +inline ::testing::Matcher Shape( + const class Shape& shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); +} +inline ::testing::Matcher Shape( + tensorflow::StringPiece shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( + ShapeUtil::ParseShapeString(shape).ValueOrDie())); +} +inline ::testing::Matcher ShapeWithLayout( + const class Shape& shape) { + return ::testing::MakeMatcher( + new ::xla::testing::HloShapeAndLayoutMatcher(shape)); +} +inline ::testing::Matcher ShapeWithLayout( + tensorflow::StringPiece shape) { + return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( + ShapeUtil::ParseShapeString(shape).ValueOrDie())); +} + +// Verifies the value of the HloSharing against the provided sharding object. +inline ::testing::Matcher Sharding( + const HloSharding& sharding) { + return ::testing::MakeMatcher( + new ::xla::testing::HloShardingMatcher(sharding)); +} +// Verifies that no HloSharding is set for an HLO instruction. +inline ::testing::Matcher NoSharding() { + return ::testing::MakeMatcher( + new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); +} + +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + ::xla::HloOpcode::kDot, {lhs_matcher, rhs_matcher})); +} + +// Matches a Dot HLO instruction if it has exactly one lhs contracting dimension +// equal to `lhs_contracting_dim` and exactly one rhs contracting dimension +// equal to `rhs_contracting_dim`. +// +// Currently the HLO verifier rejects Dot operations with more than one +// contracting dimension (even though we can represent these in the +// DotDimensionNumbers proto) so there is no need to generalize this to support +// multiple contracting dimensions. +inline ::testing::Matcher Dot( + ::testing::Matcher lhs_matcher, + ::testing::Matcher rhs_matcher, + int64 lhs_contracting_dim, int64 rhs_contracting_dim) { + return ::testing::MakeMatcher( + new ::xla::testing::HloDotWithContractingDimsMatcher( + lhs_matcher, rhs_matcher, lhs_contracting_dim, rhs_contracting_dim)); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 1c21703a45e..016cc01e338 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -100,5 +101,106 @@ TEST(HloMatchersTest, CustomCallMatcher) { R"(custom-call with call target that is equal to "foo_target")"); } +TEST(HloMatchersTest, ShapeMatcher) { + auto p0 = HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param"); + + EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {5, 7}))); + EXPECT_THAT(p0.get(), op::Shape("f32[5,7]")); + EXPECT_THAT( + p0.get(), + ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {5, 7})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]"))); + EXPECT_THAT(p0.get(), + ::testing::Not(op::Shape(ShapeUtil::MakeShape(F32, {7, 5})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::Shape("f32[7,5]"))); + EXPECT_THAT( + p0.get(), + ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {7, 5})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[7,5]"))); + EXPECT_THAT(p0.get(), + op::Shape(ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}))); + EXPECT_THAT(p0.get(), op::Shape("f32[5,7]{0,1}")); + EXPECT_THAT(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout( + F32, {5, 7}, {0, 1}))); + EXPECT_THAT(p0.get(), op::ShapeWithLayout("f32[5,7]{0,1}")); + EXPECT_THAT(p0.get(), + ::testing::Not(op::ShapeWithLayout( + ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {1, 0})))); + EXPECT_THAT(p0.get(), ::testing::Not(op::ShapeWithLayout("f32[5,7]{1,0}"))); + + EXPECT_THAT(Explain(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))), + "%param = f32[5,7]{0,1} parameter(0) has incorrect shape " + "(expected: f32[7,5])"); + EXPECT_THAT( + Explain(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout( + F32, {7, 5}, {1, 0}))), + "%param = f32[5,7]{0,1} parameter(0) has incorrect shape " + "(expected: f32[7,5]{1,0})"); +} + +TEST(HloMatchersTest, ShardingMatcher) { + auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}), + "param.0"); + p0->clear_sharding(); + auto p1 = HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {7}), + "param.1"); + p1->set_sharding(HloSharding::AssignDevice(1)); + + EXPECT_THAT(p0.get(), op::NoSharding()); + EXPECT_THAT(p0.get(), + ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); + EXPECT_THAT(p1.get(), ::testing::Not(op::NoSharding())); + EXPECT_THAT(p1.get(), + ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); + EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), + "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " + "{maximal device=1})"); + EXPECT_THAT(Explain(p1.get(), op::NoSharding()), + "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} " + "expected to have no sharding."); + EXPECT_THAT(Explain(p1.get(), op::Sharding(HloSharding::AssignDevice(0))), + "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} " + "has incorrect sharding (expected: {maximal device=0})"); +} + +TEST(HloMatchersTest, DotMatcher) { + string hlo_string = R"( +HloModule DotOperationFusion_TransposeFusion + +ENTRY DotOperationFusion_TransposeFusion { + arg0 = f32[1,256] parameter(0) + arg1 = f32[256,1024] parameter(1) + ROOT dot = f32[1,1024] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + HloInstruction* root = module->entry_computation()->root_instruction(); + + EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/0)); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/0, + /*rhs_contracting_dim=*/0)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "lhs_contracting_dimensions (got {1} want {0})"); + + EXPECT_THAT( + Explain(root, op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, + /*rhs_contracting_dim=*/1)), + "%dot = f32[1,1024]{1,0} dot(f32[1,256]{1,0} %arg0, f32[256,1024]{1,0} " + "%arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} has wrong " + "rhs_contracting_dimensions (got {0} want {1})"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 08b9a29aeda..fbf1d58007e 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -41,14 +41,23 @@ HloModule::HloModule(const string& name, entry_computation_handle_(entry_computation_handle), unique_id_(next_unique_module_id_++) {} -HloModule::HloModule(const string& name) - : name_(NameUniquer::GetSanitizedName(name)), - unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), unique_id_(next_unique_module_id_++) {} +StatusOr HloModule::LaunderConstInstructionFromModule( + const HloInstruction* hlo) { + if (hlo == nullptr) { + return nullptr; + } + + TF_RET_CHECK(hlo->GetModule() == this); + + // TODO(b/78350259): Eliminate const laundering. + return const_cast(hlo); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { @@ -58,7 +67,7 @@ HloComputation* HloModule::AddComputationInternal( // If the module configuration has no entry layout computation set, create a // default one based on the program shape. - if (!config_.has_entry_computation_layout()) { + if (!config_.has_host_entry_computation_layout()) { config_.SetDefaultComputationLayout( entry_computation_->ComputeProgramShape()); } @@ -232,11 +241,14 @@ StatusOr> HloModule::CreateFromProto( TF_RET_CHECK(proto.has_program_shape()) << "No program shape found in the proto"; const auto& expected_program_shape = proto.program_shape(); - TF_RET_CHECK(expected_program_shape.parameters_size() == - module_config.entry_computation_layout().parameter_count()); + TF_RET_CHECK( + expected_program_shape.parameters_size() == + module_config.device_entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { const Shape& parameter_shape = - module_config.entry_computation_layout().parameter_layout(i).shape(); + module_config.device_entry_computation_layout() + .parameter_layout(i) + .shape(); TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i), parameter_shape)) << "HloModuleConfig has different shape for parameter " << i @@ -246,7 +258,7 @@ StatusOr> HloModule::CreateFromProto( << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); } const Shape& result_shape = - module_config.entry_computation_layout().result_layout().shape(); + module_config.device_entry_computation_layout().result_layout().shape(); TF_RET_CHECK( ShapeUtil::Compatible(expected_program_shape.result(), result_shape)) << "HloModuleConfig has different result shape than the HLO module. " @@ -254,24 +266,44 @@ StatusOr> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); - tensorflow::gtl::FlatMap computation_map; + tensorflow::gtl::FlatMap to_proto_id; + std::vector> computations; + HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto(computation_proto, computation_map)); CHECK_NE(computation.get(), nullptr); int64 computation_id = computation_proto.id(); TF_RET_CHECK(computation_id != -1); TF_RET_CHECK(!ContainsKey(computation_map, computation_id)); + computation_map[computation_id] = computation.get(); + to_proto_id[computation.get()] = computation_id; + if (computation_id == proto.entry_computation_id()) { + entry = computation.get(); + } + computations.push_back(std::move(computation)); + } + TF_RET_CHECK(entry != nullptr); + + auto module = MakeUnique(proto.name(), entry_computation_handle, + module_config); + + // Sort the computations in the proto id's order. + std::sort(computations.begin(), computations.end(), + [&](const std::unique_ptr& a, + const std::unique_ptr& b) { + return to_proto_id[a.get()] < to_proto_id[b.get()]; + }); + + // Add sorted computations to the module. + for (auto& computation : computations) { + bool is_entry = computation.get() == entry; // Don't uniquify names because we want names to be stable across // serialization and deserialization. - computation_map[computation_id] = module->AddComputationInternal( - std::move(computation), - /*is_entry=*/proto.entry_computation_id() == computation_id, - /*uniquify_names=*/false); + module->AddComputationInternal(std::move(computation), is_entry, + /*uniquify_names=*/false); } TF_RET_CHECK(module->entry_computation_ != nullptr); @@ -306,7 +338,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( // The module config is constructed with default layouts regardless of what is // passed in via the ProgramShape. Set the layouts to the appropriate values. ComputationLayout* entry_layout = - module_config.mutable_entry_computation_layout(); + module_config.mutable_host_entry_computation_layout(); for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { TF_RETURN_IF_ERROR( entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -314,6 +346,8 @@ StatusOr HloModule::CreateModuleConfigFromProto( } TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( program_shape.result())); + *module_config.mutable_device_entry_computation_layout() = + module_config.host_entry_computation_layout(); return module_config; } @@ -479,8 +513,7 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique(name_ + "-" + suffix); - module->config_ = config_; + auto module = MakeUnique(name_ + "-" + suffix, config_); module->entry_computation_handle_ = entry_computation_handle_; module->has_entry_computation_handle_ = has_entry_computation_handle_; @@ -539,6 +572,14 @@ uint64 HloModule::RandomNew64() const { return rng_(); } +HloComputation* HloModule::GetComputationWithName( + tensorflow::StringPiece name) { + auto it = c_find_if(computations(), [&](HloComputation* computation) { + return computation->name() == name; + }); + return it == computations().end() ? nullptr : *it; +} + /* static */ std::atomic HloModule::next_unique_module_id_(0); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 9f7f25202ba..02918c37777 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -41,10 +42,18 @@ namespace xla { // Describes a compilation unit at the HLO level. // -// A HLO module contains one or more HLO computations. The module contains one -// "entry" computation which produces the result. The module also includes any -// embedded computations used by instructions such as "map" and "reduce". All -// computations are owned by the module. +// HloModule is the top-level unit in the HLO IR. It corresponds to a whole +// "program". Running a module, from beginning to end, is the only way to run +// an XLA program. +// +// A module contains one "entry computation"; this HloComputation is like main() +// in a C program. The result of running the module is the result of running +// this computation. +// +// A module also contains some number of "nested computations". Each nested +// computation is attached to an HloInstruction within some other computation. +// The meaning of the nested computation depends on the instruction it's +// attached to. class HloModule { public: HloModule(const string& name, @@ -55,7 +64,6 @@ class HloModule { // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. - explicit HloModule(const string& name); explicit HloModule(const string& name, const HloModuleConfig& config); // Adds an entry computation to the module. A module can only have one entry @@ -99,12 +107,20 @@ class HloModule { return entry_computation_; } - ComputationLayout* mutable_entry_computation_layout() { - return config_.mutable_entry_computation_layout(); + ComputationLayout* mutable_host_entry_computation_layout() { + return config_.mutable_host_entry_computation_layout(); } - const ComputationLayout& entry_computation_layout() const { - return config_.entry_computation_layout(); + const ComputationLayout& host_entry_computation_layout() const { + return config_.host_entry_computation_layout(); + } + + ComputationLayout* mutable_device_entry_computation_layout() { + return config_.mutable_device_entry_computation_layout(); + } + + const ComputationLayout& device_entry_computation_layout() const { + return config_.device_entry_computation_layout(); } const VersionedComputationHandle& entry_computation_handle() const { @@ -131,6 +147,10 @@ class HloModule { MakeUnwrappingIterator(computations_.end())}; } + // Returns the computation in this module that has the name `name`. Returns + // null if there is no such computation. + HloComputation* GetComputationWithName(tensorflow::StringPiece name); + // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } @@ -205,6 +225,25 @@ class HloModule { // the lifetime of this process. int unique_id() const { return unique_id_; } + // Returns a non-const version of the passed-in const HloInstruction*. This is + // safe on the argument that if you have a non-const module, then you can + // access all instructions in the module as non-const. + // + // Returns an error if the passed-in instruction is not from this module, + // except that it is allowed to pass in a null pointer. + // + // TODO(b/78350259): Eliminate const laundering. The argument above is not + // reliable since at any time someone could add or discover a way for a + // non-const module to transitively contain a const HloInstruction. The + // reliable way to do this would be to create a const laundering map from a + // module, mapping each encountered HloInstruction to its non-const version + // and then look up each instruction in need of laundering in that map, but + // this is much more expensive and complicated. This returns a Status instead + // of doing a CHECK-failure in part to make it strongly apparent that this is + // something that can fail. + StatusOr LaunderConstInstructionFromModule( + const HloInstruction* hlo); + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 4205b0402cb..dae5578a315 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -31,11 +31,13 @@ using tensorflow::strings::StrAppend; HloModuleConfig::HloModuleConfig() {} HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape) - : entry_computation_layout_(program_shape) {} + : host_entry_computation_layout_(program_shape), + device_entry_computation_layout_(program_shape) {} void HloModuleConfig::SetDefaultComputationLayout( const ProgramShape& program_shape) { - entry_computation_layout_ = ComputationLayout(program_shape); + host_entry_computation_layout_ = ComputationLayout(program_shape); + device_entry_computation_layout_ = ComputationLayout(program_shape); } string HloModuleConfig::compilation_cache_key() const { @@ -44,11 +46,18 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : - entry_computation_layout_->parameter_layouts()) { + host_entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", - entry_computation_layout_->result_shape().SerializeAsString()); + host_entry_computation_layout_->result_shape().SerializeAsString()); + for (const ShapeLayout& param_layout : + device_entry_computation_layout_->parameter_layouts()) { + params.push_back(param_layout.shape().DebugString()); + } + StrAppend( + &key, tensorflow::str_util::Join(params, ", "), ") => ", + device_entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. static std::atomic counter{0}; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 586a03d4126..cdb0b29a239 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -41,26 +41,44 @@ class HloModuleConfig { explicit HloModuleConfig(const ProgramShape& program_shape); // Checks if this config has an entry computation layout already. - bool has_entry_computation_layout() const { - return entry_computation_layout_.has_value(); + bool has_host_entry_computation_layout() const { + return host_entry_computation_layout_.has_value(); + } + + bool has_device_entry_computation_layout() const { + return device_entry_computation_layout_.has_value(); } // Sets the entry computation layout for this config. If the entry computation // layout already exists, it is silently replaced. void SetDefaultComputationLayout(const ProgramShape& program_shape); - // Returns a constant reference to the layout of the entry computation. - // Assumes the layout was set. - const ComputationLayout& entry_computation_layout() const { - CHECK(entry_computation_layout_.has_value()); - return *entry_computation_layout_; + // Returns a constant reference to the on-host layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& host_entry_computation_layout() const { + CHECK(host_entry_computation_layout_.has_value()); + return *host_entry_computation_layout_; } - // Returns a mutable pointer to the layout of the entry computation. Assumes - // the layout was set. - ComputationLayout* mutable_entry_computation_layout() { - CHECK(entry_computation_layout_.has_value()); - return &(*entry_computation_layout_); + // Returns a mutable pointer to the layout of the on-host entry computation. + // Assumes the layout was set. + ComputationLayout* mutable_host_entry_computation_layout() { + CHECK(host_entry_computation_layout_.has_value()); + return &(*host_entry_computation_layout_); + } + + // Returns a constant reference to the on-device layout of the entry + // computation. Assumes the layout was set. + const ComputationLayout& device_entry_computation_layout() const { + CHECK(device_entry_computation_layout_.has_value()); + return *device_entry_computation_layout_; + } + + // Returns a mutable pointer to the layout of the on-device entry computation. + // Assumes the layout was set. + ComputationLayout* mutable_device_entry_computation_layout() { + CHECK(device_entry_computation_layout_.has_value()); + return &(*device_entry_computation_layout_); } // Returns whether to enable HLO-level profiling. @@ -109,7 +127,8 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional entry_computation_layout_; + tensorflow::gtl::optional host_entry_computation_layout_; + tensorflow::gtl::optional device_entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc new file mode 100644 index 00000000000..98d20315e39 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_liveness_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +bool HasSendRecv(HloComputation* computation) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecv || + instruction->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (auto* sub_computation : instruction->called_computations()) { + if (HasSendRecv(sub_computation)) { + return true; + } + } + } + return false; +} + +StatusOr RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + + const auto* xla_while = instruction; + auto* while_body_comp = xla_while->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + + if (!ShapeUtil::IsTuple(xla_while->shape()) || + while_body_root->opcode() != HloOpcode::kTuple || + HasSendRecv(while_body_comp)) { + // Only run DCE on tuple-shaped while loops where body root is Tuple, + // with no send/recv instructions. + VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); + continue; + } + + // Remove dead tuple elements. + const int64 tuple_element_count = + ShapeUtil::TupleElementCount(xla_while->shape()); + for (int64 i = 0; i < tuple_element_count; ++i) { + if (liveness->IsLive(xla_while, {i})) { + continue; + } + VLOG(1) << "WhileDCE Dead while tuple element." + << " while: " << xla_while->name() << " tuple_index: " << i; + // Transform while.body computation to make tuple element at + // 'shape_index' as simple pass-through parameter (which candidate + // be removed later by simplification pass). + HloInstruction* pass_thru_gte = while_body_comp->AddInstruction( + HloInstruction::CreateGetTupleElement( + while_body_param->shape().tuple_shapes(i), while_body_param, + i)); + // Replace while.body.root Tuple operand at 'tuple_index' with + // 'pass_thru_gte', making prior operand a dead root (to be cleaned + // up with a subsequent DCE pass). + TF_RETURN_IF_ERROR( + while_body_root->ReplaceOperandWith(i, pass_thru_gte)); + changed = true; + } + } + } + return changed; +} + +} // namespace + +StatusOr HloModuleDCE::Run(HloModule* module) { + VLOG(2) << "Before HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + std::unique_ptr liveness; + TF_ASSIGN_OR_RETURN(liveness, HloLivenessAnalysis::Run(*module)); + + // Sweep through while instructions, transforming dead while tuple element + // computations to pass through tuple values (creating dead roots in while + // body computation in the process). + TF_ASSIGN_OR_RETURN(bool hlo_module_dce_changed, + RunWhileDCE(module, liveness.get())); + + // Run HloDCE to clean up any dead code created during HloModuleDCE. + HloDCE hlo_dce; + TF_ASSIGN_OR_RETURN(bool hlo_dce_changed, hlo_dce.Run(module)); + + VLOG(2) << "After HloModuleDCE:"; + XLA_VLOG_LINES(3, module->ToString()); + + return hlo_module_dce_changed | hlo_dce_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h new file mode 100644 index 00000000000..29024085c10 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass which removes dead code from computations in the module using +// HloModule-scoped analysis (HloLivenessAnalysis). +// +// Sweeps through live instructions which cross computation boundaries (kWhile), +// and removes code at dead shape indices. +// +class HloModuleDCE : public HloPassInterface { + public: + ~HloModuleDCE() override {} + tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_DCE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc new file mode 100644 index 00000000000..53b7d0ed396 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module_dce.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class HloModuleDceTest : public HloTestBase { + protected: + HloModuleDceTest() {} + + // Returns whether the given instruction exists in the given computation. + bool HasInstruction(const HloComputation& computation, + const HloInstruction* instruction) { + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); + } + + // Returns whether the while instruction with name 'while_name' in + // 'computation' passes through its tuple element at 'tuple_index' from + // parameter to root instruction. + bool WhileBodyHasPassThroughTupleElement(const HloComputation* computation, + const string& while_name, + const int64 tuple_index) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile && + instruction->name() == while_name) { + auto* while_body_comp = instruction->while_body(); + auto* while_body_param = while_body_comp->parameter_instruction(0); + auto* while_body_root = while_body_comp->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + return false; + } + auto* operand = while_body_root->operand(tuple_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->tuple_index() == tuple_index && + operand->operand(0) == while_body_param) { + return true; + } + return false; + } + } + return false; + } +}; + +// Tests that a while with all outputs live is unmodified. +TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests a while loop with one unused output (which is used in the while loop +// body by an instruction with side-effects: rng) is unmodified. +TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], f32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = f32[] get-tuple-element(loop_var.1), index=1 + constant.2 = f32[] constant(1.0) + rng = f32[] rng(constant.2, get-tuple-element.2), distribution=rng_uniform + add.1 = s32[] add(get-tuple-element.2, constant.2) + ROOT tuple = (s32[], f32[]) tuple(add, add.1) + } + SimpleLoop.condition { + loop_var.2 = (s32[], f32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.3 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.3) + } + ENTRY SimpleLoop { + constant.4 = s32[] constant(0) + constant.5 = f32[] constant(0.0) + tuple.1 = (s32[], f32[]) tuple(constant.4, constant.5) + while = (s32[], f32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a while loop with one dead tuple element at {1} has its while +// loop body modified to make that tuple element pass-through the while body. +TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[3]{0} constant({0, 1, 2}) + tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4) + while = (s32[], s32[3]{0}) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} should now be pass-through after ModuleDCE. + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that a tuple element {1} used by condition computation (which appears +// dead in while.body{1} and at while.result{1}) propgates liveness of this +// tuple element to while.body{1} and at while.result{1}. +TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body { + loop_var.1 = (s32[], s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[] get-tuple-element(loop_var.1), index=1 + multiply = s32[] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[]) tuple(add, multiply) + } + SimpleLoop.condition { + loop_var.2 = (s32[], s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + constant.4 = s32[] constant(0) + tuple.1 = (s32[], s32[]) tuple(constant.3, constant.4) + while = (s32[], s32[]) while(tuple.1), condition= + SimpleLoop.condition, body=SimpleLoop.body + ROOT get-tuple-element.4 = s32[] get-tuple-element(while), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // While tuple element {1} should not be pass-through before ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 0)); + // While tuple element {1} still be pass-through after ModuleDCE. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at index {1} between +// two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[], s32[3]{0}) tuple(constant.5, constant.6) + while.1 = (s32[], s32[3]{0}) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=0 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1 and while.2 should not have pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1 and while.2 should have pass-thru elements, + // after being modified to pass through unused tuple element {1}. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and +// while.2{1}, between two dependent while loops. +TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { + auto module = tools::Parse(R"( + HloModule SimpleLoop + SimpleLoop.body0 { + loop_var.1 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=1 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=0 + multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple = (s32[3]{0}, s32[]) tuple(multiply, add) + } + SimpleLoop.condition0 { + loop_var.2 = (s32[3]{0}, s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=1 + constant.2 = s32[] constant(5) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + SimpleLoop.body1 { + loop_var.3 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(loop_var.3), index=0 + constant.3 = s32[] constant(1) + add.1 = s32[] add(get-tuple-element.4, constant.3) + get-tuple-element.5 = s32[3]{0} get-tuple-element(loop_var.3), index=1 + multiply.1 = s32[3]{0} multiply(get-tuple-element.5, get-tuple-element.5) + ROOT tuple.1 = (s32[], s32[3]{0}) tuple(add.1, multiply.1) + } + SimpleLoop.condition1 { + loop_var.4 = (s32[], s32[3]{0}) parameter(0) + get-tuple-element.6 = s32[] get-tuple-element(loop_var.4), index=0 + constant.4 = s32[] constant(5) + ROOT less-than.1 = pred[] less-than(get-tuple-element.6, constant.4) + } + ENTRY SimpleLoop { + constant.5 = s32[] constant(0) + constant.6 = s32[3]{0} constant({0, 1, 2}) + tuple.2 = (s32[3]{0}, s32[]) tuple(constant.6, constant.5) + while.1 = (s32[3]{0}, s32[]) while(tuple.2), condition= + SimpleLoop.condition0, body=SimpleLoop.body0 + get-tuple-element.7 = s32[] get-tuple-element(while.1), index=1 + tuple.3 = (s32[], s32[3]{0}) tuple(get-tuple-element.7, constant.6) + while.2 = (s32[], s32[3]{0}) while(tuple.3), condition= + SimpleLoop.condition1, body=SimpleLoop.body1 + ROOT get-tuple-element.8 = s32[] get-tuple-element(while.2), index=0 + })") + .ValueOrDie(); + + HloModuleDCE dce; + // Before HloModuleDCE while.1{0} and while.2{1} should not be pass-thru. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + // After HloModuleDCE while.1{0} and while.2{1} not be pass-thru elements. + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 1)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.1", 0)); + EXPECT_FALSE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 0)); + EXPECT_TRUE(WhileBodyHasPassThroughTupleElement(module->entry_computation(), + "while.2", 1)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 54c34ce1166..b4cd3c730e3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" +#include #include #include @@ -47,13 +48,16 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { case ComputationKind::kConditionalFalse: repr += ":CONDITIONAL_FALSE"; break; + case ComputationKind::kCallFunction: + repr += ":CALL"; + break; } return repr; } /* static */ StatusOr> HloModuleGroupMetadata::Build(const std::vector& modules) { - auto metadata = absl::make_unique(modules); + auto metadata = MakeUnique(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -107,6 +111,31 @@ Status HloModuleGroupMetadata::Build() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + TF_RETURN_IF_ERROR(VerifyCompanionSets()); + return Status::OK(); +} + +Status HloModuleGroupMetadata::VerifyCompanionSets() const { + // TODO(dlibenzi): Migrate this to use the device instead of module ID, once + // the kDomain CL goes in. + for (const auto& companions : companion_sets_) { + // A companion set must be composed at most of an instruction per + // device/module. + std::unordered_set devices; + for (HloInstruction* instruction : *companions) { + int64 device = GetModuleId(instruction->parent()->parent()); + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << " (" + << GetModuleId(hlo->parent()->parent()) << ")" << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); + } + } + } return Status::OK(); } @@ -206,6 +235,9 @@ Status HloModuleGroupMetadata::RecordInstructions() { TrackedInstruction(hlo, ComputationKind::kConditionalTrue); tracked_instructions_[hlo->false_computation()] = TrackedInstruction(hlo, ComputationKind::kConditionalFalse); + } else if (hlo->opcode() == HloOpcode::kCall) { + tracked_instructions_[hlo->to_apply()] = + TrackedInstruction(hlo, ComputationKind::kCallFunction); } if (!IsChannelInstruction(hlo)) { return Status::OK(); @@ -258,14 +290,15 @@ Status HloModuleGroupMetadata::RecordInstructions() { Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2) { TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile || - instruction1->opcode() == HloOpcode::kConditional); + instruction1->opcode() == HloOpcode::kConditional || + instruction1->opcode() == HloOpcode::kCall); VLOG(2) << "adding as companions:" << instruction1->ToString() << " and " << instruction2->ToString(); if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique>()); + tensorflow::MakeUnique>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); @@ -336,21 +369,11 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { } } - // Check if channel instructions are used only in allowed computations. - const auto allowed = [this](HloInstruction* hlo) { - HloComputation* computation = hlo->parent(); - const HloModule* module = computation->parent(); - if (module->entry_computation() == computation || - tracked_instructions_.count(computation) > 0) { - return true; - } - return false; - }; for (const Channel& channel : channels_) { - if (!allowed(channel.send) || !allowed(channel.send_done) || - !allowed(channel.recv) || !allowed(channel.recv_done)) { - return FailedPrecondition("channel is used in disallowed computation"); - } + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv)); + TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done)); } // Check if the nest levels match for each channel. for (const Channel& channel : channels_) { @@ -368,4 +391,15 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { return Status::OK(); } +Status HloModuleGroupMetadata::CheckCommunicatingInstruction( + HloInstruction* instruction) const { + HloComputation* computation = instruction->parent(); + const HloModule* module = computation->parent(); + if (module->entry_computation() == computation || + tracked_instructions_.count(computation) > 0) { + return Status::OK(); + } + return FailedPrecondition("channel is used in disallowed computation"); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index c48a7ab0b59..3ef4542f912 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -60,6 +60,7 @@ class HloModuleGroupMetadata { kWhileBody, kConditionalTrue, kConditionalFalse, + kCallFunction, }; // Tracks the instruction mapped to a given computation, and the computation @@ -202,6 +203,15 @@ class HloModuleGroupMetadata { Status AddCompanion(HloInstruction* instruction1, HloInstruction* instruction2); + // Checks whether a communicating instruction is placed in a valid position + // within the graph. + Status CheckCommunicatingInstruction(HloInstruction* instruction) const; + + // Performs a consistency check on the companion sets built for the input + // modules. Check that a companion set does not include instructions from the + // same module/device. + Status VerifyCompanionSets() const; + // Retrieves a pointer to the stored TrackedInstruction associated with a // tracked computation, or nullptr in case such computation is not tracked. const TrackedInstruction* GetTrackedInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 289c96b0a7b..5a0d1e264eb 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -289,7 +290,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = absl::make_unique(post_order); + auto reachability = MakeUnique(post_order); for (HloInstruction* hlo : post_order) { reachability->SetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index dddc72480f9..ac7cd2f2f51 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -54,10 +54,10 @@ namespace xla { V(kBitcast, "bitcast") \ V(kBitcastConvert, "bitcast-convert") \ V(kBroadcast, "broadcast") \ - V(kBroadcastDimOne, "broadcast-dim-one") \ V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ V(kConditional, "conditional") \ @@ -74,6 +74,7 @@ namespace xla { V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kExpm1, "exponential-minus-one") \ V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ @@ -87,6 +88,7 @@ namespace xla { V(kIsFinite, "is-finite") \ V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ V(kLog, "log") \ + V(kLog1p, "log-plus-one") \ V(kAnd, "and") \ V(kNot, "not") \ V(kOr, "or") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index e89d94bede6..dcd4725fe78 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -170,10 +169,10 @@ bool HloOrdering::UseIsBeforeValueDefinition( // is before the def if the instruction allows buffer sharing (in place // computation). if (use.instruction == value.defining_instruction() && - CanShareOperandBufferWithUser( + dataflow.CanShareOperandBufferWithUser( use.instruction->mutable_operand(use.operand_number), use.operand_index, value.defining_instruction(), - value.defining_index(), dataflow)) { + value.defining_index())) { VLOG(4) << " use is value def, and instruction can share use buffer"; return true; } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 5120775737b..d8f1ab916b5 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -90,7 +90,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = name().ToString() + ": pipeline start"; + string prefix = std::string(name()) + ": pipeline start"; bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +98,12 @@ StatusOr HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), - "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + std::string(name()), "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(pass->name().ToString()) > 0) { + if (disabled_passes.count(std::string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -121,7 +121,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - name().ToString(), pass->name().ToString()); + std::string(name()), std::string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index b0632448933..39b85de0f12 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -273,9 +273,8 @@ ItemList GetUsers(const InstructionList& instruction_list, for (const BufferAlias& buffer_alias : points_to_analysis.GetBufferAliases(*logical_buffer)) { for (const HloInstruction* user : buffer_alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(buffer_alias.instruction(), - buffer_alias.index(), user, - points_to_analysis)) { + if (points_to_analysis.DoesNotUseOperandBuffer( + buffer_alias.instruction(), buffer_alias.index(), user)) { // The alias may be an operand of 'user', but the LogicalBuffer cannot // possibly be used by the instruction so ignore 'user'. This is the // case, for example, for the tuple element buffers in a GetTupleElement @@ -1216,7 +1215,7 @@ StatusOr HloRematerialization::Run( // Create initial sequence of HLO instructions. TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( *module, - [this](const LogicalBuffer& buffer) { + [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, scheduler_algorithm_)); diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 2e834a79d9f..2a601ec3d18 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -30,8 +29,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { /*static*/ StatusOr> @@ -109,33 +106,31 @@ StatusOr> HloRunner::Execute( const ExecutableRunOptions& run_options = service_run_options.run_options(); // Copy arguments to device. - std::vector> argument_buffers; - std::vector argument_buffer_ptrs; + std::vector argument_buffers; for (Literal* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr argument_buffer, + ScopedShapedBuffer argument_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( argument->shape(), run_options.allocator(), run_options.device_ordinal())); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, *argument_buffer)); + stream.parent(), *argument, argument_buffer)); argument_buffers.push_back(std::move(argument_buffer)); - argument_buffer_ptrs.push_back(argument_buffers.back().get()); + } + + std::vector argument_buffer_ptrs; + argument_buffer_ptrs.reserve(argument_buffers.size()); + for (const auto& buf : argument_buffers) { + argument_buffer_ptrs.push_back(&buf); } TF_ASSIGN_OR_RETURN( - std::unique_ptr result, + ScopedShapedBuffer result, executable->ExecuteOnStreamWrapper( &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); - // Create a ScopedShapedBuffer of the result to manage deallocation. This will - // deallocate all the device memory when it goes out of scope. - TF_ASSIGN_OR_RETURN( - std::unique_ptr scoped_result, - ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator())); - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), *scoped_result); + stream.parent(), result); if (result_literal.ok()) { VLOG(4) << "Executed binary and got result: " << result_literal.ValueOrDie()->ToString(); @@ -157,7 +152,13 @@ StatusOr>> HloRunner::ExecuteReplicated( backend().computation_placer()->AssignDevices(options.num_replicas, 1)); std::vector> streams; std::vector service_run_options; - std::vector> argument_buffers; + + std::vector argument_buffers; + // This reserve() call is necessary for correctness, because + // argument_buffer_ptrs contains pointers into the elements of + // argument_buffers. + argument_buffers.reserve(options.num_replicas * options.arguments.size()); + // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are // no arguments. std::vector argument_buffer_ptrs( @@ -169,7 +170,7 @@ StatusOr>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(absl::make_unique(executor)); + streams.push_back(MakeUnique(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -177,13 +178,13 @@ StatusOr>> HloRunner::ExecuteReplicated( // Copy arguments to device. for (const Literal* argument : options.arguments) { TF_ASSIGN_OR_RETURN( - std::unique_ptr argument_buffer, + ScopedShapedBuffer argument_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( argument->shape(), backend().memory_allocator(), device)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - executor, *argument, *argument_buffer)); + executor, *argument, argument_buffer)); argument_buffers.push_back(std::move(argument_buffer)); - argument_buffer_ptrs[index++] = argument_buffers.back().get(); + argument_buffer_ptrs[index++] = &argument_buffers.back(); } argument_buffer_slices.emplace_back( &argument_buffer_ptrs[index - options.arguments.size()], @@ -196,7 +197,7 @@ StatusOr>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = absl::make_unique( + pool = MakeUnique( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -227,7 +228,7 @@ StatusOr>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = absl::make_unique(); + auto literal = MakeUnique(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { @@ -242,19 +243,16 @@ StatusOr>> HloRunner::ExecuteReplicated( } LOG(INFO) << "Replicated execution started"; - TF_ASSIGN_OR_RETURN(std::vector> results, + TF_ASSIGN_OR_RETURN(std::vector results, executable->ExecuteOnStreams(service_run_options, argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; std::vector> exec_results; for (int64 i = 0; i < options.num_replicas; ++i) { - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - ScopedShapedBuffer::MakeScoped( - results[i].get(), backend().memory_allocator())); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, backend().transfer_manager()->TransferLiteralFromDevice( - streams[i]->parent(), *result)); + streams[i]->parent(), results[i])); exec_results.push_back(std::move(literal)); } return std::move(exec_results); @@ -279,14 +277,14 @@ ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice( run_options.set_device_ordinal(device); run_options.set_stream(stream); run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); run_options.set_intra_op_thread_pool( backend().eigen_intra_op_thread_pool_device()); if (device_assignment != nullptr) { run_options.set_device_assignment(device_assignment); } - return ServiceExecutableRunOptions(run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); + return ServiceExecutableRunOptions( + run_options, backend().StreamBorrower(), + /*xla_intra_op_thread_pool=*/backend().eigen_intra_op_thread_pool()); } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index f54fb44766e..53f7c6fe4a0 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -80,7 +80,7 @@ class HloRunner { bool run_hlo_passes = false; }; - explicit HloRunner(::perftools::gputools::Platform* platform); + explicit HloRunner(se::Platform* platform); ~HloRunner(); @@ -149,8 +149,7 @@ class HloRunner { // will be used to configure the replication parameters. Replicated executions // should pass the device_assignment parameter. ServiceExecutableRunOptions GetServiceRunOptionsForDevice( - int64 device, ::perftools::gputools::Stream* stream, - DeviceAssignment* device_assignment); + int64 device, se::Stream* stream, DeviceAssignment* device_assignment); std::unique_ptr backend_; }; diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 1a767628f6e..51c29d47a19 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -62,7 +62,34 @@ StatusOr MinimumMemoryForSequence( namespace { // Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. +// sequence which minimizes memory usage by preferring to schedule the node that +// frees bigger buffer and defines smaller outputs. +// +// Note that list scheduler is a greedy algorithm which cannot guarantee a +// global optimal solution. As a counterexample, considering the following +// graph: +// +// +--> B ===> C -------+ +// A -> | | +// | v +// +--> D ---> F=======>G +// | ^ +// | | +// +--> E -----+ +// +// --> : Buffer with size 1 +// ==> : Buffer with size 2 +// +// The list scheduler will always try to defer scheduling B in a greedy way +// since its output buffer is bigger than input. The sequence it creates will +// be: +// A D E F B C G +// , which has a maximum memory usage of 6 (B is alive while F is executing). +// +// An optimal way to shedule the previous graph is: +// A B C D E F G +// , which has a maximum memory usage of 5 (when F is executing). +// class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions @@ -70,8 +97,11 @@ class ListScheduler { static StatusOr> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + ListScheduler scheduler(computation, points_to_analysis, size_function, + memory_by_computation); return scheduler.CreateSchedule(); } @@ -92,10 +122,13 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function) { + size_function_(size_function), + memory_by_computation_(memory_by_computation) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -185,6 +218,12 @@ class ListScheduler { } // Returns the number of bytes freed if the HLO instruction is scheduled. + // If the instruction calls subcomputations, we count the memory used by the + // subcomputations as memory "defined" by the instruction. This is not + // entirely accurate, because subcomputation memory will be freed after the + // instruction finishes. But it is more accurate than not taking + // subcomputations into account at all. In the future, we may improve + // accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -194,7 +233,19 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - return freed_bytes - entry.bytes_defined; + // We only count the memory usage of the largest subcomputation, instead of + // adding them all, because subcomputations won't execute in parallel. + int64 max_subcomputation_bytes = 0; + for (const auto* c : entry.instruction->called_computations()) { + auto it = memory_by_computation_.find(c); + if (it != memory_by_computation_.end()) { + int64 subcomputation_bytes = it->second; + if (subcomputation_bytes > max_subcomputation_bytes) { + max_subcomputation_bytes = subcomputation_bytes; + } + } + } + return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; } // Constructs the scheduling priority of the given instruction. @@ -315,6 +366,11 @@ class ListScheduler { const HloComputation& computation_; const TuplePointsToAnalysis& points_to_analysis_; const LogicalBuffer::SizeFunction& size_function_; + // Computations are analyzed in post-order. When scheduling an instruction + // that includes subcomputations, such as a while loop, we use this map to + // look up the memory needed by subcomputations. + const tensorflow::gtl::FlatMap& + memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. tensorflow::gtl::FlatMap> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm) { + return algorithm(computation, points_to_analysis, size_function, + memory_by_computation); + } + return DefaultMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation); +} + +} // namespace + StatusOr MinimumMemoryForComputation( const HloComputation& computation, const std::vector& sequence, @@ -352,24 +426,12 @@ StatusOr MinimumMemoryForComputation( return result.heap_size; } -StatusOr> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { - VLOG(2) << "Computation: " << computation.name(); - if (algorithm) { - return algorithm(computation, points_to_analysis, size_function); - } - return DefaultMemoryScheduler(computation, points_to_analysis, size_function); -} - -} // namespace - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { // This ordering is based on DFS post-order, with a heuristic to decide which // operand to visit first. The heuristic is based on 'extra_users', which is // simply users-1 for each instruction. By subtracting 1, we're saying that @@ -421,19 +483,35 @@ StatusOr> DFSMemoryScheduler( })); CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; -} +} // namespace xla StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - return ListScheduler::Run(computation, points_to_analysis, size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + return ListScheduler::Run(computation, points_to_analysis, size_function, + memory_by_computation); +} + +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { + const auto& post_order = computation.MakeInstructionPostOrder(); + return std::vector{post_order.begin(), + post_order.end()}; } StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation) { // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -443,30 +521,48 @@ StatusOr> DefaultMemoryScheduler( // within the caller's context. But it's good enough for now. TF_ASSIGN_OR_RETURN( std::vector list_sequence, - ListMemoryScheduler(computation, points_to_analysis, size_function)); + ListMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN( - std::vector dfs_sequence, - DFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + DFSMemoryScheduler(computation, points_to_analysis, + size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); - if (list_memory <= dfs_memory) { + TF_ASSIGN_OR_RETURN( + std::vector post_order_sequence, + PostOrderMemoryScheduler(computation, points_to_analysis, size_function, + memory_by_computation)); + TF_ASSIGN_OR_RETURN( + const int64 post_order_memory, + MinimumMemoryForComputation(computation, post_order_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory post order sequence: " + << HumanReadableNumBytes(post_order_memory); + + auto min_memory = std::min({dfs_memory, post_order_memory, list_memory}); + + if (min_memory == list_memory) { VLOG(2) << "Chose min-memory list sequence: " << HumanReadableNumBytes(list_memory); return list_sequence; - } else { + } else if (min_memory == dfs_memory) { VLOG(2) << "Chose min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); return dfs_sequence; + } else { + VLOG(2) << "Chose min-memory post_order sequence: " + << HumanReadableNumBytes(post_order_memory); + return post_order_sequence; } } @@ -477,24 +573,32 @@ CreateMemoryMinimizingSequence(const HloModule& module, SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN( - sequence[computation], - CreateMemoryMinimizingSequence(*computation, *points_to_analysis, - size_function, algorithm)); + tensorflow::gtl::FlatMap memory_by_computation; + for (const auto* computation : module.MakeComputationPostOrder()) { + if (!computation->IsFusionComputation()) { + TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function, + algorithm, memory_by_computation)); + memory_by_computation[computation] = + MinimumMemoryForComputation(*computation, one_computation_sequence, + *points_to_analysis, size_function) + .ValueOrDie(); + sequence[computation] = std::move(one_computation_sequence); + } } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { + const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); + tensorflow::gtl::FlatMap empty_map; return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, algorithm); + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 068e68383de..49b927eefd2 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -34,26 +34,47 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +// Returns the minimum memory required to compute the given computation, +// assuming no fragmentation. +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. typedef std::function>( const HloComputation&, const TuplePointsToAnalysis&, - const LogicalBuffer::SizeFunction&)> + const LogicalBuffer::SizeFunction&, + const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // DFS-order scheduler StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); + +// Naive Post Order scheduler +StatusOr> PostOrderMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, @@ -61,7 +82,9 @@ StatusOr> DFSMemoryScheduler( StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + const tensorflow::gtl::FlatMap& + memory_by_computation); // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes @@ -72,10 +95,10 @@ CreateMemoryMinimizingSequence(const HloModule& module, const MemorySchedulerAlgorithm& algorithm = {}); // Overload of above that computes the sequence for a single computation. +// Currently only used by the GPU backend. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); + const LogicalBuffer::SizeFunction& size_function); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 74544c4a67a..c018ba2ffc4 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -77,7 +77,7 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - auto size_fn = [](const LogicalBuffer& buffer) { + auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; @@ -124,7 +124,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const LogicalBuffer& buffer) { + CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -160,7 +160,7 @@ ENTRY root { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, tools::Parse(module_str)); - auto size_fn = [](const LogicalBuffer& buffer) { + auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( @@ -190,5 +190,104 @@ ENTRY root { instructions_by_name.at("e"))); } +TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) + // ROOT %not-equal-to = pred[] not-equal-to( + // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) + // } + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // ROOT %subtract = f32[4]{0} subtract( + // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) + // } + // %SubcomputationsNotAccounted () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant( + // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) + // %transpose = f32[2,4]{1,0} transpose( + // f32[2,4]{1,0} %constant.3), dimensions={0,1} + // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) + // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), + // condition=%WhileCond, + // body=%WhileBody + // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} + // ROOT %add = f32[2,4]{1,0} add( + // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewModule(); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // param != 0 + // Needs 17 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* zero_vector = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{0, 0, 0, 0}}))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({{1, 1, 1, 1}}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }, + ListMemoryScheduler)); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // This schedule is an example of List's greedy heuristics being suboptimal. + // The while_loop is more expensive than transpose, so it would have been + // better to schedule it first, instead of during the busy time. + EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); + EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); + EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 1b42349b0b3..7f7e3f7dab0 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -256,37 +256,24 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, ", input_shape=", ShapeUtil::HumanString(shape)); } - // The tile shape must not be the same as the input shape without maximal_ - // also set. If this is the case, we're not actually sharded and the correct - // constructor should have been used. - if (ShapeUtil::Equal(shape, tile_shape_)) { + // The correct constructor have to be used to create tile maximal shardings. + if (tile_assignment_.num_elements() == 1) { return tensorflow::errors::InvalidArgument( - "Tile shape is the same as the input shape. If a replicated sharding " - "was intended, use HloSharding::Replicated(). If a device placement " - "was intended, use HloSharding::AssignDevice()"); + "Tile assignment only contains a single device. If a replicated " + "sharding was intended, use HloSharding::Replicated(). If a device " + "placement was intended, use HloSharding::AssignDevice()"); } - // The tile shape must not be greater than the input shape in any dimension. - for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) { - auto tile_dim = tile_shape_.dimensions(i); - auto shape_dim = shape.dimensions(i); - if (tile_dim > shape_dim) { - return tensorflow::errors::InvalidArgument( - StrCat("Tile is larger than input shape (dimension ", i, ", ", - tile_dim, " > ", shape_dim)); - } - } - - // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim] - // tile[dim]) for every dimension contained within tile. + // The tile assignment tensor must contain enough element to cover the full + // shape with tiles of the specified size. for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { - int64 expected_dim = - CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); - if (tile_assignment_.dimensions()[i] != expected_dim) { + int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i); + if (shape.dimensions(i) > total_tile_size) { return tensorflow::errors::InvalidArgument( - StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, - " expected ", expected_dim, " but got ", - tile_assignment_.dimensions()[i])); + StrCat("Tile assignment tensor has too few element to cover the full " + "shape. Dimension ", + i, ", shape ", shape.dimensions(i), ", total size ", + total_tile_size)); } } @@ -380,10 +367,14 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); - ShapeTree sub_shape_tree(ShapeUtil::GetSubshape(shape, index), - Replicate()); + Shape sub_shape = ShapeUtil::GetSubshape(shape, index); + ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - return Tuple(sub_shape_tree); + if (ShapeUtil::IsTuple(sub_shape)) { + return Tuple(sub_shape_tree); + } else { + return sub_shape_tree.element({}); + } } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 69ea4233e45..3bf0d25efb7 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -88,7 +88,7 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should pass. + // Test should fail because of more devices used then `num_device`. Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); HloSharding sharding = HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); @@ -97,17 +97,8 @@ TEST_F(HloShardingTest, Tile) { } { - // Test should fail due to the tile being larger than the input space. - Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); - HloSharding sharding = - HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); - EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}), - /*num_devices=*/4)); - } - - { - // Test should fail due to the tile not dividing the input space into 4 - // sections (even with padding). + // Test should fail because the total tiled size in dimension 0 is 4 but we + // have 6 elements along that dimensions. Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); HloSharding sharding = HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index f8d98f06785..be156d765dc 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 05b7dce3d1e..7b27dbfec37 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -29,9 +29,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -69,7 +71,7 @@ std::ostream& operator<<(std::ostream& out, const HloUse& use) { HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) - : id_(id), is_phi_(is_phi) { + : BufferValue(instruction, index, id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. positions_.push_back(HloPosition{instruction, index}); } @@ -90,8 +92,8 @@ string HloValue::ToShortString() const { string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) ? defining_index().ToString() : ""; - return StrCat(id_, " ", is_phi_ ? "PHI " : "", defining_instruction()->name(), - index_str); + return StrCat(id(), " ", is_phi_ ? "PHI " : "", + defining_instruction()->name(), index_str); } string HloValue::ToString(int indent) const { diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 2a711e8b425..a1151f65e07 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -16,16 +16,20 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VALUE_H_ -#include +#include #include #include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -80,30 +84,9 @@ struct HloUse { std::ostream& operator<<(std::ostream& out, const HloUse& use); -// Class describing a value used by the dataflow analysis. XLA arrays are -// trivially a single HloValue. Tuples are made up of more than one HloValue: an -// HloValue for the pointer vector, and an HloValue for each child element. -// -// Every HloValue is defined by a particular instruction and most instructions -// define only a single HloValue. Instructions which define a single HloValue -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single HloValue -// which is a vector of pointers to the values containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple values only the top-level HloValue (the vector of pointers) is -// defined by the Tuple instruction. The values containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one HloValue. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. These tuple-shaped instructions do -// not assemble a tuple from existing HloValues like the Tuple instruction does, -// but rather define all the HloValues in the tuple. -class HloValue { +// HloDataflowAnalysis uses this subclass of BufferValue. +class HloValue : public BufferValue { public: - using Id = int64; - // Predicate comparing HloValues by increasing id, useful for std::sort. static bool IdLessThan(const HloValue* a, const HloValue* b) { return a->id() < b->id(); @@ -120,6 +103,7 @@ class HloValue { // dataflow analysis (HloDataflowAnalysis::ssa_form_ is true). HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + ~HloValue() override {} // Sets the positions in the module at which the HloValue appears. Updates // uses. Should be called once and only once. The defining position should not @@ -127,10 +111,6 @@ class HloValue { void SetPositionsAndComputeUses( tensorflow::gtl::ArraySlice positions); - // Return a unique identifier for this HloValue. This value is used for stable - // sorting and iteration - Id id() const { return id_; } - // Returns whether this value is a phi value. bool is_phi() const { return is_phi_; } @@ -142,12 +122,18 @@ class HloValue { return defining_position().instruction; } + HloInstruction* instruction() const override { + return defining_instruction(); + } + // Return the shape index at which this HloValue is defined in the output of // its defining instruction. const ShapeIndex& defining_index() const { return defining_position().index; } + const ShapeIndex& index() const override { return defining_index(); } + // Return the shape of this HloValue. - const Shape& shape() const { return defining_position().shape(); } + const Shape& shape() const override { return defining_position().shape(); } // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } @@ -164,12 +150,11 @@ class HloValue { // Return a single-line string representation of the value. string ToShortString() const; - string ToString(int indent = 0) const; + string ToString(int indent) const; + + string ToString() const override { return ToString(0); } private: - // Unique identifier for this HloValue. Used for stable sorting and iteration. - const Id id_; - // Whether this instruction is a phi value. const bool is_phi_; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 63ec5964eb9..7d6d0d9eaf7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -105,9 +106,7 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -Status ShapeVerifier::HandleInfeed(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleInfeed(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // Outfeed has a separate shape field for the value which is outfed to the @@ -115,7 +114,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { // produces no HLO value in the graph. if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { - return InvalidArgument( + return InternalError( "Expected outfeed to have shape compatible with operand's shape %s, " "actual shape is %s:\n%s", ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), @@ -126,12 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { } Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return tensorflow::Status::OK(); + return Status::OK(); } -Status ShapeVerifier::HandleRng(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleRng(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { return CheckShape( @@ -163,7 +160,7 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { @@ -174,32 +171,15 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == broadcast->dimensions().size()); - for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { - int64 output_dimension = broadcast->dimensions()[i]; + for (int64 operand_dimension = 0; + operand_dimension < ShapeUtil::Rank(operand_shape); + ++operand_dimension) { + int64 output_dimension = broadcast->dimensions()[operand_dimension]; TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(i)) + operand_shape.dimensions(operand_dimension)) << broadcast->ToString() << " operand shape " << operand_shape; } - return tensorflow::Status::OK(); -} - -Status ShapeVerifier::HandleBroadcastDimOne(HloInstruction* broadcastDimOne) { - const Shape& operand_shape = broadcastDimOne->operand(0)->shape(); - int64 operand_rank = ShapeUtil::Rank(operand_shape); - const Shape& output_shape = broadcastDimOne->shape(); - // Check for mixed precision. - TF_RETURN_IF_ERROR(CheckShape(broadcastDimOne, output_shape)); - TF_RET_CHECK(operand_rank == ShapeUtil::Rank(output_shape)); - for (int64 i = 0; i < operand_rank; ++i) { - int64 operand_dimension = operand_shape.dimensions(i); - int64 output_dimension = output_shape.dimensions(i); - TF_RET_CHECK(operand_dimension == 1 || - operand_dimension == output_dimension) - << "Dimension " << i << " of broadcastDimOne " - << broadcastDimOne->ToString() << " is " << operand_dimension - << ", expected 1 or " << output_dimension; - } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { @@ -207,7 +187,7 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { @@ -216,22 +196,18 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { transpose->operand(0)->shape(), transpose->dimensions())); } -Status ShapeVerifier::HandleParameter(HloInstruction*) { - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleCall(HloInstruction* call) { // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { - return tensorflow::Status::OK(); -} +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, @@ -426,7 +402,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { if (fp_type == PRIMITIVE_TYPE_INVALID) { fp_type = subshape.element_type(); } else if (fp_type != subshape.element_type()) { - return FailedPrecondition( + return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", instruction->ToString().c_str()); @@ -506,14 +482,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } } if (!compatible) { - return InvalidArgument( + return InternalError( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, @@ -557,13 +533,13 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( + return InternalError( "Expected to have the same channel id, actual channel ids are: %s " "(%lld), %s (%lld)", instr1->ToString().c_str(), instr1->channel_id(), instr2->ToString().c_str(), instr2->channel_id()); } - return tensorflow::Status::OK(); + return Status::OK(); } string ComputationsToString( @@ -587,22 +563,22 @@ string ComputationsToString( Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { - return FailedPrecondition("Computation %s has a null parent pointer", - computation->name().c_str()); + return InternalError("Computation %s has a null parent pointer", + computation->name().c_str()); } if (computation->parent() != module) { - return FailedPrecondition( + return InternalError( "Computation %s parent() does not point to parent module", computation->name().c_str()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { - return FailedPrecondition("Instruction %s has a null parent pointer", - instruction->name().c_str()); + return InternalError("Instruction %s has a null parent pointer", + instruction->name().c_str()); } if (instruction->parent() != computation) { - return FailedPrecondition( + return InternalError( "Instruction %s parent() does not point to parent computation", instruction->name().c_str()); } @@ -618,7 +594,7 @@ Status VerifyHloStructure(HloModule* module) { for (int i = 0; i < instruction->operand_count(); ++i) { const HloInstruction* operand = instruction->operand(i); if (operand->parent() != instruction->parent()) { - return FailedPrecondition( + return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", i, operand->name().c_str(), instruction->name().c_str(), @@ -628,14 +604,14 @@ Status VerifyHloStructure(HloModule* module) { } } } - return tensorflow::Status::OK(); + return Status::OK(); } Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { - return FailedPrecondition( + return InternalError( "Instruction of fused computation does not match expected instruction " "%s.", fusion->ToString().c_str()); @@ -651,37 +627,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto* instruction : fused_computation->instructions()) { if (fused_root == instruction) { if (root_owned) { - return FailedPrecondition("Root appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Root appears more than once in %s.", + fusion->ToString().c_str()); } root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { - return FailedPrecondition("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + return InternalError("Parameter appears more than once in %s.", + fusion->ToString().c_str()); } parameter_owned[i] = true; } } } if (!root_owned) { - return FailedPrecondition("Root not found in computation of %s.", - fusion->ToString().c_str()); + return InternalError("Root not found in computation of %s.", + fusion->ToString().c_str()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { - return FailedPrecondition("Parameter %d not found in computation of %s.", - i, fusion->ToString().c_str()); + return InternalError("Parameter %d not found in computation of %s.", i, + fusion->ToString().c_str()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return FailedPrecondition("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", + fusion->ToString().c_str()); } // All uses of fused instructions must be in the fusion computation, and every @@ -690,13 +666,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_instructions_computation()->instructions()) { if (instruction != fused_root) { if (instruction->user_count() == 0) { - return FailedPrecondition( - "Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + return InternalError("Non-root instruction %s in %s must have users.", + instruction->ToString().c_str(), + fusion->ToString().c_str()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { - return FailedPrecondition( + return InternalError( "Non-root instruction %s in %s may not have external users.", instruction->ToString().c_str(), fusion->ToString().c_str()); } @@ -711,41 +687,107 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return FailedPrecondition( - "Unexpected negative parameter number %lld in %s.", param_no, - fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %lld in %s.", + param_no, fusion->ToString().c_str()); } if (param_no >= fused_parameters.size()) { - return FailedPrecondition( + return InternalError( "Unexpected parameter number %lld in %s: higher then number of " "parameters %lu.", param_no, fusion->ToString().c_str(), fused_parameters.size()); } if (parameter_numbers[param_no]) { - return FailedPrecondition( + return InternalError( "Did not expect parameter number %lld more than once in %s.", param_no, fusion->ToString().c_str()); } parameter_numbers[param_no] = true; if (!ShapeUtil::Compatible(fused_param->shape(), fusion->operand(param_no)->shape())) { - return FailedPrecondition( + return InternalError( "Shape mismatch between parameter number %lld and its operand in %s.", param_no, fusion->ToString().c_str()); } } - // Make sure all the parameter_numbers entries were seen + // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { - return FailedPrecondition("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + return InternalError("Did not see parameter number %d in %s.", i, + fusion->ToString().c_str()); } } // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. - return tensorflow::Status::OK(); + return Status::OK(); +} + +Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { + auto* while_cond = instruction->while_condition(); + auto* while_body = instruction->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %lld : %s", + while_cond->num_parameters(), while_cond->ToString().c_str()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %lld : %s", + while_body->num_parameters(), while_body->ToString().c_str()); + } + if (instruction->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %lld : %s", + instruction->operand_count(), instruction->ToString().c_str()); + } + auto* init = instruction->operand(0); + auto* cond_param = while_cond->parameter_instruction(0); + if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) { + return FailedPrecondition( + "While condition's parameter must have the same shape as the " + "loop's 'init'. init: %s, param: %s", + init->ToString().c_str(), cond_param->ToString().c_str()); + } + auto* cond_root = while_cond->root_instruction(); + if (!ShapeUtil::Compatible(cond_root->shape(), + ShapeUtil::MakeShape(PRED, {}))) { + return FailedPrecondition("While condition should have shape PRED: %s", + cond_root->ToString().c_str()); + } + auto* body_param = while_body->parameter_instruction(0); + if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) { + return FailedPrecondition( + "While body's parameter must have the same shape as the loop's" + " 'init'. init: %s, param: %s", + init->ToString().c_str(), body_param->ToString().c_str()); + } + auto* body_root = while_body->root_instruction(); + if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) { + return FailedPrecondition( + "While body should have same shape as the loop's 'init'." + "init: %s, body: %s", + init->ToString().c_str(), body_root->ToString().c_str()); + } + return Status::OK(); +} + +Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { + const Shape& out_shape = instruction->shape(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (!ShapeUtil::IsScalar(operand_shape) && + !ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { + return FailedPrecondition( + "Implicit broadcast is not allowed in HLO." + "Found non-compatible shapes for instruction %s.\n" + "output: %s\noperand: %s\n", + HloOpcodeString(instruction->opcode()).c_str(), + ShapeUtil::HumanString(out_shape).c_str(), + ShapeUtil::HumanString(operand_shape).c_str()); + } + } + return Status::OK(); } StatusOr HloVerifier::Run(HloModule* module) { @@ -788,39 +830,9 @@ StatusOr HloVerifier::Run(HloModule* module) { << instruction->dimensions().size() << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - TF_RET_CHECK(while_cond->num_parameters() == 1) - << "While condition must have exactly 1 parameter; had " - << while_cond->num_parameters() << ": " << while_cond->ToString(); - TF_RET_CHECK(while_body->num_parameters() == 1) - << "While body must have exactly 1 parameter; had " - << while_body->num_parameters() << ": " << while_body->ToString(); - TF_RET_CHECK(instruction->operand_count() == 1) - << "While loop must have exactly one operand; had " - << instruction->operand_count() << ": " << instruction->ToString(); - - auto* init = instruction->operand(0); - auto* cond_param = while_cond->parameter_instruction(0); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape())) - << "While condition's parameter must have the same shape as the " - "loop's 'init'. init: " - << init->ToString() << ", param: " << cond_param->ToString(); - auto* cond_root = while_cond->root_instruction(); - TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(), - ShapeUtil::MakeShape(PRED, {}))) - << "While condition should have shape PRED: " - << cond_root->ToString(); - - auto* body_param = while_body->parameter_instruction(0); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape())) - << "While body's parameter must have the same shape as the loop's " - "'init'. init: " - << init->ToString() << ", param: " << body_param->ToString(); - auto* body_root = while_body->root_instruction(); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape())) - << "While body should have same shape as the loop's 'init'. init: " - << init->ToString() << ", body: " << body_root->ToString(); + TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); + } else if (instruction->IsElementwise()) { + TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); } auto previous = instructions.find(instruction->name()); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index a4dff977ba2..1392a78097a 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -54,7 +54,6 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleReduce(HloInstruction* reduce) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleBroadcastDimOne(HloInstruction* broadcastDimOne) override; Status HandleReshape(HloInstruction* reshape) override; Status HandleTranspose(HloInstruction* transpose) override; Status HandleParameter(HloInstruction*) override; @@ -83,9 +82,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } + Status FinishVisit(HloInstruction*) override { return Status::OK(); } protected: // Check the instruction's shape against the shape given by ShapeInference @@ -103,7 +100,7 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); - // Checks if the given two instructions shares the same channel id. + // Checks if the given two instructions share the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2); @@ -145,9 +142,15 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; + Status CheckWhileInstruction(HloInstruction* instruction); + + // Checks that the non-scalar operand shapes are compatible to the output + // shape, i.e., that there are no implicit broadcasts of size-one dimensions. + Status CheckElementwiseInstruction(HloInstruction* instruction); + // Creates a ShapeVerifier that checks that shapes match inferred - // expectations. This is a factory function because ShapeVerifier, Note that - // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object + // expectations. This is a factory function because ShapeVerifier, + // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; }; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index 13e4557317f..dc3bfce0c49 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -27,6 +27,7 @@ using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; @@ -35,20 +36,26 @@ string HumanReadableProfileBuilder::ToString() const { computation_name_.c_str(), HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); - auto append_op = [&](const OpInfo& op) { + auto print_op = [&](const OpInfo& op) { + // Skip ops with 0 optimal seconds and 0 actual cycles. These are ops that + // were expected to be free and are actually free -- things like (on most + // backends) kParameter or kConstant HLOs. There's no need to clutter the + // profile with these. + if (op.optimal_seconds == 0 && op.cycles == 0) { + return; + } + string bytes_per_sec; string bytes_per_cycle; - if (op.cycles <= 0 || op.bytes_accessed < 0) { - bytes_per_sec = ""; - bytes_per_cycle = ""; - } else { - bytes_per_sec = - HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); + if (op.cycles > 0 && op.bytes_accessed >= 0) { + bytes_per_sec = StrCat( + HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)), + "/s"); + double bpc = static_cast(op.bytes_accessed) / op.cycles; if (op.bytes_accessed > op.cycles) { - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = - Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + bytes_per_cycle = Printf("%.3fB/cycle", bpc); } } @@ -59,14 +66,16 @@ string HumanReadableProfileBuilder::ToString() const { double nsecs = op.cycles / clock_rate_ghz_; Appendf(&s, - "%15lld cycles (%6.2f%%) :: %12.1f usec (%12.1f optimal) :: %18s " - ":: %18s :: %12s/s :: %12s/cycle :: %s\n", + "%15lld cycles (%6.2f%%) :: %12.1f usec %22s :: %18s " + ":: %18s :: %14s :: %16s :: %s\n", op.cycles, cycles_percent, CyclesToMicroseconds(op.cycles), - op.optimal_seconds * 1e6, + op.optimal_seconds < 0 + ? "" + : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), op.flop_count <= 0 - ? "" + ? "" : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), - op.transcendental_count <= 0 ? "" + op.transcendental_count <= 0 ? "" : HumanReadableNumTranscendentalOps( op.transcendental_count, nsecs) .c_str(), @@ -78,24 +87,26 @@ string HumanReadableProfileBuilder::ToString() const { int64 total_transcendentals = 0.; int64 total_bytes = 0; for (const auto& op : op_infos_) { - optimal_seconds_sum += op.optimal_seconds; - total_flops += op.flop_count; - total_transcendentals += op.transcendental_count; - total_bytes += op.bytes_accessed; + if (op.optimal_seconds > 0) { + optimal_seconds_sum += op.optimal_seconds; + } + total_flops += std::max(op.flop_count, int64{0}); + total_transcendentals += std::max(op.transcendental_count, int64{0}); + total_bytes += std::max(op.bytes_accessed, int64{0}); } VLOG(1) << "Total floating point ops: " << total_flops; - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, - total_transcendentals, total_bytes, optimal_seconds_sum}); + print_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); - // Sort ops in decreasing order of cycles. + // Sort ops in decreasing order of cycles, and print them. std::vector sorted_ops(op_infos_); std::sort( sorted_ops.begin(), sorted_ops.end(), [](const OpInfo& a, const OpInfo& b) { return a.cycles > b.cycles; }); for (const auto& op : sorted_ops) { - append_op(op); + print_op(op); } if (total_cycles_ <= 0) { @@ -109,8 +120,20 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds above estimated optimum"); table.SetEntryName("ops"); table.SetShowCategoryTable(); + table.SetShowAllEntries(); float total_discrepancy_in_microseconds = 0.0f; - for (const auto& op : sorted_ops) { + for (const auto& op : op_infos_) { + // Skip ops with < 0 optimal seconds. These are ops for which we don't + // know the optimal time. + if (op.optimal_seconds < 0) { + continue; + } + // Also skip ops with 0 actual cycles. These ops were free; there's no + // need to clutter the "above estimated optimum" table with them, + // because they can't be optimized further. + if (op.cycles == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; @@ -128,7 +151,14 @@ string HumanReadableProfileBuilder::ToString() const { table.SetMetricName("microseconds"); table.SetEntryName("ops"); table.SetShowCategoryTable(); - for (const auto& op : sorted_ops) { + table.SetShowAllEntries(); + for (const auto& op : op_infos_) { + // Skip ops with 0 optimal seconds and 0 actual cycles. As in + // print_op(), these are uninteresting because they're expected to be + // free, and they were actually free. + if (op.cycles == 0 && op.optimal_seconds == 0) { + continue; + } MetricTableReport::Entry entry; entry.text = op.name; entry.short_text = op.short_name; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index fc24acd2713..6f56c3aa82e 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -32,7 +32,7 @@ class HumanReadableProfileBuilder { explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(computation_name.ToString()), + : computation_name_(std::string(computation_name)), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -41,15 +41,17 @@ class HumanReadableProfileBuilder { int64 total_cycles() const { return total_cycles_; } // Adds an operation to the profile. If you don't know the number of - // floating-point ops or bytes touched by the op, pass -1 for that param. + // floating-point ops or bytes touched by the op, or if you don't know how + // fast it would run optimally, pass -1 for that param. void AddOp(tensorflow::StringPiece op_name, tensorflow::StringPiece short_name, tensorflow::StringPiece category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back( - {op_name.ToString(), short_name.ToString(), category.ToString(), cycles, - flop_count, transcendental_count, bytes_accessed, optimal_seconds}); + op_infos_.push_back({std::string(op_name), std::string(short_name), + std::string(category), cycles, flop_count, + transcendental_count, bytes_accessed, + optimal_seconds}); } // Gets the human-readable profile. @@ -61,10 +63,10 @@ class HumanReadableProfileBuilder { string short_name; string category; int64 cycles; - int64 flop_count; + int64 flop_count; // -1 if unknown int64 transcendental_count; - int64 bytes_accessed; - float optimal_seconds; + int64 bytes_accessed; // -1 if unknown + float optimal_seconds; // -1 if unknown }; double CyclesToSeconds(int64 cycles) const { diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc new file mode 100644 index 00000000000..15b2d8f4990 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -0,0 +1,269 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gtl = ::tensorflow::gtl; + +namespace { +using Analysis = IndexedArrayAnalysis; +using UnknownArray = Analysis::UnknownArray; +using ConstantArray = Analysis::ConstantArray; +using ScalarIndexedArray = Analysis::ScalarIndexedArray; +} // namespace + +string IndexedArrayAnalysis::ToString(Array* root) { + switch (root->kind()) { + case Array::kUnknown: { + auto* unknown_tensor = root->as(); + return tensorflow::strings::StrCat("%", + unknown_tensor->instruction().name()); + } + + case Array::kConstant: { + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + } + + case Array::kScalarIndexedConstant: + case Array::kScalarIndexed: { + auto* indexed_array = root->as(); + string name = root->kind() == Array::kScalarIndexedConstant + ? "scalar-indexed-const" + : "scalar-indexed"; + return tensorflow::strings::StrCat( + "(", name, " ", ToString(indexed_array->source()), " ", + ToString(indexed_array->indices()), " ", indexed_array->source_dim(), + "->[", tensorflow::str_util::Join(indexed_array->output_dims(), ","), + "])"); + } + } +} + +Analysis::Array* IndexedArrayAnalysis::GetArrayFor( + const HloInstruction* instr) { + auto it = cache_.find(instr); + if (it != cache_.end()) { + return it->second; + } + + TraverseAndPopulateCache(instr); + return FindOrDie(cache_, instr); +} + +void IndexedArrayAnalysis::TraverseAndPopulateCache( + const HloInstruction* root) { + // Depth first search over the DAG, invoking ComputeArrayFor in post order. + // The HLO instructions already in the cache are considered leaves. + + gtl::InlinedVector stack; + + enum DfsState { kDiscovered, kVisited }; + gtl::FlatMap dfs_state_map; + + stack.push_back(root); + InsertOrDie(&dfs_state_map, root, kDiscovered); + + do { + const HloInstruction* instr = stack.back(); + if (cache_.count(instr)) { + stack.pop_back(); + continue; + } + + switch (FindOrDie(dfs_state_map, instr)) { + case kDiscovered: { + for (const HloInstruction* operand : instr->operands()) { + if (!cache_.count(operand)) { + stack.push_back(operand); + CHECK(!dfs_state_map.count(operand) || + dfs_state_map[operand] == kDiscovered); + dfs_state_map[operand] = kDiscovered; + } + } + dfs_state_map[instr] = kVisited; + break; + } + + case kVisited: + stack.pop_back(); + InsertOrDie(&cache_, instr, ComputeArrayFor(instr)); + break; + } + } while (!stack.empty()); +} + +Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor( + const HloInstruction* instr) { + Array* computed_array; + switch (instr->opcode()) { + default: + computed_array = nullptr; + break; + case HloOpcode::kConstant: + computed_array = ComputeArrayForConstant(instr->literal()); + break; + case HloOpcode::kGather: + computed_array = ComputeArrayForGather( + instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1))); + break; + } + + if (!computed_array) { + computed_array = Construct(instr); + } + + return computed_array; +} + +Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant( + const Literal& literal) { + return Construct(&literal); +} + +ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape) { + // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). + // `source` is the inner Gather(A, X). + + Array* a = source->source(); + Array* x = source->indices(); + Array* y = indices; + + // This bit is slightly tricky, so we do a naive "simulation" of the two + // consecutive gather operations to infer what the composed gather should look + // like. + + enum class IndexComponent { Ungathered, GatheredFirst, GatheredSecond }; + + std::vector simulated_index(a->shape().dimensions_size(), + IndexComponent::Ungathered); + + // Simulate the first gather. + simulated_index.erase(simulated_index.begin() + source->source_dim()); + for (int64 gather_dim : source->output_dims()) { + simulated_index.insert(simulated_index.begin() + gather_dim, + IndexComponent::GatheredFirst); + } + + // Simulate the second gather. + simulated_index.erase(simulated_index.begin() + source_dim); + for (int64 output_dim : output_dims) { + simulated_index.insert(simulated_index.begin() + output_dim, + IndexComponent::GatheredSecond); + } + + int64 source_dim_for_index_array = + FindIndex(source->output_dims(), source_dim); + CHECK_NE(source_dim_for_index_array, source->output_dims().size()); + + std::vector output_dims_for_index_array; + int64 gathered_index_components_seen = 0; + for (IndexComponent simulation_dim : simulated_index) { + if (simulation_dim == IndexComponent::GatheredSecond) { + output_dims_for_index_array.push_back(gathered_index_components_seen); + } + if (simulation_dim != IndexComponent::Ungathered) { + gathered_index_components_seen++; + } + } + + std::vector dim_sizes_for_composed_index; + std::vector output_dims_for_new_gather; + for (int64 i = 0, e = simulated_index.size(); i < e; i++) { + if (simulated_index[i] != IndexComponent::Ungathered) { + dim_sizes_for_composed_index.push_back(shape.dimensions(i)); + output_dims_for_new_gather.push_back(i); + } + } + + Array* inner_indices = ConstructScalarIndexedArray( + x, y, source_dim_for_index_array, output_dims_for_index_array, + ShapeUtil::MakeShape(x->shape().element_type(), + dim_sizes_for_composed_index)); + return ConstructScalarIndexedArray(a, inner_indices, source->source_dim(), + output_dims_for_new_gather, + std::move(shape)); +} + +Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices) { + if (dim_numbers.index_vector_dim() != indices->shape().dimensions_size()) { + return nullptr; + } + + CHECK_EQ(dim_numbers.gather_dims_to_operand_dims_size(), 1); + if (!c_binary_search(dim_numbers.elided_window_dims(), + dim_numbers.gather_dims_to_operand_dims(0))) { + return nullptr; + } + + int64 source_dim = dim_numbers.gather_dims_to_operand_dims(0); + std::vector output_dims; + for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { + if (!c_binary_search(dim_numbers.output_window_dims(), i)) { + output_dims.push_back(i); + } + } + + if (auto* indexed = dynamic_cast(source)) { + auto it = c_find(indexed->output_dims(), source_dim); + if (it != indexed->output_dims().end()) { + return FoldGatherOfGather(indexed, indices, source_dim, output_dims, + shape); + } + } else if (auto* constant = dynamic_cast(source)) { + return Construct(constant, indices, source_dim, + output_dims, shape); + } + + return Construct(source, indices, source_dim, output_dims, + shape); +} + +tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { + return "indexed-array-analysis-printer-pass"; +} + +StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { + if (!VLOG_IS_ON(2)) { + return false; + } + + IndexedArrayAnalysis analysis; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instr : computation->instructions()) { + auto* t = analysis.GetArrayFor(instr); + if (!dynamic_cast(t) && !dynamic_cast(t)) { + VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); + } + } + } + + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h new file mode 100644 index 00000000000..b132a8f2515 --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -0,0 +1,298 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { kUnknown, kConstant, kScalarIndexedConstant, kScalarIndexed }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,19] and the output index [A,B,C,D,E] is mapped to the input + // index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + int64 source_dim() const { return source_dim_; } + tensorflow::gtl::ArraySlice output_dims() const { + return output_dims_; + } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64 source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64 source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + Array* GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + string ToString(Array* root); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + void TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + Array* ComputeArrayFor(const HloInstruction* instr); + + Array* ComputeArrayForConstant(const Literal& literal); + + Array* ComputeArrayForGather(const Shape& shape, + const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, + Array* source, Array* indices); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + ScalarIndexedArray* FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64 source_dim, + tensorflow::gtl::ArraySlice output_dims, Shape shape); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64 source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + std::vector> owned_tensors_; + std::vector> owned_literals_; + tensorflow::gtl::FlatMap cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloPassInterface { + public: + tensorflow::StringPiece name() const override; + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc new file mode 100644 index 00000000000..b2731b7c51a --- /dev/null +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -0,0 +1,191 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/indexed_array_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +namespace xla { +namespace { +class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + protected: + void AssertArrayForRootExpressionIs(const string& hlo_text, + const string& root_expression) { + IndexedArrayAnalysis indexed_tensor_analysis; + ParseAndVerifyModule(hlo_text); + + string result = + indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor( + module().entry_computation()->root_instruction())); + LOG(INFO) << result; + ASSERT_EQ(result, root_expression); + } +}; + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[5] parameter(1) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, SimpleOneToOneConstantGather) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices = s32[5] parameter(0) + ROOT gather = s32[5,3] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}}) + indices_a = s32[5] parameter(0) + indices_b = s32[2] parameter(1) + gather_a = s32[5,3] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} + ROOT gather_b = s32[2,3] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,3} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3]) (scalar-indexed %indices_a " + "%indices_b 0->[0]) 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithOneToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[2] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[5,3,2] gather(gather_a, indices_b), + output_window_dims={0,1}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=1, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 1->[1]) 1->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,6] parameter(0) + indices_a = s32[2] parameter(1) + indices_b = s32[5,7] parameter(2) + gather_a = s32[2,6] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,6} + ROOT gather_b = s32[5,6,7] gather(gather_a, indices_b), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,6} +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, + "(scalar-indexed %operand (scalar-indexed " + "%indices_a %indices_b 0->[0,1]) 0->[0,2])"); +} + +TEST_F(IndexedArrayAnalysisTest, GatherOfGather_ManyToOneWithManyToOne) { + string hlo_text = R"( +HloModule SimpleGather + +ENTRY main { + operand = s32[3,2] parameter(0) + indices_a = s32[5,7] parameter(1) + indices_b = s32[4,8] parameter(2) + gather_a = s32[5,3,7] gather(operand, indices_a), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3,1} + ROOT gather_b = s32[4,5,3,8] gather(gather_a, indices_b), + output_window_dims={1,2}, + elided_window_dims={2}, + gather_dims_to_operand_dims={2}, + index_vector_dim=2, + window_bounds={5,3,1} +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " + "1->[0,2]) 1->[0,1,3])"); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 7aa1c7c8358..d2af261008f 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({4, 3, 3, 4}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } // Test that `constant` function is changed to `broadcast`. @@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } TEST_F(InlinerTest, MapSubtractOppositeOrder) { @@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); auto expected = Literal::CreateR1({3, 1, -1, -3}); - LiteralTestUtil::ExpectEqual(*result, *expected); + EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected)); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 3f4dbf897df..cb6c98c4817 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -28,6 +28,25 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { +// These nodes can always be duplicated into consumers, even if +// InstructionFusion::may_duplicate_ is false. +// +// In general these should be nodes that get *cheaper* the more they're +// duplicated (and fused into consumers). +// +// TODO(jlebar): Duplicating instructions when we have a variable called "may +// duplicate" that's equal to false is not pretty. +bool IsAlwaysDuplicable(const HloInstruction& instruction) { + // We are always willing to duplicate a widening type-conversion instruction + // if it means we can fuse the convert into a consumer. This allows the + // consumer to read less memory, which is almost always a performance win. + return instruction.opcode() == HloOpcode::kConvert && + ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction.shape()); +} +} // namespace + /*static*/ bool InstructionFusion::IsExpensive( const HloInstruction& instruction) { switch (instruction.opcode()) { @@ -37,9 +56,9 @@ namespace xla { case HloOpcode::kBitcast: case HloOpcode::kBitcastConvert: case HloOpcode::kBroadcast: - case HloOpcode::kBroadcastDimOne: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kClz: case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: @@ -101,11 +120,13 @@ namespace xla { case HloOpcode::kDivide: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: case HloOpcode::kHostCompute: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: @@ -128,11 +149,11 @@ namespace xla { return false; } -// An "effectively unary" operation is one that has one "large" +// An "effectively at most unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. -bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { +bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { int64 output_rank = 0; ShapeUtil::ForEachSubshape( hlo->shape(), @@ -143,8 +164,7 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { }); return std::count_if(hlo->operands().begin(), hlo->operands().end(), [output_rank](HloInstruction* operand) { - if (operand->opcode() == HloOpcode::kBroadcast || - operand->opcode() == HloOpcode::kBroadcastDimOne) { + if (operand->opcode() == HloOpcode::kBroadcast) { return false; } if (operand->opcode() == HloOpcode::kConstant && @@ -157,66 +177,91 @@ bool InstructionFusion::EffectivelyUnary(HloInstruction* hlo) { } bool InstructionFusion::CanFuseOnAllPaths( - const HloReachabilityMap& reachability_map, HloInstruction* producer, - HloInstruction* consumer, DoNotFuseSet* do_not_fuse) { - auto could_fuse_on_all_paths = [&] { - // First check to see if we have already marked this producer as infeasible - // to fuse into consumer. - if (do_not_fuse->count(producer) > 0) { - return false; - } - // Make sure it is possible for producer and consumer to exist in a fusion - // node. - if (!producer->IsFusable() || !consumer->IsFusable()) { - return false; - } - // We do an upward walk of the graph from consumer towards all paths which - // lead to producer to find any unfusable paths. - for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { - auto* consumer_operand = consumer->mutable_operand(i); - if (consumer_operand == producer) { - // This is the base case: our upward crawl ends but we need to make sure - // that fusion from consumer can happen. - if (!ShouldFuse(consumer, i)) { - return false; - } - } else if (reachability_map.IsReachable(producer, consumer_operand)) { - // The reachability map told us that consumer_operand is a node on the - // path to producer. We need to further investigate from - // consumer_operand. - - // First check if we have already ruled out fusing producer into - // consumer_operand. - if (do_not_fuse->count(consumer_operand) > 0) { - return false; - } - // Make sure it is possible for consumer_operand to exist in a fusion - // node. - if (!consumer_operand->IsFusable()) { - return false; - } - // The producer is reachable from consumer_operand which means we need - // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. - if (!ShouldFuse(consumer, i)) { - return false; - } - // Perform the recursive step: make sure producer can be fused into - // consumer_operand on all paths. - if (!CanFuseOnAllPaths(reachability_map, producer, consumer_operand, - do_not_fuse)) { - return false; - } - } - } - return true; - }; - if (could_fuse_on_all_paths()) { + HloInstruction* producer, HloInstruction* consumer, + const HloReachabilityMap& reachability_map, + const DoNotFuseSet& do_not_fuse) { + if (consumer == producer) { return true; } - // We couldn't fuse on all paths, record this result. - do_not_fuse->insert(producer); - return false; + if (!consumer->IsFusable()) { + return false; + } + for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { + auto* consumer_operand = consumer->mutable_operand(i); + // If the operand is not on a path to the producer, it doesn't matter + // whether it's fusable. + if (!reachability_map.IsReachable(producer, consumer_operand)) { + continue; + } + if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + return false; + } + // The producer is reachable from consumer_operand which means we need + // to be able to fuse consumer_operand into consumer in order for + // producer to be fusable into consumer on all paths. + // Perform the recursive step: make sure producer can be fused into + // consumer_operand on all paths. + if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map, + do_not_fuse)) { + return false; + } + } + return true; +} + +InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( + tensorflow::gtl::ArraySlice post_order) { + auto reachability = computation_->ComputeReachability(); + + // Forbid fusion of producers that: + // a) Need to be duplicated, unless they can be fused into all consumers + // via all paths. + // b) Are more than unary, that is, fusing them would likely lead to an + // increase in memory bandwidth use. + // + // Note that if we allow fusion by these global rules, we may still forbid + // fusing operations that require duplication later depending on + // is_expensive_(). + DoNotFuseSet do_not_fuse; + for (HloInstruction* consumer : post_order) { + for (HloInstruction* producer : consumer->operands()) { + if (do_not_fuse.count(producer) > 0) { + continue; + } + + // If the producer is effectively not more than unary, duplicating it + // will not increase the number of relevant inputs read, as the fusion + // node will only need to read at most 1 relevant input (the input of + // the producer). In that case, we do not forbid fusion of the operation + // here. + if (EffectivelyAtMostUnary(producer)) { + continue; + } + // Otherwise we will forbid fusing the op unless we can fuse it into + // all of its consumers on all paths. + // + // That means, that for: + // A --> B (fusable) + // \-> C (non-fusable) + // A will be not allowed to be fused into B, as it cannot be fused into C. + // + // Similarly, for: + // A -------------> B + // \-> C -> D -/ + // If: + // - A is fusable into B and C, and D is fusable into B + // - C is *not* fusable into D + // A will be not allowed to be fused into B, as it cannot be fused via + // all paths. + if (producer->IsFusable() && + CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) { + continue; + } + do_not_fuse.insert(producer); + } + } + + return do_not_fuse; } StatusOr InstructionFusion::Run(HloModule* module) { @@ -245,37 +290,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - DoNotFuseSet do_not_fuse; - auto reachability = computation->ComputeReachability(); - - auto cheap_to_duplicate = [this](HloInstruction* producer) { - if (producer->opcode() == HloOpcode::kBroadcast || - producer->opcode() == HloOpcode::kBroadcastDimOne) { - return true; - } - if (producer->opcode() == HloOpcode::kConstant && - ShapeUtil::IsEffectiveScalar(producer->shape())) { - return true; - } - if (EffectivelyUnary(producer)) { - return true; - } - return false; - }; - - for (HloInstruction* consumer : post_order) { - for (HloInstruction* producer : consumer->operands()) { - if (cheap_to_duplicate(producer)) { - continue; - } - if (CanFuseOnAllPaths(*reachability, producer, consumer, - &do_not_fuse)) { - CHECK_EQ(do_not_fuse.count(producer), 0); - } else { - CHECK_GT(do_not_fuse.count(producer), 0); - } - } - } + DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -399,12 +414,9 @@ StatusOr InstructionFusion::Run(HloModule* module) { return changed; } -HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, - HloInstruction* consumer) { +HloInstruction* InstructionFusion::AddFusionInstruction( + HloInstruction* producer, HloInstruction* consumer) { HloInstruction* fusion_instruction; - - VLOG(2) << "Fusing " << producer->ToString() << " into " - << consumer->ToString(); auto kind = ChooseKind(producer, consumer); if (consumer->opcode() == HloOpcode::kFusion) { fusion_instruction = consumer; @@ -416,17 +428,35 @@ HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction)); } + return fusion_instruction; +} +HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, + HloInstruction* consumer) { + VLOG(2) << "Fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); fusion_instruction->FuseInstruction(producer); return fusion_instruction; } +HloInstruction* InstructionFusion::FuseIntoMultiOutput( + HloInstruction* producer, HloInstruction* consumer) { + VLOG(2) << "Multi-output fusing " << producer->ToString() << " into " + << consumer->ToString(); + HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); + fusion_instruction->FuseInstructionIntoMultiOutput(producer); + return fusion_instruction; +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); + // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && - (is_expensive_(*producer) || !may_duplicate_)) { + (!may_duplicate_ || is_expensive_(*producer)) && + !IsAlwaysDuplicable(*producer)) { return false; } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 152d0886ee9..c3c2ed0aaa8 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -70,11 +70,18 @@ class InstructionFusion : public HloPassInterface { virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer); - // An "effectively unary" operation is one that has one "large" + // Creates a new fusion instruction containing `producer` and `consumer`. A + // tuple is added as the fusion instruction's root, which consumes from both, + // `producer` and `consumer`. This style of fusion is referred to as + // multi-output fusion. + virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer); + + // An "effectively unary" operation is one that has at most one "large" // input with the others being negligible in terms of memory usage. // We use "has a smaller true rank than the output" as a heuristic // for "negligible" memory usage. - bool EffectivelyUnary(HloInstruction* hlo); + bool EffectivelyAtMostUnary(HloInstruction* hlo); // Returns true if fusing producer into consumer would cause producer to be // duplicated. This is the case if producer has uses other than consumer. @@ -95,11 +102,19 @@ class InstructionFusion : public HloPassInterface { // The set of producers whose consumers we cannot fuse into. using DoNotFuseSet = std::unordered_set; - // Whether or not we can fuse consumer into original_producer on all paths + HloInstruction* AddFusionInstruction(HloInstruction* producer, + HloInstruction* consumer); + + // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(const HloReachabilityMap& reachability_map, - HloInstruction* producer, HloInstruction* consumer, - DoNotFuseSet* do_not_fuse); + bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, + const HloReachabilityMap& reachability_map, + const DoNotFuseSet& do_not_fuse); + + // Computes the set of nodes that we do not want to fuse into any of their + // consumers based on a global analysis of the HLO graph. + DoNotFuseSet ComputeGloballyUnfusable( + tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 0fa2c95fb45..df109df7877 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -17,11 +17,99 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { +namespace op = xla::testing::opcode_matchers; + using InstructionFusionTest = HloTestBase; +// Subclass of InstructionFusion exposing the protected methods Fuse and +// FuseIntoMultiOutput for testing. +class InstructionFusionForTesting : public InstructionFusion { + public: + explicit InstructionFusionForTesting(HloModule* module) + : InstructionFusion(InstructionFusion::IsExpensive) { + module_ = module; + computation_ = module->entry_computation(); + } + + HloInstruction* Fuse(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::Fuse(producer, consumer); + } + + HloInstruction* FuseIntoMultiOutput(HloInstruction* producer, + HloInstruction* consumer) override { + return InstructionFusion::FuseIntoMultiOutput(producer, consumer); + } +}; + +TEST_F(InstructionFusionTest, FuseInstructions) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT sub = f32[4,3]{1,0} subtract(add, p0) + })") + .ValueOrDie(); + HloInstruction* sub = module->entry_computation()->root_instruction(); + HloInstruction* add = sub->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(add, sub); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), + op::Subtract(op::Add(), op::Parameter())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { + auto module = tools::Parse(R"( + HloModule test_module + fused_computation { + p1 = f32[4,3] parameter(0) + add = f32[4,3] add(p1, p1) + } + ENTRY entry_computation { + p0 = f32[4,3] parameter(0) + abs = f32[4,3] abs(p0) + ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).Fuse(abs, root); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs())) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY entry_computation { + p0 = f32[4,3]{1,0} parameter(0) + abs = f32[4,3]{1,0} abs(p0) + tanh = f32[4,3]{1,0} tanh(abs) + ROOT add = f32[4,3]{1,0} add(abs, tanh) + })") + .ValueOrDie(); + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction* abs = root->mutable_operand(0); + HloInstruction* tanh = root->mutable_operand(1); + HloInstruction* fusion = + InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh); + + ASSERT_THAT(fusion, op::Fusion()) << module->ToString(); + EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs())) + << module->ToString(); +} + TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) { HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( @@ -89,7 +177,172 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { EXPECT_FALSE( InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); +} + +// Counts the number of HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + ROOT root = f32[4,3]{1,0} subtract(add, add) + })") + .ValueOrDie(); + // Expect the add and subtraction to be fused. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); +} + +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { + // Make sure we do not duplicate the add, as we cannot fuse through the rng. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> rng -> abs2 -/ + auto module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + rng = f32[4,3]{1,0} rng(abs1), distribution=rng_uniform + abs2 = f32[4,3]{1,0} abs(rng) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + // We expect abs2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Subtract(op::Abs(op::Parameter()), op::Parameter())) + << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); + + // Use a log node with a second consumer to break the fusion. + // + // p0 -> add -------------------------> sub + // \-> abs1 -> log -> abs2 -/ + // \-> send + module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add = f32[4,3]{1,0} add(p0, p0) + abs1 = f32[4,3]{1,0} abs(add) + log = f32[4,3]{1,0} log(abs1) + send = f32[4,3]{1,0} send(log), channel_id=0 + abs2 = f32[4,3]{1,0} abs(log) + ROOT root = f32[4,3]{1,0} subtract(abs2, add) + })") + .ValueOrDie(); + + // We expect abs2 to be fused into root and abs1 to be fused into log. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 2) << module->ToString(); + + // Make sure the add hasn't been duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); + + // Make sure we still fuse ops where one operand in the chain to the producer + // can't be fused. + // + // p0 ---> add1 -----------> sub + // \ \-> add2 -/ + // \-> log -/ + // \-> send + module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add1 = f32[4,3]{1,0} add(p0, p0) + log = f32[4,3]{1,0} log(p0) + send = f32[4,3]{1,0} send(log), channel_id=0 + add2 = f32[4,3]{1,0} add(log, add1) + ROOT root = f32[4,3]{1,0} subtract(add1, add2) + })") + .ValueOrDie(); + + // Expect the add1 and add2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString(); + + // Make sure we didn't duplicate any adds. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); + + // A variant of the above that allows the algorithm to put add2 into the set + // of unfusable ops to short-circuit the decision whether add1 should be fused + // into sub2. + // + // /---------------\ + // p0 ---> add1 ---> add2 ------> sub2 + // \------> sub1 + // log -/ + // \-> send + module = tools::Parse(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + add1 = f32[4,3]{1,0} add(p0, p0) + add2 = f32[4,3]{1,0} add(add1, add1) + log = f32[4,3]{1,0} log(add2) + send = f32[4,3]{1,0} send(log), channel_id=0 + sub1 = f32[4,3]{1,0} subtract(log, add2) + sub2 = f32[4,3]{1,0} subtract(add2, add1) + ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2) + })") + .ValueOrDie(); + + // Expect sub1 and sub2 to be fused into root. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Tuple(op::Subtract(op::Parameter(), op::Parameter()), + op::Subtract(op::Parameter(), op::Parameter()))) + << module->ToString(); + + // Make sure we didn't duplicate any adds. + EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); } TEST_F(InstructionFusionTest, AllowUnaryDuplication) { @@ -135,4 +388,29 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { .ValueOrDie()); } +TEST_F(InstructionFusionTest, + WideningConvertsAreAlwaysDuplicableIntoConsumers) { + auto module = tools::Parse(R"( + HloModule test_module + ENTRY Test { + p0 = f16[100] parameter(0) + c = f32[100] convert(p0) + add = f32[100] add(c, c) + ROOT mul = f32[100] multiply(c, c) + })") + .ValueOrDie(); + + // The convert should be fused into the add and mul, even though may_duplicate + // is false, because it's always beneficial to fuse/duplicate widening + // converts into consumers. + EXPECT_TRUE( + InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion(op::Parameter())); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 45505484951..524d3234eb4 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -18,7 +18,6 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -117,6 +116,5 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_headers_lib", - "//tensorflow/core:stream_executor_no_cuda", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/README.md b/tensorflow/compiler/xla/service/interpreter/README.md index 4c19a1b916d..0b21b251c3f 100644 --- a/tensorflow/compiler/xla/service/interpreter/README.md +++ b/tensorflow/compiler/xla/service/interpreter/README.md @@ -5,7 +5,7 @@ evaluating the result of the HLO graph directly with HloEvaluator, without lowering it further (to LLVM IR for example) before execution as other backends (CPU and GPU for example) do. -Its key componenets are: +Its key components are: * [`InterpreterCompiler`] despite the inherited naming of "compiler", all `InterpreterCompiler` really does is the following: diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 5b9bf5faf36..c1666530687 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -34,22 +34,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace interpreter { -namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; - Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass( - hlo_module->mutable_entry_computation_layout()); - + hlo_module->mutable_device_entry_computation_layout()); return pipeline.Run(hlo_module).status(); } @@ -74,7 +69,8 @@ StatusOr> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr executable = - xla::MakeUnique(std::move(hlo_module)); + xla::MakeUnique(std::move(hlo_module), + xla::MakeUnique()); return std::move(executable); } @@ -96,7 +92,7 @@ InterpreterCompiler::CompileAheadOfTime( } se::Platform::Id InterpreterCompiler::PlatformId() const { - return sep::kXlaInterpreterPlatformId; + return se::interpreter::kXlaInterpreterPlatformId; } HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() @@ -104,16 +100,14 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() return InterpreterExecutable::ShapeSizeBytes; } -static std::unique_ptr CreateComputationPlacer() { - return xla::MakeUnique(); -} - static bool InitModule() { - xla::Compiler::RegisterCompilerFactory(sep::kXlaInterpreterPlatformId, []() { - return xla::MakeUnique(); - }); + xla::Compiler::RegisterCompilerFactory( + se::interpreter::kXlaInterpreterPlatformId, []() { + return xla::MakeUnique(); + }); xla::ComputationPlacer::RegisterComputationPlacer( - sep::kXlaInterpreterPlatformId, &CreateComputationPlacer); + se::interpreter::kXlaInterpreterPlatformId, + []() { return xla::MakeUnique(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index c8660c04d86..e90ae3e8185 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -44,19 +44,16 @@ class InterpreterCompiler : public Compiler { ~InterpreterCompiler() override {} StatusOr> RunHloPasses( - std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( - std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec, + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( std::vector> hlo_modules, - std::vector> - stream_exec, + std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) override; StatusOr>> @@ -65,7 +62,7 @@ class InterpreterCompiler : public Compiler { HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override; - perftools::gputools::Platform::Id PlatformId() const override; + se::Platform::Id PlatformId() const override; private: Status RunHloOptimization(HloModule* hlo_module); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 883063d0f07..029e71058a7 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -32,22 +31,21 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace interpreter { -namespace se = ::perftools::gputools; - InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module, + std::unique_ptr evaluator) : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, - /*hlo_profile_index_map=*/nullptr) {} + /*hlo_profile_index_map=*/nullptr), + evaluator_(std::move(evaluator)) {} InterpreterExecutable::~InterpreterExecutable() {} -StatusOr> InterpreterExecutable::ExecuteOnStream( +StatusOr InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { @@ -84,18 +82,21 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( } // Execute the graph using the HloEvaluator. - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - std::unique_ptr result_literal, - evaluator.Evaluate>(*computation, arg_literals)); + std::unique_ptr result_literal; + { + tensorflow::mutex_lock lock(evaluator_lock_); + TF_ASSIGN_OR_RETURN(result_literal, + evaluator_->Evaluate>( + *computation, arg_literals)); + } // Transform the result literal back into a ShapedBuffer. - TF_ASSIGN_OR_RETURN(std::unique_ptr result, - transfer_manager->AllocateShapedBuffer( + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + transfer_manager->AllocateScopedShapedBuffer( result_literal->shape(), run_options->allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( - executor, *result_literal, *result)); + executor, *result_literal, result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -108,8 +109,7 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( return std::move(result); } -StatusOr> -InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 410110a1adf..91d8148d26d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -40,20 +42,27 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module, + std::unique_ptr evaluator); ~InterpreterExecutable() override; - StatusOr> ExecuteOnStream( + StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) override; + HloExecutionProfile* hlo_execution_profile) override + LOCKS_EXCLUDED(evaluator_lock_); - StatusOr> ExecuteAsyncOnStream( + StatusOr ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) override; static int64 ShapeSizeBytes(const Shape& shape); + protected: + // The interpreter interprets executables with an HloEvaluator. + std::unique_ptr evaluator_ PT_GUARDED_BY(evaluator_lock_); + mutable tensorflow::mutex evaluator_lock_; + private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 3caf9e7b82b..97e9fa2c8e8 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -19,8 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { host::HostStream *AsExecutorStream(Stream *stream) { @@ -119,5 +118,4 @@ DeviceDescription *XlaInterpreterExecutor::PopulateDeviceDescription() const { } } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 77426b0820d..9b109022fbf 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -44,8 +44,7 @@ limitations under the License. #include "tensorflow/stream_executor/stream_executor_internal.h" #include "tensorflow/stream_executor/timer.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { using Args = tensorflow::gtl::ArraySlice; @@ -213,7 +212,6 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { }; } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_EXECUTOR_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index 3cf8506d1c4..d27cd7502f1 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -21,12 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" -namespace sei = ::perftools::gputools::interpreter; - namespace xla { InterpreterTransferManager::InterpreterTransferManager() - : GenericTransferManager(sei::kXlaInterpreterPlatformId, + : GenericTransferManager(se::interpreter::kXlaInterpreterPlatformId, /*pointer_size=*/sizeof(void*)) {} } // namespace xla @@ -38,7 +36,8 @@ CreateInterpreterTransferManager() { static bool InitModule() { xla::TransferManager::RegisterTransferManager( - sei::kXlaInterpreterPlatformId, &CreateInterpreterTransferManager); + stream_executor::interpreter::kXlaInterpreterPlatformId, + &CreateInterpreterTransferManager); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 015e00e1e8e..42c2c28997d 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/interpreter/executor.h" -#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -28,20 +27,16 @@ limitations under the License. #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" -namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; - -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { -XlaInterpreterPlatform::XlaInterpreterPlatform() : name_("Interpreter") {} +XlaInterpreterPlatform::XlaInterpreterPlatform(const string& name, + const Platform::Id& id) + : name_(name), id_(id) {} XlaInterpreterPlatform::~XlaInterpreterPlatform() {} -Platform::Id XlaInterpreterPlatform::id() const { - return kXlaInterpreterPlatformId; -} +Platform::Id XlaInterpreterPlatform::id() const { return id_; } int XlaInterpreterPlatform::VisibleDeviceCount() const { return 1; } @@ -75,8 +70,8 @@ port::StatusOr XlaInterpreterPlatform::GetExecutor( port::StatusOr> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { - auto executor = port::MakeUnique( - this, port::MakeUnique(config.plugin_config)); + auto executor = MakeUnique( + this, MakeUnique(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ @@ -99,18 +94,16 @@ void XlaInterpreterPlatform::UnregisterTraceListener(TraceListener* listener) { } static void InitializeXlaInterpreterPlatform() { - std::unique_ptr platform(new sep::XlaInterpreterPlatform); - SE_CHECK_OK(se::MultiPlatformManager::RegisterPlatform(std::move(platform))); + std::unique_ptr platform(new XlaInterpreterPlatform); + SE_CHECK_OK(MultiPlatformManager::RegisterPlatform(std::move(platform))); } } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor -REGISTER_MODULE_INITIALIZER(interpreter_platform, - sep::InitializeXlaInterpreterPlatform()); - -DECLARE_MODULE_INITIALIZER(multi_platform_manager); +REGISTER_MODULE_INITIALIZER( + interpreter_platform, + stream_executor::interpreter::InitializeXlaInterpreterPlatform()); // Note that module initialization sequencing is not supported in the // open-source project, so this will be a no-op there. diff --git a/tensorflow/compiler/xla/service/interpreter/platform.h b/tensorflow/compiler/xla/service/interpreter/platform.h index 2f71b29be44..0187f6d473b 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.h +++ b/tensorflow/compiler/xla/service/interpreter/platform.h @@ -18,18 +18,19 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/stream_executor/executor_cache.h" #include "tensorflow/stream_executor/plugin.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/trace_listener.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { class XlaInterpreterPlatform : public Platform { public: - XlaInterpreterPlatform(); + XlaInterpreterPlatform(const string& name = "Interpreter", + const Platform::Id& id = kXlaInterpreterPlatformId); ~XlaInterpreterPlatform() override; Platform::Id id() const override; @@ -56,6 +57,8 @@ class XlaInterpreterPlatform : public Platform { private: // This platform's name. string name_; + // This platform's id. + Platform::Id id_; // Cache of created StreamExecutors. ExecutorCache executor_cache_; @@ -64,7 +67,6 @@ class XlaInterpreterPlatform : public Platform { }; } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.cc b/tensorflow/compiler/xla/service/interpreter/platform_id.cc index b7fb365b70d..3272396ce50 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform_id.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.cc @@ -14,12 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { PLATFORM_DEFINE_ID(kXlaInterpreterPlatformId); } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor diff --git a/tensorflow/compiler/xla/service/interpreter/platform_id.h b/tensorflow/compiler/xla/service/interpreter/platform_id.h index 292f958449b..a6cc10bcc1e 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform_id.h +++ b/tensorflow/compiler/xla/service/interpreter/platform_id.h @@ -18,14 +18,12 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" -namespace perftools { -namespace gputools { +namespace stream_executor { namespace interpreter { extern const Platform::Id kXlaInterpreterPlatformId; } // namespace interpreter -} // namespace gputools -} // namespace perftools +} // namespace stream_executor #endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_PLATFORM_ID_H_ diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 2494569db53..7067b6f86a0 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -31,10 +31,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -400,9 +402,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints) { + const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, HloComputation* computation, + LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -424,11 +426,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { - // Parameter layouts must match the respective layout in - // ComputationLayout. - shape_with_layout = - &computation_layout.parameter_layout(instruction->parameter_number()) - .shape(); + if (computation_layout != nullptr) { + const ShapeLayout& parameter_layout = + computation_layout->parameter_layout( + instruction->parameter_number()); + if (parameter_layout.LayoutIsSet()) { + // Parameter layouts must match the respective layout in + // ComputationLayout, if there is one. + shape_with_layout = ¶meter_layout.shape(); + } + } } if (shape_with_layout != nullptr) { TF_RETURN_IF_ERROR( @@ -493,9 +500,8 @@ Status LayoutAssignment::AddMandatoryConstraints( HloComputation* body = instruction->while_body(); HloComputation* condition = instruction->while_condition(); const HloInstruction* init = instruction->operand(0); - const ComputationLayout& body_layout = - FindOrDie(computation_layouts_, body); - const ComputationLayout& condition_layout = + ComputationLayout& body_layout = FindOrDie(computation_layouts_, body); + ComputationLayout& condition_layout = FindOrDie(computation_layouts_, condition); // Check a few invariants irrespective of layout. @@ -508,26 +514,19 @@ Status LayoutAssignment::AddMandatoryConstraints( condition_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape())); - // Return error if earlier layout assignment of the embedded computations - // has produced conflicting layouts. - if (!ShapeUtil::Equal(body_layout.result_shape(), - body_layout.parameter_shape(0))) { - return InternalError( - "Parameter and result of body computation %s of while instruction " - "%s have different layouts: %s vs %s", - body->name().c_str(), instruction->name().c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str(), - ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str()); + if (body_layout.result_layout() != body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while body parameter layout: body=" << body->name() + << " while=" << instruction->name() + << " shape=" << body_layout.result_layout().ToString(); + *body_layout.mutable_parameter_layout(0) = body_layout.result_layout(); } - if (!ShapeUtil::Equal(body->root_instruction()->shape(), - condition->parameter_instruction(0)->shape())) { - return InternalError( - "Parameter of condition computation %s of while instruction " - "%s does not match body computation %s result: %s vs %s", - condition->name().c_str(), instruction->name().c_str(), - body->name().c_str(), - ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(), - ShapeUtil::HumanString(body_layout.result_shape()).c_str()); + if (condition_layout.parameter_layout(0) != + body_layout.parameter_layout(0)) { + VLOG(2) << "Reset %while condition parameter layout: cond=" + << condition->name() << " while=" << instruction->name() + << " shape=" << body_layout.parameter_layout(0).ToString(); + *condition_layout.mutable_parameter_layout(0) = + body_layout.parameter_layout(0); } // Constrain the output and the operand of the while instruction to match @@ -557,7 +556,20 @@ Status LayoutAssignment::AddMandatoryConstraints( true_computation_layout.parameter_shape(0))); DCHECK(ShapeUtil::Compatible( false_operand->shape(), false_computation_layout.parameter_shape(0))); - + if (true_computation_layout.result_layout() != + false_computation_layout.result_layout()) { + // We assign layouts in DFS fashion, so the true and false computations + // might have negotiated a different layout. But for the conditional + // instruction POV the layout must match, so we run again on the false + // computation, this time with proper computation layout. + VLOG(2) << "Reset %conditional false computation result layout: " + "false_computation=" + << false_computation->name() + << " conditional=" << instruction->name() << " shape=" + << true_computation_layout.result_layout().ToString(); + *false_computation_layout.mutable_result_layout() = + true_computation_layout.result_layout(); + } TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( true_computation_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( @@ -593,10 +605,14 @@ Status LayoutAssignment::AddMandatoryConstraints( } } } - - // Finally set the result layout to match ComputationLayout. - return constraints->SetResultLayout( - computation_layout.result_layout().shape()); + // Finally set the result layout to match ComputationLayout, if there is one. + if (computation_layout != nullptr) { + const ShapeLayout& result_layout = computation_layout->result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape())); + } + } + return Status::OK(); } namespace { @@ -760,6 +776,7 @@ StatusOr LayoutAssignment::CreateCopyWithNewLayout( HloInstruction* copy = instruction->parent()->AddInstruction(HloInstruction::CreateUnary( instruction->shape(), HloOpcode::kCopy, instruction)); + RegisterAddedCopy(copy); SetupCopiedInstruction(*instruction, copy, {}); LayoutUtil::ClearLayout(copy->mutable_shape()); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( @@ -783,13 +800,19 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer( TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + VLOG(5) << "Operand " << operand->ToString() << " layout matches in " + << instruction->ToString(); // Operand layout already matches our constraint. Nothing to do. return Status::OK(); } + VLOG(4) << "Operand " << operand->ToString() << " layout does not match " + << operand_layout.ToString() << " in " << instruction->ToString(); TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, CreateCopyWithNewLayout(operand_layout.shape(), operand)); + VLOG(4) << "New copy of " << operand->ToString() << " is " + << operand_copy->ToString(); return instruction->ReplaceOperandWith(operand_no, operand_copy); } @@ -896,15 +919,16 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { } } } - - // Finally verify the result layout matches the layout of the entry + // Finally verify the result layout, if set, matches the layout of the entry // computation root. - TF_RET_CHECK(ShapeUtil::Equal( - module->entry_computation()->root_instruction()->shape(), + const ShapeLayout& result_layout = FindOrDie(computation_layouts_, module->entry_computation()) - .result_layout() - .shape())); - + .result_layout(); + if (result_layout.LayoutIsSet()) { + TF_RET_CHECK(ShapeUtil::Equal( + module->entry_computation()->root_instruction()->shape(), + result_layout.shape())); + } return Status::OK(); } @@ -913,18 +937,13 @@ LayoutAssignment::LayoutAssignment( ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), channel_layout_constraints_(channel_constraints) { - VLOG(1) << "entry computation layout given to layout assignment: " + VLOG(1) << "Entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. for (const ShapeLayout& parameter_layout : entry_computation_layout_->parameter_layouts()) { CHECK(parameter_layout.LayoutIsSet()); } - // If the result layout is not set, then choose the default. - // TODO(b/29118294): Choose a better layout in this case. - if (!entry_computation_layout_->result_layout().LayoutIsSet()) { - entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout(); - } } std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( @@ -1484,16 +1503,60 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); } +Status LayoutAssignment::CalculateComputationLayout( + HloComputation* computation) { + ComputationLayout computation_layout(computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + InsertOrDie(&computation_layouts_, computation, computation_layout); + VLOG(2) << " Calculated ComputationLayout = " + << computation_layout.ToString(); + return Status::OK(); +} + +Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString().c_str()); + } + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + return Status::OK(); +} + Status LayoutAssignment::RunOnComputation( - const ComputationLayout& computation_layout, + ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints) { - DCHECK(computation_layout.LayoutIsSet()); - InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() << ")"; - VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + TF_RETURN_IF_ERROR(ClearComputationLayouts(computation)); + if (computation_layout != nullptr) { + auto it = computation_layouts_.find(computation); + if (it == computation_layouts_.end()) { + VLOG(2) << " New ComputationLayout = " << computation_layout->ToString(); + computation_layouts_.emplace(computation, *computation_layout); + } else { + TF_RET_CHECK(computation_layout == &it->second || + computation_layout == entry_computation_layout_); + VLOG(2) << " Existing ComputationLayout = " + << computation_layout->ToString(); + } + } else { + VLOG(2) << " No ComputationLayout specified (will be calculated)"; + } // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); @@ -1536,12 +1599,19 @@ Status LayoutAssignment::RunOnComputation( CHECK_LT(constraints.unconstrained_buffer_ids().size(), unconstrained_count); } - // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + // If the computation layout wasn't specified, now it is the time to compute + // it according to the parameters and root instruction layouts. + // This allows the first pass through this API to record the best flowing + // layout to parameters and root instruction. + if (computation_layout == nullptr) { + TF_RETURN_IF_ERROR(CalculateComputationLayout(computation)); + } + // Record the layouts assigned for any communication ops in // channel_constraints so that they are constrained for future modules. for (HloInstruction* instruction : computation->instructions()) { @@ -1556,6 +1626,34 @@ Status LayoutAssignment::RunOnComputation( return Status::OK(); } +Status LayoutAssignment::PropagateComputationLayouts( + HloComputation* computation, ComputationLayout* computation_layout) { + ComputationLayout computed_computation_layout( + computation->ComputeProgramShape(), + /*ignore_layouts=*/false); + for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) { + ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i); + if (!param_layout->LayoutIsSet()) { + VLOG(4) << "Assigning layout to parameter " << i << " of computation " + << computation->name() << ": " + << computed_computation_layout.parameter_layout(i).ToString(); + *param_layout = computed_computation_layout.parameter_layout(i); + } else { + TF_RET_CHECK(computed_computation_layout.parameter_layout(i) == + *param_layout); + } + } + ShapeLayout* result_layout = computation_layout->mutable_result_layout(); + if (!result_layout->LayoutIsSet()) { + VLOG(4) << "Assigning result layout of computation " << computation->name() + << ": " << computed_computation_layout.result_layout().ToString(); + *result_layout = computed_computation_layout.result_layout(); + } else { + TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout); + } + return Status::OK(); +} + StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); XLA_VLOG_LINES(3, module->ToString()); @@ -1564,52 +1662,45 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "before layout assignment", module->config().debug_options()); } + TF_RETURN_IF_ERROR(Init()); - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Assign layouts to computations in an order such that a callee computation - // is handled before its caller computation. This ensures that the layout of - // all callers of a computation will agree. - std::list computation_post_order = - module->MakeComputationPostOrder(); - for (auto* computation : module->MakeComputationPostOrder()) { - if (computation->IsFusionComputation()) { - continue; - } - // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidently use the existing layout. - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kBitcast) { - // bitcasts are inherently layout sensitive and so a bitcast instruction - // present in the IR before layout assignment is a bug. - return InternalError( - "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + // We do two passes. The first one we pass a nullptr ComputationLayout to + // the RunOnComputation() calls (for non entry computations), and we register + // the ComputationLayout which are naturally flowing in DFS fashion to the + // parameters and root instruction. + // Walking in DFS mode though, means that we can end up with incorrect layouts + // when seen from an outer instruction, which has across-computation + // constraints to impose. + // For example, the kWhile instruction needs to enforce the same layouts for + // the parameters and root of the bosy, as well as the condition parameters. + // Similarly, the kConditional instruction needs to enforce the same layouts + // for the root of the true and false computations. + // So in the first pass, while allowing the layouts to flow to parameters and + // root, we also fix up the eventually inconsistent ComputationLayout, which + // will be then made mandatory by the second pass. + for (int64 i = 0; i < 2; ++i) { + TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); + for (auto* computation : module->MakeComputationPostOrder()) { + if (computation->IsFusionComputation()) { + continue; } - if (instruction->opcode() != HloOpcode::kInfeed) { - LayoutUtil::ClearLayout(instruction->mutable_shape()); + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); + } else { + ComputationLayout* computation_layout = + (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation); + TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, + *points_to_analysis, computation, + channel_layout_constraints_)); } } - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation( - *entry_computation_layout_, *points_to_analysis, - module->entry_computation(), channel_layout_constraints_)); - } else { - ComputationLayout computation_layout(computation->ComputeProgramShape()); - // Setting all embedded computations to the default layout is potentially - // suboptimal. - computation_layout.SetToDefaultLayout(); - TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation, - channel_layout_constraints_)); - } } - + TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(), + entry_computation_layout_)); TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; @@ -1619,9 +1710,54 @@ StatusOr LayoutAssignment::Run(HloModule* module) { "after layout assignment", module->config().debug_options()); } - // All layouts are reset then reassigned by this pass. return true; } +Status LayoutAssignment::Init() { + computation_layouts_.clear(); + return Status::OK(); +} + +Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) { + // Clear all the copies which have been added, and all the related + // instructions (like GTE and tuples). + int64 removed_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kCopy && + added_copies_.count(instruction) > 0) { + VLOG(5) << "Removing added copy: " << instruction->ToString(); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWith(instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction)); + ++removed_copies; + } + } + } + added_copies_.clear(); + if (removed_copies > 0) { + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction, + int64 operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + SetupCopiedInstruction(*operand, copy, {}); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy)); + } + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index ae4986d6ad9..c287cca0c54 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -281,8 +282,8 @@ class LayoutAssignment : public HloPassInterface { // the case that no particular layout is requested. // // channel_constraints is both an input and output. Any sends or recvs that - // are present in channel_constraints will be layed out as constrained. Any - // unconstrained sends or recvs will be layed out as locally optimal and their + // are present in channel_constraints will be laid out as constrained. Any + // unconstrained sends or recvs will be laid out as locally optimal and their // layout will be added as a constraint to channel_constraints. // // If channel_constraints is nullptr, no kSend or kRecvs must be contained @@ -362,12 +363,15 @@ class LayoutAssignment : public HloPassInterface { int64 operand_no); private: + // Initializes the layout assignment object for a new Run() call. + Status Init(); + // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints( - const ComputationLayout& computation_layout, - const ChannelLayoutConstraints* channel_constraints, - HloComputation* computation, LayoutConstraints* constraints); + Status AddMandatoryConstraints(const ComputationLayout* computation_layout, + ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, + LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -378,10 +382,12 @@ class LayoutAssignment : public HloPassInterface { } // Construct contraints and assign layouts to all instructions in the - // computation satisfying the given ComputationLayout. Layouts constraints are - // added, then propagated until all LogicalBuffers in the computation are - // constrained. - Status RunOnComputation(const ComputationLayout& computation_layout, + // computation satisfying the given ComputationLayout, if not nullptr. + // Otherwise the ComputationLayout will be calculated by propagating the + // computation instruction contraints. + // Layouts constraints are added, then propagated until all LogicalBuffers in + // the computation are constrained. + Status RunOnComputation(ComputationLayout* computation_layout, const TuplePointsToAnalysis& points_to_analysis, HloComputation* computation, ChannelLayoutConstraints* channel_constraints); @@ -402,6 +408,25 @@ class LayoutAssignment : public HloPassInterface { // necessary conditions. Status CheckLayouts(HloModule* module); + // Computes the ComputationLayout of the given computation based of the + // layouts assigned to parameters and root instruction, and inserts it to the + // computation_layouts_ map. + Status CalculateComputationLayout(HloComputation* computation); + + // Clears all the layouts which can be cleared within a computation. + Status ClearComputationLayouts(HloComputation* computation); + + // Clears the side effects of a previous pass, like added copy instructions. + Status ClearPreviousPassSideEffects(HloModule* module); + + // Propagates the layouts computed by the layout assignment pass on the given + // computation, to the computation layout passed in to this API. + // This API propagates missing layout, and also checks that the caller + // specified have been respected, by comparing those with the parameters and + // root computation instruction. + Status PropagateComputationLayouts(HloComputation* computation, + ComputationLayout* computation_layout); + ComputationLayout* entry_computation_layout_; protected: @@ -418,21 +443,37 @@ class LayoutAssignment : public HloPassInterface { // Creates and returns a copy of the given instruction with a different // layout. Tuple-shaped instructions will be deep-copied, and the last Tuple // instruction producing the copy is returned. - static StatusOr CreateCopyWithNewLayout( + StatusOr CreateCopyWithNewLayout( const Shape& shape_with_layout, HloInstruction* instruction); // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); + + // Registers a copy instruction added by the layout assignment pass. + void RegisterAddedCopy(HloInstruction* copy) { + CHECK_EQ(copy->opcode(), HloOpcode::kCopy); + added_copies_.insert(copy); + } + + // Adds a copy for the operand of an instruction, unless such operand is + // already a copy, and has a single user (which is forcibly the instruction + // itself). + Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number); // Map containing the layouts of all computations assigned so // far. Computations are handled in a topological sort where computations are // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map computation_layouts_; + + // Every copy added to the module by the layout assignment pass is registered + // here. + tensorflow::gtl::FlatSet added_copies_; + ChannelLayoutConstraints* channel_layout_constraints_; }; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 4b1c9bad41d..7508013199a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -660,13 +660,12 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - EXPECT_EQ( - ::tensorflow::Status::OK(), - backend() - .compiler() - ->RunBackend(std::move(module), backend().default_stream_executor(), - /*device_allocator=*/nullptr) - .status()); + EXPECT_EQ(Status::OK(), backend() + .compiler() + ->RunBackend(std::move(module), + backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); } // A GTE inside of a fusion node inherits the layout of its operand (which diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc deleted file mode 100644 index 68c99256a24..00000000000 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ /dev/null @@ -1,379 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" - -namespace xla { - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return true; - } else if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - auto it = std::find_if( - user->fused_parameters().begin(), user->fused_parameters().end(), - [=](HloInstruction* fused_param) { - return user->operand(fused_param->parameter_number()) == operand; - }); - CHECK(it != user->fused_parameters().end()); - // Iterate through all users of all buffer aliases of the buffer in the - // points-to set of fusion parameter at 'index'. - // Return false if any uses are detected at 'index', returns true otherwise. - const LogicalBuffer* buffer = - points_to_analysis.GetBufferDefinedAt(*it, index).ValueOrDie(); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. - return false; - } - } - // Return true: found no uses of 'operand' at 'index' in 'user'. - return true; - } - return false; -} - -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : - dataflow.GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; - } - } - } - } - - return true; -} - -namespace { - -// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. -// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'instruction' at 'index', and -// 'operand_index' is the operand index at which the alias appears in the -// operand list of 'user'. -std::vector> GetAllUsesOfInstructionAtIndex( - HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector> uses; - const PointsToSet::BufferList& points_to = - points_to_analysis.GetPointsToSet(instruction).element(index); - for (const LogicalBuffer* buffer : points_to) { - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* alias_user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - alias_user, points_to_analysis)) { - continue; - } - for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { - uses.emplace_back(alias_user, op_idx); - } - } - } - } - return uses; -} - -// Returns true if there is exactly one use of 'operand' at 'operand_index' -// in 'fusion.fused_instructions', where the singleton use is the fused -// root at operand index 'use_operand_index'. Returns false otherwise. -// -// REQUIRES: 'fusion' opcode is a kFusion instruction. -bool HasUniqueFusedUseOfOperandAt( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* fusion, const int64 use_operand_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - // Check that 'operand' is unique in the operand list of 'fusion'. - if (fusion->OperandIndices(operand).size() > 1) { - return false; - } - // Find fusion parameter associated with 'operand'. - const auto& fused_params = fusion->fused_parameters(); - auto fused_param_it = std::find_if( - fused_params.begin(), fused_params.end(), - [&](HloInstruction* fused_param) { - return fusion->operand(fused_param->parameter_number()) == operand; - }); - if (fused_param_it == fused_params.end()) { - return false; - } - auto* fused_param = *fused_param_it; - // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. - auto fused_param_uses = GetAllUsesOfInstructionAtIndex( - fused_param, operand_index, points_to_analysis); - // Return true iff there is exactly one use of 'operand' at 'index', and - // this singleton use is the fused root (at index in 'use_operand_indices'). - return fused_param_uses.size() == 1 && - fused_param_uses[0].first == fusion->fused_expression_root() && - fused_param_uses[0].second == use_operand_index; -} - -} // namespace - -// User and operand can share buffers iff both instructions emit the same shape -// and layout, and 'user' meets one of the following qualifications: -// -// (1) Is element-wise. Or... -// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index -// 0. -// -// (2) and (3) can only be determined if points-to analysis is available. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, - other_add_operand_index, - points_to_analysis); - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // TODO(b/62548313): Remove when buffer assignment is module scoped and - // does not assign buffers to calls. - // Find called computation parameter associated with 'operand'. - const std::vector operand_indices = user->OperandIndices(operand); - if (operand_indices.size() > 1) { - return false; - } - CHECK_EQ(1, operand_indices.size()); - auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); - // Get all uses of 'operand' at 'index' in called computation. - auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, - points_to_analysis); - - // Return true iff: - // *) There exists exactly one use of 'operand' in called computation. - // *) The unique use is by the root instruction of called computation. - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - auto* callee_root = user->to_apply()->root_instruction(); - return param_uses.size() == 1 && param_uses[0].first == callee_root && - callee_root->IsElementwiseOnOperand(param_uses[0].second); - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - const Shape& operand_subshape = - ShapeUtil::GetSubshape(operand->shape(), operand_index); - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - // Check that operand and user emit the same shape and layout. - if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { - return false; - } - - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& value = - dataflow.GetValueDefinedAt(fusion_param, operand_index); - if (value.uses().size() != 1) { - return false; - } - const HloUse& use = value.uses()[0]; - - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; - } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && - user->fused_expression_root()->opcode() == HloOpcode::kAdd) { - // Output fusion with kAdd fused root. - - // Check if one operand of kAdd fused root is either kDot, or nested - // kFusion of kind kTransposeDot. - auto* add = user->fused_expression_root(); - auto add_operand_it = - std::find_if(add->operands().begin(), add->operands().end(), - [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kConvolution || - operand->opcode() == HloOpcode::kDot || - (operand->opcode() == HloOpcode::kFusion && - operand->fusion_kind() == - HloInstruction::FusionKind::kTransposeDot); - }); - if (add_operand_it == add->operands().end()) { - return false; - } - auto* matched_add_operand = *add_operand_it; - // Calculate operand index of 'add' operand which was not matched above. - const int64 other_add_operand_index = - matched_add_operand == add->operand(0) ? 1 : 0; - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root (at operand - // index 'other_add_operand_index'). - return use.instruction == user->fused_expression_root() && - use.operand_number == other_add_operand_index; - } - } - if (user->opcode() == HloOpcode::kDynamicUpdateSlice || - user->opcode() == HloOpcode::kWhile) { - // We eliminated other users in BufferLiveness::live_range_strictly_before, - // so here we just need to check that the use is at operand index 0. - std::vector operand_indices = user->OperandIndices(operand); - return operand_indices.size() == 1 && operand_indices[0] == 0; - } - if (user->opcode() == HloOpcode::kCall) { - // Get all uses of value defined by 'operand' at 'operand_index'. - const auto& uses = - dataflow.GetValueDefinedAt(operand, operand_index).uses(); - // Return true iff: - // *) There exists two uses of 'operand'. - // *) One use is by 'user' (caller). - // *) One use is by root instruction of called computation (callee root). - // (Note: we check the root of the called computation, because the - // root result buffer is required to alias with the Call result buffer). - // *) The root instruction of the called computation is element-wise on - // 'operand'. - const bool found_caller_use = - std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { - return use.instruction == user; - }) != uses.end(); - auto* callee_root = user->to_apply()->root_instruction(); - const bool found_elementwise_callee_use = - std::find_if( - uses.begin(), uses.end(), [callee_root](const HloUse& use) { - return use.instruction == callee_root && - callee_root->IsElementwiseOnOperand(use.operand_number); - }) != uses.end(); - return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; - } - // Check if 'user' is element-wise. - return user->IsElementwise(); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h deleted file mode 100644 index 28ef9918800..00000000000 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// A collection of utilities on the HLO graph. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ - -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" - -namespace xla { - -// Returns true if 'user' cannot possibly use the buffer at 'index' in -// 'operand'. Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const TuplePointsToAnalysis& points_to_analysis); -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user, - const HloDataflowAnalysis& dataflow); - -// Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). Returns false otherwise. -// -// REQUIRES: 'operand' is an operand of 'user'. -// -// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have -// moved over to the dataflow overload. -bool CanShareOperandBufferWithUser( - HloInstruction* operand, const ShapeIndex& operand_index, - HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); -bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index, - const HloDataflowAnalysis& dataflow); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc deleted file mode 100644 index f8b309488ee..00000000000 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ /dev/null @@ -1,463 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/liveness_util.h" - -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" - -namespace xla { -namespace { - -class PointsToAnalysisTestBase : public HloTestBase { - protected: - void BuildModule(std::unique_ptr computation) { - module_ = CreateNewModule(); - computation_ = module_->AddEntryComputation(std::move(computation)); - } - - void RunAnalysis() { - CHECK_NOTNULL(module_.get()); - points_to_analysis_ = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); - } - - void BuildModuleAndRunAnalysis(std::unique_ptr computation) { - BuildModule(std::move(computation)); - RunAnalysis(); - } - - std::unique_ptr module_; - HloComputation* computation_ = nullptr; - std::unique_ptr points_to_analysis_; - std::unique_ptr dataflow_analysis_; -}; - -class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; - -TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { - auto builder = HloComputation::Builder(TestName()); - - Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); - builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *points_to_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); - EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); -} - -TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction never uses tuple element 0, but does use element 1. - EXPECT_TRUE( - DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); - - EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); - EXPECT_FALSE( - DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); -} - -class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto log = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { - auto builder = HloComputation::Builder(TestName()); - - Shape in_shape = ShapeUtil::MakeShape(F32, {8}); - Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); - auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, in_shape, "param0")); - auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, in_shape, "param1")); - auto result = builder.AddInstruction( - HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *dataflow_analysis_)); - EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { - auto builder = HloComputation::Builder(TestName()); - - Shape shape = ShapeUtil::MakeShape(F32, {8}); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto exp = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); - auto copy = builder.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); - - BuildModuleAndRunAnalysis(builder.Build()); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); - - // Create a DynamicUpdateSlice instruction of tuple element 1. - auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); - auto update = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR1({2.f, 2.f, 2.f}))); - auto dynamic_update_slice = - builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, gte1, update, starts)); - builder.AddInstruction( - HloInstruction::CreateTuple({gte0, dynamic_update_slice})); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {dynamic_update_slice, starts, update, gte1}, - HloInstruction::FusionKind::kLoop); - RunAnalysis(); - - // The fusion instruction can share with tuple element 1. - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *dataflow_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { - auto builder = HloComputation::Builder(TestName()); - - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - Shape update_shape = ShapeUtil::MakeShape(F32, {4}); - Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto update = builder.AddInstruction( - HloInstruction::CreateParameter(1, update_shape, "update")); - auto starts = builder.AddInstruction( - HloInstruction::CreateParameter(2, starts_shape, "starts")); - auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - data_shape, data, update, starts)); - - BuildModuleAndRunAnalysis(builder.Build()); - - // The DynamicUpdateSlice instruction can share with the data operand, but not - // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, dot}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused dot add should be able to share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto a = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); - auto b = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - auto b_t = builder.AddInstruction( - HloInstruction::CreateTranspose(data_shape, b, {1, 0})); - - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto add_operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto add = builder.AddInstruction(HloInstruction::CreateBinary( - data_shape, HloOpcode::kAdd, dot, add_operand)); - - BuildModule(builder.Build()); - - auto nested_fusion = computation_->CreateFusionInstruction( - {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); - - auto fusion = computation_->CreateFusionInstruction( - {add, nested_fusion}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused transpose-dot-add should be share buffer with 'add_operand'. - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { - auto builder = HloComputation::Builder(TestName()); - Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); - - auto one = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto operand = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape, one, {1})); - - auto reverse = builder.AddInstruction( - HloInstruction::CreateReverse(data_shape, operand, {0, 1})); - - auto two = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); - - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); - - BuildModule(builder.Build()); - auto fusion = computation_->CreateFusionInstruction( - {add, two, reverse}, HloInstruction::FusionKind::kOutput); - RunAnalysis(); - - // Output fused operand->reverse->add cannot alias operand buffer 'operand'. - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); - - EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *dataflow_analysis_)); -} - -TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { - Shape data_shape = ShapeUtil::MakeShape(F32, {8}); - - auto make_cond = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Cond"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); - return builder.Build(); - }; - - auto make_body = [this, &data_shape]() { - auto builder = HloComputation::Builder(TestName() + ".Body"); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); - return builder.Build(); - }; - - module_ = CreateNewModule(); - HloComputation* cond_computation = - module_->AddEmbeddedComputation(make_cond()); - HloComputation* body_computation = - module_->AddEmbeddedComputation(make_body()); - - auto builder = HloComputation::Builder(TestName()); - auto data = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape, "data")); - auto whil = builder.AddInstruction(HloInstruction::CreateWhile( - data_shape, cond_computation, body_computation, data)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); - - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); -} - -// Tests that Call can alias operand buffer if the only use of the operand -// in the called computation is an elementwise instruction. -TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { - Shape shape = ShapeUtil::MakeShape(F32, {8}); - // Build sub-computation with fusion root. - auto sub_builder = HloComputation::Builder(TestName() + "_sub"); - auto sub_param = sub_builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "sub_param")); - auto one = sub_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1.0))); - auto ones = sub_builder.AddInstruction( - HloInstruction::CreateBroadcast(shape, one, {1})); - auto add = sub_builder.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); - - module_ = CreateNewModule(); - auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); - sub_computation->CreateFusionInstruction({add, ones}, - HloInstruction::FusionKind::kLoop); - - // Build entry-computation with kCall which calls 'sub_computation'. - auto builder = HloComputation::Builder(TestName()); - - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "param")); - auto reverse = - builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); - auto call = builder.AddInstruction( - HloInstruction::CreateCall(shape, {reverse}, sub_computation)); - computation_ = module_->AddEntryComputation(builder.Build()); - - RunAnalysis(); - - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *points_to_analysis_)); - EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, - *dataflow_analysis_)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index 911b243fe28..b17c9d50450 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -23,7 +23,7 @@ limitations under the License. namespace xla { StatusOr>> LLVMCompiler::Compile( std::vector> modules, - std::vector> stream_execs, + std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { // Tensorflow tries to enable the following behaviors in all its threads: // diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index d74e81bb7f6..f1c623508c5 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -60,19 +60,18 @@ class LLVMCompiler : public Compiler { // Bring in // StatusOr> RunBackend( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec, + // se::StreamExecutor* stream_exec, // DeviceMemoryAllocator* device_allocator) // StatusOr> RunHloPasses( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec, + // se::StreamExecutor* stream_exec, // DeviceMemoryAllocator* device_allocator) using Compiler::RunBackend; using Compiler::RunHloPasses; StatusOr>> Compile( std::vector> modules, - std::vector> - stream_execs, + std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) override; protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index bc683a1880b..f172b1d87c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -151,7 +151,7 @@ Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { Status FusedIrEmitter::FinishVisit(HloInstruction* root) { fused_root_ = root; - return tensorflow::Status::OK(); + return Status::OK(); } FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const { diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 3312a888443..7323abeb207 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -333,18 +333,7 @@ llvm::Value* IrArray::EmitArrayElementAddress( } CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_)); - std::vector actual_index; - bool is_implicit_broadcast = false; - // We perform broadcasting when the operand shape has dimension(s) of size - // 1. In this case we fix the index value for that dimension to zero. This - // effectively broadcasts along this dimension. - for (int64 i = 0; i < index.size(); ++i) { - auto dim = shape_->dimensions(i); - actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); - is_implicit_broadcast |= dim == 1; - } - - if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + if (index.LinearValidOnShape(*shape_)) { llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( @@ -354,6 +343,15 @@ llvm::Value* IrArray::EmitArrayElementAddress( {index.linear()}, llvm_ir::AsStringRef(name)); } + std::vector actual_index; + for (int64 i = 0; i < index.size(); ++i) { + // When dimension i is of size 1, LLVM optimization is able to replace + // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to + // produce better code in some cases. + auto dim = shape_->dimensions(i); + actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]); + } + // "base_ptr_" has the type of "*" // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element // should be computed by diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 06cfb2a36c5..4c3195c29c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -97,6 +97,10 @@ class IrArray { llvm::Value*& operator[](size_t i) { return multidim()[i]; } void push_back(llvm::Value* value) { multidim().push_back(value); } + void InsertAt(int64 index, llvm::Value* value) { + CHECK_LE(index, size()); + multidim().insert(multidim().begin() + index, value); + } using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 7b227ce2941..497b48ff227 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -36,8 +36,8 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, bool prevent_unrolling, bool prevent_vectorization) - : prefix_(prefix.ToString()), - suffix_(suffix.ToString()), + : prefix_(std::string(prefix)), + suffix_(std::string(suffix)), start_index_(start_index), end_index_(end_index), step_(step), diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 20069ce5a28..d915f95db13 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -174,7 +174,7 @@ class ForLoopNest { : ForLoopNest(/*name=*/"", ir_builder) {} ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) - : name_(name.ToString()), + : name_(std::string(name)), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 3978acc132f..0728ccfff7b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -39,14 +39,13 @@ LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { + : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status { // Convert target_element_generator to a BodyEmitter. TF_ASSIGN_OR_RETURN(llvm::Value * target_element, target_element_generator(array_index)); target_array.EmitWriteArrayElement(array_index, target_element, ir_builder); - return tensorflow::Status::OK(); + return Status::OK(); }), shape_(target_array.GetShape()), ir_builder_(ir_builder) {} @@ -124,7 +123,7 @@ std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { +Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { for (const IrArray::Index& array_index : EmitIndexAndSetExitBasicBlock(loop_name)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); @@ -135,7 +134,7 @@ tensorflow::Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name) { if (exit_bb_ != nullptr) { ir_builder_->SetInsertPoint(exit_bb_); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 9ff497aecd0..b70d28ecd30 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -38,8 +38,7 @@ using ElementGenerator = // Emits a loop for every element in the given shape. class LoopEmitter { public: - using BodyEmitter = - std::function; + using BodyEmitter = std::function; LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape, llvm::IRBuilder<>* ir_builder); @@ -72,7 +71,7 @@ class LoopEmitter { tensorflow::StringPiece loop_name); // Emits a complete loop nest for every element in the given shape. - tensorflow::Status EmitLoop(tensorflow::StringPiece loop_name = ""); + Status EmitLoop(tensorflow::StringPiece loop_name = ""); protected: // An IR emitter that generates the loop body. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 34899b74004..dacc54742c0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -49,22 +49,41 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( for (int64 i = 0; i < rank; ++i) { IrArray::Index dim_index({ir_builder->getInt64(i)}); TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); + llvm::Value* output_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* update_dim_size = llvm::ConstantInt::get( + start_index[i]->getType(), update_shape.dimensions(i)); + + // Clamp the start index so that the update region fits in the operand. + // start_index = clamp(start_index, 0, output_dim_size - update_dim_size) + + // TODO(b/74360564): This is implementation defined behavior, but is + // currently respected by all implementations. Change this if we ever decide + // to oficially document different behavior. + llvm::Value* max_bound = + ir_builder->CreateSub(output_dim_size, update_dim_size); + llvm::Value* zero = llvm::ConstantInt::get(start_index[i]->getType(), 0); + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SGE, zero, start_index[i]), + zero, start_index[i]); + + start_index[i] = ir_builder->CreateSelect( + ir_builder->CreateICmp(llvm::ICmpInst::ICMP_SLE, max_bound, + start_index[i]), + max_bound, start_index[i]); } auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // output_index[dim] = start_index[dim] + update_index[dim] // IrArray::Index output_index(rank); for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - update_index[i]->getType(), output_shape.dimensions(i)); - llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + llvm::Value* start_index0 = ir_builder->CreateSExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateURem( - ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); } // Do output[output_index] = update[update_index]. diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 499f280211a..0fa40617386 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -43,13 +43,11 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { /* static */ StatusOr> LocalService::NewService( const ServiceOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); } diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index 68553bed121..c742d35a7bc 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,9 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include -#include - #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" @@ -28,43 +25,20 @@ namespace xla { LogicalBuffer::LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) - : instruction_(instruction), id_(id), color_(kInvalidColor), index_(index) { - const auto& s = shape(); - is_array_ = ShapeUtil::IsArray(s); - is_tuple_ = ShapeUtil::IsTuple(s); -} + : BufferValue(instruction, index, id), + instruction_(instruction), + index_(index) {} + +LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { + string color_string; + if (has_color()) { + color_string = tensorflow::strings::StrCat(" @", color().value()); + } return tensorflow::strings::StrCat(instruction_->name(), "[", tensorflow::str_util::Join(index_, ","), - "](#", id_, " @", color_.value(), ")"); -} - -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) { - out << buffer.ToString(); - return out; -} - -/*static*/ LogicalBufferProto::Location LogicalBuffer::ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index) { - LogicalBufferProto::Location proto; - proto.set_computation_name(instruction.parent()->name()); - proto.set_instruction_name(instruction.name()); - for (const int64 index_entry : index) { - proto.add_shape_index(index_entry); - } - return proto; -} - -LogicalBufferProto LogicalBuffer::ToProto(const SizeFunction& size_fn) const { - LogicalBufferProto proto; - proto.set_id(id_); - proto.set_size(size_fn(*this)); - LogicalBufferProto::Location proto_location = - ToLocationProto(*instruction_, index_); - proto.mutable_defined_at()->Swap(&proto_location); - proto.set_color(color_.value()); - return proto; + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 67b205e289e..f9ba5a55474 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -16,11 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ -#include -#include #include -#include +#include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -33,133 +31,30 @@ limitations under the License. namespace xla { -// Class describing a contiguous sequence of elements (ie, C array) which form -// the components of Shaped values in XLA. XLA arrays are trivially a -// single LogicalBuffer. Tuple values are made up of more than one -// LogicalBuffer: a LogicalBuffer for the pointers to elements, and a -// LogicalBuffer for each child element. -// -// Every buffer is defined by a particular instruction and most instructions -// define only a single buffer. Instructions which define a single buffer -// include array-shaped instructions such as Add but also includes Tuple-shaped -// instructions such as Tuple. The Tuple instruction defines a single buffer -// which is a vector of pointers to the buffers containing the Tuple -// instruction's operands. Though the result of the Tuple instruction includes -// multiple buffers only the top-level buffer (the vector of pointers) is -// defined by the Tuple instruction. The buffers containing the tuple elements -// are defined by earlier instructions, usually the operands of the Tuple -// instruction. -// -// Instructions which construct both the tuple *and* the tuple elements define -// more than one buffer. This includes (at least) tuple-shaped Constant, -// Parameter, Infeed and While instructions. The tuple-shaped instructions do -// not assemble a tuple from existing buffers like the Tuple instruction does, -// but rather define the entire tuple. -// -// Some instructions, such as Bitcast, define no buffers. These instructions -// simply forward buffers from their operands. -// -// The LogicalBuffer object describes which HLO instruction defines a buffer and -// where within that instruction's output shape the buffer is defined. The -// location within the output shape is indicated by LogicalBuffer::index() which -// is defined identically to the index used in -// ShapeUtil::GetSubshape(). Examples: -// -// %add = Add(%foo, %bar) -// %tuple_constant = Constant({1, {42, 43}}) -// -// %add defines a single array-shaped buffer LogicalBuffer(%add, {}) which holds -// the array result of the add operation. The nested-tuple-shaped -// %tuple_constant defines 5 buffers described by the following LogicalBuffer -// objects: -// -// LogicalBuffer(%tuple_constant, {}) // "Top-level" buffer: vector of -// // pointers to LogicalBuffers at -// // indices {0} and {1} -// LogicalBuffer(%tuple_constant, {0}) // Holds value "1" -// LogicalBuffer(%tuple_constant, {1}) // Holds nested tuple: vector of -// // pointers to LogicalBuffers at -// // indices {1, 0} and {1, 1} -// LogicalBuffer(%tuple_constant, {1, 0}) // Holds value "42" -// LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43" -class LogicalBuffer { +// TuplePointsToAnalysis uses this subclass of BufferValue. +class LogicalBuffer : public BufferValue { public: - TF_LIB_GTL_DEFINE_INT_TYPE(Color, int64); - - // Id is a unique identifier for the LogicalBuffer to facilitate efficient - // collections of LogicalBuffers with stable iteration order. - // LogicalBuffers are typically created and accessed through - // TuplePointsToAnalysis, and points-to analysis assigns each LogicalBuffer a - // unique value. - using Id = int64; - - // Functions which return the size and alignment of a logical buffer in bytes. - using SizeFunction = std::function; - using AlignmentFunction = std::function; - LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id); - - Id id() const { return id_; } + ~LogicalBuffer() override; // Return the instruction that defines the buffer. - HloInstruction* instruction() const { return instruction_; } + HloInstruction* instruction() const override { return instruction_; } // Return the index within the output of the instruction where the buffer is // defined. Index used defined as in ShapeUtil::GetSubshape() - const ShapeIndex& index() const { return index_; } - - // Return the color of the logical buffer. Differently colored buffers can - // not be parts of the same allocation. - Color color() const { - CHECK_NE(color_, kInvalidColor) - << "Should not query the color of a buffer that was never colored"; - return color_; - } - - void set_color(Color color) { - CHECK_NE(color, kInvalidColor) - << "Should not set the color of a buffer to the invalid color"; - color_ = color; - } - - bool has_color() const { return color_ != kInvalidColor; } + const ShapeIndex& index() const override { return index_; } // Return the shape of the buffer. This reference points into the shape field // of the instruction defining the buffer. Therefore, the returned shape will // contain the layout of instruction, if any. - const Shape& shape() const { + const Shape& shape() const override { return ShapeUtil::GetSubshape(instruction_->shape(), index_); } - // Returns true if this buffer is the top-level output buffer of the defining - // HLO instruction. This is equivalent to index == {}. - bool IsTopLevel() const { return index_.empty(); } - - // Whether this buffer contains a tuple. - bool IsTuple() const { return is_tuple_; } - - // Whether this buffer contains an array. - bool IsArray() const { return is_array_; } - - // operator< is required for std::set. - bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; } - - string ToString() const; - LogicalBufferProto ToProto(const SizeFunction& size_fn) const; - - // Returns the LogicalBufferProto::Location that serializes the given - // instruction and index. - static LogicalBufferProto::Location ToLocationProto( - const HloInstruction& instruction, const ShapeIndex& index); - - const Color kInvalidColor = Color(-1); + string ToString() const override; private: HloInstruction* instruction_; - Id id_ : 62; - bool is_array_ : 1; - bool is_tuple_ : 1; - Color color_; ShapeIndex index_; // Similar to HLO constructs (HloInstruction, etc), pointers are used for @@ -167,8 +62,6 @@ class LogicalBuffer { TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer); }; -std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_ diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 7d8c05fffa4..3a6a7c25f4b 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -53,17 +53,18 @@ NameUniquer::NameUniquer(const string& separator) { } string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = prefix.empty() ? "name" : prefix.ToString(); - root = GetSanitizedName(root); + string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. + bool has_numeric_suffix = false; + int64 numeric_suffix = 0; size_t separator_index = root.rfind(separator_); if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - int64 numeric_suffix; if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); // Update count to at least the numeric suffix value to avoid future @@ -71,11 +72,11 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { generated_names_[root] = std::max(generated_names_[root], numeric_suffix); } } - int64* count = &(generated_names_[root]); if (*count == 0) { *count = 1; - return root; + return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) + : root; } else { tensorflow::strings::StrAppend(&root, separator_, *count); // Increment lookup under old 'root' name. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 4258cf16876..2ec255558c4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -57,11 +57,18 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); - EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); + EXPECT_EQ("bar.0", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); } +TEST_F(NameUniquerTest, PrefixHasSuffix) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo.11.0", uniquer.GetUniqueName("foo.11.0")); + EXPECT_EQ("foo.11", uniquer.GetUniqueName("foo.11")); +} + TEST_F(NameUniquerTest, Sanitize) { NameUniquer uniquer("_"); @@ -73,7 +80,7 @@ TEST_F(NameUniquerTest, Sanitize) { EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); // Invalid characters will be replaced with '_'. - EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_0", uniquer.GetUniqueName("bar<-1000")); EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); diff --git a/tensorflow/compiler/xla/service/owning_device_memory.cc b/tensorflow/compiler/xla/service/owning_device_memory.cc new file mode 100644 index 00000000000..c115bc097f3 --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/owning_device_memory.h" + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" + +namespace xla { + +void OwningDeviceMemory::Free() { + CHECK(allocator_ != nullptr) + << "Can't call Free() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + auto status = allocator_->Deallocate(device_ordinal_, mem_); + if (!status.ok()) { + LOG(WARNING) << "Deallocating buffer " << mem_.opaque() << " failed."; + } + + allocator_ = nullptr; + mem_ = se::DeviceMemoryBase(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/owning_device_memory.h b/tensorflow/compiler/xla/service/owning_device_memory.h new file mode 100644 index 00000000000..9cf071f0d9d --- /dev/null +++ b/tensorflow/compiler/xla/service/owning_device_memory.h @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// Break circular dependency between this file and device_memory_allocator.h. +class DeviceMemoryAllocator; + +// Owning pointer for memory on a device. +// +// OwningDeviceMemory is an owning pointer like std::unique_ptr, but it can +// point to memory that resides on a "device" (e.g. a GPU). When an +// OwningDeviceMemory goes out of scope, it frees the memory it owns. +// +// We say that an instance of OwningDeviceMemory is "active" if it currently +// owns a (possibly empty) slice of memory on the device. Moving, Forget()'ing, +// Free()'ing, and other actions can deactive an active object. +// +// Note that we can't simply use stream_executor::ScopedDeviceMemory instead of +// OwningDeviceMemory, because ScopedDeviceMemory frees its pointer via a +// StreamExecutor. This class needs to free via a xla::DeviceMemoryAllocator. +class OwningDeviceMemory { + public: + OwningDeviceMemory() : device_ordinal_(-1), allocator_(nullptr) {} + + explicit OwningDeviceMemory(se::DeviceMemoryBase mem, int device_ordinal, + DeviceMemoryAllocator* allocator) + : mem_(mem), device_ordinal_(device_ordinal), allocator_(allocator) { + CHECK(allocator != nullptr) << "allocator cannot be null."; + } + + OwningDeviceMemory(OwningDeviceMemory&& other) + : mem_(other.mem_), + device_ordinal_(other.device_ordinal_), + allocator_(other.allocator_) { + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + } + + OwningDeviceMemory& operator=(OwningDeviceMemory&& other) { + if (allocator_ != nullptr) { + Free(); + } + mem_ = other.mem_; + device_ordinal_ = other.device_ordinal_; + allocator_ = other.allocator_; + + other.mem_ = se::DeviceMemoryBase(); + other.allocator_ = nullptr; + return *this; + } + + // Deactivates this instance if it's active. Nop if it's not active. + OwningDeviceMemory& operator=(std::nullptr_t) { + if (allocator_ != nullptr) { + Free(); + } + return *this; + } + + ~OwningDeviceMemory() { + if (allocator_ != nullptr) { + Free(); + } + } + + // The returned allocator is nonnull iff this object is active. + DeviceMemoryAllocator* allocator() const { return allocator_; } + + int device_ordinal() const { return device_ordinal_; } + + // Gets the device memory pointer. + const void* opaque() const { return mem_.opaque(); } + void* opaque() { return mem_.opaque(); } + + uint64 size() const { return mem_.size(); } + + // Determines whether this wraps a null pointer. + // + // !is_null() is sufficient but not necessary to imply `this` is active. + bool is_null() const { return mem_.is_null(); } + + se::DeviceMemoryBase AsDeviceMemoryBase() { + return se::DeviceMemoryBase(opaque(), size(), /*is_sub_buffer=*/false); + } + + // Returns the wrapped DeviceMemoryBase without freeing it, and deactivates + // this object. Precondition: `this` is active. + TF_MUST_USE_RESULT se::DeviceMemoryBase Forget() { + CHECK(allocator_ != nullptr) + << "Can't call Forget() on an inactive (i.e. moved from, Forget()'ten, " + "or Free()'ed) instance."; + allocator_ = nullptr; + se::DeviceMemoryBase mem(mem_); + mem_ = se::DeviceMemoryBase(); + return mem; + } + + // Frees the wrapped DeviceMemoryBase and deactivates this object. + // Precondition: `this` is active. + void Free(); + + private: + se::DeviceMemoryBase mem_; + int device_ordinal_; + DeviceMemoryAllocator* allocator_; // Null if this object is inactive. +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_OWNING_DEVICE_MEMORY_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 5d496380772..d3bc47e61e0 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -532,7 +532,7 @@ class ShapePattern { ShapeType, ShapePatternLayoutImpl>> - IsDenseArray(const ::xla::Layout* layout) const { + IsDenseArray() const { return WithLayout(Layout().WithDenseFormat()); } @@ -540,7 +540,7 @@ class ShapePattern { ShapeType, ShapePatternLayoutImpl>> - IsSparseArray(const ::xla::Layout* layout) const { + IsSparseArray() const { return WithLayout(Layout().WithSparseFormat()); } @@ -702,6 +702,30 @@ class HloInstructionPatternOperandImpl { HloInstructionPattern operand_; }; +// An HloInstructionPattern implementation that matches only if the instruction +// is a fusion node with a particular kind. +template +class HloInstructionPatternFusionKindImpl { + public: + explicit constexpr HloInstructionPatternFusionKindImpl( + const Previous& previous, ::xla::HloInstruction::FusionKind kind) + : previous_(previous), kind_(kind) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && + inst->fusion_kind() == kind_; + } + + private: + Previous previous_; + ::xla::HloInstruction::FusionKind kind_; +}; + // A pattern that matches HloInstructions. template class HloInstructionPattern { @@ -807,6 +831,16 @@ class HloInstructionPattern { matched_inst_); } + // Modifies the pattern to match only if the instruction is a fusion node with + // the given kind. + constexpr HloInstructionPattern> + WithFusionKind(HloInstruction::FusionKind kind) const { + return HloInstructionPattern>( + HloInstructionPatternFusionKindImpl(impl_, kind), matched_inst_); + } + private: Impl impl_; HloInstructionType** matched_inst_; @@ -879,7 +913,6 @@ XLA_UNOP_PATTERN(Abs) XLA_UNOP_PATTERN(RoundNearestAfz) XLA_UNOP_PATTERN(Bitcast) XLA_UNOP_PATTERN(Broadcast) -XLA_UNOP_PATTERN(BroadcastDimOne) XLA_UNOP_PATTERN(Ceil) XLA_UNOP_PATTERN(Copy) XLA_UNOP_PATTERN(Cos) diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 5291b1437af..204e8c99209 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -67,6 +67,7 @@ TEST(PatternMatcherTest, ScalarShape) { EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar())); EXPECT_EQ(matched_shape, &scalar_shape); EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray())); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsDenseArray())); EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple())); EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32))); EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0))); @@ -75,11 +76,13 @@ TEST(PatternMatcherTest, ScalarShape) { match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32))); } -TEST(PatternMatcherTest, ArrayShape) { +TEST(PatternMatcherTest, DenseArrayShape) { auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4}); Shape* matched_shape; EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); EXPECT_EQ(matched_shape, &array_shape); + EXPECT_TRUE(Match(&array_shape, match::Shape().IsDenseArray())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsSparseArray())); EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); @@ -90,6 +93,33 @@ TEST(PatternMatcherTest, ArrayShape) { EXPECT_FALSE(Match(&array_shape, match::Shape().WithLayout( match::Layout(&matched_layout).WithSparseFormat()))); + EXPECT_TRUE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithDenseFormat()))); + EXPECT_EQ(matched_layout, &array_shape.layout()); +} + +TEST(PatternMatcherTest, SparseArrayShape) { + auto array_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {2, 3, 4}, 10); + Shape* matched_shape; + EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); + EXPECT_EQ(matched_shape, &array_shape); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsDenseArray())); + EXPECT_TRUE(Match(&array_shape, match::Shape().IsSparseArray())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3))); + EXPECT_FALSE( + Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); + Layout* matched_layout; + EXPECT_FALSE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithDenseFormat()))); + EXPECT_TRUE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithSparseFormat()))); + EXPECT_EQ(matched_layout, &array_shape.layout()); } TEST(PatternMatcherTest, TupleShape) { @@ -140,5 +170,28 @@ TEST(PatternMatcherTest, TupleShape) { Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); } +TEST(PatternMatcherTest, FusionKind) { + constexpr char kModuleStr[] = R"( + HloModule test_module + + fused_computation { + ROOT fp0 = f32[] parameter(0) + } + + ENTRY while.v11 { + p0 = f32[] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kLoop))); + EXPECT_FALSE(Match( + root, match::Op().WithFusionKind(HloInstruction::FusionKind::kInput))); + EXPECT_FALSE(Match(root->operand(0), match::Op().WithFusionKind( + HloInstruction::FusionKind::kLoop))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index aa974ee61a2..7c63c0acc77 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -29,8 +29,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -namespace se = ::perftools::gputools; - namespace xla { using tensorflow::str_util::Lowercase; diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index 69188820a70..571451ba43a 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -34,29 +34,27 @@ class PlatformUtil { // // Note that, even if a platform is present with zero devices, if we *do* have // compilation support for it, it will be returned in this sequence. - static StatusOr> - GetSupportedPlatforms(); + static StatusOr> GetSupportedPlatforms(); // Convenience function which returns the default supported platform for // tests. If exactly one supported platform is present, then this platform is // the default platform. If exactly two platforms are present and one of them // is the interpreter platform, then the other platform is the default // platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetDefaultPlatform(); // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetSolePlatform(); + static StatusOr GetSolePlatform(); // Returns the platform according to the given name. Returns error if there is // no such platform. - static StatusOr GetPlatform( - const string& platform_name); + static StatusOr GetPlatform(const string& platform_name); // Returns exactly one platform that does not have given name. Returns error // if there is no such platform, or there are multiple such platforms. - static StatusOr GetPlatformExceptFor( + static StatusOr GetPlatformExceptFor( const string& platform_name); // Returns a vector of StreamExecutors for the given platform. The vector is @@ -64,8 +62,8 @@ class PlatformUtil { // element is nullptr, then the device is present by not supported by XLA. // // If the platform has no visible devices, a not-found error is returned. - static StatusOr> - GetStreamExecutors(perftools::gputools::Platform* platform); + static StatusOr> GetStreamExecutors( + se::Platform* platform); private: TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil); diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index e2c07e38271..688cceff0cd 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -75,7 +75,7 @@ StatusOr ReducePrecisionInsertion::insert_after( return false; } - // Check that we haven't already inserted an equivalant reduce-precision + // Check that we haven't already inserted an equivalent reduce-precision // operation after this instruction. (The zero-user case occurs when this is // the root instruction.) if (instruction->user_count() > 0) { diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 49ec38eb62c..0f26a025bf1 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -155,15 +155,20 @@ HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, case HloOpcode::kConstant: { if (first_reshape_operand->opcode() == HloOpcode::kReshape) { VLOG(5) << "Adding reshape to kConstant operand"; - return computation->AddInstruction( + HloInstruction* reshape = computation->AddInstruction( HloInstruction::CreateReshape(new_shape, operand)); + operand->SetupDerivedInstruction(reshape); + return reshape; } else { CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); VLOG(5) << "Adding transpose to kConstant operand"; std::vector inverse_permutation = InversePermutation(first_reshape_operand->dimensions()); - return computation->AddInstruction(HloInstruction::CreateTranspose( - new_shape, operand, inverse_permutation)); + HloInstruction* transpose = + computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); + operand->SetupDerivedInstruction(transpose); + return transpose; } } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 094f7319f46..13e2d3258e3 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -458,57 +458,6 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { EXPECT_EQ(select, computation->root_instruction()); } -// Tree looks like: -// -// param0 [1,128,1] -// | -// reshape [128,1] constant [128,1024] -// \ / -// multiply w/implicit broadcast [128,1024] -// -// The reshape mover would like to sink the reshape below the multiply. -// -// Previously we would attempt to insert a reshape of the constant to [1,128,1] -// (which is unsound, because it has a different number of elements) as -// preparation for sinking the reshape. -// -// To eliminate the unsoundness, we outlaw reshape sinking when one of the -// operands is implicitly broadcast in the elementwise consumer. -// -// TODO(b/37799338) However, it would be possible in this case to do a more -// in-depth analysis to get reshape movement to occur: -// -// 1. Note that the broadcast dimension (logical dimension 1) in the operands -// would map back to logical dimension 2 in the param0 node. -// 2. Match rank of the constant to the param0 node (by prepending a trivial 1 -// dimension). -// 3. Reshape to [128,1024] at the root. -// -// But this is not currently done. -TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { - HloComputation::Builder builder(TestName()); - auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0")); - auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {128, 1}), param0)); - Array2D a(128, 1024); - auto literal = Literal::CreateR2FromArray2D(a); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(std::move(literal))); - auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( - constant->shape(), HloOpcode::kMultiply, constant, reshape)); - - auto computation = module().AddEntryComputation(builder.Build()); - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Constant(), op::Reshape(param0))); - - EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); - - EXPECT_THAT(computation->root_instruction(), - op::Multiply(op::Constant(), op::Reshape(param0))); - EXPECT_EQ(multiply, computation->root_instruction()); -} - // Tree looks like this: // // add1 diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 52500e4e790..cb0f76ebe4d 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -54,8 +54,6 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; using ::xla::source_map_util::InvalidParameterArgument; @@ -66,7 +64,7 @@ namespace { // Records the arguments used to invoke a computation in a SessionModule // proto. -tensorflow::Status RecordArguments( +Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { @@ -77,33 +75,54 @@ tensorflow::Status RecordArguments( transfer_manager->TransferLiteralFromDevice(executor, *argument)); *module->add_arguments() = literal->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const ShapedBuffer& result, - se::StreamExecutor* executor, - TransferManager* transfer_manager, - SessionModule* module) { +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); TF_ASSIGN_OR_RETURN( std::unique_ptr literal, transfer_manager->TransferLiteralFromDevice(executor, result)); *module->mutable_result() = literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); +} + +// Records the arguments used to invoke a computation in an HloSnapshot proto. +Status RecordArguments( + const tensorflow::gtl::ArraySlice arguments, + se::StreamExecutor* executor, TransferManager* transfer_manager, + HloSnapshot* module) { + module->clear_arguments(); + for (const ShapedBuffer* argument : arguments) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, *argument)); + *module->add_arguments() = literal->ToProto(); + } + return Status::OK(); +} + +// Records the result of a computation in a HloSnapshot proto. +Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, + TransferManager* transfer_manager, HloSnapshot* module) { + module->clear_result(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, result)); + *module->mutable_result() = literal->ToProto(); + return Status::OK(); } } // namespace -ServiceOptions& ServiceOptions::set_platform( - perftools::gputools::Platform* platform) { +ServiceOptions& ServiceOptions::set_platform(se::Platform* platform) { platform_ = platform; return *this; } -perftools::gputools::Platform* ServiceOptions::platform() const { - return platform_; -} +se::Platform* ServiceOptions::platform() const { return platform_; } ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) { number_of_replicas_ = number_of_replicas; @@ -123,7 +142,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { } /* static */ StatusOr> Service::NewService( - perftools::gputools::Platform* platform) { + se::Platform* platform) { ServiceOptions default_options; default_options.set_platform(platform); return NewService(default_options); @@ -131,7 +150,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { /* static */ StatusOr> Service::NewService( const ServiceOptions& options) { - perftools::gputools::Platform* platform = options.platform(); + se::Platform* platform = options.platform(); std::unique_ptr execute_backend; if (platform == nullptr) { TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); @@ -176,8 +195,8 @@ Service::Service(const ServiceOptions& options, } } -tensorflow::Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { +Status Service::Computation(const ComputationRequest* arg, + ComputationResponse* result) { if (arg->name().empty()) { return InvalidArgument("computation request needs a name"); } @@ -187,24 +206,23 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, VLOG(1) << Printf("Created new computation %s on service %p, name %s", result->computation().ShortDebugString().c_str(), this, arg->name().c_str()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) { +Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) { +Status Service::Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) { return allocation_tracker_.Unregister(arg->data()); } // Deconstructs a previously-allocated global handle. -tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) { +Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) { TF_ASSIGN_OR_RETURN( std::vector elements, allocation_tracker_.DeconstructTuple(arg->tuple_handle())); @@ -212,11 +230,11 @@ tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg, for (auto& element : elements) { *result->add_element_handles() = element; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const { +Status Service::ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const { if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) { return InvalidArgument( "Shape used to set computation result layout %s is not compatible " @@ -235,8 +253,7 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( StatusOr>> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice - stream_executors) { + tensorflow::gtl::ArraySlice stream_executors) { CHECK_EQ(options_.number_of_replicas(), stream_executors.size()); std::vector> replicated_arguments; replicated_arguments.resize(options_.number_of_replicas()); @@ -274,8 +291,10 @@ StatusOr> Service::CreateModuleConfig( const ExecutionOptions* execution_options, const UserComputation* user_computation) { auto config = MakeUnique(program_shape); - auto* computation_layout = config->mutable_entry_computation_layout(); - + ComputationLayout* host_computation_layout = + config->mutable_host_entry_computation_layout(); + ComputationLayout* device_computation_layout = + config->mutable_device_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { return InvalidArgument("computation takes %d parameters, but %zu given", program_shape.parameters_size(), @@ -300,9 +319,10 @@ StatusOr> Service::CreateModuleConfig( i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } - TF_RETURN_IF_ERROR( - computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( - *argument_shapes[i])); + TF_RETURN_IF_ERROR(host_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); + TF_RETURN_IF_ERROR(device_computation_layout->mutable_parameter_layout(i) + ->CopyLayoutFromShape(*argument_shapes[i])); } if (execution_options != nullptr && execution_options->has_shape_with_output_layout()) { @@ -311,10 +331,20 @@ StatusOr> Service::CreateModuleConfig( TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(shape_with_output_layout, program_shape.result())); TF_RETURN_IF_ERROR( - computation_layout->mutable_result_layout()->CopyLayoutFromShape( + host_computation_layout->mutable_result_layout()->CopyLayoutFromShape( + shape_with_output_layout)); + TF_RETURN_IF_ERROR( + device_computation_layout->mutable_result_layout()->CopyLayoutFromShape( shape_with_output_layout)); } else { - computation_layout->mutable_result_layout()->Clear(); + // If the result layout is not set, then choose the default. + // TODO(b/29118294): Allow the compiler to choose a better layout in this + // case. + // TODO(b/78356948): We are forcing the default layout here. We should fix + // clients which expect a default layout, to be explicit about it, by + // passing the proper ExecutionOptions with shape_with_output_layout set. + host_computation_layout->mutable_result_layout()->SetToDefaultLayout(); + device_computation_layout->mutable_result_layout()->SetToDefaultLayout(); } config->set_replica_count(options_.number_of_replicas()); @@ -349,8 +379,7 @@ StatusOr> Service::CreateModuleConfig( StatusOr>> Service::BuildExecutables( std::vector versioned_handles, std::vector> module_configs, - Backend* backend, - std::vector> executors, + Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p", this); @@ -412,11 +441,32 @@ StatusOr>> Service::BuildExecutables( StatusOr>> Service::BuildExecutables( const std::vector& module_protos, std::vector> module_configs, - Backend* backend, - std::vector> executors, + Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p", this); + // Dump computation proto state if flag is set. + std::vector> hlo_snapshots; + for (int64 i = 0; i < module_protos.size(); ++i) { + const string& directory_path = + module_configs[i]->debug_options().xla_dump_computations_to(); + const string& execution_directory_path = + module_configs[i]->debug_options().xla_dump_executions_to(); + if (directory_path.empty() && execution_directory_path.empty()) { + continue; + } + auto hlo_snapshot = MakeUnique(); + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; + if (!directory_path.empty()) { + string filename = + Printf("computation_%lld__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name().c_str()); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); + hlo_snapshots.push_back(std::move(hlo_snapshot)); + } + } + VLOG(1) << "Computations:"; for (const HloModuleProto* proto : module_protos) { VLOG(1) << proto->name(); @@ -437,9 +487,31 @@ StatusOr>> Service::BuildExecutables( backend->compiler()->Compile(std::move(modules), std::move(executors), device_allocator)); + for (size_t i = 0; i < module_protos.size(); ++i) { + if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { + executables[i]->set_hlo_snapshot(std::move(hlo_snapshots[i])); + } + } + return std::move(executables); } +Status Service::ValidateEntryComputationLayout(HloModule* module) { + const ComputationLayout& on_device = + module->device_entry_computation_layout(); + for (int64 i = 0; i < on_device.parameter_count(); ++i) { + TF_RET_CHECK(ShapeUtil::Equal( + on_device.parameter_shape(i), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().parameter_shape(i)))); + } + TF_RET_CHECK(ShapeUtil::Equal( + module->device_entry_computation_layout().result_shape(), + execute_backend_->transfer_manager()->HostShapeToDeviceShape( + module->host_entry_computation_layout().result_shape()))); + return Status::OK(); +} + StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, @@ -478,6 +550,8 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); + // Check that on-host and on-device shapes are consistent. + TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( @@ -493,7 +567,7 @@ StatusOr> Service::BuildExecutable( StatusOr> Service::BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, + se::StreamExecutor* executor, ExecutionProfile* profile, DeviceMemoryAllocator* device_allocator) { std::shared_ptr executable = compilation_cache_.LookUp(versioned_handle, *module_config); @@ -541,7 +615,7 @@ Service::ExecuteParallelAndRegisterResult( // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector::SmartPtr> streams; - std::vector> timers; + std::vector> timers; // Global data handles for the computation results, one for each computation. std::vector result_handles; @@ -558,15 +632,14 @@ Service::ExecuteParallelAndRegisterResult( // Stream executors for the replicas of the current computation. TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); CHECK_EQ(replicas.size(), arguments[i].size()); - std::vector> result_buffers; + std::vector result_buffers; for (int64 replica = 0; replica < replicas.size(); ++replica) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.emplace_back( - new perftools::gputools::Timer(streams.back()->parent())); + timers.emplace_back(new se::Timer(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -583,7 +656,6 @@ Service::ExecuteParallelAndRegisterResult( ExecutableRunOptions options; options.set_stream(streams.back().get()); options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); @@ -591,7 +663,7 @@ Service::ExecuteParallelAndRegisterResult( backend->StreamBorrower()); // Asynchronously launch the computation. - TF_ASSIGN_OR_RETURN(std::unique_ptr result, + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, executables[i]->ExecuteAsyncOnStream( &run_options, arguments[i][replica])); @@ -697,12 +769,12 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_stream(stream.get()); options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); options.set_device_assignment(&device_assignment); - run_options.emplace_back(options, backend->StreamBorrower(), - backend->inter_op_thread_pool()); + run_options.emplace_back( + options, backend->StreamBorrower(), + /*xla_intra_op_thread_pool=*/backend->eigen_intra_op_thread_pool()); } if (options_.number_of_replicas() == 1) { @@ -727,16 +799,16 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { +Status Service::SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); return computation->SetReturnValue(arg->operand()); } -StatusOr> -Service::GetExecutors(const ExecutionOptions& execution_options, - int64 requests_size, int64 request_index) const { +StatusOr> Service::GetExecutors( + const ExecutionOptions& execution_options, int64 requests_size, + int64 request_index) const { if (execution_options.device_handles().empty()) { return FailedPrecondition( "device handles must be given to execute parallel computations"); @@ -748,7 +820,7 @@ Service::GetExecutors(const ExecutionOptions& execution_options, "handles.", requests_size, request_index, execution_options.device_handles_size()); } - std::vector executors; + std::vector executors; for (const auto& device_handle : execution_options.device_handles()) { TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, device_handle)); @@ -775,12 +847,12 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { +Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); std::vector>> all_arguments; - std::vector> all_executors; + std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; @@ -836,7 +908,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, CreateModuleConfig(*program_shape, replicated_arguments.front(), request.execution_options(), user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -883,15 +955,15 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, } VLOG(1) << "successfully completed 'execute-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { +Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; std::vector>> all_arguments; - std::vector> all_executors; + std::vector> all_executors; std::vector module_protos; std::vector> module_configs; std::vector computation_names; @@ -939,7 +1011,7 @@ tensorflow::Status Service::ExecuteGraphParallel( /*user_computation=*/nullptr)); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(replicated_arguments); @@ -984,11 +1056,11 @@ tensorflow::Status Service::ExecuteGraphParallel( } VLOG(1) << "successfully completed 'execute-graph-parallel' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) { +Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) { const int64 available_device_count = execute_backend_->device_count(); const int64 replica_count = options_.number_of_replicas(); if (replica_count <= 0) { @@ -1008,11 +1080,11 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, *result->add_device_handles() = device_handle; } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteRequest* arg, + ExecuteResponse* result) { ExecuteParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1020,8 +1092,8 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; *parallel_arg.add_requests() = *arg; ExecuteParallelResponse parallel_result; @@ -1029,7 +1101,7 @@ tensorflow::Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, return PickParallelResponse(parallel_result, result); } -tensorflow::Status Service::PickParallelResponse( +Status Service::PickParallelResponse( const ExecuteParallelResponse& parallel_result, ExecuteResponse* result) { // The "result device" selection is a bit hacky, but better than assuming it // is device 0. We have b/76035356 for restructuring the client API to clean @@ -1052,8 +1124,7 @@ tensorflow::Status Service::PickParallelResponse( return Status::OK(); } -tensorflow::Status Service::Execute(const ExecuteRequest* arg, - ExecuteResponse* result) { +Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { VLOG(1) << "running execute request: " << arg->ShortDebugString(); TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1089,7 +1160,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, arg->execution_options(), user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); TF_ASSIGN_OR_RETURN( std::shared_ptr executable, @@ -1124,7 +1195,7 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, } VLOG(1) << "successfully completed 'execute' request"; - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr> Service::BuildExecutable( @@ -1135,6 +1206,22 @@ StatusOr> Service::BuildExecutable( "BuildExecutable on service %p with serialized module proto: %s", this, module_proto.name().c_str()); + // Dump computation proto state if flag is set. + auto hlo_snapshot = MakeUnique(); + const string& directory_path = + module_config->debug_options().xla_dump_computations_to(); + const string& execution_directory_path = + module_config->debug_options().xla_dump_executions_to(); + if (!directory_path.empty() || !execution_directory_path.empty()) { + *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; + if (!directory_path.empty()) { + string filename = Printf("computation_%lld__%s", module_proto.id(), + module_proto.entry_computation_name().c_str()); + TF_RETURN_IF_ERROR( + Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); + } + } + TF_ASSIGN_OR_RETURN(std::unique_ptr module, HloModule::CreateFromProto(module_proto, *module_config)); @@ -1143,6 +1230,8 @@ StatusOr> Service::BuildExecutable( TF_ASSIGN_OR_RETURN( module, backend->compiler()->RunHloPasses(std::move(module), executor, device_allocator)); + // Check that on-host and on-device shapes are consistent. + TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, backend->compiler()->RunBackend( @@ -1151,8 +1240,8 @@ StatusOr> Service::BuildExecutable( return std::move(executable); } -tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) { +Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) { VLOG(1) << "running execute-graph request"; if (!arg->has_computation()) { @@ -1185,18 +1274,37 @@ tensorflow::Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, execute_backend_->default_stream_executor(), /*device_allocator=*/nullptr)); + if (executable->dumping_snapshot()) { + executable->hlo_snapshot()->set_execution_platform( + execute_backend_->platform()->Name()); + TF_RETURN_IF_ERROR(RecordArguments( + replicated_arguments.front(), + execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->hlo_snapshot())); + } + TF_ASSIGN_OR_RETURN( *result->mutable_output(), ExecuteAndRegisterResult( executable.get(), replicated_arguments, execute_backend_.get(), "result of " + arg->computation().name(), result->mutable_profile())); + if (executable->dumping_snapshot()) { + TF_ASSIGN_OR_RETURN( + const ShapedBuffer* result_buffer, + allocation_tracker_.ResolveForReplica(result->output(), 0)); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable->DumpHloSnapshot()); + } + VLOG(1) << "successfully completed 'execute-graph' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { +Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) { VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); TF_ASSIGN_OR_RETURN(UserComputation * user_computation, @@ -1225,7 +1333,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, arg->execution_options(), user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->entry_computation_layout().ToString(); + << module_config->host_entry_computation_layout().ToString(); ExecutionProfile profile; @@ -1243,20 +1351,19 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.push_back(std::move(stream)); } - std::vector> result_buffers; + std::vector result_buffers; for (size_t i = 0; i < streams.size(); ++i) { const auto& stream = streams[i]; ExecutableRunOptions options; options.set_stream(stream.get()); options.set_allocator(execute_backend_->memory_allocator()); - options.set_inter_op_thread_pool(execute_backend_->inter_op_thread_pool()); options.set_intra_op_thread_pool( execute_backend_->eigen_intra_op_thread_pool_device()); ServiceExecutableRunOptions service_options( options, execute_backend_->StreamBorrower()); - TF_ASSIGN_OR_RETURN(std::unique_ptr this_result_buffer, + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, executable->ExecuteAsyncOnStream( &service_options, replicated_arguments[i])); @@ -1273,11 +1380,11 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.clear(); VLOG(1) << "successfully completed 'execute-async' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) { +Status Service::WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, execution_tracker_.Resolve(arg->execution())); @@ -1288,11 +1395,11 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution())); VLOG(1) << "successfully completed 'wait-for-execution' request"; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, - TransferToClientResponse* result) { +Status Service::TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); @@ -1322,7 +1429,7 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, *result->mutable_literal() = result_literal->Relayout(*return_shape)->ToProto(); } - return tensorflow::Status::OK(); + return Status::OK(); } namespace { @@ -1340,8 +1447,8 @@ std::unique_ptr CloneShapedBufferOnDevice( } // namespace -tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, - TransferToServerResponse* result) { +Status Service::TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(arg->literal())); const Shape& shape = literal->shape(); @@ -1356,16 +1463,16 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, } // Allocate memory in each replica and transfer the data to all replicas. - std::vector> replicated_buffers; + std::vector replicated_buffers; for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN( - std::unique_ptr shaped_buffer, - execute_backend_->transfer_manager()->AllocateShapedBuffer( + ScopedShapedBuffer shaped_buffer, + execute_backend_->transfer_manager()->AllocateScopedShapedBuffer( shape, execute_backend_->memory_allocator(), executor->device_ordinal())); TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, *literal, *shaped_buffer)); + executor, *literal, shaped_buffer)); replicated_buffers.emplace_back(std::move(shaped_buffer)); } TF_ASSIGN_OR_RETURN(*result->mutable_data(), @@ -1374,11 +1481,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, StrCat("TransferToServer literal of shape ", ShapeUtil::HumanString(shape)))); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) { +Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1407,9 +1514,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, *literal); } -tensorflow::Status Service::TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) { +Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( @@ -1435,16 +1541,16 @@ tensorflow::Status Service::TransferFromOutfeed( execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( executor, arg->shape_with_layout(), &literal)); *result->mutable_literal() = literal.ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) { +Status Service::ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); } -tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { +Status Service::IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1460,11 +1566,11 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { +Status Service::ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1551,11 +1657,11 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { +Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) { if (!arg->has_computation()) { return InvalidArgument("computations may not be empty"); } @@ -1593,20 +1699,18 @@ tensorflow::Status Service::ComputeConstantGraph( } *result->mutable_literal() = result_literal->ToProto(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) { +Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.ResolveForReplica(arg->data(), 0)); *result->mutable_shape() = buffer->on_host_shape(); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { +Status Service::GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); @@ -1616,21 +1720,21 @@ tensorflow::Status Service::GetComputationShape( TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( versioned_handle.version)); *result->mutable_program_shape() = *program_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { +Status Service::GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); TF_ASSIGN_OR_RETURN(*result->mutable_shape(), computation->GetShape(arg->operand())); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) { +Status Service::GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(arg->computation())); @@ -1656,10 +1760,10 @@ tensorflow::Status Service::GetComputationStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::GetComputationGraphStats( +Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { return InvalidArgument("Computations may not be empty."); @@ -1686,11 +1790,11 @@ tensorflow::Status Service::GetComputationGraphStats( stats.set_flop_count(analysis.flop_count()); stats.set_transcendental_count(analysis.transcendental_count()); *result->mutable_stats() = stats; - return tensorflow::Status::OK(); + return Status::OK(); } template -tensorflow::Status Service::AddInstruction( +Status Service::AddInstruction( const RequestT* arg, ResponseT* result, const std::function(UserComputation*)>& adder) { @@ -1698,10 +1802,10 @@ tensorflow::Status Service::AddInstruction( computation_tracker_.Resolve(arg->computation())); TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { +Status Service::Op(const OpRequest* arg, OpResponse* result) { TF_ASSIGN_OR_RETURN(UserComputation * computation, computation_tracker_.Resolve(arg->computation())); StatusOr handle_status; @@ -1923,27 +2027,26 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { if (arg->has_sharding()) { TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); } - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { +Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) { TF_ASSIGN_OR_RETURN( std::unique_ptr module, computation_tracker_.SnapshotComputation(arg->computation())); result->set_allocated_module(module.release()); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status Service::LoadComputationSnapshot( +Status Service::LoadComputationSnapshot( const LoadComputationSnapshotRequest* arg, LoadComputationSnapshotResponse* result) { TF_ASSIGN_OR_RETURN(*result->mutable_computation(), computation_tracker_.LoadSessionModule(arg->module())); - return tensorflow::Status::OK(); + return Status::OK(); } DeviceHandle Service::SingleComputationDeviceHandle() const { @@ -1953,9 +2056,9 @@ DeviceHandle Service::SingleComputationDeviceHandle() const { return device_handle; } -StatusOr> Service::Replicas( +StatusOr> Service::Replicas( const Backend& backend, const DeviceHandle& device_handle) const { - std::vector replicas; + std::vector replicas; for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { // From the computation placer, find out the device ids of the replicas for // the given device handle. diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index e399f1ac190..81fbd419578 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -53,8 +53,8 @@ namespace xla { class ServiceOptions { public: // Set the platform backing the service, or nullptr for the default platform. - ServiceOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; + ServiceOptions& set_platform(se::Platform* platform); + se::Platform* platform() const; // Set the number of replicas to use when compiling replicated // programs. @@ -66,7 +66,7 @@ class ServiceOptions { int intra_op_parallelism_threads() const; private: - perftools::gputools::Platform* platform_ = nullptr; + se::Platform* platform_ = nullptr; int number_of_replicas_ = 1; int intra_op_parallelism_threads_ = -1; }; @@ -79,61 +79,58 @@ class Service : public ServiceInterface { public: // Factory method for creating a new Service. static StatusOr> NewService( - perftools::gputools::Platform* platform = nullptr); + se::Platform* platform = nullptr); static StatusOr> NewService( const ServiceOptions& options); // Creates a new computation with the given name. // A unique ComputationHandle is returned. - tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; + Status Computation(const ComputationRequest* arg, + ComputationResponse* result) override; // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is // returned. - tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) override; + Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) override; // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each // element in the tuple. - tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, - DeconstructTupleResponse* result) override; + Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) override; // Modifies the provided computation so that subsequent executions // will compute the provided ComputationDataHandle, rather than the // last expression enqueued on that Computation. - tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; + Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) override; // Executes a computation with the provided global data passed as // immutable arguments. Returns global data output and execution timing. - tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) override; + Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) override; + Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) override; + Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) override; // Requests one or more device handles from the target. // @@ -143,9 +140,8 @@ class Service : public ServiceInterface { // the first set of replicas, and the next R devices to the second set of // replicas, etc. Each returned device handle represents the device with the // replica id 0. - tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, - GetDeviceHandlesResponse* result) override; + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) override; // Asynchronously executes a computation with provided arguments. Invokes // the provided computation with the provided global data passed as @@ -154,38 +150,33 @@ class Service : public ServiceInterface { // (Note: The corresponding function in xla::Client was removed as part of // b/64116060, in an attempt to simplify our API. We're keeping this around // for now in case we want to expose this to clients in a different way.) - tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; + Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) override; // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the // first call. - tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, - WaitForExecutionResponse* result) override; + Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) override; // Requests that global data be transferred to the client in literal form. - tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, - TransferToClientResponse* result) override; + Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) override; // Transfers data from a literal provided by the client, into device memory. - tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, - TransferToServerResponse* result) override; + Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) override; // Transfers data from a literal provided by the client, into the Infeed // buffer of the device. - tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, - TransferToInfeedResponse* result) override; + Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) override; // Transfers data from the Outfeed othe device to the literal provided by the // client. - tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) override; + Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). @@ -196,71 +187,65 @@ class Service : public ServiceInterface { // ResetDevice should be called before an Execution that expect the device to // be in the reset state. For example, if the prior Execution modifies device // state (e.g., architectural state) that the next Execution depends on. - tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) override; + Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) override; // Tests if an expression is a compile-time constant. - tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; + Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) override; // Computes the value of a constant expression. - tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) override; + Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) override; + Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) override; // Returns the shape (with layout) of an array associated with a given data // handle. - tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) override; + Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) override; // Returns the program shape of the computation associated with the given // handle. - tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; + Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) override; ///// // Computation-oriented methods. // Enqueues an Op on the computation. - tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override; + Status Op(const OpRequest* arg, OpResponse* result) override; // Retrieves the inferred shape for a value within a computation. - tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; + Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) override; // Retrieves the statistics of a computation. - tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) override; // Retrieves the statistics of a computation. // // TODO(b/74197823): This is a part of a NOT YET ready refactor. - tensorflow::Status GetComputationGraphStats( - const ComputationGraphStatsRequest* arg, - ComputationStatsResponse* result) override; + Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, + ComputationStatsResponse* result) override; // Snapshots the current state of a computation handle into a serializable // protocol buffer form, so it can be loaded via // LoadComputationSnapshot. - tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; + Status SnapshotComputation(const SnapshotComputationRequest* arg, + SnapshotComputationResponse* result) override; // Loads a computation from a serialized protocol buffer created via // SnapshotComputation. - tensorflow::Status LoadComputationSnapshot( + Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* arg, LoadComputationSnapshotResponse* result) override; // Creates a unique channel handle that can be used for Send/Recv // instructions. - tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) override; + Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) override; // Returns the ComputationTracker of the current service instance. // Only used in unit tests to access user computations from client. @@ -286,7 +271,7 @@ class Service : public ServiceInterface { ExecuteResponse* result); // Prepare the executors for executing parallel. - StatusOr> GetExecutors( + StatusOr> GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const; @@ -295,6 +280,9 @@ class Service : public ServiceInterface { const ExecutionOptions& execution_options, tensorflow::gtl::ArraySlice arguments); + // Assert that host- and device-shapes are in a consistent state. + Status ValidateEntryComputationLayout(HloModule* module); + protected: friend class LocalExecutable; @@ -310,8 +298,7 @@ class Service : public ServiceInterface { StatusOr>> ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice - stream_executors); + tensorflow::gtl::ArraySlice stream_executors); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. @@ -329,7 +316,7 @@ class Service : public ServiceInterface { StatusOr> BuildExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); // Builds an Executable for the given HLO module proto. @@ -338,7 +325,7 @@ class Service : public ServiceInterface { StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator = nullptr); // Same as BuildExecutable() above, but builds a list of Executables for the @@ -346,14 +333,12 @@ class Service : public ServiceInterface { StatusOr>> BuildExecutables( std::vector versioned_handles, std::vector> module_configs, - Backend* backend, - std::vector> executors, + Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, - Backend* backend, - std::vector> executors, + Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the @@ -362,7 +347,7 @@ class Service : public ServiceInterface { StatusOr> BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, std::unique_ptr module_config, Backend* backend, - perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, + se::StreamExecutor* executor, ExecutionProfile* profile, DeviceMemoryAllocator* device_allocator = nullptr); // Runs the given executable with the given arguments and register the result @@ -389,7 +374,7 @@ class Service : public ServiceInterface { // Convenience function for adding a function to a user computation. template - tensorflow::Status AddInstruction( + Status AddInstruction( const RequestT* arg, ResponseT* result, const std::function(UserComputation*)>& adder); @@ -397,21 +382,19 @@ class Service : public ServiceInterface { // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - tensorflow::Status ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result); - tensorflow::Status ExecuteOneToN(const ExecuteGraphRequest* arg, - ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); + Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout // (presumably passed by the client to set the result layout) is valid for the // given computation result shape. - tensorflow::Status ValidateResultShapeWithLayout( - const Shape& shape_with_layout, const Shape& result_shape) const; + Status ValidateResultShapeWithLayout(const Shape& shape_with_layout, + const Shape& result_shape) const; // Returns the stream executors assigned to the replicas represented by the // given device handle. Each device_handle is a virtual replicated device that // represents a set of physical devices for the replicas. - StatusOr> Replicas( + StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; Status MaybeDumpHloModule(const HloModule& module) const; diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 6c1f8feac7e..7f3910cdb03 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -28,7 +28,7 @@ namespace xla { class ServiceExecutableRunOptions { public: using StreamBorrower = - std::function::SmartPtr>(int)>; + std::function::SmartPtr>(int)>; ServiceExecutableRunOptions() : ServiceExecutableRunOptions(ExecutableRunOptions()) {} @@ -45,14 +45,13 @@ class ServiceExecutableRunOptions { ExecutableRunOptions* mutable_run_options() { return &run_options_; } // Delegate to `ExecutableRunOptions` member. - perftools::gputools::Stream* stream() const { return run_options_.stream(); } + se::Stream* stream() const { return run_options_.stream(); } DeviceMemoryAllocator* allocator() const { return run_options_.allocator(); } int device_ordinal() const { return run_options_.device_ordinal(); } // Borrows a stream and returns a smart pointer which returns the stream on // destruction. - StatusOr::SmartPtr> BorrowStream( - int device_ordinal) const { + StatusOr::SmartPtr> BorrowStream(int device_ordinal) const { return borrow_stream_ ? borrow_stream_(device_ordinal) : Status(tensorflow::error::UNIMPLEMENTED, "No stream cache"); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 77e12d36024..3500978bdd8 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -52,10 +52,14 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_ABS; case HloOpcode::kCeil: return UNOP_CEIL; + case HloOpcode::kClz: + return UNOP_CLZ; case HloOpcode::kCos: return UNOP_COS; case HloOpcode::kExp: return UNOP_EXP; + case HloOpcode::kExpm1: + return UNOP_EXPM1; case HloOpcode::kFloor: return UNOP_FLOOR; case HloOpcode::kImag: @@ -64,6 +68,8 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_IS_FINITE; case HloOpcode::kLog: return UNOP_LOG; + case HloOpcode::kLog1p: + return UNOP_LOG1P; case HloOpcode::kNot: return UNOP_NOT; case HloOpcode::kNegate: @@ -166,24 +172,24 @@ bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); } -tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape, - tensorflow::StringPiece op_type) { +Status ExpectNotTupleOrOpaque(const Shape& shape, + tensorflow::StringPiece op_type) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Expected non-tuple argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else if (ShapeUtil::IsOpaque(shape)) { return InvalidArgument("Expected non-opaque argument for %s, but got %s.", - op_type.ToString().c_str(), + std::string(op_type).c_str(), ShapeUtil::HumanString(shape).c_str()); } else { - return tensorflow::Status::OK(); + return Status::OK(); } } -tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, - const Shape& init_value_shape, - const PrimitiveType& input_element_type) { +Status VerifyReducerShape(const ProgramShape& reducer_shape, + const Shape& init_value_shape, + const PrimitiveType& input_element_type) { if (reducer_shape.parameters_size() != 2) { return InvalidArgument( "Reduction function must take 2 parameters, but " @@ -243,7 +249,7 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, ShapeUtil::HumanString(accumulator_shape).c_str()); } - return tensorflow::Status::OK(); + return Status::OK(); } StatusOr InferWindowOutputShape(const Shape& base_shape, @@ -335,7 +341,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case UNOP_COS: case UNOP_SIN: case UNOP_EXP: + case UNOP_EXPM1: case UNOP_LOG: + case UNOP_LOG1P: case UNOP_TANH: if (!ShapeUtil::ElementIsFloating(arg) && !ShapeUtil::ElementIsComplex(arg)) { @@ -360,6 +368,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, arg, primitive_util::ComplexComponentType(arg.element_type())); } return arg; + case UNOP_CLZ: case UNOP_NEGATE: case UNOP_ROUND_NEAREST_AFZ: case UNOP_SIGN: @@ -1209,11 +1218,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm training")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( @@ -1315,15 +1324,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( scale_shape, "scale input of batch norm inference")); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(offset_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(scale_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape) == - tensorflow::Status::OK()); + Status::OK()); TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(variance_shape) == - tensorflow::Status::OK()); + Status::OK()); if (feature_index >= ShapeUtil::Rank(operand_shape)) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 6e9986165f7..6bacb37206c 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include #include #include @@ -25,11 +24,10 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -namespace se = ::perftools::gputools; - namespace xla { using ::tensorflow::strings::Appendf; @@ -68,6 +66,8 @@ ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { return *this; } +ShapedBuffer::~ShapedBuffer() {} + void ShapedBuffer::clear() { for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. @@ -104,18 +104,6 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } -/* static */ -StatusOr> ScopedShapedBuffer::MakeScoped( - ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { - auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( - shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), - allocator, shaped_buffer->device_ordinal())); - scoped_buffer->buffers_ = shaped_buffer->buffers(); - shaped_buffer->clear(); - - return std::move(scoped_buffer); -} - ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, DeviceMemoryAllocator* allocator, @@ -128,25 +116,41 @@ ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, DeviceMemoryAllocator* allocator) : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} +ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) + : ShapedBuffer(static_cast(s)), allocator_(s.allocator_) { + // Null out s.allocator_ so it doesn't try to free anything in its destructor. + s.allocator_ = nullptr; +} + +ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { + *static_cast(this) = std::move(static_cast(s)); + allocator_ = s.allocator_; + // Null out s.allocator_ so it doesn't try to free anything in its destructor. + s.allocator_ = nullptr; + return *this; +} + ScopedShapedBuffer::~ScopedShapedBuffer() { + // allocator_ will be null if we were moved-from. + if (allocator_ == nullptr) { + return; + } // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - std::set deallocated_opaques; + tensorflow::gtl::FlatSet deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && - deallocated_opaques.count(memory_base.opaque()) == 0) { - deallocated_opaques.insert(memory_base.opaque()); - TF_CHECK_OK( - this->allocator_->Deallocate(this->device_ordinal(), &memory_base)); + deallocated_ptrs.insert(memory_base.opaque()).second) { + TF_CHECK_OK(allocator_->Deallocate(device_ordinal(), memory_base)); } } } -std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = MakeUnique(std::move(*this)); - buffers_ = ShapeTree(); +ShapedBuffer ScopedShapedBuffer::release() { + ShapedBuffer shaped_buffer(static_cast(*this)); + buffers_ = ShapeTree(); return shaped_buffer; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index b816df8385e..25b709523b7 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -30,6 +30,8 @@ limitations under the License. namespace xla { +class ScopedShapedBuffer; + // Class which encapsulates a buffer or set of buffers containing data of a // particular XLA shape. class ShapedBuffer { @@ -41,8 +43,19 @@ class ShapedBuffer { // determines the number of device allocations (DeviceMemoryBase) held by the // ShapedBuffer. ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, - const perftools::gputools::Platform* platform, - int device_ordinal); + const se::Platform* platform, int device_ordinal); + + // Movable, but not copyable. + ShapedBuffer(ShapedBuffer&& s); + ShapedBuffer& operator=(ShapedBuffer&&); + ShapedBuffer(const ShapedBuffer&) = delete; + ShapedBuffer& operator=(const ShapedBuffer&) = delete; + + // Prevent (some forms of) accidental object slicing. + ShapedBuffer(const ScopedShapedBuffer&) = delete; + ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; + + virtual ~ShapedBuffer(); // Returns the shape of the on-host representation of the data held by this // ShapedBuffer. @@ -52,48 +65,36 @@ class ShapedBuffer { // ShapedBuffer. const Shape& on_device_shape() const { return on_device_shape_; } - const perftools::gputools::Platform* platform() const { return platform_; } + const se::Platform* platform() const { return platform_; } int device_ordinal() const { return device_ordinal_; } // Return the root buffer of the shape (shape index {}). - const perftools::gputools::DeviceMemoryBase& root_buffer() const { + const se::DeviceMemoryBase& root_buffer() const { return buffer(/*index=*/{}); } // Returns the buffer at the given shape index where index is defined as in // ShapeUtil::GetSubshape. - const perftools::gputools::DeviceMemoryBase& buffer( - const ShapeIndex& index) const { + const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const { return buffers_.element(index); } // Sets the device memory buffer at the given index. - void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& index) { + void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) { *buffers_.mutable_element(index) = buffer; } // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. - const ShapeTree& buffers() const { - return buffers_; - } - ShapeTree& buffers() { - return buffers_; - } + const ShapeTree& buffers() const { return buffers_; } + ShapeTree& buffers() { return buffers_; } // Set all device memory pointers in the object to null. void clear(); string ToString() const; - ShapedBuffer(ShapedBuffer&& s); - ShapedBuffer& operator=(ShapedBuffer&&); - protected: - ShapedBuffer(const ShapedBuffer&) = delete; - ShapedBuffer& operator=(const ShapedBuffer&) = delete; - // The shape of the data when represented on the host. Shape on_host_shape_; @@ -101,13 +102,13 @@ class ShapedBuffer { Shape on_device_shape_; // The platform the memory is allocated on. - const perftools::gputools::Platform* platform_; + const se::Platform* platform_; // The device the memory is allocated on. int device_ordinal_; // The tree of device buffers. Its shape is on_device_shape(). - ShapeTree buffers_; + ShapeTree buffers_; }; std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); @@ -115,41 +116,59 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); // ShapedBuffer derived class which allocates all internal buffers on // construction and deallocates the memory when the object is // destructed. +// +// TODO(timshen): Remove inheritance between ScopedShapedBuffer and +// ShapedBuffer. There should never be a need to consider a ScopedShapedBuffer +// as a ShapedBuffer, because in that case we should just be able to pass around +// our ShapeTree. Inheritance only adds complexity. See +// discussion in cl/192849370. class ScopedShapedBuffer : public ShapedBuffer { public: - // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the - // deallocation of the device memory held in the shaped buffer. All device - // memory pointers in the given ShapedBuffer are set to null. - static StatusOr> MakeScoped( - ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); - - // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. - ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, - DeviceMemoryAllocator* allocator, int device_ordinal); + // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. + explicit ScopedShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + DeviceMemoryAllocator* allocator, + int device_ordinal); // Create a ScopedShapedBuffer by taking over the memory from the incoming // ShapedBuffer. - ScopedShapedBuffer(ShapedBuffer shaped_buffer, - DeviceMemoryAllocator* allocator); + explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer, + DeviceMemoryAllocator* allocator); + + // Movable, but not copyable. + ScopedShapedBuffer(ScopedShapedBuffer&& s); + ScopedShapedBuffer& operator=(ScopedShapedBuffer&&); + ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; + ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; + + // All buffers in the shape are deallocated on destruction. + ~ScopedShapedBuffer() override; // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Release all device memory owned by this ScopedShapedBuffer and - // return the device memory pointers in the form of a - // ShapedBuffer. The returned ShapedBuffer takes over the memory - // from the ScopedShapedBuffer. The resulting ScopedShapedBuffer can - // only be destroyed. - std::unique_ptr release(); + // Sets the device memory buffer at the given index. + // + // If the given buffer's device memory is non-null, its device_ordinal and + // allocator must match those in `this`. + void set_buffer(OwningDeviceMemory buffer, const ShapeIndex& index) { + if (!buffer.is_null()) { + CHECK_EQ(buffer.device_ordinal(), device_ordinal()); + CHECK_EQ(buffer.allocator(), allocator_); + *buffers_.mutable_element(index) = buffer.Forget(); + } else { + *buffers_.mutable_element(index) = se::DeviceMemoryBase(); + } + } - // All buffers in the shape are deallocated on destruction. - virtual ~ScopedShapedBuffer(); + // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from + // this ScopedShapedBuffer, without freeing any of the associated memory. + // + // It's the caller's job to ensure that the memory contained therein is freed. + TF_MUST_USE_RESULT ShapedBuffer release(); protected: - ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; - void operator=(const ScopedShapedBuffer&) = delete; - DeviceMemoryAllocator* allocator_; }; diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index a776d745f4e..18e2651abb1 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { namespace source_map_util { -// Creates an INVALID_ARUGMENT status with the given format string. +// Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable // and append it to the status message for source mapping to user code. diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 2f36e2b16e0..c4d01562c4e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -25,24 +25,20 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" -namespace se = ::perftools::gputools; - namespace xla { /* static */ tensorflow::mutex TransferManager::platform_transfer_manager_mutex_( tensorflow::LINKER_INITIALIZED); -/* static */ std::map* +/* static */ std::map* TransferManager::GetPlatformTransferManagers() { - static auto* r = - new std::map; + static auto* r = new std::map; return r; } Status TransferManager::TransferArrayToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - const perftools::gputools::DeviceMemoryBase& dest) { + se::StreamExecutor* executor, const LiteralSlice& literal, + const se::DeviceMemoryBase& dest) { const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) << "On-device representation of " @@ -61,8 +57,8 @@ Status TransferManager::TransferArrayToDevice( } StatusOr> TransferManager::TransferArrayFromDevice( - perftools::gputools::StreamExecutor* executor, const Shape& shape, - const perftools::gputools::DeviceMemoryBase& source) { + se::StreamExecutor* executor, const Shape& shape, + const se::DeviceMemoryBase& source) { TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) << "Shape " << ShapeUtil::HumanString(shape) << " has a differently shaped representation on-device: " @@ -112,8 +108,7 @@ StatusOr> TransferManager::TransferArrayFromDevice( } Status TransferManager::WriteTupleIndexTables( - perftools::gputools::StreamExecutor* executor, - const ShapedBuffer& device_buffer) { + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "Writing tuple index tables for " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); @@ -180,7 +175,7 @@ Status TransferManager::TransferBufferToDevice( return Status::OK(); } -StatusOr> TransferManager::AllocateShapedBuffer( +StatusOr TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { @@ -192,31 +187,23 @@ StatusOr> TransferManager::AllocateShapedBuffer( const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); - auto shaped_buffer = WrapUnique(new ShapedBuffer( - on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); + ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator, + device_ordinal); // Allocate an appropriate sized buffer for each element in the shape // including the tuple pointer arrays. - for (auto& pair : shaped_buffer->buffers()) { + for (auto& pair : shaped_buffer.buffers()) { const ShapeIndex& index = pair.first; se::DeviceMemoryBase& memory_base = pair.second; const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); - TF_ASSIGN_OR_RETURN(memory_base, - allocator->Allocate(shaped_buffer->device_ordinal(), + TF_ASSIGN_OR_RETURN(auto memory, + allocator->Allocate(shaped_buffer.device_ordinal(), GetByteSizeRequirement(subshape))); + // Move the allocated buffer into the ScopedShapedBuffer, which owns it. + memory_base = memory.Forget(); } return std::move(shaped_buffer); } -StatusOr> -TransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, - DeviceMemoryAllocator* allocator, - int device_ordinal) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr unscoped_buffer, - AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); - return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 9f2b5c4aecf..43a8092b06f 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -42,7 +42,7 @@ class TransferManager { virtual ~TransferManager() {} // Returns the ID of the platform that this transfer manager acts on. - virtual perftools::gputools::Platform::Id PlatformId() const = 0; + virtual se::Platform::Id PlatformId() const = 0; // Returns the shape of the on-device representation for the given shape on // the host. This is intended for use with ShapedBuffer where buffers are @@ -58,48 +58,45 @@ class TransferManager { // DeviceShape(literal_shape) must be compatible, but need not have the same // layout. virtual StatusOr> TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const ShapedBuffer& device_buffer) = 0; + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; // Transfers the given literal into the previously allocated device memory // represented by the given ShapedBuffer using the given executor. The shape // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, // but need not have the same layout - virtual Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - const ShapedBuffer& device_buffer) = 0; + virtual Status TransferLiteralToDevice(se::StreamExecutor* executor, + const LiteralSlice& literal, + const ShapedBuffer& device_buffer) = 0; // Convenience methods for transferring an array to or from the device at a // known address. This avoids having to construct a ShapedBuffer just to // transfer an array at a known address. - Status TransferArrayToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - const perftools::gputools::DeviceMemoryBase& dest); + Status TransferArrayToDevice(se::StreamExecutor* executor, + const LiteralSlice& literal, + const se::DeviceMemoryBase& dest); StatusOr> TransferArrayFromDevice( - perftools::gputools::StreamExecutor* executor, const Shape& shape, - const perftools::gputools::DeviceMemoryBase& source); + se::StreamExecutor* executor, const Shape& shape, + const se::DeviceMemoryBase& source); // Transfers the given literal into the Infeed interface of the device, // using the given executor. - virtual Status TransferLiteralToInfeed( - perftools::gputools::StreamExecutor* executor, - const Literal& literal) = 0; + virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, + const LiteralSlice& literal) = 0; // Transfers the given literal from the Outfeed interface of the device, // using the given executor. - virtual Status TransferLiteralFromOutfeed( - perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, - Literal* literal) = 0; + virtual Status TransferLiteralFromOutfeed(se::StreamExecutor* executor, + const Shape& literal_shape, + Literal* literal) = 0; // Resets the devices associated with this transfer manager. virtual Status ResetDevices( - tensorflow::gtl::ArraySlice - executor) = 0; + tensorflow::gtl::ArraySlice executor) = 0; // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. - Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, + Status WriteTupleIndexTables(se::StreamExecutor* executor, const ShapedBuffer& device_buffer); // Determines the byte size requirement for the given shape on the underlying @@ -107,13 +104,10 @@ class TransferManager { // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - // Allocate a ShapedBuffer which can hold data with the given on-host + // Allocates a ScopedShapedBuffer which can hold data with the given on-host // shape. The on-device shape may be different as indicated by // HostShapeToDeviceShape. - StatusOr> AllocateShapedBuffer( - const Shape& on_host_shape, DeviceMemoryAllocator* allocator, - int device_ordinal); - StatusOr> AllocateScopedShapedBuffer( + StatusOr AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal); @@ -127,13 +121,13 @@ class TransferManager { // Precondition: a platform kind must not be registered more than once. typedef std::unique_ptr (*TransferManagerCreationFunction)(); static void RegisterTransferManager( - perftools::gputools::Platform::Id platform_id, + se::Platform::Id platform_id, TransferManagerCreationFunction transfer_manager); // Returns the transfer manager singleton pointer if it is available for the // given platform, or an error status if it is not. static StatusOr GetForPlatform( - const perftools::gputools::Platform* platform); + const se::Platform* platform); protected: // Transfer a memory block of the given size from 'source' buffer to the @@ -143,35 +137,32 @@ class TransferManager { // // source is the source data that must be in the target-dependent layout that // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) = 0; + virtual Status TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, const void* source) = 0; // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, int64 size, - void* destination); + virtual Status TransferBufferFromDevice(se::StreamExecutor* executor, + const se::DeviceMemoryBase& source, + int64 size, void* destination); // Transfer a memory block of the given size from 'source' buffer to the given // destination of the device. // // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source, perftools::gputools::DeviceMemoryBase* destination); + virtual Status TransferBufferToDevice(se::StreamExecutor* executor, + int64 size, const void* source, + se::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region // to construct a tuple index table in the platform-specific tuple // representation. virtual Status WriteSingleTupleIndexTable( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + se::StreamExecutor* executor, + tensorflow::gtl::ArraySlice elements, + const Shape& shape, se::DeviceMemoryBase* region) = 0; private: // The mutex that guards the platform-to-transfer manager map. @@ -186,8 +177,7 @@ class TransferManager { }; // Map from platform kind to transfer manager singleton. - static std::map* - GetPlatformTransferManagers(); + static std::map* GetPlatformTransferManagers(); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 83185ac49e9..ba16dc640e2 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -35,7 +35,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( const HloInstruction& dot, const TransposeFolding::TransposableGemmOperandsFn& transposable_gemm_operands) { - if (HloOpcode::kDot != dot.opcode()) { + if (HloOpcode::kDot != dot.opcode() || + dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) { return {}; } @@ -44,6 +45,8 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( auto& operand = *dot.operand(i); if (operand.IsRank2Transpose()) { operand_set.push_back(i); + } else if (ShapeUtil::Rank(operand.shape()) != 2) { + return {}; } } @@ -74,23 +77,39 @@ using InstructionOperandsPair = // Folds the operands of `dot` that are foldable transposes. `computation` is // the parent HLO computation of `dot`. -// -// Returns whether the module is changed. -bool FoldTransposeIntoDot(InstructionOperandsPair pair) { - auto* dot = pair.first; - std::vector instructions_to_fuse(1, dot); - for (const int64 operand_index : pair.second) { - instructions_to_fuse.push_back(dot->mutable_operand(operand_index)); +Status FoldTransposeIntoDot(InstructionOperandsPair pair) { + HloInstruction* dot = pair.first; + + DotDimensionNumbers new_dim_numbers = dot->dot_dimension_numbers(); + HloInstruction* new_lhs = dot->mutable_operand(0); + HloInstruction* new_rhs = dot->mutable_operand(1); + + CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0); + CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1); + + for (int64 operand_index : pair.second) { + // We've checked that there aren't any batch dimensions and that the inputs + // are rank 2, and shape inference guarantees that there is exactly one + // contracting dimension. + if (operand_index == 0) { + CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_lhs_contracting_dimensions( + 0, 1 - new_dim_numbers.lhs_contracting_dimensions(0)); + new_lhs = new_lhs->mutable_operand(0); + } else { + CHECK_EQ(operand_index, 1); + CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose); + new_dim_numbers.set_rhs_contracting_dimensions( + 0, 1 - new_dim_numbers.rhs_contracting_dimensions(0)); + new_rhs = new_rhs->mutable_operand(0); + } } - // Early-exit if no operands are foldable. - if (instructions_to_fuse.size() == 1) { - return false; - } - - dot->parent()->CreateFusionInstruction( - instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot); - return true; + std::unique_ptr new_dot = HloInstruction::CreateDot( + dot->shape(), new_lhs, new_rhs, new_dim_numbers); + return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } // Folds the operands of `convolution` that are foldable transposes. @@ -159,6 +178,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); + convolution.SetupDerivedInstruction(new_conv.get()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); @@ -195,7 +215,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { std::make_pair(instruction, operand_indices)); } } - return tensorflow::Status::OK(); + return Status::OK(); }; for (auto* comp : module->MakeNonfusionComputations()) { @@ -204,7 +224,8 @@ StatusOr TransposeFolding::Run(HloModule* module) { bool changed = false; for (InstructionOperandsPair& pair : foldable_dots) { - changed |= FoldTransposeIntoDot(pair); + TF_RETURN_IF_ERROR(FoldTransposeIntoDot(pair)); + changed = true; } for (InstructionOperandsPair& pair : foldable_convolutions) { changed |= FoldTransposeIntoConvolution(pair); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index caa1a111ad8..f73f1227aaf 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -19,11 +19,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -31,9 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -54,83 +58,102 @@ class TransposeFoldingTest : public HloTestBase { }; TEST_F(TransposeFoldingTest, FoldDotTranspose) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); + string hlo_string = R"( +HloModule FoldDotTranspose - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); +ENTRY entry_computation { + x = f32[2,3]{1,0} parameter(0) + y = f32[2,3]{1,0} parameter(1) + transpose = f32[3,2]{1,0} transpose(y), dimensions={1,0} + ROOT dot = f32[2,2]{1,0} dot(x, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.size()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = *instruction_set.begin(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); + FoldTranspose(module.get()); - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Parameter(0), op::Parameter(1), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfBatchDim) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[2,3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(TransposeFoldingTest, DontFoldTransposeOfRank1Dot) { + string hlo_string = R"( +HloModule FoldDotTranspose + +ENTRY entry_computation { + x = f32[3] parameter(0) + y = f32[3,2] parameter(1) + transpose = f32[2,3] transpose(y), dimensions={1,0} + ROOT dot = f32[2] dot(x, transpose), lhs_batch_dims={}, rhs_batch_dims={0}, lhs_contracting_dims={0}, rhs_contracting_dims={1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TransposeFolding transpose_folding( + [](const HloInstruction& dot, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }, + [](const HloInstruction& convolution, + const TransposeFolding::OperandIndices& candidate_operands) { + return candidate_operands; + }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, transpose_folding.Run(module.get())); + EXPECT_FALSE(changed); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { - auto builder = HloComputation::Builder("entry_computation"); - // 2x1 - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); - // 3x2 - HloInstruction* const1 = - builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); - HloInstruction* transpose0 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); - HloInstruction* transpose1 = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( - ShapeUtil::MakeShape(F32, {1, 3}), - /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); + string hlo_string = R"( +HloModule FoldDotTransposeConstant - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); - FoldTranspose(&module); +ENTRY entry_computation { + constant = f32[2,1]{1,0} constant(f32[2,1] { { 1 }, { 2 } }) + transpose = f32[1,2]{1,0} transpose(constant), dimensions={1,0} + constant.1 = f32[3,2]{1,0} constant(f32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 } }) + transpose.1 = f32[2,3]{1,0} transpose(constant.1), dimensions={1,0} + ROOT dot = f32[1,3]{1,0} dot(transpose, transpose.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); - for (auto* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kFusion) { - CHECK_EQ(2, instruction->operand_count()); - EXPECT_EQ(const0, instruction->operand(0)); - EXPECT_EQ(const1, instruction->operand(1)); - } - } + FoldTranspose(module.get()); - // The created fusion instruction should contain two parameters, two - // transposes (one for each parameter) and one dot. - EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instruction_count()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::Constant(), op::Constant(), + /*lhs_contracting_dim=*/0, /*rhs_contracting_dim=*/1)); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -149,10 +172,10 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary( add->shape(), HloOpcode::kMultiply, add, sub)); - HloModule module("fuse_with_constant_operands"); + auto module = CreateNewModule("fuse_with_constant_operands"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(mul)); - HloInstruction* call = module.OutlineExpressionFromComputation( + module->AddEntryComputation(builder.Build(mul)); + HloInstruction* call = module->OutlineExpressionFromComputation( {add, sub, mul}, "", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); @@ -164,50 +187,32 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { EXPECT_EQ(6, callee_computation->instruction_count()); } -TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { - auto builder = HloComputation::Builder("entry_computation"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}), - /*name=*/"y")); - HloInstruction* transpose_y = - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - DotDimensionNumbers dot_dnums; - dot_dnums.add_lhs_contracting_dimensions(1); - dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, - /*rhs=*/transpose_y, dot_dnums)); +TEST_F(TransposeFoldingTest, FoldDotTransposeInCall) { + string hlo_string = R"( +HloModule FoldDotTransposeInCall - HloModule module("test_module"); - HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(dot)); +callee { + name.0 = f32[2,3]{1,0} parameter(0) + name.1 = f32[2,3]{1,0} parameter(1) + transpose.clone = f32[3,2]{1,0} transpose(name.0), dimensions={1,0} + ROOT dot.clone = f32[2,2]{1,0} dot(name.1, transpose.clone), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} - HloInstruction* call = module.OutlineExpressionFromComputation( - {transpose_y, dot}, "outlined", entry_computation); +ENTRY entry_computation { + y = f32[2,3]{1,0} parameter(1) + x = f32[2,3]{1,0} parameter(0) + ROOT call = f32[2,2]{1,0} call(y, x), to_apply=callee +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + FoldTranspose(module.get()); - FoldTranspose(&module); - - // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set( - entry_computation->instructions().begin(), - entry_computation->instructions().end()); - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(call)) - << "call is not in entry_computation."; - CHECK(instruction_set.empty()) - << "entry_computation should contain exactly 3 instructions."; - HloInstruction* fusion = - call->called_computations().front()->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, fusion->opcode()); - - // The fusion instruction should contain two parameters, one transpose and - // one dot. - EXPECT_EQ(4, fusion->fused_instruction_count()); + const HloComputation* callee = module->GetComputationWithName("callee"); + ASSERT_NE(callee, nullptr); + EXPECT_THAT(callee->root_instruction(), + op::Dot(op::Parameter(1), op::Parameter(0), + /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } // Test that a two dimension swap of the kernel gets folded into convolution. @@ -222,7 +227,7 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -240,10 +245,10 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -275,7 +280,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), y, {1, 3, 0, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -293,10 +298,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -334,7 +339,7 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 2, 3})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -351,10 +356,10 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( @@ -398,7 +403,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* transpose_x = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); - auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(); Window window; for (int i = 0; i < 2; ++i) { WindowDimension* dim = window.add_dimensions(); @@ -415,10 +420,10 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); - HloModule module("test_module"); + auto module = CreateNewModule("test_module"); HloComputation* entry_computation = - module.AddEntryComputation(builder.Build(conv)); - FoldTranspose(&module); + module->AddEntryComputation(builder.Build(conv)); + FoldTranspose(module.get()); // Instructions after folding: x, y, and the convolution. std::unordered_set instruction_set( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 657a8fe09ae..8cb654493ca 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -588,4 +588,201 @@ void TuplePointsToAnalysis::InstructionToString( }); } +bool TuplePointsToAnalysis::DoesNotUseOperandBuffer( + const HloInstruction* operand, const ShapeIndex& index, + const HloInstruction* user) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) { + // GetTupleElement instructions only access the top-level buffer of their + // operand. + return true; + } else if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + auto it = std::find_if( + user->fused_parameters().begin(), user->fused_parameters().end(), + [=](HloInstruction* fused_param) { + return user->operand(fused_param->parameter_number()) == operand; + }); + CHECK(it != user->fused_parameters().end()); + // Iterate through all users of all buffer aliases of the buffer in the + // points-to set of fusion parameter at 'index'. + // Return false if any uses are detected at 'index', returns true otherwise. + const LogicalBuffer* buffer = GetBufferDefinedAt(*it, index).ValueOrDie(); + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + // Return false: use detected at 'buffer' -> 'alias' -> 'alias_user'. + return false; + } + } + // Return true: found no uses of 'operand' at 'index' in 'user'. + return true; + } + return false; +} + +// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. +// Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) +// where 'user' is a user of an alias of 'instruction' at 'index', and +// 'operand_index' is the operand index at which the alias appears in the +// operand list of 'user'. +std::vector> +TuplePointsToAnalysis::GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const { + std::vector> uses; + const PointsToSet::BufferList& points_to = + GetPointsToSet(instruction).element(index); + for (const LogicalBuffer* buffer : points_to) { + for (const BufferAlias& alias : GetBufferAliases(*buffer)) { + for (HloInstruction* alias_user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + alias_user)) { + continue; + } + for (int64 op_idx : alias_user->OperandIndices(alias.instruction())) { + uses.emplace_back(alias_user, op_idx); + } + } + } + } + return uses; +} + +// Returns true if there is exactly one use of 'operand' at 'operand_index' +// in 'fusion.fused_instructions', where the singleton use is the fused +// root at operand index 'use_operand_index'. Returns false otherwise. +// +// REQUIRES: 'fusion' opcode is a kFusion instruction. +bool TuplePointsToAnalysis::HasUniqueFusedUseOfOperandAt( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* fusion, const int64 use_operand_index) const { + CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); + // Check that 'operand' is unique in the operand list of 'fusion'. + if (fusion->OperandIndices(operand).size() > 1) { + return false; + } + // Find fusion parameter associated with 'operand'. + const auto& fused_params = fusion->fused_parameters(); + auto fused_param_it = std::find_if( + fused_params.begin(), fused_params.end(), + [&](HloInstruction* fused_param) { + return fusion->operand(fused_param->parameter_number()) == operand; + }); + if (fused_param_it == fused_params.end()) { + return false; + } + auto* fused_param = *fused_param_it; + // Get all uses of 'operand' at 'index' from 'fusion.fused_instructions'. + auto fused_param_uses = + GetAllUsesOfInstructionAtIndex(fused_param, operand_index); + // Return true iff there is exactly one use of 'operand' at 'index', and + // this singleton use is the fused root (at index in 'use_operand_indices'). + return fused_param_uses.size() == 1 && + fused_param_uses[0].first == fusion->fused_expression_root() && + fused_param_uses[0].second == use_operand_index; +} + +// User and operand can share buffers iff both instructions emit the same shape +// and layout, and 'user' meets one of the following qualifications: +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd output fusion instruction where the only use of +// 'operand' at 'index' in the set 'user.fused_instructions' is a kAdd fused +// root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. +bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( + HloInstruction* operand, const ShapeIndex& operand_index, + HloInstruction* user, const ShapeIndex& user_index) const { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + if (user->opcode() == HloOpcode::kFusion) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is kDot or kConvolution. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot; + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, + other_add_operand_index); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index c3743b15016..1ac71301365 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -256,6 +256,23 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { string ToString() const; + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + private: explicit TuplePointsToAnalysis( const HloModule* module, @@ -310,6 +327,13 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { return &per_instruction_[id]; } + std::vector> GetAllUsesOfInstructionAtIndex( + HloInstruction* instruction, const ShapeIndex& index) const; + bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* fusion, + const int64 use_operand_index) const; + // The module this analysis is performed on. const HloModule* module_; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dec446d4dac..f558316b05b 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -805,5 +805,348 @@ TEST_F(FusionPointsToAnalysisTest, FusionParam0TwoUsers) { Run(/*add_additional_gte0_user=*/true); } +class PointsToAnalysisTestBase : public HloTestBase { + protected: + void BuildModule(std::unique_ptr computation) { + module_ = CreateNewModule(); + computation_ = module_->AddEntryComputation(std::move(computation)); + } + + void RunAnalysis() { + CHECK_NOTNULL(module_.get()); + points_to_analysis_ = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + } + + void BuildModuleAndRunAnalysis(std::unique_ptr computation) { + BuildModule(std::move(computation)); + RunAnalysis(); + } + + std::unique_ptr module_; + HloComputation* computation_ = nullptr; + std::unique_ptr points_to_analysis_; +}; + +class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; + +TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { + auto builder = HloComputation::Builder(TestName()); + + Shape elem_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1)); + builder.AddInstruction( + HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // GetTupleElement instructions only access the top-level buffer of their + // operand. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, gte0)); + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, gte1)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte0)); + EXPECT_FALSE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {}, gte1)); +} + +TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(points_to_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE( + points_to_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); +} + +class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { + auto builder = HloComputation::Builder(TestName()); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + auto result = builder.AddInstruction( + HloInstruction::CreateBinary(out_shape, HloOpcode::kEq, param0, param1)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + result, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + result, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { + auto builder = HloComputation::Builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {8}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kExp, param)); + auto copy = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(param, {}, exp, {})); + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(exp, {}, copy, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can share with tuple element 1. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {0}, + fusion, {})); + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(tuple, {1}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape starts_shape = ShapeUtil::MakeShape(S32, {1}); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto update = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "update")); + auto starts = builder.AddInstruction( + HloInstruction::CreateParameter(2, starts_shape, "starts")); + auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, data, update, starts)); + + BuildModuleAndRunAnalysis(builder.Build()); + + // The DynamicUpdateSlice instruction can share with the data operand, but not + // with update or starts. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(update, {}, dus, {})); + EXPECT_FALSE( + points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, dot}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused dot add should be able to share buffer with 'add_operand'. + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser( + add_operand, {}, fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, operand, {0, 1})); + + auto two = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, two, reverse}, HloInstruction::FusionKind::kOutput); + RunAnalysis(); + + // Output fused operand->reverse->add cannot alias operand buffer 'operand'. + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + + auto make_cond = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Cond"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, data, data)); + return builder.Build(); + }; + + auto make_body = [this, &data_shape]() { + auto builder = HloComputation::Builder(TestName() + ".Body"); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data)); + return builder.Build(); + }; + + module_ = CreateNewModule(); + HloComputation* cond_computation = + module_->AddEmbeddedComputation(make_cond()); + HloComputation* body_computation = + module_->AddEmbeddedComputation(make_body()); + + auto builder = HloComputation::Builder(TestName()); + auto data = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "data")); + auto whil = builder.AddInstruction(HloInstruction::CreateWhile( + data_shape, cond_computation, body_computation, data)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + // The While instruction can share with the data operand. + EXPECT_TRUE( + points_to_analysis_->CanShareOperandBufferWithUser(data, {}, whil, {})); +} + +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(reverse, {}, + call, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index 113c2e2bd9f..d668855084a 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -69,6 +69,7 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // Tuple // HloInstruction* top_tuple = nullptr; + HloInstruction* first_gte = nullptr; bool can_simplify = true; for (int64 operand_number = 0; operand_number < instruction->operand_count(); ++operand_number) { @@ -78,11 +79,17 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - + if (first_gte == nullptr) { + first_gte = operand; + } else if (!first_gte->has_compatible_sharding(operand)) { + can_simplify = false; + break; + } if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), - instruction->shape())) { + instruction->shape()) || + !instruction->has_compatible_sharding(top_tuple)) { can_simplify = false; break; } @@ -108,15 +115,17 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // | // GTE if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { - changed = true; HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); - for (HloInstruction* user : element_source->users()) { - if (user->opcode() == HloOpcode::kTuple || - user->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(user); + if (instruction->has_compatible_sharding(element_source)) { + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); + } } } } diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 532f7fd5bfc..9e62d0acfb9 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -49,10 +49,14 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kAbs; case UNOP_CEIL: return HloOpcode::kCeil; + case UNOP_CLZ: + return HloOpcode::kClz; case UNOP_COS: return HloOpcode::kCos; case UNOP_EXP: return HloOpcode::kExp; + case UNOP_EXPM1: + return HloOpcode::kExpm1; case UNOP_FLOOR: return HloOpcode::kFloor; case UNOP_IMAG: @@ -61,6 +65,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; + case UNOP_LOG1P: + return HloOpcode::kLog1p; case UNOP_NOT: return HloOpcode::kNot; case UNOP_NEGATE: diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc new file mode 100644 index 00000000000..10fc4958fae --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +// Replaces all uses of old_instr with new_instr except the use at +// `while_body_root` (which must be a tuple instruction) at index `tuple_index`. +// This utility helps us replace an instruction in the while body with a +// constant while still keeping it trivially loop invariant. +static Status ReplaceUsesWhileKeepingLoopInvariance( + HloInstruction* old_instr, HloInstruction* new_instr, + HloInstruction* while_body_root, int64 tuple_index) { + CHECK_EQ(while_body_root->opcode(), HloOpcode::kTuple); + + std::vector users; + users.reserve(old_instr->user_count()); + c_copy(old_instr->users(), std::back_inserter(users)); + + for (auto* user : users) { + for (int64 i = 0, e = user->operand_count(); i < e; i++) { + if (user->operand(i) == old_instr && + !(user == while_body_root && i == tuple_index)) { + TF_RETURN_IF_ERROR(user->ReplaceOperandWith(i, new_instr)); + } + } + } + + return Status::OK(); +} + +StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileBody( + HloInstruction* while_instr) { + HloComputation* while_body = while_instr->while_body(); + + const HloInstruction& init_value = *while_instr->operand(0); + if (init_value.opcode() != HloOpcode::kTuple) { + return false; + } + + bool changed = false; + + for (HloInstruction* invariant_gte : + WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + int64 index = invariant_gte->tuple_index(); + const HloInstruction& invariant_value = *init_value.operand(index); + if (invariant_value.opcode() == HloOpcode::kConstant) { + auto* constant_instr = + while_body->AddInstruction(invariant_value.Clone(/*suffix=*/".sunk")); + TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( + invariant_gte, constant_instr, while_body->root_instruction(), + index)); + changed = true; + } + } + + return changed; +} + +StatusOr WhileLoopConstantSinking::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->MakeNonfusionComputations()) { + // Right now we don't particulary care about optimizing while-of-while + // patterns. If/When we do, we'll want to visit the outer while (while_0) + // before we visit the inner while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(constant, ...), body=while_0_body, ...) + // } + // + // This will let us sink the constant into the outer while first and then + // into the inner while in a single run of this pass. + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // We only sink into while loop bodies, but this can be extended to + // transform conditions as well. + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileBody(while_instr)); + changed |= result; + } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h new file mode 100644 index 00000000000..21fb8568a84 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// Sinks while loop invariant values that happen to be constants into the while +// loop body. This is probably not a win in isolation but may unlock further +// optimizations like constant folding. +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(v) +// state = (..., v, ...) +// } +// +// => +// +// state = (..., const, ...) +// while (pred(state)) { +// (..., v, ...) = state +// use(const) +// state = (..., v, ...) +// } +// +// Note that it leaves the `v` in place to keep that component of the state +// tuple trivially loop invariant. WhileLoopSimplifier will later get rid of +// `v`. +// +// We only sink into while loop bodies, but this can be extended to transform +// conditions as well. +// +// TODO(b/79121449): We should also sink broadcasts of constants. +class WhileLoopConstantSinking : public HloPassInterface { + public: + ~WhileLoopConstantSinking() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr TrySinkingConstantsIntoWhileBody(HloInstruction* while_instr); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_CONSTANT_SINKING_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc new file mode 100644 index 00000000000..0d2288d8ea6 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -0,0 +1,200 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; +using ::testing::_; + +class WhileLoopConstantSinkingTest : public ::testing::Test {}; + +TEST_F(WhileLoopConstantSinkingTest, SinkOneConstant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + ROOT while = (f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant()), _)); +} + +TEST_F(WhileLoopConstantSinkingTest, KeepConstantsLoopInvariant) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=1 + p_body.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_body), index=2 + + add.0 = f32[2] add(p_body.1, p_body.2) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_body.1, p_body.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), op::Constant()), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, TupleShapedConstants) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],(f32[2],f32[2])) parameter(0) + p_b.0 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=0 + p_b.1 = (f32[2],f32[2]) get-tuple-element((f32[2],(f32[2],f32[2])) p_b), index=1 + + p_b.1.1 = f32[2] get-tuple-element(p_b.1), index=0 + + ROOT root = (f32[2],f32[2],f32[2]) tuple(p_b.1.1, p_b.1) +} + +condition { + p_cond = (f32[2],(f32[2],f32[2])) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = (f32[2], f32[2]) constant((f32[2], f32[2]) ({2, 1},{3,1})) + while_init = (f32[2],(f32[2],f32[2])) tuple(const_0, const_1) + ROOT while = (f32[2],(f32[2],f32[2])) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::GetTupleElement(op::Constant(), 0), + op::GetTupleElement(op::Parameter(0)))); +} + +TEST_F(WhileLoopConstantSinkingTest, DuplicateGTEs) { + // This test shows that the pass fails to optimize non-canonical IR. + // + // Even though the input IR has a constant value for p_b.2.dup, + // WhileLoopConstantSinking doesn't try to detect this. Instead, it relies on + // prior runs of HLO CSE to have commoned these identical GTE instructions. + + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_b = (f32[2],f32[2],f32[2]) parameter(0) + + p_b.1 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=1 + p_b.2 = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + p_b.2.dup = f32[2] get-tuple-element((f32[2],f32[2],f32[2]) p_b), index=2 + + add.0 = f32[2] add(p_b.1, p_b.2.dup) + ROOT root = (f32[2],f32[2],f32[2]) tuple(add.0, p_b.1, p_b.2) +} + +condition { + p_cond = (f32[2],f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + const_2 = f32[2] constant({3, 1}) + while_init = (f32[2],f32[2],f32[2]) tuple(const_0, const_1, const_2) + ROOT while = (f32[2],f32[2],f32[2]) while(while_init), condition=condition, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopConstantSinking{}.Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body"); + EXPECT_THAT(while_body->root_instruction(), + op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())), + op::GetTupleElement(op::Parameter(0)), + op::GetTupleElement(op::Parameter(0)))); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 3ef0cdff675..321fdeb1ea3 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -115,25 +115,6 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -// Populates `gte_set` with the GetTupleElement instructions in `while_body` -// that access elements in the parameter tuple that don't change across -// iterations. Assumes `while_body` is the body computation of the while loop -// in question. -static void GatherInvariantGTEs(HloComputation* while_body, - FlatSet* gte_set) { - const HloInstruction::InstructionVector root_operands = - while_body->root_instruction()->operands(); - for (int i = 0; i < root_operands.size(); i++) { - HloInstruction* instr = root_operands[i]; - if (instr->opcode() == HloOpcode::kGetTupleElement && - instr->tuple_index() == i && - instr->operand(0) == while_body->parameter_instruction(0) && - ShapeUtil::IsArray(instr->shape())) { - InsertOrDie(gte_set, instr); - } - } -} - static StatusOr TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -172,7 +153,13 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( // unhoisted_invariant_instructions -- they can be legally hoisted, but there // is no benefit to hoisting them unless something that uses it is also // hoisted. - GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + for (auto* instr : WhileUtil::GetInvariantGTEsForWhileBody(*while_body)) { + if (ShapeUtil::IsArray(instr->shape())) { + // TODO(b/79147885): We should try to generalize this to tuples for + // uniformity's sake, if nothing else. + InsertOrDie(&unhoisted_invariant_instructions, instr); + } + } if (unhoisted_invariant_instructions.empty()) { // There are no obviously loop invariant elements in the state being diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index bd079418432..ed20b36292a 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -244,4 +244,21 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { } return result; } + +/*static*/ std::vector WhileUtil::GetInvariantGTEsForWhileBody( + const HloComputation& while_body) { + std::vector result; + const HloInstruction::InstructionVector root_operands = + while_body.root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body.parameter_instruction(0)) { + result.push_back(instr); + } + } + return result; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 1688d467426..322d27b88ca 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -74,6 +74,12 @@ class WhileUtil { HloComputation* computation, int32 trip_count, const LoopStateTy& init_values, const LoopBodyGeneratorTy& loop_body_generator); + + // Returns the GetTupleElement instructions in `while_body` that access + // elements in the parameter tuple that don't change across iterations. + // Assumes `while_body` is the body computation of the while loop in question. + static std::vector GetInvariantGTEsForWhileBody( + const HloComputation& while_body); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index cf0d0db99bd..974bc542a34 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -126,5 +126,42 @@ TEST(WhileUtilTest, MakeTwoInstructionsLive) { op::GetTupleElement(op::Parameter(0), 3))); } +TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT constant = pred[] constant(true) +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(hlo_string)); + + HloComputation* while_body = module->GetComputationWithName("body"); + + ASSERT_NE(while_body, nullptr) + << "Expected exactly one while_body computation"; + + std::vector gte_list = + WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + + ASSERT_EQ(gte_list.size(), 1); + EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc index 4f8cdc1e0e7..f5331280ee9 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,9 +45,9 @@ class ZeroSizedHloEliminationTest : public HloTestBase { 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} StatusOr RunZeroSizedElimination() { - HloModule module("zero_sized_elimination_test_module"); - module.AddEntryComputation(builder_.Build()); - return ZeroSizedHloElimination{}.Run(&module); + auto module = CreateNewModule("zero_sized_elimination_test_module"); + module->AddEntryComputation(builder_.Build()); + return ZeroSizedHloElimination{}.Run(module.get()); } HloComputation::Builder builder_; diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 5b44c26b7c7..141347a792c 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -16,8 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_ +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -31,99 +32,93 @@ class ServiceInterface { virtual ~ServiceInterface() = default; // TODO(b/31824348): Convert to use StatusOr. - virtual tensorflow::Status TransferToClient( - const TransferToClientRequest* arg, TransferToClientResponse* result) = 0; + virtual Status TransferToClient(const TransferToClientRequest* arg, + TransferToClientResponse* result) = 0; - virtual tensorflow::Status TransferToServer( - const TransferToServerRequest* arg, TransferToServerResponse* result) = 0; + virtual Status TransferToServer(const TransferToServerRequest* arg, + TransferToServerResponse* result) = 0; - virtual tensorflow::Status TransferToInfeed( - const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual Status TransferToInfeed(const TransferToInfeedRequest* arg, + TransferToInfeedResponse* result) = 0; - virtual tensorflow::Status TransferFromOutfeed( - const TransferFromOutfeedRequest* arg, - TransferFromOutfeedResponse* result) = 0; + virtual Status TransferFromOutfeed(const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; - virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, - ResetDeviceResponse* result) = 0; + virtual Status ResetDevice(const ResetDeviceRequest* arg, + ResetDeviceResponse* result) = 0; - virtual tensorflow::Status LoadComputationSnapshot( + virtual Status LoadComputationSnapshot( const LoadComputationSnapshotRequest* request, LoadComputationSnapshotResponse* result) = 0; - virtual tensorflow::Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; + virtual Status Execute(const ExecuteRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status ExecuteGraph(const ExecuteGraphRequest* arg, - ExecuteResponse* result) = 0; + virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, + ExecuteResponse* result) = 0; - virtual tensorflow::Status ExecuteParallel( - const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0; + virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteGraphParallel( - const ExecuteGraphParallelRequest* arg, - ExecuteParallelResponse* result) = 0; + virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, + ExecuteParallelResponse* result) = 0; - virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; + virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, + ExecuteAsyncResponse* result) = 0; - virtual tensorflow::Status WaitForExecution( - const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; + virtual Status WaitForExecution(const WaitForExecutionRequest* arg, + WaitForExecutionResponse* result) = 0; - virtual tensorflow::Status DeconstructTuple( - const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; + virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, + DeconstructTupleResponse* result) = 0; - virtual tensorflow::Status GetComputationStats( - const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0; + virtual Status GetComputationStats(const ComputationStatsRequest* arg, + ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationGraphStats( + virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual tensorflow::Status GetComputationShape( - const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; + virtual Status GetComputationShape(const GetComputationShapeRequest* arg, + GetComputationShapeResponse* result) = 0; - virtual tensorflow::Status GetShape(const GetShapeRequest* arg, - GetShapeResponse* result) = 0; + virtual Status GetShape(const GetShapeRequest* arg, + GetShapeResponse* result) = 0; - virtual tensorflow::Status CreateChannelHandle( - const CreateChannelHandleRequest* arg, - CreateChannelHandleResponse* result) = 0; + virtual Status CreateChannelHandle(const CreateChannelHandleRequest* arg, + CreateChannelHandleResponse* result) = 0; - virtual tensorflow::Status GetDeviceHandles( - const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; + virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, + GetDeviceHandlesResponse* result) = 0; // Methods used by ComputationBuilder. - virtual tensorflow::Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; + virtual Status Computation(const ComputationRequest* arg, + ComputationResponse* result) = 0; - virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0; + virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; + virtual Status GetLocalShape(const GetLocalShapeRequest* arg, + GetLocalShapeResponse* result) = 0; - virtual tensorflow::Status SetReturnValue( - const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0; + virtual Status SetReturnValue(const SetReturnValueRequest* arg, + SetReturnValueResponse* results) = 0; - virtual tensorflow::Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; + virtual Status IsConstant(const IsConstantRequest* arg, + IsConstantResponse* result) = 0; - virtual tensorflow::Status ComputeConstant( - const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0; + virtual Status ComputeConstant(const ComputeConstantRequest* arg, + ComputeConstantResponse* result) = 0; - virtual tensorflow::Status ComputeConstantGraph( - const ComputeConstantGraphRequest* arg, - ComputeConstantResponse* result) = 0; + virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, + ComputeConstantResponse* result) = 0; // Methods used by Computation. - virtual tensorflow::Status SnapshotComputation( - const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; + virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, + SnapshotComputationResponse* result) = 0; // Methods used by GlobalData. - virtual tensorflow::Status Unregister(const UnregisterRequest* arg, - UnregisterResponse* result) = 0; + virtual Status Unregister(const UnregisterRequest* arg, + UnregisterResponse* result) = 0; }; } // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 789eba5780d..7ee366b27a8 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -22,24 +22,24 @@ limitations under the License. namespace xla { -tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { +Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { if (!ShapeUtil::Compatible(other_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(other_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } shape_ = other_shape; - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { +Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } *to_shape = shape_; - return tensorflow::Status::OK(); + return Status::OK(); } void ShapeLayout::SetToDefaultLayout() { diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 4c83750f3e6..36806da599c 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -40,7 +40,7 @@ class ShapeLayout { // Assigns the layouts in this ShapeLayout to the Layout fields of the given // shape. 'to_shape' and the shape of the ShapeLayout object must be // compatible. - tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; + Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible @@ -48,9 +48,8 @@ class ShapeLayout { bool MatchesLayoutInShape(const Shape& shape) const; // Copies the layout from the given shape into this ShapeLayout. 'other_shape' - // must be compatible with the ShapeLayout's shape, and 'other_shape' must - // have a layout (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); + // must be compatible with the ShapeLayout's shape. + Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ac7e201bfdc..7a897f6f8f9 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -510,7 +511,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { break; } else if (must_end) { return InvalidArgument("Expected end of tuple; got: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } shapes.emplace_back(); TF_ASSIGN_OR_RETURN(shapes.back(), ParseShapeStringInternal(s)); @@ -540,7 +541,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { return InvalidArgument( "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - input.c_str(), s->ToString().c_str()); + input.c_str(), std::string(*s).c_str()); } return element; }; @@ -593,7 +594,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } return InvalidArgument("Invalid shape string to parse: \"%s\"", - s->ToString().c_str()); + std::string(*s).c_str()); } } // namespace @@ -602,7 +603,7 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); if (!s.empty()) { return InvalidArgument("Invalid shape string to parse: \"%s\"", - s.ToString().c_str()); + std::string(s).c_str()); } return shape; } @@ -905,10 +906,17 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { std::is_permutation(minor_to_major.begin(), minor_to_major.end(), dims.begin())); } - Shape stripped_shape = - shape.has_layout() ? MakeShapeWithLayout(shape.element_type(), - dimension_sizes, minor_to_major) - : MakeShape(shape.element_type(), dimension_sizes); + Shape stripped_shape; + if (LayoutUtil::IsDenseArray(shape)) { + stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, + minor_to_major); + } else if (LayoutUtil::IsSparseArray(shape)) { + stripped_shape = + MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, + shape.layout().max_sparse_elements()); + } else { + stripped_shape = MakeShape(shape.element_type(), dimension_sizes); + } VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); @@ -1465,4 +1473,26 @@ std::ostream& operator<<(std::ostream& out, const Shape& shape) { return out; } +/*static*/ size_t ShapeUtil::Hash(const Shape& shape) { + using tensorflow::hash; + using tensorflow::Hash64Combine; + + size_t hash_value = hash()(shape.element_type()); + + if (shape.tuple_shapes().empty()) { + for (int64 dim : shape.dimensions()) { + hash_value = Hash64Combine(hash_value, hash()(dim)); + } + + hash_value = Hash64Combine(hash_value, LayoutUtil::Hash(shape.layout())); + } else { + hash_value = 0; + for (const Shape& subshape : shape.tuple_shapes()) { + hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(subshape)); + } + } + + return hash_value; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 63da9154cfc..82c75f85d83 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -230,7 +231,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that that they have the same element type + // point types; otherwise, checks that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -649,6 +650,9 @@ class ShapeUtil { .ok()); } + // Compute a hash for `shape`. + static size_t Hash(const Shape& shape); + private: // Validates all of the non-layout properties of the shape -- this is a helper // used by both the layout-optional and layout-required public method. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 13582a2a267..f7675e97da7 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -713,6 +713,16 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } +TEST(ShapeUtilTest, StripDegenerateDimensions) { + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShape(F32, {3, 1, 2})), + ShapeUtil::MakeShape(F32, {3, 2}))); + EXPECT_TRUE(ShapeUtil::Equal( + ShapeUtil::StripDegenerateDimensions( + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), + ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); +} + TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h index 4eb3bf37664..69abb51852a 100644 --- a/tensorflow/compiler/xla/status.h +++ b/tensorflow/compiler/xla/status.h @@ -21,7 +21,7 @@ limitations under the License. namespace xla { -using tensorflow::Status; +using tensorflow::Status; // TENSORFLOW_STATUS_OK } // namespace xla diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h index 641b5e9a6ac..0e1387c9393 100644 --- a/tensorflow/compiler/xla/statusor.h +++ b/tensorflow/compiler/xla/statusor.h @@ -13,13 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. +// StatusOr is the union of a Status object and a T object. StatusOr models +// the concept of an object that is either a value, or an error Status +// explaining why such a value is not present. To this end, StatusOr does not +// allow its Status value to be Status::OK. // // The primary use-case for StatusOr is as the return value of a // function which may fail. @@ -113,17 +110,19 @@ class StatusOr : private internal_statusor::StatusOrData, StatusOr& operator=(StatusOr&&) = default; // Conversion copy/move constructor, T must be convertible from U. - // TODO(b/62186717): These should not participate in overload resolution if U - // is not convertible to T. - template + template ::value>::type* = nullptr> StatusOr(const StatusOr& other); - template + template ::value>::type* = nullptr> StatusOr(StatusOr&& other); // Conversion copy/move assignment operator, T must be convertible from U. - template + template ::value>::type* = nullptr> StatusOr& operator=(const StatusOr& other); - template + template ::value>::type* = nullptr> StatusOr& operator=(StatusOr&& other); // Constructs a new StatusOr with the given value. After calling this @@ -233,12 +232,14 @@ StatusOr& StatusOr::operator=(Status&& status) { } template -template +template ::value>::type*> inline StatusOr::StatusOr(const StatusOr& other) : Base(static_cast::Base&>(other)) {} template -template +template ::value>::type*> inline StatusOr& StatusOr::operator=(const StatusOr& other) { if (other.ok()) this->Assign(other.ValueOrDie()); @@ -248,12 +249,14 @@ inline StatusOr& StatusOr::operator=(const StatusOr& other) { } template -template +template ::value>::type*> inline StatusOr::StatusOr(StatusOr&& other) : Base(static_cast::Base&&>(other)) {} template -template +template ::value>::type*> inline StatusOr& StatusOr::operator=(StatusOr&& other) { if (other.ok()) { this->Assign(std::move(other).ValueOrDie()); diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index f9d25945bc6..377a618ffbd 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) { static_assert(std::is_same::element_type, char>(), ""); } +TEST(StatusOr, NullPointerStatusOr) { + // As a very special case, null-plain-pointer StatusOr used to be an + // error. Test that it no longer is. + StatusOr null_status(nullptr); + EXPECT_TRUE(null_status.ok()); + EXPECT_EQ(null_status.ValueOrDie(), nullptr); +} + TEST(StatusOr, TestNoDefaultConstructorInitialization) { // Explicitly initialize it with an error code. StatusOr statusor(tensorflow::errors::Cancelled("")); @@ -405,7 +413,7 @@ TEST(StatusOr, TestPointerValueConst) { EXPECT_EQ(&kI, thing.ValueOrDie()); } -// NOTE(tucker): tensorflow::StatusOr does not support this kind +// NOTE(tucker): StatusOr does not support this kind // of resize op. // TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) { // using EvilType = std::vector>; diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 17bae2e4f61..8918350135f 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -40,13 +40,10 @@ class Literal; namespace testing { namespace internal_status { -inline const ::tensorflow::Status& GetStatus( - const ::tensorflow::Status& status) { - return status; -} +inline const Status& GetStatus(const Status& status) { return status; } template -inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { +inline const Status& GetStatus(const StatusOr& status) { return status.status(); } } // namespace internal_status @@ -57,21 +54,17 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { // The following macros are similar to macros in gmock, but deliberately named // differently in order to avoid conflicts in files which include both. -// Macros for testing the results of functions that return tensorflow::Status or +// Macros for testing the results of functions that return Status or // StatusOr (for any type T). -#define EXPECT_IS_OK(expression) \ - EXPECT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) -#define EXPECT_IS_NOT_OK(expression) \ - EXPECT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_OK(expression) \ + EXPECT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK -#define ASSERT_IS_OK(expression) \ - ASSERT_EQ(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_OK(expression) \ + ASSERT_EQ(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_NOT_OK -#define ASSERT_IS_NOT_OK(expression) \ - ASSERT_NE(tensorflow::Status::OK(), \ - xla::testing::internal_status::GetStatus(expression)) +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(Status::OK(), xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 1f90a44d8ba..95acfe59edf 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -87,12 +87,12 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -152,7 +152,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", @@ -186,11 +187,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -257,8 +257,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", @@ -285,9 +285,9 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -308,9 +308,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -328,8 +329,9 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -369,8 +371,9 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:test_utils", @@ -387,9 +390,9 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -410,8 +413,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -438,10 +439,10 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -458,9 +459,10 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", @@ -475,9 +477,10 @@ xla_test( ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -510,10 +513,11 @@ xla_test( tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -532,9 +536,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -550,10 +555,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -574,10 +579,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -600,11 +605,12 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -627,6 +633,7 @@ xla_test( deps = [ ":client_library_test_base", ":literal_test_util", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -668,8 +675,9 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -702,9 +710,6 @@ xla_test( "cpu": [ "--xla_cpu_multi_thread_eigen=false", ], - "cpu_parallel": [ - "--xla_cpu_multi_thread_eigen=false", - ], }, shard_count = 20, tags = ["optonly"], @@ -713,8 +718,9 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -736,8 +742,9 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -758,8 +765,9 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -801,7 +809,6 @@ xla_test( backend_tags = { # TODO(b/31436974): Fix msan failure. Failed on 2016-09-12. "cpu": ["nomsan"], - "cpu_parallel": ["nomsan"], }, shard_count = 30, deps = [ @@ -810,9 +817,10 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -833,9 +841,10 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -893,11 +902,11 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -921,8 +930,8 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -961,8 +970,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1012,8 +1022,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1035,9 +1043,10 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1063,8 +1072,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", @@ -1158,6 +1165,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1194,8 +1202,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1233,8 +1241,9 @@ xla_test( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1254,8 +1263,9 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1292,8 +1302,9 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1308,8 +1319,9 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1332,9 +1344,9 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1353,8 +1365,9 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1423,11 +1436,11 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1470,8 +1483,9 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1510,10 +1524,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1530,8 +1544,9 @@ xla_test( "enable_for_xla_interpreter", ], deps = [ - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1552,8 +1567,6 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", @@ -1574,7 +1587,6 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", @@ -1593,8 +1605,9 @@ xla_test( srcs = ["execution_profile_test.cc"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1606,8 +1619,9 @@ xla_test( args = ["--xla_hlo_profile"], deps = [ ":client_library_test_base", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], @@ -1625,11 +1639,11 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1710,9 +1724,9 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1737,9 +1751,9 @@ xla_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", @@ -1774,9 +1788,9 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1789,7 +1803,10 @@ xla_test( xla_test( name = "local_client_execute_test", + # TODO(b/79375911): Test times out in LLVM at normal size. + size = "large", srcs = ["local_client_execute_test.cc"], + shard_count = 30, tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", @@ -1799,9 +1816,9 @@ xla_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -1853,22 +1870,6 @@ xla_test( ], ) -xla_test( - name = "set_return_value_test", - srcs = ["set_return_value_test.cc"], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - xla_test( name = "reshape_motion_test", srcs = ["reshape_motion_test.cc"], @@ -1882,10 +1883,10 @@ xla_test( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:computation", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1898,6 +1899,7 @@ xla_test( name = "deep_graph_test", srcs = ["deep_graph_test.cc"], deps = [ + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -1981,7 +1983,8 @@ xla_test( ":local_client_test_base", ":test_utils", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/client/xla_client:xla_computation", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 03c91745b97..36a70649691 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -214,7 +213,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { } XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector lhs{0xFFFFFFFF, static_cast(-1), @@ -255,7 +254,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { } XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector lhs{static_cast(0x8000000000000000LL), static_cast(0x8000000000000000LL), @@ -1332,7 +1331,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) { // Some Pow cases that can be implemented more efficiently. XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1360,7 +1359,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { } XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1385,7 +1384,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) { } XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1410,7 +1409,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) { } XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1435,7 +1434,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) { } XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f}; std::vector values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1460,7 +1459,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) { } XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1492,7 +1491,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) { } XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -1525,7 +1524,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) { } XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f}; @@ -1558,7 +1557,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) { } XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f}; std::vector values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; @@ -2217,6 +2216,24 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClzU32s) { + XlaBuilder builder(TestName()); + auto a = builder.ConstantR1( + {0, 1, 0x10, 0x10000, 0x700000, 0x12345678, 0xF2345678}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {32, 31, 27, 15, 9, 3, 0}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClzS64s) { + XlaBuilder builder(TestName()); + auto a = + builder.ConstantR1({0, 1, 0x80000000, 0x7FFFFFFFF2345678ul, -1}); + builder.Clz(a); + + ComputeAndCompareR1(&builder, {64, 63, 32, 1, 0}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / @@ -2348,7 +2365,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) { XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { // Test broadcasting in Eq comparison. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({42, 73}); auto m = builder.ConstantR2({{42, 73}, {42, 52}}); @@ -2774,7 +2791,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { // Regression test for b/31927799. "slice - y" is fused and requires implicit // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x_literal = Literal::CreateR1({1, 2, 3}); auto y_literal = Literal::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index ec3b46acfec..fcd9ff55e39 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -42,7 +41,7 @@ TEST_F(AxpySimpleTest, AxTenValues) { } XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { - ComputationBuilder builder(client_, "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1({}); auto y = builder.ConstantR1({}); @@ -54,7 +53,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) { } TEST_F(AxpySimpleTest, AxpyTenValues) { - ComputationBuilder builder(client_, "axpy_10"); + XlaBuilder builder("axpy_10"); auto alpha = builder.ConstantR0(3.1415926535); auto x = builder.ConstantR1( {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index e4bf1827acf..22c3394e6f3 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -34,13 +34,13 @@ namespace { class BadRngShapeValidationTest : public ClientLibraryTestBase {}; TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0.0); auto one = builder.ConstantR0(1.0); Shape default_constructed; builder.RngUniform(zero, one, default_constructed); - StatusOr computation = builder.Build(); + StatusOr computation = builder.Build(); EXPECT_FALSE(computation.ok()); LOG(INFO) << "status received: " << computation.status(); EXPECT_THAT(computation.status().error_message(), @@ -48,7 +48,7 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) { } TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0.0); auto one = builder.ConstantR0(1.0); Shape sans_layout; @@ -57,7 +57,7 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { builder.RngUniform(zero, one, sans_layout); - StatusOr computation = builder.Build(); + StatusOr computation = builder.Build(); ASSERT_TRUE(computation.ok()); LOG(INFO) << computation.status(); } diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index b853dfaa15d..ca337e78840 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -19,10 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -38,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -52,7 +50,7 @@ class Bfloat16Test : public ClientLibraryTestBase { }; XLA_TEST_F(Bfloat16Test, ScalarOperation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR0(static_cast(2.0f)); auto y = builder.ConstantR0(static_cast(1.0f)); builder.Add(x, y); @@ -62,7 +60,7 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { } XLA_TEST_F(Bfloat16Test, LogOperation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR0(static_cast(4.0f)); builder.Log(x); @@ -71,7 +69,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) { } XLA_TEST_F(Bfloat16Test, NegateScalarF16) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(static_cast(2.1f))); ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, @@ -80,7 +78,7 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) { XLA_TEST_F(Bfloat16Test, BatchNormTraining) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( {{{{static_cast(1.f)}, {static_cast(2.f)}}, @@ -117,7 +115,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { XLA_TEST_F(Bfloat16Test, BatchNormGrad) { const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR4FromArray4D( Array4D(2, 2, 2, 1, static_cast(0.0f))); diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 97fec89b63f..48203b1d40e 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -32,7 +32,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -48,7 +48,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 129); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -64,7 +64,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 9, 5); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -80,7 +80,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 257); auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); builder.Add(lhs, rhs); @@ -93,7 +93,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) { } TEST_F(BinopScalingTest, R0PlusR2F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR0(42.0); auto rhs = builder.ConstantR2({ {1.0, 2.0}, {3.0, 4.0}, @@ -109,7 +109,7 @@ TEST_F(BinopScalingTest, R0PlusR2F32) { } TEST_F(BinopScalingTest, R4PlusR0S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // clang-format off Array4D lhs_array({ {{{1, 2}, diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc index 777ac167a3c..bff60f25ec8 100644 --- a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -34,7 +34,7 @@ namespace { class BitcastConvertTest : public ClientLibraryTestBase { public: - explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr) + explicit BitcastConvertTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 97095f1cc42..34c86e007be 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -33,10 +33,8 @@ namespace { class BroadcastSimpleTest : public ClientLibraryTestBase { public: - ComputationDataHandle BuildBinOp(HloOpcode op, - const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs, - ComputationBuilder* builder) { + XlaOp BuildBinOp(HloOpcode op, const XlaOp& lhs, const XlaOp& rhs, + XlaBuilder* builder) { switch (op) { case HloOpcode::kMinimum: { return builder->Min(lhs, rhs); @@ -105,21 +103,21 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { using ::testing::HasSubstr; XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(1.5), {}); ComputeAndCompareR0(&b, 1.5, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(2.25), {2, 3}); Array2D expected(2, 3, 2.25); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle src; + XlaBuilder b(TestName()); + XlaOp src; std::unique_ptr param_data = CreateR0Parameter(2.25f, /*parameter_number=*/0, /*name=*/"src", /*builder=*/&b, /*data_handle=*/&src); @@ -131,21 +129,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) { } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(2.25), {2, 0}); Array2D expected(2, 0); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR0(2.25), {0, 2}); Array2D expected(0, 2); ComputeAndCompareR2(&b, expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR1({1, 2, 3}), {2}); Array2D expected(2, 3); @@ -160,7 +158,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { // Tests implicit broadcasting of PREDs. XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); Array2D x_vals(2, 1); x_vals(0, 0) = true; @@ -171,7 +169,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { y_vals(1, 0, 0) = true; y_vals(1, 1, 0) = true; - ComputationDataHandle x, y; + XlaOp x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); b.And(x, y, /*broadcast_dimensions=*/{1, 2}); @@ -186,7 +184,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { } XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR1({}), {2}); Array2D expected(2, 0); @@ -194,7 +192,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) { } XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Broadcast(b.ConstantR1({1, 2, 3}), {0}); Array2D expected(0, 3); @@ -209,7 +207,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one // dimensions. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 5.0}}), b.ConstantLiteral(*Literal::CreateR3( @@ -247,7 +245,7 @@ class BroadcastR3ImplicitTest XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { const R3ImplicitBroadcastSpec& spec = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape r3_shape, r3_implicit_shape; Array3D r3_array(spec.output_bounds[0], spec.output_bounds[1], @@ -264,8 +262,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input"); auto r3_parameter = builder.Parameter(1, r3_shape, "input"); - ComputationDataHandle op = - BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); + XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder); Array3D expected_array(spec.output_bounds[0], spec.output_bounds[1], spec.output_bounds[2]); @@ -300,9 +297,9 @@ INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances, // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1: XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { - ComputationBuilder b(client_, TestName()); - ComputationDataHandle r1h; - ComputationDataHandle r3h; + XlaBuilder b(TestName()); + XlaOp r1h; + XlaOp r3h; Array3D r1d = {{{1}}, {{2}}}; Array3D r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; @@ -319,7 +316,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -332,7 +329,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -345,7 +342,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -358,7 +355,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -371,7 +368,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = b.ConstantLiteral( @@ -385,7 +382,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -491,7 +488,7 @@ class BroadcastR2ImplicitTest XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { const R2ImplicitBroadcastSpec& spec = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Operands with degenerate dimensions require implicit broadcasting: Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2; @@ -517,10 +514,9 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { auto r2_implicit_parameter2 = builder.Parameter(2, r2_implicit_shape2, "input2"); - ComputationDataHandle op1 = + XlaOp op1 = BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder); - ComputationDataHandle op2 = - BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); + XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder); Array2D expected_array(spec.output_bounds[0], spec.output_bounds[1]); @@ -547,7 +543,7 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, ::testing::ValuesIn(kR2ImplicitBroadcastTestCases)); XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); @@ -558,7 +554,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); @@ -569,7 +565,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -582,7 +578,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -595,7 +591,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); @@ -608,7 +604,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1_0 = b.ConstantR1({1000, 2000}); auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); @@ -629,7 +625,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto r1_0 = b.ConstantR1({1000, 2000}); auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); @@ -652,7 +648,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2]) // results in a shape incompatible with the lhs [2, 3, 1]. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), b.ConstantLiteral(*Literal::CreateR3( @@ -667,7 +663,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 2.0}}), b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); @@ -680,7 +676,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { // Test invalid broadcasting with [1, 2] and [2, 3] inputs. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Add(b.ConstantR2({{1.0, 2.0}}), b.ConstantR2({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 6ebbf719183..51b9f0d3e33 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0(42.0), *result, + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { @@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, - error_spec_); + error_spec_)); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { @@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - LiteralView::Create(*result, {0}), error_spec_); + LiteralSlice(*result, {0}), error_spec_)); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - LiteralView::Create(*result, {1}), error_spec_); + LiteralSlice(*result, {1}), error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { @@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { @@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, - error_spec_); + EXPECT_TRUE( + LiteralTestUtil::Near(*Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), + *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { @@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), - *result, error_spec_); + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { @@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, - error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array), + *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { @@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { @@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR4FromArray4D(expected), *result, error_spec_)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index eac2eb286c3..53f2c3bfbfc 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -4,7 +4,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") -all_backends = ["cpu", "cpu_parallel", "gpu"] + plugins.keys() +all_backends = ["cpu", "gpu"] + plugins.keys() def filter_backends(backends): """Removes "gpu" from a backend list if CUDA is not enabled. @@ -39,10 +39,10 @@ def xla_test(name, **kwargs): """Generates cc_test targets for the given XLA backends. - This rule generates a cc_test target for one or more XLA backends and also - a platform-agnostic cc_library rule. The arguments are identical to cc_test - with two additions: 'backends' and 'backend_args'. 'backends' specifies the - backends to generate tests for ("cpu", "cpu_parallel", "gpu"), and + This rule generates a cc_test target for one or more XLA backends and also a + platform-agnostic cc_library rule. The arguments are identical to cc_test with + two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "gpu"), and 'backend_args'/'backend_tags' specifies backend-specific args parameters to use when generating the cc_test. @@ -90,9 +90,9 @@ def xla_test(name, deps: Dependencies of the target. xla_test_library_deps: If set, the generated test targets will depend on the respective cc_libraries generated by the xla_test_library rule. - backends: A list of backends to generate tests for. Supported - values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will - be generated for all supported backends. + backends: A list of backends to generate tests for. Supported values: "cpu", + "gpu". If this list is empty, the test will be generated for all supported + backends. blacklisted_backends: A list of backends to NOT generate tests for. args: Test arguments for the target. tags: Tags for the target. @@ -128,10 +128,6 @@ def xla_test(name, if backend == "cpu": backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - elif backend == "cpu_parallel": - backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - this_backend_args += ["--xla_backend_extra_options=\"xla_cpu_parallel\""] elif backend == "gpu": backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] @@ -201,7 +197,7 @@ def xla_test_library(name, hdrs: Headers for the target. deps: Dependencies of the target. backends: A list of backends to generate libraries for. - Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the + Supported values: "cpu", "gpu". If this list is empty, the library will be generated for all supported backends. """ @@ -210,7 +206,7 @@ def xla_test_library(name, for backend in filter_backends(backends): this_backend_copts = [] - if backend in ["cpu", "cpu_parallel", "gpu"]: + if backend in ["cpu", "gpu"]: backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] elif backend in plugins: backend_deps = plugins[backend]["deps"] diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 5e42365ae38..5fd33b50c94 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -32,16 +32,16 @@ namespace { class CallOpTest : public ClientLibraryTestBase { protected: - Computation CreateR0F32IdentityComputation() { - ComputationBuilder builder(client_, "Identity"); + XlaComputation CreateR0F32IdentityComputation() { + XlaBuilder builder("Identity"); builder.Parameter(0, r0f32_, "x"); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); return build_status.ConsumeValueOrDie(); } - Computation CreateR1S0F32AdditionComputation() { - ComputationBuilder builder(client_, "Addition"); + XlaComputation CreateR1S0F32AdditionComputation() { + XlaBuilder builder("Addition"); auto x = builder.Parameter(0, r1s0f32_, "x"); auto y = builder.Parameter(1, r1s0f32_, "y"); builder.Add(x, y); @@ -50,8 +50,8 @@ class CallOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR1S2F32AdditionComputation() { - ComputationBuilder builder(client_, "Addition"); + XlaComputation CreateR1S2F32AdditionComputation() { + XlaBuilder builder("Addition"); auto x = builder.Parameter(0, r1s2f32_, "x"); auto y = builder.Parameter(1, r1s2f32_, "y"); builder.Add(x, y); @@ -60,8 +60,8 @@ class CallOpTest : public ClientLibraryTestBase { return build_status.ConsumeValueOrDie(); } - Computation CreateR0F32TupleComputation() { - ComputationBuilder builder(client_, "Tuple"); + XlaComputation CreateR0F32TupleComputation() { + XlaBuilder builder("Tuple"); builder.Tuple({builder.Parameter(0, r0f32_, "x")}); auto build_status = builder.Build(); EXPECT_IS_OK(build_status.status()); @@ -74,8 +74,8 @@ class CallOpTest : public ClientLibraryTestBase { }; XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR0F32IdentityComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR0F32IdentityComputation(); auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); builder.Call(callee, {constant}); @@ -83,8 +83,8 @@ XLA_TEST_F(CallOpTest, CallR0F32IdentityScalar) { } XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR1S0F32AdditionComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR1S0F32AdditionComputation(); auto x = builder.ConstantLiteral(*Literal::CreateR1({})); auto y = builder.ConstantLiteral(*Literal::CreateR1({})); builder.Call(callee, {x, y}); @@ -93,8 +93,8 @@ XLA_TEST_F(CallOpTest, CallR1S0F32AddArray) { } XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR1S2F32AdditionComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR1S2F32AdditionComputation(); auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); builder.Call(callee, {x, y}); @@ -103,23 +103,23 @@ XLA_TEST_F(CallOpTest, CallR1S2F32AddArray) { } XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { - ComputationBuilder builder(client_, "inner"); + XlaBuilder builder("inner"); { auto x = builder.Parameter(0, r0f32_, "x"); builder.Add(x, builder.ConstantR0(1.0)); } - TF_ASSERT_OK_AND_ASSIGN(Computation inner, builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation inner, builder.Build()); - ComputationBuilder builder2(client_, "outer"); + XlaBuilder builder2("outer"); { auto x = builder2.Parameter(0, r0f32_, "x"); x = builder2.Call(inner, {x}); x = builder2.Call(inner, {x}); x = builder2.Call(inner, {x}); } - TF_ASSERT_OK_AND_ASSIGN(Computation outer, builder2.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation outer, builder2.Build()); - ComputationBuilder builder3(client_, "outermost"); + XlaBuilder builder3("outermost"); { auto x = builder3.Parameter(0, r0f32_, "x"); x = builder3.Call(outer, {x}); @@ -134,8 +134,8 @@ XLA_TEST_F(CallOpTest, CallTreeTwoDeepBranchFactorThree) { } XLA_TEST_F(CallOpTest, CallR0F32Tuple) { - ComputationBuilder builder(client_, TestName()); - Computation callee = CreateR0F32TupleComputation(); + XlaBuilder builder(TestName()); + XlaComputation callee = CreateR0F32TupleComputation(); auto elem = Literal::CreateR0(42.0); auto tuple = Literal::MakeTuple({elem.get()}); builder.Call(callee, {builder.ConstantLiteral(*elem)}); diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index f594cc10ac6..660ff0cad56 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -35,7 +35,7 @@ using ::testing::ContainsRegex; class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { - ComputationBuilder builder(client_, "add_two_params"); + XlaBuilder builder("add_two_params"); auto param_literal = Literal::CreateR1({1.1f, 2.2f}); auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); @@ -75,7 +75,7 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { } XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { - ComputationBuilder builder(client_, "add_two_params"); + XlaBuilder builder("add_two_params"); auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1"); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 312d8f284d3..bf8ed4d9fb0 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -32,8 +31,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -59,11 +56,15 @@ se::Platform* GetReferencePlatform() { } // namespace ClientLibraryTestBase::ClientLibraryTestBase( - perftools::gputools::Platform* platform, - const LocalClientOptions& client_options) + se::Platform* platform, const LocalClientOptions& client_options) : client_(GetOrCreateLocalClientOrDie(client_options)), execution_options_(CreateDefaultExecutionOptions()) { CHECK_EQ(platform, client_options.platform()); + + LocalClientOptions ref_options; + ref_options.set_platform(GetReferencePlatform()); + ref_client_ = GetOrCreateLocalClientOrDie(ref_options); + // Disabling constant_folding so that tests (usually written using Constants) // will exercise the intended code paths, instead of being constant folded. // @@ -92,27 +93,13 @@ string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } -template StatusOr> ClientLibraryTestBase::Execute( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments) { + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { // Build the computation, as a convenience. TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); return client_->Execute(computation, arguments, &execution_options_); } -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout) { - ExecutionOptions execution_options = execution_options_; - if (shape_with_output_layout != nullptr) { - *execution_options.mutable_shape_with_output_layout() = - *shape_with_output_layout; - } - return client_->ExecuteAndTransfer(computation, arguments, - &execution_options); -} - StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, @@ -126,17 +113,6 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } -template <> -StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout) { - // Build the computation, as a convenience. - TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); -} - -template <> StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { @@ -155,22 +131,11 @@ ClientLibraryTestBase::ExecuteAndTransferReference( *execution_options.mutable_shape_with_output_layout() = *shape_with_output_layout; } + execution_options.clear_device_handles(); return ref_client_->ExecuteAndTransfer(computation, arguments, &execution_options); } -std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return Execute(builder, arguments).ConsumeValueOrDie(); -} - -std::unique_ptr ClientLibraryTestBase::ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie(); -} - string ClientLibraryTestBase::ExecuteToString( XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { auto computation_status = builder->Build(); @@ -188,53 +153,32 @@ string ClientLibraryTestBase::ExecuteToString( } } -string ClientLibraryTestBase::ExecuteToString( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - auto computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status().ToString(); - } - auto computation = computation_status.ConsumeValueOrDie(); - - auto result = - client_->ExecuteAndTransfer(computation, arguments, &execution_options_); - if (!result.ok()) { - return result.status().ToString(); - } else { - return result.ValueOrDie()->ToString(); - } -} - void ClientLibraryTestBase::ComputeAndCompareR1( - ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, + XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, shape_with_layout)); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments, error, shape_with_layout)); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( + const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output) { @@ -255,12 +199,11 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( "Test with output layout: ", ShapeUtil::HumanStringWithLayout(layout))); } while (std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( + const xla::XlaComputation& computation, const Literal& /*expected*/, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, @@ -270,8 +213,8 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( // This is a recursive function. It's an std::function instead of a lambda // because it needs to capture itself. The index is the index of the argument // to try all layouts for. - std::function choose; - choose = [&, this](int64 index) -> tensorflow::Status { + std::function choose; + choose = [&, this](int64 index) -> Status { if (index < arguments.size()) { // Try out all layouts for the operand. TF_ASSIGN_OR_RETURN(auto literal, @@ -284,7 +227,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( TF_RETURN_IF_ERROR(choose(index + 1)); arguments_with_layout.pop_back(); layout_strings.pop_back(); - return tensorflow::Status::OK(); + return Status::OK(); } std::vector minor_to_major(ShapeUtil::Rank(literal->shape())); @@ -302,7 +245,7 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( layout_strings.pop_back(); } while ( std::next_permutation(minor_to_major.begin(), minor_to_major.end())); - return tensorflow::Status::OK(); + return Status::OK(); } // Every argument has an assigned layout. @@ -317,34 +260,14 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::strings::StrAppend(&error_message, str, " "); } verify_output(*actual, error_message); - return tensorflow::Status::OK(); + return Status::OK(); }; return choose(0); } -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice /*arguments*/, - const std::function& /*verify_output*/) { - return Unimplemented("not yet implemented for XlaComputation"); -} - -tensorflow::Status -ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( - const xla::XlaComputation& /*computation*/, const Literal& /*expected*/, - tensorflow::gtl::ArraySlice /*arguments*/, - const std::function& /*verify_output*/, - const Shape* /*output_with_layout*/) { - return Unimplemented("not yet implemented for XlaComputation"); -} - -template -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -371,7 +294,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -385,7 +308,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -397,13 +320,12 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual)); + return Status::OK(); } -template -tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, +Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments_passed_in, ErrorSpec error, const Shape* shape_with_layout) { std::vector arguments(arguments_passed_in.begin(), @@ -424,7 +346,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( std::unique_ptr converted_expected; Shape layout_shape; if (use_bfloat16_) { - converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + converted_expected = Literal::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; @@ -438,7 +360,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error)) + << error_message; }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( @@ -450,12 +373,12 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); - return tensorflow::Status::OK(); + EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error)); + return Status::OK(); } void ClientLibraryTestBase::ComputeAndCompareR1U8( - ComputationBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, tensorflow::StringPiece expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -473,9 +396,8 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( EXPECT_EQ(expected, actual->GetR1U8AsString()); } -template void ClientLibraryTestBase::ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -483,12 +405,11 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual)); } -template void ClientLibraryTestBase::ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { auto actual_status = ExecuteAndTransfer(builder, arguments); EXPECT_IS_OK(actual_status.status()); @@ -496,61 +417,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(expected, *actual, error); -} - -void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); - EXPECT_IS_OK(status_or_data); - if (!status_or_data.ok()) { - return; - } - std::unique_ptr reference, result; - std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); -} - -void ClientLibraryTestBase::ComputeAndCompare( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { - auto status_or_data = ComputeValueAndReference(builder, operand, arguments); - EXPECT_IS_OK(status_or_data); - if (!status_or_data.ok()) { - return; - } - std::unique_ptr reference, result; - std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); -} - -StatusOr, std::unique_ptr>> -ClientLibraryTestBase::ComputeValueAndReference( - ComputationBuilder* builder, const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments) { - // Transfer the arguments to the executor service. We put the unique_ptr's - // into a vector to keep the data alive on the service until the end of this - // function. - std::vector> argument_data; - for (const auto& arg : arguments) { - TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); - argument_data.push_back(std::move(data)); - } - - // Create raw pointers to the GlobalData for the rest of the call stack. - std::vector argument_data_ptr; - std::transform( - argument_data.begin(), argument_data.end(), - std::back_inserter(argument_data_ptr), - [](const std::unique_ptr& data) { return data.get(); }); - - TF_ASSIGN_OR_RETURN( - auto reference, - builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); - TF_ASSIGN_OR_RETURN(auto result, - ExecuteAndTransfer(builder, argument_data_ptr)); - return std::make_pair(std::move(reference), std::move(result)); + EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -562,7 +429,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*reference, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result)); } void ClientLibraryTestBase::ComputeAndCompare( @@ -575,7 +442,7 @@ void ClientLibraryTestBase::ComputeAndCompare( } std::unique_ptr reference, result; std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*reference, *result, error); + EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error)); } StatusOr, std::unique_ptr>> @@ -616,8 +483,8 @@ ClientLibraryTestBase::ComputeValueAndReference( return std::make_pair(std::move(reference), std::move(result)); } -Computation ClientLibraryTestBase::CreateScalarRelu() { - ComputationBuilder builder(client_, "relu"); +XlaComputation ClientLibraryTestBase::CreateScalarRelu() { + XlaBuilder builder("relu"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto z_value = builder.Parameter(0, shape, "z_value"); auto zero = use_bfloat16_ @@ -629,8 +496,8 @@ Computation ClientLibraryTestBase::CreateScalarRelu() { return computation_status.ConsumeValueOrDie(); } -Computation ClientLibraryTestBase::CreateScalarMax() { - ComputationBuilder builder(client_, "max"); +XlaComputation ClientLibraryTestBase::CreateScalarMax() { + XlaBuilder builder("max"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto x = builder.Parameter(0, shape, "x"); auto y = builder.Parameter(1, shape, "y"); @@ -640,8 +507,8 @@ Computation ClientLibraryTestBase::CreateScalarMax() { return computation_status.ConsumeValueOrDie(); } -Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { - ComputationBuilder builder(client_, "relu_sensitivity"); +XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { + XlaBuilder builder("relu_sensitivity"); auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); auto activation = builder.Parameter(0, shape, "activation"); auto backprop = builder.Parameter(1, shape, "backprop"); @@ -682,14 +549,6 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } -ComputationDataHandle ClientLibraryTestBase::AddParam( - const Literal& argument, ComputationBuilder* builder) { - ComputationDataHandle data_handle; - arguments_.push_back(CreateParameterAndTransferLiteral( - arguments_.size(), argument, "", builder, &data_handle)); - return data_handle; -} - XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { XlaOp data_handle; @@ -698,59 +557,39 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, return data_handle; } -ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( - const Literal& literal, ComputationBuilder* builder) { - return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); -} - XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { return builder->ConstantLiteral( - use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); + use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal); } -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout); +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, + const Literal& literal, + const string& name, + XlaBuilder* builder, + XlaOp* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareLiteral( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - -template void ClientLibraryTestBase::ComputeAndCompareTuple( - XlaBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - -template StatusOr> ClientLibraryTestBase::Execute( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - -template StatusOr> ClientLibraryTestBase::Execute( - XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = Literal::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index b3212dd2282..0499fec5898 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -25,10 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -64,11 +63,10 @@ std::vector ExpandUseBfloat16( // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: - explicit ClientLibraryTestBase( - perftools::gputools::Platform* platform = nullptr); + explicit ClientLibraryTestBase(se::Platform* platform = nullptr); // Creates a new ClientLibraryTestBase with custom client options. - ClientLibraryTestBase(perftools::gputools::Platform* platform, + ClientLibraryTestBase(se::Platform* platform, const LocalClientOptions& client_options); // Returns the name of the test currently being run. @@ -92,21 +90,11 @@ class ClientLibraryTestBase : public ::testing::Test { // Convenience methods for building and running a computation with the member // execution options. Modify execution_options_ in your test if you want to // customize the options. - template StatusOr> Execute( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments); - - // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once - // the migration to XlaBuilder is complete. - - template - StatusOr> ExecuteAndTransfer( - BuilderT* builder, tensorflow::gtl::ArraySlice arguments, - const Shape* shape_with_output_layout = nullptr); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); StatusOr> ExecuteAndTransfer( - const Computation& computation, - tensorflow::gtl::ArraySlice arguments, + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); StatusOr> ExecuteAndTransfer( @@ -122,128 +110,108 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); - // Convenience OrDie variants of above methods. - std::unique_ptr ExecuteOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteAndTransferOrDie( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); - // Run a computation and return its value as a string. If an error // occurs, then instead return the error as a string. string ExecuteToString(XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments); - string ExecuteToString(ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments); // Convenience methods for building and running a computation, transferring // the result, and comparing it to the expected value(s). Methods are // templated on the native host type which maps to specific XLA types (See - // ComputationBuilder/XlaBuilder for details). For each rank, two forms are + // XlaBuilder for details). For each rank, two forms are // provided: one for floating point types with an ErrorSpec parameter, and one // for integral types without the ErrorSpec parameter. - template - void ComputeAndCompareR0(BuilderT* builder, NativeT expected, + template + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR0(BuilderT* builder, NativeT expected, + template + void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR1(BuilderT* builder, + template + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR1(BuilderT* builder, + template + void ComputeAndCompareR1(XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // As above, but uses a bitmap to hold the predicate vector to avoid // deficiencies of vector. - void ComputeAndCompareR1(ComputationBuilder* builder, + void ComputeAndCompareR1(XlaBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, + template + void ComputeAndCompareR2(XlaBuilder* builder, + const Array2D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR2(BuilderT* builder, const Array2D& expected, + template + void ComputeAndCompareR2(XlaBuilder* builder, + const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, + template + void ComputeAndCompareR3(XlaBuilder* builder, + const Array3D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR3(BuilderT* builder, const Array3D& expected, + template + void ComputeAndCompareR3(XlaBuilder* builder, + const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - template - void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, + template + void ComputeAndCompareR4(XlaBuilder* builder, + const Array4D& expected, tensorflow::gtl::ArraySlice arguments); - template - void ComputeAndCompareR4(BuilderT* builder, const Array4D& expected, + template + void ComputeAndCompareR4(XlaBuilder* builder, + const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Build and run the computation and compare the result with the given // literal. shape_with_layout indicates the result layout to request when // calling Execute. - template void ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - template void ComputeAndCompareLiteral( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // ComputeAndCompare variant which returns an error status. - template - tensorflow::Status ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + Status ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout = nullptr); - template - tensorflow::Status ComputeAndCompareLiteralWithStatus( - BuilderT* builder, const Literal& expected, + Status ComputeAndCompareLiteralWithStatus( + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout = nullptr); // Compare the result of the computation to a strings. In XLA strings are // represented using rank-1 U8 shapes. void ComputeAndCompareR1U8( - ComputationBuilder* builder, tensorflow::StringPiece expected, + XlaBuilder* builder, tensorflow::StringPiece expected, tensorflow::gtl::ArraySlice arguments); // Convenience method for running a built computation, transferring the // result, and comparing it to the expected tuple literal. - template void ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments); - template void ComputeAndCompareTuple( - BuilderT* builder, const Literal& expected, + XlaBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error); - // Convenience method for running a built computation and comparing the result - // with the HloEvaluator. - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments); - void ComputeAndCompare(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments, - ErrorSpec error); - // Convenience method for running a built computation and comparing the result // with the reference result. void ComputeAndCompare(XlaBuilder* builder, @@ -253,9 +221,9 @@ class ClientLibraryTestBase : public ::testing::Test { ErrorSpec error); // Create scalar operations for use in reductions. - Computation CreateScalarRelu(); - Computation CreateScalarMax(); - Computation CreateScalarReluSensitivity(); + XlaComputation CreateScalarRelu(); + XlaComputation CreateScalarMax(); + XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -295,34 +263,25 @@ class ClientLibraryTestBase : public ::testing::Test { // server, then stores into "data_handle" the global handle for that // parameter. When the use_bfloat16 flag is set but the literal has F32 // elements, the literal will be converted to BF16 before being transferred. - template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - BuilderT* builder, HandleT* data_handle); + XlaBuilder* builder, XlaOp* data_handle); // As above, but the caller can specify the device that the literal is // transferred to. If device_handle is nullptr, the literal will be // transferred to the default device. - template std::unique_ptr CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, BuilderT* builder, - HandleT* data_handle); + const DeviceHandle* device_handle, XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction and sets the value that will be passed to // the computation as specified. This function must be used for all parameters // or none and no parameters must be passed when invoking the computation if // using this mechanism. If using this mechanism, then each parameter must be // set exactly once. The first added parameter gets index 0, then 1 and so on. - ComputationDataHandle AddParam(const Literal& argument, - ComputationBuilder* builder); XlaOp AddParam(const Literal& argument, XlaBuilder* builder); - template - ComputationDataHandle AddParam(const Array& argument, - ComputationBuilder* builder) { - return AddParam(*Literal::CreateFromArray(argument), builder); - } template XlaOp AddParam(const Array& argument, XlaBuilder* builder) { return AddParam(*Literal::CreateFromArray(argument), builder); @@ -331,18 +290,11 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a constant instruction with the given literal. When the // use_bfloat16 flag is set but the literal has F32 elements, the elements // will be converted to BF16s. - ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, - ComputationBuilder* builder); XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); // Creates a constant instruction with the given array. When the use_bfloat16 // flag is set but the array has float elements, the elements will be // converted to bfloat16s. - template - ComputationDataHandle CreateConstantFromArray(const Array& array, - ComputationBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); - } template XlaOp CreateConstantFromArray(const Array& array, @@ -351,13 +303,6 @@ class ClientLibraryTestBase : public ::testing::Test { } // Same as CreateConstantFromArray, but for scalars. - template - ComputationDataHandle CreateConstantFromScalar(NativeT value, - ComputationBuilder* builder) { - return CreateConstantFromLiteral(*Literal::CreateR0(value), - builder); - } - template XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) { return CreateConstantFromLiteral(*Literal::CreateR0(value), @@ -372,12 +317,12 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR0Parameter(NativeT value, int64 parameter_number, const string& name, - BuilderT* builder, - HandleT* data_handle); + XlaBuilder* builder, + XlaOp* data_handle); // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. @@ -387,10 +332,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that @@ -401,10 +346,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that @@ -415,10 +360,10 @@ class ClientLibraryTestBase : public ::testing::Test { // // When the use_bfloat16 flag is set but NativeT is float, the data will be // converted to bfloat16. - template + template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle); + const string& name, XlaBuilder* builder, XlaOp* data_handle); // Getter and setter for the use_bfloat16 flag, which indicates whether to run // tests with all float-type input/output converted to bfloat16. @@ -433,40 +378,18 @@ class ClientLibraryTestBase : public ::testing::Test { ExecutionOptions execution_options_; private: - // Build and run the computation with all permutations of output layouts. - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( - const xla::Computation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const std::function& verify_output); - // Build and run the computation with all permutations of layouts of all input - // arguments. - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( - const xla::Computation& computation, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, - const std::function& verify_output, - const Shape* output_with_layout = nullptr); - - tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts( + Status ComputeAndCompareLiteralWithAllOutputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output); - tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts( + Status ComputeAndCompareLiteralWithAllInputLayouts( const xla::XlaComputation& computation, const Literal& expected, tensorflow::gtl::ArraySlice arguments, const std::function& verify_output, const Shape* output_with_layout = nullptr); - // Executes the computation and calculates the expected reference value using - // the HloEvaluator. Returns two literals in the order of (expected, actual). - StatusOr, std::unique_ptr>> - ComputeValueAndReference(ComputationBuilder* builder, - const ComputationDataHandle& operand, - tensorflow::gtl::ArraySlice arguments); - // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, // actual). @@ -482,9 +405,9 @@ class ClientLibraryTestBase : public ::testing::Test { std::vector> arguments_; }; -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - BuilderT* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR0(expected); @@ -492,9 +415,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR0( - BuilderT* builder, NativeT expected, + XlaBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -508,9 +431,9 @@ void ClientLibraryTestBase::ComputeAndCompareR0( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - BuilderT* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR1(expected); @@ -518,9 +441,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR1( - BuilderT* builder, tensorflow::gtl::ArraySlice expected, + XlaBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -534,9 +457,9 @@ void ClientLibraryTestBase::ComputeAndCompareR1( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - BuilderT* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); @@ -544,9 +467,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR2( - BuilderT* builder, const Array2D& expected, + XlaBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -560,9 +483,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - BuilderT* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); @@ -570,9 +493,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR3( - BuilderT* builder, const Array3D& expected, + XlaBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -586,9 +509,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( arguments, error); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - BuilderT* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); @@ -596,9 +519,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments); } -template +template void ClientLibraryTestBase::ComputeAndCompareR4( - BuilderT* builder, const Array4D& expected, + XlaBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || @@ -612,13 +535,13 @@ void ClientLibraryTestBase::ComputeAndCompareR4( arguments, error); } -template +template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, - BuilderT* builder, HandleT* data_handle) { + XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -626,13 +549,13 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -640,13 +563,13 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -654,13 +577,13 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( return data; } -template +template std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, - const string& name, BuilderT* builder, HandleT* data_handle) { + const string& name, XlaBuilder* builder, XlaOp* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); if (use_bfloat16_ && literal->shape().element_type() == F32) { - literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + literal = Literal::ConvertF32ToBF16(*literal); } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -693,37 +616,6 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( return result; } -template -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number, - const Literal& literal, - const string& name, - BuilderT* builder, - HandleT* data_handle) { - return CreateParameterAndTransferLiteral(parameter_number, literal, name, - nullptr, builder, data_handle); -} - -template -std::unique_ptr -ClientLibraryTestBase::CreateParameterAndTransferLiteral( - int64 parameter_number, const Literal& literal, const string& name, - const DeviceHandle* device_handle, BuilderT* builder, - HandleT* data_handle) { - const Literal* param_literal = &literal; - std::unique_ptr converted_literal; - if (use_bfloat16_) { - converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); - param_literal = converted_literal.get(); - } - std::unique_ptr data = - client_->TransferToServer(*param_literal, device_handle) - .ConsumeValueOrDie(); - *data_handle = - builder->Parameter(parameter_number, param_literal->shape(), name); - return data; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 32e2f2c0848..08671cf6244 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" @@ -39,7 +38,7 @@ namespace { class ClientTest : public ClientLibraryTestBase {}; XLA_TEST_F(ClientTest, ExecuteWithLayout) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; for (const std::vector& execute_layout : layouts) { @@ -63,15 +62,15 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) { TF_ASSERT_OK_AND_ASSIGN( auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})}); @@ -92,9 +91,9 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - LiteralView::Create(*result, {0})); + LiteralSlice(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - LiteralView::Create(*result, {1})); + LiteralSlice(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -109,8 +108,7 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -XLA_TEST_F(ClientTest, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { +XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) { XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); @@ -144,7 +142,7 @@ XLA_TEST_F(ClientTest, auto result_literal, client_->Transfer(*results[0], &expected_result->shape())); - LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 0f780fa87ef..50a00696486 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -39,7 +39,7 @@ namespace { class CompilationCacheTest : public ClientLibraryTestBase { public: void ExecuteComputationR0F32( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, float expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; @@ -49,13 +49,13 @@ class CompilationCacheTest : public ClientLibraryTestBase { /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR0(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } void ExecuteComputationR2F32( - const Computation& computation, + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, std::initializer_list> expected_result, bool expect_cache_hit) { @@ -66,25 +66,28 @@ class CompilationCacheTest : public ClientLibraryTestBase { .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), - *result, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2(expected_result), *result, error_spec_)); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } ErrorSpec error_spec_{0.0001}; }; -XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { - ComputationBuilder builder(client_, TestName()); +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_ComputationCalledMultipleTimes) { + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(42.0)); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, + DISABLED_ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = client_->TransferToServer(*Literal::CreateR0(42.0f)) .ConsumeValueOrDie(); @@ -95,9 +98,9 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { client_->TransferToServer(*Literal::CreateR0(456.0f)) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param")); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation, {data_42.get()}, -42.0, /*expect_cache_hit=*/false); @@ -109,19 +112,20 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, MultipleComputations) { - ComputationBuilder builder_neg(client_, TestName() + "_neg"); +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_MultipleComputations) { + XlaBuilder builder_neg(TestName() + "_neg"); builder_neg.Neg(builder_neg.ConstantR0(42.0)); - Computation computation_neg = builder_neg.Build().ConsumeValueOrDie(); + XlaComputation computation_neg = builder_neg.Build().ConsumeValueOrDie(); - ComputationBuilder builder_exp(client_, TestName() + "_exp"); + XlaBuilder builder_exp(TestName() + "_exp"); builder_exp.Exp(builder_exp.ConstantR0(1.0)); - Computation computation_exp = builder_exp.Build().ConsumeValueOrDie(); + XlaComputation computation_exp = builder_exp.Build().ConsumeValueOrDie(); - ComputationBuilder builder_add(client_, TestName() + "_add"); + XlaBuilder builder_add(TestName() + "_add"); builder_add.Add(builder_add.ConstantR0(2.0), builder_add.ConstantR0(3.0)); - Computation computation_add = builder_add.Build().ConsumeValueOrDie(); + XlaComputation computation_add = builder_add.Build().ConsumeValueOrDie(); ExecuteComputationR0F32(computation_neg, {}, -42.0, /*expect_cache_hit=*/false); @@ -133,7 +137,8 @@ XLA_TEST_F(CompilationCacheTest, MultipleComputations) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { +// TODO(b/74197823): Disabled because there is no cache in the new design. +XLA_TEST_F(CompilationCacheTest, DISABLED_DifferentParameterLayouts) { // Create two GlobalData arrays with the same shape but different // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache @@ -148,9 +153,9 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); ExecuteComputationR2F32(computation, {colmaj_handle.get()}, {{1.0f, 2.0f}, {3.0f, 4.0f}}, @@ -169,32 +174,5 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { /*expect_cache_hit=*/true); } -XLA_TEST_F(CompilationCacheTest, MutatedComputation) { - // Build a computation, execute it, then mutate it. The mutated computation - // should not be in the cache until it is run once. This must be done through - // the stub interface because Computations built from ComputationBuilder are - // immutable. - ComputationBuilder builder(client_, TestName()); - auto neg = builder.Neg(builder.ConstantR0(42.0)); - Computation computation = builder.Build().ConsumeValueOrDie(); - - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true); - - BinaryOpRequest request; - request.set_binop(BINOP_ADD); - *request.mutable_lhs() = neg; - *request.mutable_rhs() = neg; - OpRequest op_request; - *op_request.mutable_computation() = computation.handle(); - *op_request.mutable_binary_op_request() = request; - OpResponse response; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - ASSERT_TRUE(s.ok()); - - ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false); - ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index c15d808f1dd..ba22530f1cf 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -18,8 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -47,16 +45,14 @@ ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; class ComputeConstantTest : public ::testing::Test { public: - explicit ComputeConstantTest( - perftools::gputools::Platform* platform = nullptr) + explicit ComputeConstantTest(se::Platform* platform = nullptr) : platform_(platform) {} string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - Client* ClientOrDie(::perftools::gputools::Platform* platform, - ClientType client_type) { + Client* ClientOrDie(se::Platform* platform, ClientType client_type) { if (client_type == ClientType::kLocal) { StatusOr result = ClientLibrary::GetOrCreateLocalClient(platform); @@ -90,24 +86,13 @@ class ComputeConstantTest : public ::testing::Test { return literal->Get({}); } - template - StatusOr ComputeConstantScalar( - Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice parameters = {}) { - TF_ASSIGN_OR_RETURN(auto literal, - builder->ComputeConstant( - operand, /*output_layout=*/nullptr, parameters)); - return literal->Get({}); - } - bool IsConstant(const XlaOp& operand, XlaBuilder* builder) { StatusOr result = builder->IsConstant(operand); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } - perftools::gputools::Platform* platform_; + se::Platform* platform_; }; TEST_F(ComputeConstantTest, ScalarInt32Literal) { @@ -152,26 +137,6 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, Param) { - for (ClientType client_type : client_types) { - Client* client = ClientOrDie(platform_, client_type); - ComputationBuilder b(client, TestName()); - auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); - auto computation = b.Add(param, b.ConstantR0(1.5f)); - - std::vector arguments; - arguments.push_back(std::move(*Literal::CreateR0(42.5f))); - TF_ASSERT_OK_AND_ASSIGN(bool is_constant, - b.IsConstant(computation, arguments.size())); - EXPECT_TRUE(is_constant); - - TF_ASSERT_OK_AND_ASSIGN( - auto value, - ComputeConstantScalar(client, computation, &b, arguments)); - EXPECT_EQ(value, 44.0f); - } -} - TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); @@ -243,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR1({4, 6}); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -257,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { TF_ASSERT_OK_AND_ASSIGN(auto computed, ComputeConstantLiteral(client, computation, &b)); std::unique_ptr expected_literal = Literal::CreateR0(5); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } @@ -279,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) { std::unique_ptr expected_literal = Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(layout)); - LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), - computed->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed); + ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts( + expected_literal->shape(), computed->shape())); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed)); } } } diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 35aa3f6d696..916ffadbc79 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -21,12 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,7 +38,7 @@ class ConstantsTest : public ClientLibraryTestBase { }; TEST_F(ConstantsTest, ZeroCellF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({}); ComputeAndCompareR1(&builder, {}, {}, error_spec_); @@ -48,7 +47,7 @@ TEST_F(ConstantsTest, ZeroCellF32) { TEST_F(ConstantsTest, OneCellF32) { std::vector constant = {2.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); @@ -57,7 +56,7 @@ TEST_F(ConstantsTest, OneCellF32) { TEST_F(ConstantsTest, OneCellS32) { std::vector constant = {2}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}); @@ -66,7 +65,7 @@ TEST_F(ConstantsTest, OneCellS32) { TEST_F(ConstantsTest, OneCellU32) { std::vector constant = {2}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}); @@ -75,7 +74,7 @@ TEST_F(ConstantsTest, OneCellU32) { TEST_F(ConstantsTest, EightCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); @@ -85,14 +84,14 @@ TEST_F(ConstantsTest, SixteenCells) { std::vector constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1(constant); ComputeAndCompareR1(&builder, constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR2FromArray2D(Array2D(0, 2)); ComputeAndCompareR2(&builder, Array2D(0, 2), {}, error_spec_); @@ -102,14 +101,14 @@ TEST_F(ConstantsTest, Small_2x2) { std::unique_ptr> constant = MakeLinspaceArray2D(100.0, 200.0, 2, 2); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR2FromArray2D(*constant); ComputeAndCompareR2(&builder, *constant, {}, error_spec_); } TEST_F(ConstantsTest, Empty_3x0x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto constant = builder.ConstantLiteral( *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); @@ -117,7 +116,7 @@ TEST_F(ConstantsTest, Empty_3x0x2) { } TEST_F(ConstantsTest, Small_2x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array3D array3d({ // x0 x1 {{1.f, 2.f}, // y0 @@ -145,13 +144,13 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { Literal::CreateR4FromArray4D(input_array); { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantLiteral(*input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR4FromArray4D(input_array); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -159,17 +158,18 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantLiteral( *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), Literal::CreateR1({2.0, 42}).get()})); - std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); + std::unique_ptr result = + ExecuteAndTransfer(&builder, {}).ConsumeValueOrDie(); LiteralTestUtil::ExpectR2Near( - {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + {{1.0}, {2.0}}, LiteralSlice(*result, {0}), error_spec_); LiteralTestUtil::ExpectR1Near( - {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); + {2.0, 42.0}, LiteralSlice(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 0842a8918bc..4ef0a77884c 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -36,7 +36,7 @@ namespace { class ConvertTest : public ClientLibraryTestBase { public: - explicit ConvertTest(perftools::gputools::Platform* platform = nullptr) + explicit ConvertTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -44,7 +44,7 @@ class ConvertTest : public ClientLibraryTestBase { }; TEST_F(ConvertTest, ConvertR1S32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.ConvertElementType(a, S32); @@ -53,7 +53,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1S32) { } TEST_F(ConvertTest, ConvertR1F32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.0f, 64.0f}); builder.ConvertElementType(a, F32); @@ -62,7 +62,7 @@ TEST_F(ConvertTest, ConvertR1F32ToR1F32) { } TEST_F(ConvertTest, ConvertR1S32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42, 64}); builder.ConvertElementType(a, F32); @@ -71,7 +71,7 @@ TEST_F(ConvertTest, ConvertR1S32ToR1F32) { } TEST_F(ConvertTest, ConvertR1PREDToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, true}); builder.ConvertElementType(a, S32); @@ -80,7 +80,7 @@ TEST_F(ConvertTest, ConvertR1PREDToR1S32) { } TEST_F(ConvertTest, ConvertR1PREDToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, true}); builder.ConvertElementType(a, F32); @@ -89,7 +89,7 @@ TEST_F(ConvertTest, ConvertR1PREDToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); builder.ConvertElementType(a, F32); @@ -98,7 +98,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) { } TEST_F(ConvertTest, ConvertR1F32ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({42.6, 64.4}); builder.ConvertElementType(a, S32); @@ -107,7 +107,7 @@ TEST_F(ConvertTest, ConvertR1F32ToR1S32) { } XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{ -9223371216516022272, -2, @@ -160,7 +160,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000000, 0x80000001, 0x80000002, 0x80000003, 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; @@ -179,7 +179,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0.0f, 1.0f, 16777216.0f, 16777218.0f, 2147483647.0f, 4294967040.0f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); @@ -197,7 +197,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1U32) { } XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); @@ -214,7 +214,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { } XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector arg{0, 1, 0x1000, -1, -0x1000}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); @@ -231,7 +231,7 @@ XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Test cases from compiler_rt library. std::vector arg{0.0f, 0.5f, @@ -268,7 +268,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, F32); @@ -277,7 +277,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, S32); @@ -286,7 +286,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) { } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32, 64}); builder.ConvertElementType(a, U32); @@ -295,7 +295,7 @@ XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) { } XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32.0f, 64.0f}); builder.ConvertElementType(a, F64); @@ -304,7 +304,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) { } XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({32.0, 64.0}); builder.ConvertElementType(a, F32); @@ -313,7 +313,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) { } TEST_F(ConvertTest, ConvertS32Extremes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1( {std::numeric_limits::min(), std::numeric_limits::max()}); builder.ConvertElementType(a, F32); @@ -325,7 +325,7 @@ TEST_F(ConvertTest, ConvertS32Extremes) { } TEST_F(ConvertTest, ConvertMapToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->ConvertElementType(param, S32); @@ -337,7 +337,7 @@ TEST_F(ConvertTest, ConvertMapToS32) { } TEST_F(ConvertTest, ConvertMapToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto b = builder.CreateSubBuilder("convert"); auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->ConvertElementType(param, F32); @@ -354,7 +354,7 @@ TEST_F(ConvertTest, ConvertMapToF32) { // input -> convert -> reshape // the new convert should have the same element type as the old convert. TEST_F(ConvertTest, ConvertReshape) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR1({42}); auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); builder.ConvertElementType(reshape, F32); @@ -393,7 +393,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { std::unique_ptr dot_lhs_handle, client_->TransferToServer(*Literal::CreateR1(input))); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConvertElementType( builder.Parameter( 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), @@ -413,7 +413,7 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { std::unique_ptr dot_lhs_handle, client_->TransferToServer(*Literal::CreateR1(input))); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConvertElementType( builder.Parameter( 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), @@ -424,28 +424,28 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { } XLA_TEST_F(ConvertTest, ConvertC64ToC64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector x = {{42.0f, 64.0f}}; builder.ConvertElementType(builder.ConstantR1(x), C64); ComputeAndCompareR1(&builder, x, {}, ErrorSpec(0.0001)); } XLA_TEST_F(ConvertTest, ConvertS64S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector x = {{-42, 64}}; builder.ConvertElementType(builder.ConstantR1(x), S64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64U64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector x = {{42, 64}}; builder.ConvertElementType(builder.ConstantR1(x), U64); ComputeAndCompareR1(&builder, x, {}); } XLA_TEST_F(ConvertTest, ConvertU64S64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector unsigned_x = {{42, UINT64_MAX}}; builder.ConvertElementType(builder.ConstantR1(unsigned_x), S64); std::vector signed_x = {{42, -1}}; @@ -453,7 +453,7 @@ XLA_TEST_F(ConvertTest, ConvertU64S64) { } XLA_TEST_F(ConvertTest, ConvertS64U64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector signed_x = {{42, -1, INT64_MIN}}; builder.ConvertElementType(builder.ConstantR1(signed_x), U64); std::vector unsigned_x = { diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 896b34fb6e2..b5a42e30598 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,13 +34,35 @@ limitations under the License. namespace xla { namespace { +StatusOr CreateConvDimensionNumbers( + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, + int64 kernel_output_feature, int64 kernel_input_feature, + int64 kernel_first_spatial, int64 kernel_second_spatial) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(input_batch); + dimension_numbers.set_input_feature_dimension(input_feature); + dimension_numbers.add_input_spatial_dimensions(input_first_spatial); + dimension_numbers.add_input_spatial_dimensions(input_second_spatial); + dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); + dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); + dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); + dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); + dimension_numbers.add_output_spatial_dimensions(output_first_spatial); + dimension_numbers.add_output_spatial_dimensions(output_second_spatial); + TF_RETURN_IF_ERROR(XlaBuilder::Validate(dimension_numbers)); + return dimension_numbers; +} + class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, - 1, 2, 3); + CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -49,8 +71,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, - 2, 2, 3); + CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, 2, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); @@ -59,8 +80,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { // Tests the convolution operation with invalid output dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, - 1, 2, 3); + CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("output are not unique")); @@ -76,14 +96,14 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(*input_array); auto weight = builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); ConvolutionDimensionNumbers dim_nums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(); + XlaBuilder::CreateDefaultConvDimensionNumbers(); // Swap batch_dimension and feature_dimension. int64 old_input_batch_dim = dim_nums.input_batch_dimension(); int64 old_output_batch_dim = dim_nums.output_batch_dimension(); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 9c1145def8c..fea850dc135 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -52,7 +52,7 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase { }; XLA_TEST_F(ConvolutionVariantsTest, Minimal) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Array4D input_array(1, 1, 1, 1, {2}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -67,7 +67,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { } XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Array4D input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -82,7 +82,7 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { } XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(2, 1, 3, 4); input_array.FillWithMultiples(1); @@ -99,7 +99,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 1, {10, 1}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -114,7 +114,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 2, {1, 2}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -129,7 +129,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -144,7 +144,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -159,7 +159,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -174,7 +174,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -189,7 +189,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array( 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 @@ -210,7 +210,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -225,7 +225,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -240,7 +240,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 4, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -255,7 +255,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -270,7 +270,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -285,7 +285,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 1, {1}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -300,7 +300,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -315,7 +315,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -333,7 +333,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 2, 1, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -348,7 +348,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -363,7 +363,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D(input_array); @@ -378,7 +378,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(64); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -398,7 +398,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(16 * 1 * 1 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -419,7 +419,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int bs = 16; constexpr int kx = 2; @@ -450,7 +450,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int kx = 2; constexpr int ky = 2; @@ -482,7 +482,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(16, 1, 8, 8); for (int i0 = 0; i0 < 16; ++i0) { @@ -510,7 +510,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -536,7 +536,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(2 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -562,7 +562,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(32 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -602,7 +602,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D input_array(16, 16, 1, 1); Array4D filter_array(16, 16, 1, 1); @@ -628,7 +628,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 4 * 6); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -640,14 +640,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 2, 2, {3924, 4257, 5922, 6255}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -659,14 +659,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 3 * 4); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -682,8 +682,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, - /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // @@ -693,7 +692,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -705,14 +704,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, -1}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 2, {23, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -724,14 +723,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, 2}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 5, {23, 34, 45, 50, 0}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -743,14 +742,14 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {2, -1}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D expected(1, 1, 1, 5, {0, 1, 12, 23, 34}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -763,7 +762,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {3, 2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); // input: // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] @@ -775,7 +774,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -788,7 +787,7 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-3, -2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); // input: // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] @@ -821,7 +820,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -854,7 +853,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -887,7 +886,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -920,7 +919,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -954,7 +953,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Array4D input_array(bs, iz, iy, ix, input_data); Array4D filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -966,7 +965,7 @@ XLA_TEST_F(ConvolutionVariantsTest, } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1010,7 +1009,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1054,7 +1053,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1095,7 +1094,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector input_data(1 * 2 * 3 * 2); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1147,7 +1146,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); @@ -1166,19 +1165,18 @@ XLA_TEST_F(ConvolutionVariantsTest, // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 1, /*values=*/{1})); auto weights = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 3, /*values=*/{1, 10, 100})); auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvGeneralDilated( - gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 3}}, - /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + builder.ConvGeneralDilated(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); ComputeAndCompareR4(&builder, {{{{100, 0}}}}, {}, error_spec_); } @@ -1187,7 +1185,7 @@ XLA_TEST_F(ConvolutionVariantsTest, // into // BackwardInputConv([1], [1,10,100], padding=(1,1)) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 1, /*values=*/{1})); @@ -1208,7 +1206,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { // However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't // support negative padding on backward convolution yet (b/32744257). XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D( Array4D(1, 1, 1, 3, /*values=*/{1, 2, 3})); @@ -1224,7 +1222,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1240,7 +1238,7 @@ XLA_TEST_F(ConvolutionVariantsTest, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {1, 2}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); @@ -1248,7 +1246,7 @@ XLA_TEST_F(ConvolutionVariantsTest, XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingGreaterThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1266,14 +1264,14 @@ XLA_TEST_F(ConvolutionVariantsTest, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {2, 0}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24}}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1293,14 +1291,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {2, 1}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR3FromArray3D( Array3D(1, 1, 1, /*value=*/1)); @@ -1314,26 +1312,26 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto activations = builder.ConstantR3FromArray3D(Array3D({{{1, 2, 3, 4}}})); auto gradients = builder.ConstantR3FromArray3D(Array3D({{{100, 10, 1}}})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1}, - /*padding=*/{{2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/1)); + auto forward_conv = + builder.ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); builder.Transpose(forward_conv, {0, 1, 2}); ComputeAndCompareR3(&builder, {{{13, 24, 130}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = @@ -1357,7 +1355,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = @@ -1378,7 +1376,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers( + XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index ece7c3b05e7..2b3390ca98c 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -48,7 +49,7 @@ class CopyOpTest : public HloTestBase { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = ExecuteAndTransfer(std::move(module), {}); - LiteralTestUtil::ExpectEqual(literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result)); } void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); @@ -246,13 +247,13 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); auto empty = Literal::CreateFromShape(in_shape); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto param0 = builder.Parameter(0, in_shape, "input"); auto input_data = client_->TransferToServer(*empty).ConsumeValueOrDie(); auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(*empty, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index fe5621e8dc2..bfe688e20d1 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -36,9 +36,8 @@ class DeallocationTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - Computation computation = builder->Build().ConsumeValueOrDie(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) .ConsumeValueOrDie(); @@ -48,7 +47,7 @@ class DeallocationTest : public ClientLibraryTestBase { }; TEST_F(DeallocationTest, DeallocateScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -66,7 +65,7 @@ TEST_F(DeallocationTest, DeallocateScalar) { } TEST_F(DeallocationTest, DeallocateVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -79,7 +78,7 @@ TEST_F(DeallocationTest, DeallocateVector) { } TEST_F(DeallocationTest, DeallocateEmptyVector) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -92,7 +91,7 @@ TEST_F(DeallocationTest, DeallocateEmptyVector) { } XLA_TEST_F(DeallocationTest, DeallocateTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tuple({builder.ConstantR0(42.0), builder.ConstantR1({1.0, 2.0, 3.0})}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -106,7 +105,7 @@ XLA_TEST_F(DeallocationTest, DeallocateTuple) { } XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto element = builder.ConstantR0(42.0); auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), element}); builder.Tuple({element, inner_tuple, element}); @@ -121,7 +120,7 @@ XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) { } XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto inner_tuple = builder.Tuple({builder.ConstantR0(42.0), builder.ConstantR1({1.0, 2.0, 3.0})}); diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index 3ab0ea4ad48..12789fe6653 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -42,9 +42,8 @@ class DeconstructTupleTest : public ClientLibraryTestBase { // Build and execute the given computation then verify the results can be // transferred from the device successfully. std::unique_ptr ExecuteAndCheckTransfer( - ComputationBuilder* builder, - tensorflow::gtl::ArraySlice arguments) { - Computation computation = builder->Build().ConsumeValueOrDie(); + XlaBuilder* builder, tensorflow::gtl::ArraySlice arguments) { + XlaComputation computation = builder->Build().ConsumeValueOrDie(); auto global_data = client_->Execute(computation, arguments, &execution_options_) .ConsumeValueOrDie(); @@ -54,7 +53,7 @@ class DeconstructTupleTest : public ClientLibraryTestBase { }; TEST_F(DeconstructTupleTest, DeconstructTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2}); @@ -73,7 +72,7 @@ TEST_F(DeconstructTupleTest, DeconstructTuple) { } TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2}); @@ -103,7 +102,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleTwice) { } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2, const2, const1}); @@ -129,7 +128,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) { } TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({const1, const2, const1}); @@ -159,7 +158,7 @@ TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) { } TEST_F(DeconstructTupleTest, DeconstructNonTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto global_data = ExecuteAndCheckTransfer(&builder, {}); @@ -170,7 +169,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { } XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = @@ -186,7 +185,7 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { } XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto const1 = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); auto const2 = builder.ConstantR1({2.0, 4.0, 6.0, 8.0}); builder.Tuple({builder.Tuple({const1, const2}), const1}); diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 1da7a96fe23..085a5105aca 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { @@ -22,12 +23,12 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { // intended to track, we need to set kDepth to 20000. // Unfortunately, setting it that high causes the test to time out. const int kDepth = 200; - ComputationBuilder b(client_, TestName()); - ComputationDataHandle x; - ComputationDataHandle y; + XlaBuilder b(TestName()); + XlaOp x; + XlaOp y; auto x_data = CreateR0Parameter(3, 0, "x", &b, &x); auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); - ComputationDataHandle z = x; + XlaOp z = x; for (int i = 0; i < kDepth; ++i) { z = b.Add(z, y); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index c4031dfee59..0fd846cef80 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -51,21 +51,20 @@ using TypesF16F32F64 = ::testing::Types; using TypesF16F32F64CF64 = ::testing::Types; #elif !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) && \ - defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ + defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) && \ defined(XLA_BACKEND_DOES_NOT_SUPPORT_COMPLEX) using TypesF16F32 = ::testing::Types; using TypesF16F32F64 = ::testing::Types; -using TypesF16F32F64CF64 = - ::testing::Types; +using TypesF16F32F64CF64 = ::testing::Types; #else #error "Situation not handled yet" #endif // Check that we can safely pass an input tuple's elements to a dot operation. -TEST_F(DotOperationTest, DotOfInputTupleElem) { - ComputationBuilder builder(client_, TestName()); +XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) { + XlaBuilder builder(TestName()); - ComputationDataHandle param; + XlaOp param; auto param_data = CreateParameterAndTransferLiteral( 0, *Literal::MakeTuple({Literal::CreateR2({{1, 2}, {3, 4}}).get(), @@ -86,7 +85,7 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64CF64, TypesF16F32F64CF64); XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, ZeroElementVectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR1({}); auto rhs = builder.ConstantR1({}); @@ -102,7 +101,7 @@ TYPED_TEST_CASE(DotOperationTest_F16F32F64, TypesF16F32F64); XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D({{3.0f, 4.0f}}); auto rhs = builder.ConstantFromArray({3.0f, 4.0f}); auto result = builder.Dot(lhs, rhs); @@ -113,7 +112,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TrivialMatrixVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR1({static_cast(2.0f)}); auto rhs = builder.ConstantR1({static_cast(3.0f)}); auto result = builder.Dot(lhs, rhs); @@ -124,7 +123,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, OneElementVectorDot) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, VectorDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantFromArray({1.0f, 2.5f, 42.0f}); auto rhs = builder.ConstantFromArray({11.0f, -1.0f, 0.5f}); auto result = builder.Dot(lhs, rhs); @@ -139,7 +138,7 @@ std::vector MinorToMajorForIsRowMajor(bool row_major) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto result = builder.Dot(lhs, rhs); @@ -150,7 +149,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto rhs = builder.ConstantR2FromArray2D( {{7.0f, 8.0f, 9.0f}, {42.0f, 77.0f, 101.0f}}); @@ -162,7 +161,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_0x2_2x3) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D( {{7.0f, 8.0f}, {9.0f, 42.0f}, {77.0f, 101.0f}}); auto rhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); @@ -174,7 +173,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_3x2_2x0) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs = builder.ConstantR2FromArray2D(Array2D(2, 0)); auto rhs = builder.ConstantR2FromArray2D(Array2D(0, 2)); auto result = builder.Dot(lhs, rhs); @@ -185,7 +184,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, Dot_2x0_0x2) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto param0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 4}), "arg0"); auto param1 = @@ -230,7 +229,7 @@ class SquareMatrixDot : public DotOperationTest { LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"), @@ -315,7 +314,7 @@ void ParametricDotTest::TestImpl() { addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, @@ -491,7 +490,7 @@ class NonsquareMatrixDot : public DotOperationTest { MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"), @@ -523,7 +522,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { LayoutUtil::MakeLayout({1, 0}))) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), @@ -538,7 +537,7 @@ XLA_TEST_F(DotOperationTest, MatrixVectorC64) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, ConcurrentMatMult) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto matrix1 = builder.ConstantR2FromArray2D({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto matrix2 = builder.ConstantR2FromArray2D({{5.0f, 6.0f}, {7.0f, 8.0f}}); auto matrix12 = builder.Dot(matrix1, matrix2); @@ -559,7 +558,7 @@ TYPED_TEST_CASE(DotOperationTestForBatchMatMul, TypesF16F32F64); // sync-dependent on bitcasts' operands. XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2, 2}), "x"); auto y = @@ -569,7 +568,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2}); // Slice batches into individual matrices and multiply them. - std::vector out_slices; + std::vector out_slices; for (int i = 0; i < 4; ++i) { // Slice off individual matrices and reshape to 2D tensors. auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2}, {1, 1, 1}); @@ -615,7 +614,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { using T = TypeParam; - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2, 2}), "x"); auto y = @@ -677,7 +676,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto lhs_arg = builder.Parameter( 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}), @@ -713,7 +712,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, new Array2D({{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}})); - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs_arg_0"); @@ -761,7 +760,7 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, {4.0f, 3.0f}, {2.0f, 1.0f}})); - ComputationBuilder builder(this->client_, this->TestName()); + XlaBuilder builder(this->TestName()); auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShapeWithType({2, 2}), "lhs_arg_0"); @@ -799,5 +798,250 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, this->error_spec_); } +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{96.0, 105.0, 114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 105, 96}, {96, 105, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{105.0}, {105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstRHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{105.0, 105.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER( + DotOfGatherOptimizationWithConstLHSReverseMM)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{114, 96}, {105, 105}, {96, 114}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{96.0}, {105.0}, {114.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{126.0, 129.0, 132.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSRows)))) { + std::unique_ptr> constant_lhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0}, + {9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0}, + {3.0, 2.0, 1.0}})); + // Dot result to slice from: {{132, 129, 126}, {126, 129, 132}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({0, 1}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {6, 1}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{129.0}, {129.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstRHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(lhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(dynamic_slice, rhs_constant, dot_dnums); + + Array2D expected({{56.0, 168.0, 91.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} + +// TODO (b/69062148) Enable when Dot implements general contracting dimensions. +XLA_TEST_F(DotOperationTest, + DISABLED_ON_CPU(DISABLED_ON_GPU( + DISABLED_ON_INTERPRETER(DotOfGatherOptimizationWithConstLHSCols)))) { + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, + {7.0, 8.0, 9.0, 9.0, 8.0, 7.0}, + {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + // Dot result to slice from: {{91, 168, 56}, {56, 168, 91}} + + XlaBuilder builder(TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto start_constant = builder.ConstantR1({1, 0}); + auto dynamic_slice = + builder.DynamicSlice(rhs_constant, start_constant, {1, 6}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(1); + auto result = builder.DotGeneral(lhs_constant, dynamic_slice, dot_dnums); + + Array2D expected({{168.0}, {168.0}}); + ComputeAndCompareR2(&builder, expected, {}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 5f00c340028..49f3a10d227 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -35,8 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -55,9 +53,9 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR1Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + void TestR1OOB() { + // Slice at dimension boundaries, but with out of bounds indices. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {4, 5, 6, 7}); } template @@ -80,10 +78,10 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR2Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR2OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); } template @@ -108,11 +106,11 @@ class DynamicSliceTest : public ClientLibraryTestBase { } template - void TestR3Wrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestR3OOB() { + // Slice at dimension boundaries, but with out of bounds indices. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, - {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); + {2, 1, 2}, {{{5, 6}}, {{11, 12}}}); } template @@ -201,19 +199,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R1OOB) { TestR1OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R2OOB) { TestR2OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } -XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } +XLA_TEST_F(DynamicSliceTest, Int32R3OOB) { TestR3OOB(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } @@ -334,17 +332,17 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void TestWrap() { - // Slice at dimension boundaries, but with sizes that cause indices to wrap. + void TestOOB() { + // // Slice at dimension boundaries, but with out of bounds indices. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {10, 1, 2, 3, 4, 5, 8, 9}); + {0, 1, 2, 3, 4, 8, 9, 10}); // R2 Shape: [3, 3] RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); } template @@ -363,9 +361,9 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { ->Convert(primitive_util::NativeToPrimitiveType()) .ValueOrDie()); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Initialize and transfer dynamic slice start indices parameter. - ComputationDataHandle starts; + XlaOp starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. @@ -472,33 +470,25 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void RunR3Contiguous(std::vector operand_shape, int32 index, int32 size) { -#ifdef XLA_TEST_BACKEND_CPU_PARALLEL - // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. - if (std::is_same::value) { - return; - } -#endif - const int32 kSeq = operand_shape[0]; const int32 kBatch = operand_shape[1]; const int32 kDim = operand_shape[2]; Array3D input_values(kSeq, kBatch, kDim); Array3D update_values(size, kBatch, kDim); Array3D expected_values(kSeq, kBatch, kDim); + index = std::min(std::max(0, index), kSeq - size); input_values.FillIota(static_cast(0)); T value = static_cast(10); update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when - // the update wraps. According to documentation, the results are technically - // implementation specific where the update is out of bounds, and hence - // we don't really know what to pass into ComputeAndCompareR3. + // the indices are out of bounds. expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = value++; + expected_values(index + i, j, k) = value++; } } } @@ -541,35 +531,25 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0(); } // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { - TestR1(); -} +XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } -// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) { - TestR2(); -} +XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } -// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R3BF16)) { - TestR3(); -} +XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } -XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32WrapBF16)) { - TestWrap(); -} -XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } -XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOBBF16) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int64OOB) { TestOOB(); } +XLA_TEST_F(DynamicUpdateSliceTest, UInt64OOB) { TestOOB(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. @@ -632,37 +612,37 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { // Tests for simple R3 case where the update is contiguous (i.e. the minor // two dimensions are not sliced). XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { - // Single element, no wrap. + // Single element, index in-bounds std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { - // Multiple element, no wrap. + // Multiples element, index in-bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOB) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { - // Multiple element, wrapping. +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleOOBBF16) { + // Multiple element, index out of bounds. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } @@ -737,11 +717,11 @@ void BM_DynamicSlice(int num_iters) { auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, *buffer)); + executors[device_ordinal], *start_indices_literal, buffer)); std::unique_ptr executable = client - ->Compile(computation, {&buffer->on_host_shape()}, + ->Compile(computation, {&buffer.on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); @@ -750,14 +730,14 @@ void BM_DynamicSlice(int num_iters) { options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({buffer.get()}, options); + auto result = executable->Run({&buffer}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({buffer.get()}, options); + auto result = executable->Run({&buffer}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc index 644cbbf40f2..a6ba6db5d3b 100644 --- a/tensorflow/compiler/xla/tests/execution_profile_test.cc +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/platform/test.h" @@ -24,8 +25,7 @@ namespace { class ExecutionProfileTest : public ClientLibraryTestBase {}; -XLA_TEST_F(ExecutionProfileTest, - DISABLED_ON_CPU_PARALLEL(ExecuteWithExecutionProfile)) { +XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) { Shape shape = ShapeUtil::MakeShape(F32, {256, 256}); TF_ASSERT_OK_AND_ASSIGN( @@ -33,9 +33,9 @@ XLA_TEST_F(ExecutionProfileTest, client_->TransferToServer( *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); - ComputationBuilder b(client_, TestName() + ".add"); + XlaBuilder b(TestName() + ".add"); b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); - TF_ASSERT_OK_AND_ASSIGN(Computation dot_product, b.Build()); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation dot_product, b.Build()); ExecutionProfile execution_profile; TF_ASSERT_OK_AND_ASSIGN( diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc index b28fe0c15a8..0a37e4d4236 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -35,7 +36,7 @@ class ExhaustiveF32ElementwiseOpTest int64 input_size = end - begin; LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr input_literal = Literal::CreateFromDimensions(F32, {input_size}); @@ -78,9 +79,7 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { #endif ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Log(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Log(input); }, std::log, known_incorrect_range); } @@ -96,17 +95,13 @@ XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { #endif ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Exp(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Exp(input); }, std::exp, known_incorrect_range); } XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { ExhaustivelyTestF32Op( - [](ComputationBuilder* builder, const ComputationDataHandle& input) { - builder->Tanh(input); - }, + [](XlaBuilder* builder, const XlaOp& input) { builder->Tanh(input); }, std::tanh, /*known_incorrect_range=*/{0, 0}); } diff --git a/tensorflow/compiler/xla/tests/filecheck.cc b/tensorflow/compiler/xla/tests/filecheck.cc index a5f6872c46c..93d1c921c4a 100644 --- a/tensorflow/compiler/xla/tests/filecheck.cc +++ b/tensorflow/compiler/xla/tests/filecheck.cc @@ -38,7 +38,7 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, pattern_path, pattern)); // Invoke FileCheck to check whether input matches `pattern`. - const char* file_check_path_suffix = "external/llvm/FileCheck"; + const char* file_check_path_suffix = "org_tensorflow/external/llvm/FileCheck"; string file_check_path; if (const char* test_srcdir = getenv("TEST_SRCDIR")) { file_check_path = JoinPath(test_srcdir, file_check_path_suffix); @@ -66,6 +66,11 @@ StatusOr RunFileCheck(const string& input, const string& pattern) { // the error message generated by FileCheck and the inputs. bool succeeded = (exit_status == 0); if (!succeeded) { + LOG(WARNING) << "Tried to execute FileCheck at " << file_check_path; + if (!env->FileExists(file_check_path).ok()) { + LOG(WARNING) << "NOTE: FileCheck binary does not exist!"; + } + LOG(WARNING) << "FileCheck error: " << standard_error; LOG(WARNING) << "FileCheck input was:"; XLA_LOG_LINES(tensorflow::WARNING, input); diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index e75a41acacc..71eb914a8e5 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -41,7 +41,7 @@ class FloorCeilTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice expected, Function f) { LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ") << "}"; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto c = builder.ConstantR1(input); if (f == kCeil) { builder.Ceil(c); @@ -54,7 +54,7 @@ class FloorCeilTest : public ClientLibraryTestBase { void TestR0F32(float input, float expected, Function f) { LOG(INFO) << "input: " << expected; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto c = builder.ConstantR0(input); if (f == kCeil) { builder.Ceil(c); diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index f2aaf6621c1..73f029b59bc 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/platform/test.h" @@ -27,7 +27,7 @@ namespace { class FmaxSimpleTest : public ClientLibraryTestBase {}; TEST_F(FmaxSimpleTest, FmaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}); auto y = builder.ConstantR1( diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index a292eab1d19..e6f79b5ac55 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -25,8 +25,7 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -50,8 +49,6 @@ limitations under the License. using tensorflow::gtl::ArraySlice; -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -121,9 +118,9 @@ class FusionTest : public HloTestBase { auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { - LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4))); } else { - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } } @@ -224,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{0.5}, {2.72}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } // Test whether we emit appropriate code for parameters of fusion instructions. @@ -250,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), - ErrorSpec(1e-4)); + EXPECT_TRUE(LiteralTestUtil::Near( + *Literal::CreateR2({{-1.0, 0.0, 1.0}}), + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, RandomizedParallelPartition) { @@ -310,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear( + EXPECT_TRUE(LiteralTestUtil::Near( *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), - *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); + *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4))); } XLA_TEST_F(FusionTest, ReshapeToScalar) { @@ -325,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(5), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { @@ -339,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { @@ -354,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_1by1by1_) { @@ -369,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__1by1by1) { @@ -383,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR3({{{7}}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape__) { @@ -397,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(7), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { @@ -411,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_2by3) { @@ -426,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Transpose_3by3) { @@ -441,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Reverse) { @@ -457,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({3, 2, 1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReverseNegate) { @@ -474,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-3, -2, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-3, -2, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, BroadcastNegate) { @@ -491,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -1}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -1}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, SliceNegate) { @@ -508,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-1, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-1, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DynamicSliceNegate) { @@ -529,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) { /*instructions_to_fuse=*/{negate3, dynamic_slice2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-2, -3}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({-2, -3}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, ReshapeNegate) { @@ -546,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -2}, {-3, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -2}, {-3, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // TODO(b/64070202): Investigate failure. @@ -564,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR2({{-1, -3}, {-2, -4}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR2({{-1, -3}, {-2, -4}}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } std::unique_ptr MakeReduceTestComputation() { @@ -594,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { @@ -615,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*Literal::CreateR0(-15), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR0(-15), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { @@ -664,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual( + EXPECT_TRUE(LiteralTestUtil::Equal( *Literal::CreateR2({{462, 2145}, {24871, 62491}}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + *ExecuteAndTransfer(std::move(hlo_module), {}))); } // When a constant (or other op) which has multiple users is imported @@ -677,21 +687,20 @@ XLA_TEST_F(FusionTest, SharedConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0}))); + HloInstruction::CreateConstant(Literal::CreateR1({0}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, const0)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add1)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add2)); auto add4 = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); + ShapeUtil::MakeShape(S32, {1}), HloOpcode::kAdd, const1, add3)); hlo_module->AddEntryComputation(builder.Build()) - ->CreateFusionInstruction( - {add4, add3, add2, add1, const1}, - HloInstruction::FusionKind::kLoop); + ->CreateFusionInstruction({add4, add3, add2, add1, const1}, + HloInstruction::FusionKind::kLoop); HloComputation* entry_comp = hlo_module->entry_computation(); @@ -701,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) { // fused instruction contains the constant(2), the parameter, and 4 adds EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); - LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), - *ExecuteAndTransfer(std::move(hlo_module), {})); + EXPECT_TRUE( + LiteralTestUtil::Equal(*Literal::CreateR1({8}), + *ExecuteAndTransfer(std::move(hlo_module), {}))); } XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D(HloOpcode::kAdd); } @@ -781,7 +791,7 @@ void BM_ParallelFusion(int num_iters) { const int64 param2_dim1 = 1024; // Create computation. - ComputationBuilder builder(client, "ParallelFusion"); + XlaBuilder builder("ParallelFusion"); Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); auto param0 = builder.Parameter(0, shape0, "param0"); Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); @@ -796,19 +806,19 @@ void BM_ParallelFusion(int num_iters) { // Transfer literals to device. auto param0_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); - std::unique_ptr buffer0 = + ScopedShapedBuffer buffer0 = client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); auto param1_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); - std::unique_ptr buffer1 = + ScopedShapedBuffer buffer1 = client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); auto param2_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); - std::unique_ptr buffer2 = + ScopedShapedBuffer buffer2 = client->LiteralToShapedBuffer(*param2_literal, device_ordinal) .ConsumeValueOrDie(); @@ -816,8 +826,8 @@ void BM_ParallelFusion(int num_iters) { std::unique_ptr executable = client ->Compile(computation, - {&buffer0->on_host_shape(), &buffer1->on_host_shape(), - &buffer2->on_host_shape()}, + {&buffer0.on_host_shape(), &buffer1.on_host_shape(), + &buffer2.on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); @@ -838,8 +848,7 @@ void BM_ParallelFusion(int num_iters) { // Run some warm-up executions. const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = - executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); + auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options); ASSERT_TRUE(result.ok()); } @@ -852,8 +861,7 @@ void BM_ParallelFusion(int num_iters) { tensorflow::testing::UseRealTime(); tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = - executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); + auto result = executable->Run({&buffer0, &buffer1, &buffer2}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 90496d55e60..4854c649c15 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -399,12 +399,187 @@ ENTRY main { RunTest(hlo_text, operand.get(), gather_indices.get()); } +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherV2) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherV2 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[3,2] gather(operand, indices), + output_window_dims={0}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[3,2] broadcast(one), dimensions={} + ROOT result = s32[3,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({0, 2}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,3,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=2, + window_bounds={3, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,3,2] broadcast(one), dimensions={} + ROOT result = s32[2,3,2]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 2}, {2, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNdMultipleBatchDims) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNdMultipleBatchDims + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=2, + window_bounds={1, 1} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR3({{{0, 2}, {2, 1}}, {{1, 2}, {2, 0}}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedTensorFlowGatherNd) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=1, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, + FusedTensorFlowGatherNdNonDefaultIndexVectorDim) { + const string hlo_text = R"( +HloModule FusedTensorFlowGatherNd + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,2] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0,1}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1,2} + one = s32[] constant(1) + one_broadcasted = s32[2,2] broadcast(one), dimensions={} + ROOT result = s32[2,2]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR3({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{0, 0}, {1, 0}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedDynamicSlice) { + const char* hlo_text = R"( +HloModule FusedDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + gather = s32[1,1] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[1,1] broadcast(one), dimensions={} + ROOT result = s32[1,1]{1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = Literal::CreateR1({1, 1}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + +XLA_TEST_F(GatherOperationTest, FusedBatchDynamicSlice) { + const string hlo_text = R"( +HloModule FusedBatchDynamicSlice + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + gather = s32[2,1,1] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={}, + gather_dims_to_operand_dims={0,1}, + index_vector_dim=0, + window_bounds={1,1} + one = s32[] constant(1) + one_broadcasted = s32[2,1,1] broadcast(one), dimensions={} + ROOT result = s32[2,1,1]{2,1,0} add(gather, one_broadcasted) +} +)"; + std::unique_ptr operand = + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr gather_indices = + Literal::CreateR2({{2, 1}, {1, 1}}); + RunTest(hlo_text, operand.get(), gather_indices.get()); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; -// TODO(b/30671675): Asynchronous execution on stream is not yet supported on -// GPU and CPU_PARALLEL. -XLA_TEST_F(GatherClientLibraryTest, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) { +XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) { // We create this HLO, but using the XlaBuilder API. // // ENTRY main { @@ -454,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, client_->ExecuteParallel(computation_instances)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result_literal, client_->Transfer(*(result_data[0]))); - LiteralTestUtil::ExpectEqual( - *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result_literal, *Literal::CreateR2({{1, 2, 3}, {7, 8, 9}}))); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc index ec2f49d43bd..76bf47845ca 100644 --- a/tensorflow/compiler/xla/tests/half_test.cc +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -16,8 +16,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -39,7 +38,7 @@ class HalfTestBase : public ClientLibraryTestBase { }; using UnaryBuildFuncTy = - std::function; + std::function; struct UnaryOpTestParam { std::function compute_func; @@ -51,8 +50,8 @@ class UnaryOpTest : public HalfTestBase, XLA_TEST_P(UnaryOpTest, Ops) { std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); @@ -79,30 +78,21 @@ half round_imp(half value) { INSTANTIATE_TEST_CASE_P( half, UnaryOpTest, - ::testing::Values(UnaryOpTestParam{[](half x) { return abs(x); }, - &ComputationBuilder::Abs}, - UnaryOpTestParam{[](half x) { return round_imp(x); }, - &ComputationBuilder::Round}, - UnaryOpTestParam{[](half x) { return ceil(x); }, - &ComputationBuilder::Ceil}, - UnaryOpTestParam{[](half x) { return cos(x); }, - &ComputationBuilder::Cos}, - UnaryOpTestParam{[](half x) { return exp(x); }, - &ComputationBuilder::Exp}, - UnaryOpTestParam{[](half x) { return floor(x); }, - &ComputationBuilder::Floor}, - UnaryOpTestParam{[](half x) { return log(x); }, - &ComputationBuilder::Log}, - UnaryOpTestParam{[](half x) { return -x; }, - &ComputationBuilder::Neg}, - UnaryOpTestParam{[](half x) { return sign_imp(x); }, - &ComputationBuilder::Sign}, - UnaryOpTestParam{[](half x) { return sin(x); }, - &ComputationBuilder::Sin}, - UnaryOpTestParam{[](half x) { return tanh(x); }, - &ComputationBuilder::Tanh} + ::testing::Values( + UnaryOpTestParam{[](half x) { return abs(x); }, &XlaBuilder::Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, + &XlaBuilder::Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, &XlaBuilder::Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, &XlaBuilder::Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, &XlaBuilder::Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, &XlaBuilder::Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, &XlaBuilder::Log}, + UnaryOpTestParam{[](half x) { return -x; }, &XlaBuilder::Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, &XlaBuilder::Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, &XlaBuilder::Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, &XlaBuilder::Tanh} - )); + )); struct UnaryPredTestParam { std::function compute_func; @@ -115,8 +105,8 @@ class UnaryPredTest : public HalfTestBase, XLA_TEST_P(UnaryPredTest, Ops) { std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); @@ -136,11 +126,11 @@ XLA_TEST_P(UnaryPredTest, Ops) { INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, ::testing::Values(UnaryPredTestParam{ [](half x) { return isfinite(x); }, - &ComputationBuilder::IsFinite})); + &XlaBuilder::IsFinite})); using BinaryBuildFuncTy = std::function)>; + xla::XlaBuilder*, const xla::XlaOp& x, const xla::XlaOp& y, + tensorflow::gtl::ArraySlice)>; struct BinaryOpTestParam { std::function compute_func; @@ -153,12 +143,12 @@ class BinaryOpTest : public HalfTestBase, XLA_TEST_P(BinaryOpTest, Ops) { std::vector x({half(1.0), half(2.0), half(3.0), half(-4.0)}); std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); - ComputationDataHandle y_opnd; + XlaOp y_opnd; auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", &builder, &y_opnd); @@ -184,21 +174,21 @@ INSTANTIATE_TEST_CASE_P( half, BinaryOpTest, ::testing::Values( BinaryOpTestParam{[](half x, half y) { return x + y; }, - &ComputationBuilder::Add}, + &XlaBuilder::Add}, BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, - &ComputationBuilder::Atan2}, + &XlaBuilder::Atan2}, BinaryOpTestParam{[](half x, half y) { return x / y; }, - &ComputationBuilder::Div}, + &XlaBuilder::Div}, BinaryOpTestParam{[](half x, half y) { return max(x, y); }, - &ComputationBuilder::Max}, + &XlaBuilder::Max}, BinaryOpTestParam{[](half x, half y) { return min(x, y); }, - &ComputationBuilder::Min}, + &XlaBuilder::Min}, BinaryOpTestParam{[](half x, half y) { return x * y; }, - &ComputationBuilder::Mul}, + &XlaBuilder::Mul}, BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, - &ComputationBuilder::Pow}, + &XlaBuilder::Pow}, BinaryOpTestParam{[](half x, half y) { return x - y; }, - &ComputationBuilder::Sub} + &XlaBuilder::Sub} )); @@ -214,12 +204,12 @@ class BinaryPredTest XLA_TEST_P(BinaryPredTest, Ops) { std::vector x({half(1.0), half(2.0), half(0.2), half(-4.0)}); std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle x_opnd; + XlaBuilder builder(TestName()); + XlaOp x_opnd; auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", &builder, &x_opnd); - ComputationDataHandle y_opnd; + XlaOp y_opnd; auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", &builder, &y_opnd); @@ -239,17 +229,17 @@ XLA_TEST_P(BinaryPredTest, Ops) { INSTANTIATE_TEST_CASE_P( half, BinaryPredTest, ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, - &ComputationBuilder::Eq}, + &XlaBuilder::Eq}, BinaryPredTestParam{[](half x, half y) { return x != y; }, - &ComputationBuilder::Ne}, + &XlaBuilder::Ne}, BinaryPredTestParam{[](half x, half y) { return x >= y; }, - &ComputationBuilder::Ge}, + &XlaBuilder::Ge}, BinaryPredTestParam{[](half x, half y) { return x > y; }, - &ComputationBuilder::Gt}, + &XlaBuilder::Gt}, BinaryPredTestParam{[](half x, half y) { return x <= y; }, - &ComputationBuilder::Le}, + &XlaBuilder::Le}, BinaryPredTestParam{[](half x, half y) { return x < y; }, - &ComputationBuilder::Lt} + &XlaBuilder::Lt} )); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 21f71fc91bb..12598579c70 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -35,8 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -95,11 +93,13 @@ HloTestBase::HloTestBase(se::Platform* test_platform, } /* static */ -std::unique_ptr HloTestBase::CreateNewModule() { +std::unique_ptr HloTestBase::CreateNewModule(const string& name) { HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_max_kernel_unroll_factor(1); + config.set_debug_options(debug_options); + + return MakeUnique(name, VersionedComputationHandle(), config); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { @@ -115,11 +115,13 @@ StatusOr> HloTestBase::Execute( return test_runner_.Execute(std::move(module), arguments); } -StatusOr> HloTestBase::ExecuteNoHloPasses( +std::unique_ptr HloTestBase::ExecuteNoHloPasses( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - return test_runner_.Execute(std::move(module), arguments, - /*run_hlo_passes=*/false); + return test_runner_ + .Execute(std::move(module), arguments, + /*run_hlo_passes=*/false) + .ValueOrDie(); } std::unique_ptr HloTestBase::ExecuteAndTransfer( diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 3e8e2360bb3..9539ae06801 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -76,8 +76,7 @@ class HloTestBase : public ::testing::Test { // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. - HloTestBase(::perftools::gputools::Platform* test_platform, - ::perftools::gputools::Platform* reference_platform); + HloTestBase(se::Platform* test_platform, se::Platform* reference_platform); ~HloTestBase() override {} @@ -86,7 +85,8 @@ class HloTestBase : public ::testing::Test { // options from command-line flags. If you want a fresh HloModule object and // then add HloComputations to it, it's recommended to use this method in your // tests. - static std::unique_ptr CreateNewModule(); + static std::unique_ptr CreateNewModule( + const string& name = TestName()); // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in @@ -100,7 +100,7 @@ class HloTestBase : public ::testing::Test { // Same as above, except the module will be executed without running any HLO // passes on it. - StatusOr> ExecuteNoHloPasses( + std::unique_ptr ExecuteNoHloPasses( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments); @@ -177,9 +177,13 @@ class HloTestBase : public ::testing::Test { // 'layout'. void ForceParameterLayout(HloModule* module, int64 param_no, const Layout& layout) { - ASSERT_LT(param_no, - module->mutable_entry_computation_layout()->parameter_count()); - module->mutable_entry_computation_layout() + ASSERT_LT( + param_no, + module->mutable_host_entry_computation_layout()->parameter_count()); + module->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_parameter_layout(param_no) ->ResetLayout(layout); } @@ -187,7 +191,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to force the layout of the computation result in a // module. The result layout of 'module' is set to 'layout'. void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->ResetLayout(layout); } @@ -195,7 +202,10 @@ class HloTestBase : public ::testing::Test { // Convenience method to clear the layout of the computation result in // 'module'. void ForceClearResultLayout(HloModule* module) { - module->mutable_entry_computation_layout() + module->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + module->mutable_device_entry_computation_layout() ->mutable_result_layout() ->Clear(); } diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 81630df34c5..cde1dcd9cd1 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -15,818 +15,93 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include -#include -#include - -#include "tensorflow/compiler/xla/index_util.h" -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" namespace xla { -/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( - const Shape& expected, const Shape& actual) { - if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { - return ::testing::AssertionFailure() - << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) - << " got: " << ShapeUtil::HumanString(actual); - } - if (ShapeUtil::IsTuple(expected)) { - if (ShapeUtil::TupleElementCount(expected) != - ShapeUtil::TupleElementCount(actual)) { - return ::testing::AssertionFailure() - << "want tuple element count: " - << ShapeUtil::TupleElementCount(expected) - << " got tuple element count: " - << ShapeUtil::TupleElementCount(actual); - } - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) - << "mismatch in tuple index " << i; - if (!result) { - return result; - } - } - } else { - if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { - return ::testing::AssertionFailure() - << "want rank of: " << ShapeUtil::HumanString(expected) - << " got rank of: " << ShapeUtil::HumanString(actual); - } - if (expected.element_type() != actual.element_type()) { - return ::testing::AssertionFailure() - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - } - if (expected.dimensions_size() != actual.dimensions_size()) { - return ::testing::AssertionFailure() - << "want dimensions_size " << expected.dimensions_size() - << " got dimensions_size " << actual.dimensions_size(); - } - for (int i = 0; i < expected.dimensions_size(); ++i) { - if (expected.dimensions(i) != actual.dimensions(i)) { - return ::testing::AssertionFailure() - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - } - } - return ::testing::AssertionSuccess(); -} - -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_TRUE(EqualShapes(expected, actual)); -} - -/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( - const Shape& expected, const Shape& actual) { - ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); -} - namespace { -// Return a literal with all arrays of type FromNativeT converted to type -// ToNativeT in the given literal. -template -std::unique_ptr ConvertType(const Literal& literal) { - // First construct shape of the result. - Shape result_shape(literal.shape()); - ShapeUtil::ForEachMutableSubshape( - &result_shape, [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == - primitive_util::NativeToPrimitiveType()) { - subshape->set_element_type( - primitive_util::NativeToPrimitiveType()); - } - }); - auto result = MakeUnique(result_shape); - - // Then copy over the data from 'literal' converting FromNativeT values to - // ToNativeT values as necessary. - ShapeUtil::ForEachSubshape( - literal.shape(), - [&](const Shape& subshape, const ShapeIndex& shape_index) { - if (ShapeUtil::IsArray(subshape)) { - if (subshape.element_type() == - primitive_util::NativeToPrimitiveType()) { - auto src = literal.data(shape_index); - auto dest = result->data(shape_index); - for (int64 i = 0; i < src.size(); ++i) { - dest[i] = static_cast(src[i]); - } - } else { - TF_CHECK_OK(result->CopyFrom(literal, - /*dest_shape_index=*/shape_index, - /*src_shape_index=*/shape_index)); - } - } - }); - return result; +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + auto get_hostname = [] { + char hostname[1024]; + gethostname(hostname, sizeof hostname); + hostname[sizeof hostname - 1] = 0; + return string(hostname); + }; + int64 now_usec = tensorflow::Env::Default()->NowMicros(); + string filename = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), + tensorflow::strings::Printf("tempfile-%s-%llx-%s", get_hostname().c_str(), + now_usec, name.c_str())); + TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), filename, + literal.ToProto())); + LOG(ERROR) << "wrote to " << name << " file: " << filename; } -} // namespace - -/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& literal) { - return ConvertType(literal); +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); } -/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& literal) { - return ConvertType(literal); -} - -namespace { - -string Hostname() { - char hostname[1024]; - gethostname(hostname, sizeof hostname); - hostname[sizeof hostname - 1] = 0; - return string(hostname); -} - -// Helper function for comparing a floating point type, FloatT, bitwise equal -// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT -// -- on miscompare, a nice error message is given in the AssertionFailure. -template -::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { - auto ulhs = tensorflow::bit_cast(lhs); - auto urhs = tensorflow::bit_cast(rhs); - auto lhs_double = static_cast(lhs); - auto rhs_double = static_cast(rhs); - if (ulhs != urhs) { - return ::testing::AssertionFailure() << tensorflow::strings::Printf( - "floating values are not bitwise-equal; and equality testing " - "was requested: %s=%g=%a vs %s=%g=%a", - tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) - .c_str(), - lhs_double, lhs_double, - tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) - .c_str(), - rhs_double, rhs_double); - } - return ::testing::AssertionSuccess(); -} - -// Templated comparator that specializes for float equality comparison with the -// bitwise helper above (this is the un-specialized fallback, to just use the -// default gunit implementation). -template -::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) { - if (lhs == rhs) { +::testing::AssertionResult StatusToAssertion(const Status& s) { + if (s.ok()) { return ::testing::AssertionSuccess(); } - ::testing::Message msg; - msg << "Expected equality of these values:"; - msg << "\n " << lhs; - msg << "\n " << rhs; - - return ::testing::AssertionFailure() << msg; -} - -// Specializations for floating types that do bitwise comparisons when equality -// comparison is requested. -template <> -::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(Eigen::half lhs, - Eigen::half rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(float lhs, float rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(double lhs, double rhs) { - return CompareFloatsBitwiseEqual(lhs, rhs); -} -template <> -::testing::AssertionResult CompareEqual(complex64 lhs, - complex64 rhs) { - auto res = CompareEqual(lhs.real(), rhs.real()); - if (!res) { - return res; - } - return CompareEqual(lhs.imag(), rhs.imag()); -} - -// A recursive function which iterates through every index of expected and -// actual literal and compares their values elementwise. Returns true if all -// elements are equal. -template -bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, - tensorflow::gtl::MutableArraySlice multi_index, - int64 dimension) { - if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = expected.Get(multi_index); - NativeT actual_value = actual.Get(multi_index); - ::testing::AssertionResult result = - CompareEqual(expected_value, actual_value); - return result; // Defines implicit coersion to bool. - } - - bool all_match = true; - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index[dimension] = i; - all_match = all_match && ExpectLiteralsEqual( - expected, actual, multi_index, dimension + 1); - } - return all_match; + return ::testing::AssertionFailure() << s.error_message(); } } // namespace -/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, - const Literal& actual, - const string& message) { - EXPECT_TRUE(Equal(expected, actual)) - << "expected:\n" - << expected.ToString() << "\n\tvs actual:\n" - << actual.ToString() - << (message.empty() - ? "" - : tensorflow::strings::StrCat("\nmessage: ", message)); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + return StatusToAssertion(literal_comparison::EqualShapes(expected, actual)); } -/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, - const Literal& actual) { - EXPECT_FALSE(Equal(expected, actual)); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) { + if (expected.ShortDebugString() != actual.ShortDebugString()) { + return ::testing::AssertionFailure() + << "want: " << expected.ShortDebugString() + << " got: " << actual.ShortDebugString(); + } + return ::testing::AssertionSuccess(); } /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( - const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); - - AssertEqualShapes(expected.shape(), actual.shape()); - std::vector multi_index(expected.shape().dimensions_size(), 0); - bool match = false; - switch (expected.shape().element_type()) { - case PRED: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U8: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case S64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case U64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case BF16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F16: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F32: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case F64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case C64: - match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); - break; - case TUPLE: { - bool tuple_match = true; - for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(tensorflow::strings::StrCat( - "Tuple index ", i, " in ", - ShapeUtil::HumanString(expected.shape()))); - - // Create LiteralViews of the expected and actual elements. - auto result = Equal(LiteralView::Create(expected, {i}), - LiteralView::Create(actual, {i})); - tuple_match = tuple_match ? !!result : false; - } - match = tuple_match; - break; - } - default: - LOG(FATAL) - << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: " - << PrimitiveType_Name(expected.shape().element_type()); - } - ::testing::AssertionResult result = ::testing::AssertionSuccess(); - if (!match) { - result = ::testing::AssertionFailure() - << "expected: " << expected.ToString() - << "\nactual: " << actual.ToString(); - VLOG(1) << result.message(); - } - return result; + const LiteralSlice& expected, const LiteralSlice& actual) { + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } -namespace { - -// Helper class for comparing floating-point literals within an error bound. -class NearComparator { - public: - explicit NearComparator(ErrorSpec error) : error_(error) {} - - // Compares the two literals elementwise. EXPECTs each pair of elements to be - // within the error bound. Emits useful log messages and dumps literals to - // temporary files on failure. Returns true if literals match. - bool ExpectNear(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(expected)); - VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, TruncateHugeLiteral(actual)); - - // If the shapes mismatch, we simply fail the expectation instead of - // printing out data, as it's a type error rather than a value error. - ::testing::AssertionResult equal_shapes = - LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); - if (!equal_shapes) { - EXPECT_TRUE(equal_shapes); - return false; - } - - // Set up members used during the comparison. - num_miscompares_ = 0; - abs_diff_sum_ = 0.0; - abs_expected_sum_ = 0.0; - abs_diff_miscompare_sum_ = 0.0; - abs_expected_miscompare_sum_ = 0.0; - max_rel_err_ = 0.0; - max_abs_err_ = 0.0; - first_linear_index_ = -1; - last_linear_index_ = -1; - max_rel_linear_index_ = -1; - max_abs_linear_index_ = -1; - miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); - miscompares_.PopulateWithValue(false); - multi_index_.resize(expected.shape().dimensions_size(), 0); - - switch (expected.shape().element_type()) { - case BF16: - ExpectLiteralsNear(expected, actual, 0); - break; - case F16: - ExpectLiteralsNear(expected, actual, 0); - break; - case F32: - ExpectLiteralsNear(expected, actual, 0); - break; - case F64: - ExpectLiteralsNear(expected, actual, 0); - break; - case C64: - ExpectLiteralsNear(expected, actual, 0); - break; - default: - LOG(FATAL) << "Unsupported primitive type in near comparator: " - << PrimitiveType_Name(expected.shape().element_type()) - << ". Must be floating-point type."; - } - - if (num_miscompares_ > 0) { - if (!VLOG_IS_ON(1)) { - LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << TruncateHugeLiteral(expected); - LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << TruncateHugeLiteral(actual); - LOG(INFO) << "Dumping literals to temp files..."; - WriteLiteralToTempFile(expected, "expected"); - WriteLiteralToTempFile(actual, "actual"); - WriteLiteralToTempFile(miscompares_, "miscompares"); - } - EXPECT_TRUE(num_miscompares_ == 0) - << "\nmax relative mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), max_rel_linear_index_)) - << "\nmaximum relative error " << max_rel_err_ - << "\nmax absolute mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), max_abs_linear_index_)) - << "\nmaximum absolute error " << max_abs_err_ - << "\nfirst mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), first_linear_index_)) - << "\nlast mismatch at index " - << LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex( - actual.shape(), last_linear_index_)) - << "\ntotal absolute error " << abs_diff_sum_ - << "\ntotal absolute error of miscompares " - << abs_diff_miscompare_sum_ << "\ntotal relative error " - << (abs_diff_sum_ / abs_expected_sum_) - << "\ntotal relative error of miscompares " - << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_) - << "\nfailure count " << num_miscompares_; - } - return num_miscompares_ == 0; - } - - private: - template - bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { - if (relaxed_nans) { - return !std::isnan(expected) && std::isnan(actual); - } else { - return std::isnan(expected) != std::isnan(actual); - } - } - - template - void ExpectNear(NativeT expected, NativeT actual, - const ::testing::Message& message) { - EXPECT_NEAR(expected, actual, error_.abs) - << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" - << message; - } - - // EXPECTs that the two given scalar values are within the error bound. Keeps - // track of how many mismatches have occurred to keep the size of the output - // manageable. - template - bool ExpectValuesNear(NativeT expected, NativeT actual) { - if (expected == actual) { - return true; - } - - const float abs_diff = std::abs(actual - expected); - const float rel_err = abs_diff / std::abs(expected); - const bool nan_mismatch = - NanMismatch(expected, actual, error_.relaxed_nans); - const bool mismatch = - (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); - return !mismatch; - } - - // Assumes that expected vs actual fail ExpectValuesNear. - template - void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual, - const Shape& shape, const int64 linear_index) { - const float abs_diff = std::abs(actual - expected); - const float rel_err = abs_diff / std::abs(expected); - abs_diff_sum_ += abs_diff; - abs_expected_sum_ += std::abs(expected); - if (rel_err > max_rel_err_ || std::isnan(rel_err)) { - max_rel_err_ = rel_err; - max_rel_linear_index_ = linear_index; - } - if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) { - max_abs_err_ = abs_diff; - max_abs_linear_index_ = linear_index; - } - if (VLOG_IS_ON(10)) { - VLOG(10) << tensorflow::strings::Printf( - "index %s abs_diff %f rel_err %f", - LiteralTestUtil::MultiIndexAsString( - IndexUtil::LinearIndexToMultidimensionalIndex(shape, - linear_index)) - .c_str(), - abs_diff, rel_err); - } - abs_diff_miscompare_sum_ += abs_diff; - abs_expected_miscompare_sum_ += std::abs(expected); - const int64 kMaxFailures = 2; - if (num_miscompares_ < kMaxFailures) { - const auto multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); - ::testing::Message msg; - msg << "mismatch at index " - << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff " - << abs_diff << " rel err " << rel_err << " failure #" - << num_miscompares_; - ExpectNear(expected, actual, msg); - } else if (num_miscompares_ == kMaxFailures) { - LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; - } - if (num_miscompares_ == 0) { - first_linear_index_ = linear_index; - } - num_miscompares_++; - last_linear_index_ = linear_index; - miscompares_.data()[linear_index] = true; - } - - // Recursive function which compares the two given literals elementwise. - template - void ExpectLiteralsNear(const Literal& expected, const Literal& actual, - int64 dimension) { - // Fast path optimization for the case were layouts match. - if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) { - tensorflow::gtl::ArraySlice expected_data = - expected.data(); - tensorflow::gtl::ArraySlice actual_data = - actual.data(); - const int64 len = expected_data.size(); - for (int64 i = 0; i < len; ++i) { - const bool near = ExpectValuesNear(expected_data[i], actual_data[i]); - if (!near) { - UpdateAndLogMiscompares(expected_data[i], actual_data[i], - actual.shape(), i); - } - } - return; - } - - if (dimension == expected.shape().dimensions_size()) { - bool near = ExpectValuesNear(expected.Get(multi_index_), - actual.Get(multi_index_)); - if (!near) { - UpdateAndLogMiscompares( - expected.Get(multi_index_), - actual.Get(multi_index_), actual.shape(), - IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(), - multi_index_)); - } - } else { - for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { - multi_index_[dimension] = i; - ExpectLiteralsNear(expected, actual, dimension + 1); - } - } - } - - // Writes the given literal to a file in the test temporary directory. - void WriteLiteralToTempFile(const Literal& literal, const string& name) { - int64 now_usec = tensorflow::Env::Default()->NowMicros(); - string filename = tensorflow::io::JoinPath( - tensorflow::testing::TmpDir(), - tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(), - now_usec, name.c_str())); - TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(), - filename, literal.ToProto())); - LOG(ERROR) << "wrote to " << name << " file: " << filename; - } - - // Gets the total element count. For tuples, this is not the count of tuple - // elements, but the sum of elements of each tuple element. - int64 RecursiveElementCount(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); - int64 total = 0; - for (int64 i = 0; i < tuple_elements; ++i) { - total += - RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); - } - return total; - } else { - return ShapeUtil::ElementsIn(shape); - } - } - - // Calling ToString on a literal with over 100 million elements takes around - // 3 minutes. The utility of printing a literal with >1000 elements is - // questionable, especially when writing the Literal proto to disk is orders - // of magnitude faster. - string TruncateHugeLiteral(const Literal& literal) { - return RecursiveElementCount(literal.shape()) < 1000 - ? literal.ToString() - : "[TRUNCATED, Literal with more than 1000 values]"; - } - - ErrorSpec error_; - - // Number of element miscomparisons encountered so far. - int64 num_miscompares_; - - // A Literal containing which elements did not match in the expected and - // actual literals. miscompares_ contains PREDs and is of the same sizes as - // the comparison literals. - Literal miscompares_; - - // A multidimensional index used when performing the recursive comparison. - std::vector multi_index_; - - // Aggregated Statistics on input. - double abs_diff_sum_; - double abs_expected_sum_; - double abs_diff_miscompare_sum_; - double abs_expected_miscompare_sum_; - float max_rel_err_; - float max_abs_err_; - int64 first_linear_index_; - int64 last_linear_index_; - int64 max_rel_linear_index_; - int64 max_abs_linear_index_; -}; - -template <> -bool NearComparator::NanMismatch(complex64 expected, - complex64 actual, - bool relaxed_nans) { - return NanMismatch(expected.real(), actual.real(), relaxed_nans) || - NanMismatch(expected.imag(), actual.imag(), relaxed_nans); -} - -template <> -void NearComparator::ExpectNear(complex64 expected, complex64 actual, - const ::testing::Message& message) { - EXPECT_NEAR(expected.real(), actual.real(), error_.abs) - << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" - << message; - EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs) - << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" - << message; -} - -template <> -bool NearComparator::ExpectValuesNear(bfloat16 expected, - bfloat16 actual) { - return ExpectValuesNear(static_cast(expected), - static_cast(actual)); -} - -template <> -bool NearComparator::ExpectValuesNear(half expected, half actual) { - return ExpectValuesNear(static_cast(std::move(expected)), - static_cast(std::move(actual))); -} - -template <> -void NearComparator::UpdateAndLogMiscompares( - const bfloat16 expected, const bfloat16 actual, const Shape& shape, - const int64 linear_index) { - UpdateAndLogMiscompares(static_cast(expected), - static_cast(actual), shape, linear_index); -} - -template <> -void NearComparator::UpdateAndLogMiscompares(half expected, half actual, - const Shape& shape, - const int64 linear_index) { - UpdateAndLogMiscompares(static_cast(std::move(expected)), - static_cast(std::move(actual)), shape, - linear_index); -} - -} // namespace - /* static */ ::testing::AssertionResult LiteralTestUtil::Near( - const Literal& expected, const Literal& actual, const ErrorSpec& error) { - ::testing::AssertionResult err = - EqualShapes(expected.shape(), actual.shape()); - if (!err) { - return err; - } - - if (ShapeUtil::IsTuple(expected.shape())) { - for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { - SCOPED_TRACE(tensorflow::strings::StrCat( - "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); - const auto expected_element = LiteralView::Create(expected, {i}); - const auto actual_element = LiteralView::Create(actual, {i}); - - ::testing::AssertionResult res = - Near(expected_element, actual_element, error); - if (err && !res) { - err = res; - } - } - return err; - } - - if (ShapeUtil::ElementIsFloating(expected.shape()) || - ShapeUtil::ElementIsComplex(expected.shape())) { - NearComparator comparator(error); - return comparator.ExpectNear(expected, actual) - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << "values were not near"; - } - - return Equal(expected, actual); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, bool detailed_message) { + return StatusToAssertion(literal_comparison::Near( + expected, actual, error_spec, detailed_message, &OnMiscompare)); } -/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, - const Literal& actual, - const ErrorSpec& error, - const string& message) { - EXPECT_TRUE(Near(expected, actual, error)) - << (message.empty() - ? "" - : tensorflow::strings::StrCat("\nmessage: ", message)); -} - -/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( - const Literal& expected, const Literal& actual, +/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) { if (error.has_value()) { VLOG(1) << "Expects near"; - return Near(expected, actual, *error); + return StatusToAssertion(literal_comparison::Near( + expected, actual, *error, /*detailed_message=*/false, &OnMiscompare)); } VLOG(1) << "Expects equal"; - return Equal(expected, actual); -} - -/*static*/ void LiteralTestUtil::ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error) { - EXPECT_TRUE(NearOrEqual(expected, actual, error)); -} - -/* static */ string LiteralTestUtil::MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index) { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(multi_index, ","), "}"); -} - -/* static */ std::unique_ptr LiteralTestUtil::Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, const Literal& literal) { - int64 new_num_elements = 1; - for (int64 i = 0; i < new_dimensions.size(); ++i) { - new_num_elements *= new_dimensions[i]; - } - CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); - CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - - auto new_literal = MakeUnique( - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); - - // Create a new shape with the given minor-to-major layout. This shape is used - // solely for converting linear address to multi-dimensional addresses when - // writing elements to the new literal. - Shape shape_with_layout = new_literal->shape(); - *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - - // Copy data into new literal, element-by-element. - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - std::vector from_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - std::vector to_multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); - switch (literal.shape().element_type()) { - case PRED: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U8: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case U64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case S64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F32: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case F64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - case C64: - new_literal->Set(to_multi_index, - literal.Get(from_multi_index)); - break; - default: - LOG(FATAL) << "Unhandled primitive element type: " - << PrimitiveType_Name(literal.shape().element_type()); - } - } - - return new_literal; + return StatusToAssertion(literal_comparison::Equal(expected, actual)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 7b757a4bd7e..d1b8a6cf0b2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/error_spec.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -38,279 +39,190 @@ limitations under the License. namespace xla { -// Structure describing permissible absolute and relative error bounds. -struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) - : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} - - float abs; // Absolute error bound. - float rel; // Relative error bound. - - // If relaxed_nans is true then any result is valid if we are expecting NaNs. - // In effect, this allows the tested operation to produce incorrect results - // for inputs outside its mathematical domain. - bool relaxed_nans; -}; - // Utility class for making expectations/assertions related to XLA literals. class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. - static ::testing::AssertionResult EqualShapes(const Shape& expected, - const Shape& actual); - static void AssertEqualShapes(const Shape& expected, const Shape& actual); + static ::testing::AssertionResult EqualShapes( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; // Asserts that the provided shapes are equal as defined in AssertEqualShapes // and that they have the same layout. - static void AssertEqualShapesAndLayouts(const Shape& expected, - const Shape& actual); + static ::testing::AssertionResult EqualShapesAndLayouts( + const Shape& expected, const Shape& actual) TF_MUST_USE_RESULT; - // If the given literal's data type is bfloat16, converts it to a float - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - - // If the given literal's data type is float, converts it to a bfloat16 - // literal; otherwise, returns a copy of it. If the literal is a tuple, - // recursively converts its elements. - static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); - - // Asserts that the expected and actual literals are (bitwise) equal for all - // elements in the literal. Also, asserts that the rank, dimensions sizes, and - // primitive type are equal. - static ::testing::AssertionResult Equal( - const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; - - // Expects that expected and actual are Equal. - static void ExpectEqual(const Literal& expected, const Literal& actual, - const string& message = ""); - - // Expects that expected and actual are Not Equal. - static void ExpectNotEqual(const Literal& expected, const Literal& actual); + static ::testing::AssertionResult Equal(const LiteralSlice& expected, + const LiteralSlice& actual) + TF_MUST_USE_RESULT; // Asserts the given literal are (bitwise) equal to given expected values. template - static void ExpectR0Equal(NativeT expected, const Literal& actual); + static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual); + template static void ExpectR1Equal(tensorflow::gtl::ArraySlice expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR2Equal( std::initializer_list> expected, - const Literal& actual); + const LiteralSlice& actual); + template static void ExpectR3Equal( std::initializer_list< std::initializer_list>> expected, - const Literal& actual); + const LiteralSlice& actual); // Asserts the given literal are (bitwise) equal to given array. template static void ExpectR2EqualArray2D(const Array2D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR3EqualArray3D(const Array3D& expected, - const Literal& actual); + const LiteralSlice& actual); template static void ExpectR4EqualArray4D(const Array4D& expected, - const Literal& actual); + const LiteralSlice& actual); - // Asserts that the expected and actual literals are within the given error - // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. + // Decorates literal_comparison::Near() with an AssertionResult return type. // - // Tuples are matched recursively. When comparing tensors of - // non-floating-point type, checks for exact equality, ignoring the ErroSpec. - // - // If the shape of the literals is neither a complex/floating-point tensor nor - // a tuple which contains a complex/floating-point tensor, Near() is - // equivalent to Equal(). We don't raise an error in this case, because we - // want to allow callers to call Near() even if they have no preconceptions - // about the shapes being compared. + // See comment on literal_comparison::Near(). static ::testing::AssertionResult Near( - const Literal& expected, const Literal& actual, - const ErrorSpec& error) TF_MUST_USE_RESULT; - - // Expects expected and actual to be Near with the given error. - static void ExpectNear(const Literal& expected, const Literal& actual, - const ErrorSpec& error, const string& message = ""); + const LiteralSlice& expected, const LiteralSlice& actual, + const ErrorSpec& error_spec, + bool detailed_message = false) TF_MUST_USE_RESULT; // Asserts the given literal are within the given error bound of the given // expected values. Only supported for floating point values. template - static void ExpectR0Near(NativeT expected, const Literal& actual, + static void ExpectR0Near(NativeT expected, const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR1Near(tensorflow::gtl::ArraySlice expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3Near( std::initializer_list< std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4Near( std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error); + const LiteralSlice& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. template static void ExpectR2NearArray2D(const Array2D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR3NearArray3D(const Array3D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); + template static void ExpectR4NearArray4D(const Array4D& expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error); // If the error spec is given, returns whether the expected and the actual are // within the error bound; otherwise, returns whether they are equal. Tuples // will be compared recursively. static ::testing::AssertionResult NearOrEqual( - const Literal& expected, const Literal& actual, + const LiteralSlice& expected, const LiteralSlice& actual, const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; - // If the error spec is given, expects the expected and the actual to be near; - // otherwise, expects them to be equal. Tuples will be compared recursively. - static void ExpectNearOrEqual( - const Literal& expected, const Literal& actual, - const tensorflow::gtl::optional& error); - - // Returns a multi-dimensional index as a string. For example: '{7, 8}' will - // be returned for a 2-dimensional index with dimension 0 index equal to 7, - // dimension 1 equal to 8. - static string MultiIndexAsString( - tensorflow::gtl::ArraySlice multi_index); - - // Creates a literal with a new shape with the given new dimensions using the - // data in the given input literal. For reshaping purposes the (flat) data - // buffer of the input literal is assumed to have the given minor_to_major - // layout order. - static std::unique_ptr Reshape( - tensorflow::gtl::ArraySlice new_dimensions, - tensorflow::gtl::ArraySlice minor_to_major, - const Literal& literal); - - // Creates a literal with the supplied shape, and uses the provided value - // generator to populate the literal's values. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation, and using the engine as entropy generator. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, typename E, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, E* engine, T mean, T stddev); - - // Creates a literal with the supplied shape, and initializes the literal - // values using a normal distribution with given mean and stddev standard - // deviation. - // Returns the new literal object, or an error Status if failed. - template < - PrimitiveType type, - typename T = typename primitive_util::PrimitiveTypeToNative::type> - static StatusOr> CreateRandomLiteral( - const Shape& shape, T mean, T stddev); - private: TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil); }; template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR0(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR0(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR1Equal( - tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR1(expected), actual); + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR1(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR2(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3Equal( std::initializer_list>> expected, - const Literal& actual) { - ExpectEqual(*Literal::CreateR3(expected), actual); + const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( - const Array2D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); + const Array2D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( - const Array3D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); + const Array3D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( - const Array4D& expected, const Literal& actual) { - ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); + const Array4D& expected, const LiteralSlice& actual) { + EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual)); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, - const Literal& actual, + const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR0(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR0(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR1Near( - tensorflow::gtl::ArraySlice expected, const Literal& actual, + tensorflow::gtl::ArraySlice expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR1(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR1(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR2(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3Near( std::initializer_list>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR3(expected), actual, error)); } template @@ -318,63 +230,29 @@ template std::initializer_list>>> expected, - const Literal& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4(expected), actual, error); + const LiteralSlice& actual, const ErrorSpec& error) { + EXPECT_TRUE(Near(*Literal::CreateR4(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( - const Array2D& expected, const Literal& actual, + const Array2D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( - const Array3D& expected, const Literal& actual, + const Array3D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); + EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error)); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( - const Array4D& expected, const Literal& actual, + const Array4D& expected, const LiteralSlice& actual, const ErrorSpec& error) { - ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral( - const Shape& shape, - const std::function)>& generator) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(literal.get()->Populate( - [&](tensorflow::gtl::ArraySlice indexes) { - return generator(indexes); - })); - return std::move(literal); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean, - T stddev) { - using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - std::normal_distribution generator(mean, stddev); - return CreateRandomLiteral( - shape, [&](tensorflow::gtl::ArraySlice /*indexes*/) { - return generator(*engine); - }); -} - -template -/* static */ StatusOr> -LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) { - std::minstd_rand0 engine; - return CreateRandomLiteral(shape, &engine, mean, stddev); + EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error)); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 3a421f84582..bbac7285aef 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { std::unique_ptr literal = Literal::MakeTuple({ Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - LiteralTestUtil::ExpectEqual(*literal, *literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal)); } TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { @@ -89,7 +89,7 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { EXPECT_EQ("2", literal->ToString()); } else if (result.find("actual") != string::npos) { EXPECT_EQ("4", literal->ToString()); - } else if (result.find("miscompares") != string::npos) { + } else if (result.find("mismatches") != string::npos) { EXPECT_EQ("true", literal->ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; @@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { } } +TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) { + auto expected = Literal::CreateR1({1, 2, 3}); + auto actual = Literal::CreateR1({4, 5, 6}); + ::testing::AssertionResult result = + LiteralTestUtil::Equal(*expected, *actual); + EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}")); + EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}")); +} + TEST(LiteralTestUtilTest, NearComparatorR1) { auto a = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 7e92439c494..2f46ee0be21 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -43,7 +43,7 @@ class LLVMCompilerTest : public ::testing::Test { ~LLVMCompilerTest() override {} protected: - using Platform = ::perftools::gputools::Platform; + using Platform = se::Platform; explicit LLVMCompilerTest(string platform_name) : platform_name_(std::move(platform_name)) {} @@ -95,7 +95,7 @@ class LLVMCompilerTest : public ::testing::Test { modules.push_back(hlo_module->Clone()); modules.push_back(std::move(hlo_module)); - std::vector> executors; + std::vector> executors; executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); diff --git a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc index 3023df47cda..2c45f19c090 100644 --- a/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc +++ b/tensorflow/compiler/xla/tests/llvm_irgen_test_base.cc @@ -62,8 +62,8 @@ void LLVMIRGenTestBase::CompileAheadOfTimeAndVerifyIr( std::unique_ptr hlo_module, const AotCompilationOptions& options, const string& pattern, bool match_optimized_ir) { SetIrHook(match_optimized_ir); - ASSERT_TRUE( - CompileToAotCompilationResult(std::move(hlo_module), options).ok()); + TF_ASSERT_OK( + CompileToAotCompilationResult(std::move(hlo_module), options).status()); ResetIrHook(); StatusOr filecheck_result = RunFileCheck(ir_, pattern); diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 3d30ceeaf1b..f21f83992ff 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -15,9 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" @@ -25,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +37,7 @@ class LocalClientAllocationTest : public LocalClientTestBase { }; XLA_TEST_F(LocalClientAllocationTest, AddVectors) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -53,7 +53,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { // deallocation happen on the right allocator. ExecutableRunOptions options; options.set_allocator(allocator); - std::unique_ptr result = + tensorflow::gtl::optional result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), options); @@ -66,14 +66,14 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { // Deallocate result and verify that deallocate was called once. int64 deallocation_count_before = allocator_->deallocation_count(); - result = nullptr; + result.reset(); EXPECT_EQ(deallocation_count_before + 1, allocator_->deallocation_count()); } XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { // Run a computation on every device on the system. Verify that allocation // occurs on the proper device. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0.0f, 1.0f, 2.0f}); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -92,7 +92,7 @@ XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) { computation, {}, ExecutableBuildOptions().set_device_ordinal(d), ExecutableRunOptions().set_device_ordinal(d).set_allocator(allocator)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); // At least one allocation should have been performed when executing the // computation. diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 3704ddd8010..a366afe8262 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -21,7 +21,8 @@ limitations under the License. #include "llvm/ADT/Triple.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/types.h" @@ -29,27 +30,31 @@ limitations under the License. #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +namespace { + using xla::string; -xla::Computation Doubler(xla::Client* client) { - xla::ComputationBuilder builder(client, "doubler"); +xla::XlaComputation Doubler() { + xla::XlaBuilder builder("doubler"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto x = builder.Parameter(0, r0f32, "x"); builder.Mul(x, builder.ConstantR0(2.0)); return std::move(builder.Build().ValueOrDie()); } +} // namespace + int main(int argc, char** argv) { tensorflow::port::InitMain(argv[0], &argc, &argv); auto client = xla::ClientLibrary::GetOrCreateCompileOnlyClient().ValueOrDie(); - xla::ComputationBuilder builder(client, "aot_test_helper"); + xla::XlaBuilder builder("aot_test_helper"); auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); auto opaque_param = builder.Parameter(0, opaque_shape, "x"); auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {}); auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32); - builder.Call(Doubler(client), {sum}); + builder.Call(Doubler(), {sum}); if (argc != 2) { LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU"; @@ -71,8 +76,8 @@ int main(int argc, char** argv) { llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string)); - xla::Computation computation = builder.Build().ConsumeValueOrDie(); - xla::CompileOnlyClient::AotComputationInstance instance{ + xla::XlaComputation computation = builder.Build().ConsumeValueOrDie(); + xla::CompileOnlyClient::AotXlaComputationInstance instance{ &computation, /*argument_layouts=*/{&opaque_shape}, &r0f32}; xla::cpu::CpuAotCompilationOptions options( diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 2462ea39f91..96858c00d6b 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -18,9 +18,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -43,8 +42,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -56,61 +53,57 @@ class LocalClientExecuteTest : public LocalClientTestBase { }; XLA_TEST_F(LocalClientExecuteTest, Constant) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto y = builder.ConstantR0(123.0f); - std::unique_ptr result = + ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - - LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(*result), + LiteralTestUtil::ExpectR0Near(123.f, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddScalars) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto y = builder.ConstantR0(123.0f); builder.Add(x, y); auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); - std::unique_ptr result = - ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_value.get()}); - - LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(*result), + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_value}); + LiteralTestUtil::ExpectR0Near(165.f, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x"); auto y = builder.ConstantR1({}); builder.Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); - std::unique_ptr result = - ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); - - LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(*result), + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); + LiteralTestUtil::ExpectR1Near({}, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectors) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); - std::unique_ptr result = - ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); - + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&x_array}); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -118,18 +111,17 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; - std::unique_ptr result = ExecuteLocallyOrDie( - builder.Build().ValueOrDie(), {x_array.get()}, - DefaultExecutableBuildOptions(), + ScopedShapedBuffer result = ExecuteLocallyOrDie( + builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions().set_execution_profile(&profile)); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); EXPECT_GT(profile.compute_and_transfer_time_ns(), 0); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Add(x, y); @@ -138,31 +130,31 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(x_array.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(y_array.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); - std::unique_ptr result_colmaj = - ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + ScopedShapedBuffer result_colmaj = + ExecuteLocallyOrDie(computation, {&x_array, &y_array}); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_colmaj), + *ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with the parameter values in a different order. - std::unique_ptr result_param_swap = - ExecuteLocallyOrDie(computation, {y_array.get(), x_array.get()}); + ScopedShapedBuffer result_param_swap = + ExecuteLocallyOrDie(computation, {&y_array, &x_array}); LiteralTestUtil::ExpectR2Near( {{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_param_swap), error_spec_); + *ShapedBufferToLiteral(result_param_swap), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Add(x, y); @@ -174,32 +166,32 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. - std::unique_ptr result_colmaj = ExecuteLocallyOrDie( - computation, {x_array.get(), y_array.get()}, + ScopedShapedBuffer result_colmaj = ExecuteLocallyOrDie( + computation, {&x_array, &y_array}, DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_colmaj), + *ShapedBufferToLiteral(result_colmaj), error_spec_); // Run with row-major result layout. - std::unique_ptr result_rowmaj = ExecuteLocallyOrDie( - computation, {x_array.get(), y_array.get()}, + ScopedShapedBuffer result_rowmaj = ExecuteLocallyOrDie( + computation, {&x_array, &y_array}, DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->on_device_shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj.on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, - *ShapedBufferToLiteral(*result_rowmaj), + *ShapedBufferToLiteral(result_rowmaj), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, TupleResult) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Tuple({x, y, x}); @@ -210,24 +202,24 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); - EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {1})); + LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); auto inner_tuple = builder.Tuple({x, y, x}); @@ -239,29 +231,29 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(computation, {&x_array, &y_array}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 0})); + LiteralSlice(*result_literal, {0, 0})); LiteralTestUtil::ExpectR2Equal( {{10.0f, 20.0f}, {30.0f, 40.0f}}, - LiteralView::Create(*result_literal, {0, 1})); + LiteralSlice(*result_literal, {0, 1})); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f}, {3.0f, 4.0f}}, - LiteralView::Create(*result_literal, {0, 2})); + LiteralSlice(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { // Verify setting the result layout of a computation with a tuple output. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Tuple({x, y}); @@ -276,15 +268,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0})}); options.set_result_layout(shape_with_layout); - std::unique_ptr result = ExecuteLocallyOrDie( - builder.Build().ValueOrDie(), {array.get(), array.get()}, options, - DefaultExecutableRunOptions()); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {&array, &array}, + options, DefaultExecutableRunOptions()); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -298,7 +290,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { // Computation adds the respective array and vector elements from each tuple // argument and returns the results as a tuple. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, tuple_shape0, "x"); auto y = builder.Parameter(1, tuple_shape1, "y"); auto x_0 = builder.GetTupleElement(x, 0); @@ -320,18 +312,18 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { auto x_buffer = LiteralToShapedBuffer(*x_literal); auto y_buffer = LiteralToShapedBuffer(*y_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); + ScopedShapedBuffer result = + ExecuteLocallyOrDie(computation, {&x_buffer, &y_buffer}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result.on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result.on_host_shape())); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( {{56.0f, 46.0f}, {36.0f, 26.0f}}, - LiteralView::Create(*result_literal, {0})); + LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); + {40.0f, 71.0f, 117.0f}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -345,7 +337,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { // Computation negates the array element and sums the two vector elements in // the nested tuple. The resulting array and vector are returned as a tuple. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, nested_tuple_shape, "param"); auto inner_tuple = builder.GetTupleElement(param, 0); auto inner_array = builder.GetTupleElement(inner_tuple, 0); @@ -365,14 +357,13 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { Literal::CreateR1({222.0, -2.0, 10.0}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR2Equal( - {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + {{-1.0, -2.0}, {-3.0, -4}}, LiteralSlice(*result_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); + {264.0, 73.0, 133.0}, LiteralSlice(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -384,7 +375,7 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { const Shape tuple_shape = ShapeUtil::MakeTupleShape({array_shape, array_shape}); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, tuple_shape, "param"); auto element_0 = builder.GetTupleElement(param, 0); auto element_1 = builder.GetTupleElement(param, 1); @@ -396,22 +387,20 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { Literal::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result_0 = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - std::unique_ptr result_0_literal = ShapedBufferToLiteral(*result_0); + ScopedShapedBuffer result_0 = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_0_literal = ShapedBufferToLiteral(result_0); LiteralTestUtil::ExpectR2Equal( {{-1.0, -2.0}, {-3.0, -4.0}}, - LiteralView::Create(*result_0_literal, {0})); + LiteralSlice(*result_0_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); + {{22.0, 6.0}, {8.0, 10}}, LiteralSlice(*result_0_literal, {1})); - std::unique_ptr result_1 = - ExecuteLocallyOrDie(computation, {result_0.get()}); - std::unique_ptr result_1_literal = ShapedBufferToLiteral(*result_1); + ScopedShapedBuffer result_1 = ExecuteLocallyOrDie(computation, {&result_0}); + std::unique_ptr result_1_literal = ShapedBufferToLiteral(result_1); LiteralTestUtil::ExpectR2Equal( - {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + {{1.0, 2.0}, {3.0, 4.0}}, LiteralSlice(*result_1_literal, {0})); LiteralTestUtil::ExpectR2Equal( - {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); + {{44.0, 12.0}, {16.0, 20}}, LiteralSlice(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -430,11 +419,11 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { std::vector element_shapes(kElementCount, element_shape); const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, tuple_shape, "param"); // Add each element's tuple index value to every element. - std::vector result_elements; + std::vector result_elements; for (int i = 0; i < kElementCount; ++i) { auto element = builder.GetTupleElement(param, i); result_elements.push_back( @@ -453,21 +442,17 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { Literal::MakeTupleOwned(std::move(arg_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + {2.0f * i, 0.0f}, LiteralSlice(*result_literal, {i}), error_spec_); } } -// TODO(b/66968986): Test times out on CPU parallel backend. Disabled -// 2017-09-26. -XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { +XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) { // Construct and run a computation which takes a two-level nested tuple // parameter with a large fanout. const int kFanout = 40; @@ -479,15 +464,15 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { std::vector inner_tuple_shapes(kFanout, inner_tuple_shape); const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto param = builder.Parameter(0, tuple_shape, "param"); // The computation increments each leaf value by an amount equal to the leaf's // ordinal position in a traversal of the tuple. - std::vector result_elements; + std::vector result_elements; for (int i = 0; i < kFanout; ++i) { auto outer_element = builder.GetTupleElement(param, i); - std::vector inner_result_elements; + std::vector inner_result_elements; for (int j = 0; j < kFanout; ++j) { auto inner_element = builder.GetTupleElement(outer_element, j); inner_result_elements.push_back(builder.Add( @@ -511,14 +496,13 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements)); auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + i + j + i * kFanout + j, LiteralSlice(*result_literal, {i, j}), error_spec_); } } @@ -535,7 +519,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { shape = ShapeUtil::MakeTupleShape({shape}); } - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto element = builder.Parameter(0, shape, "param"); for (int i = 0; i < kTupleDepth; ++i) { element = builder.GetTupleElement(element, 0); @@ -556,21 +540,20 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { } auto arg_buffer = LiteralToShapedBuffer(*arg_literal); - std::unique_ptr result = - ExecuteLocallyOrDie(computation, {arg_buffer.get()}); - std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + ScopedShapedBuffer result = ExecuteLocallyOrDie(computation, {&arg_buffer}); + std::unique_ptr result_literal = ShapedBufferToLiteral(result); ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { index.push_back(0); } LiteralTestUtil::ExpectR0Equal( - 165.0, LiteralView::Create(*result_literal, index)); + 165.0, LiteralSlice(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { // Test passing in an invalid number of arguments. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); builder.Add(x, y); @@ -578,7 +561,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = - ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); + ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), @@ -587,14 +570,14 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { // Test passing in an argument with the wrong shape. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); builder.Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = - ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); + ExecuteLocally(builder.Build().ValueOrDie(), {&x_array}); EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), @@ -604,14 +587,14 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { // Test passing in an invalid result layout parameter. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); builder.Neg(x); auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( - builder.Build().ValueOrDie(), {x_array.get()}, + builder.Build().ValueOrDie(), {&x_array}, DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{1, 2, 3, 4}, @@ -627,7 +610,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { // Try to run a trivial computation on every device on the system. If a // specific device is not supported, check that the right error is returned. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto computation = builder.Build().ConsumeValueOrDie(); for (int d = 0; d < local_client_->device_count(); ++d) { @@ -644,9 +627,9 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { computation, {}, DefaultExecutableBuildOptions().set_device_ordinal(d), DefaultExecutableRunOptions().set_device_ordinal(d)); - EXPECT_EQ(d, result->device_ordinal()); + EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(*result)); + *ShapedBufferToLiteral(result)); } } } @@ -654,7 +637,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) { XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { // Try running computations on devices with device ordinal values which do not // exist. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto computation = builder.Build().ConsumeValueOrDie(); @@ -671,7 +654,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) { XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { // Run a computation on a specific stream on each device on the system. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto computation = builder.Build().ConsumeValueOrDie(); @@ -689,9 +672,9 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnStream) { DefaultExecutableRunOptions().set_stream(&stream)); // As a check to verify that the computation ran of the device associated // with the stream. This is a weak check, but stronger verification is hard. - EXPECT_EQ(d, result->device_ordinal()); + EXPECT_EQ(d, result.device_ordinal()); LiteralTestUtil::ExpectR0Equal(42.0f, - *ShapedBufferToLiteral(*result)); + *ShapedBufferToLiteral(result)); } } @@ -707,7 +690,7 @@ XLA_TEST_F(LocalClientExecuteTest, se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie()); wrong_stream.Init(); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(), @@ -724,7 +707,7 @@ XLA_TEST_F(LocalClientExecuteTest, .ValueOrDie(); TestAllocator allocator(wrong_platform); - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto y = builder.ConstantR0(123.0f); auto execute_status = ExecuteLocally( @@ -737,7 +720,7 @@ XLA_TEST_F(LocalClientExecuteTest, XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { // Try to run a computation on a stream that has not been initialized. - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(42.0f); LOG(INFO) << "default device = " << local_client_->default_device_ordinal(); @@ -757,7 +740,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) { } XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); std::initializer_list vec1 = {1.f, 2.f, 3.f}; std::initializer_list vec2 = {2.f, 4.f, 6.f}; @@ -767,17 +750,17 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { {builder.ConstantR1(vec2), builder.ConstantR1(vec1)}); builder.Select(builder.ConstantR0(false), tuple12, tuple21); - std::unique_ptr result = + ScopedShapedBuffer result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); - std::unique_ptr tuple_literal = ShapedBufferToLiteral(*result); + std::unique_ptr tuple_literal = ShapedBufferToLiteral(result); LiteralTestUtil::ExpectR1Equal( - {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + {2.0f, 4.0f, 6.0f}, LiteralSlice(*tuple_literal, {0})); LiteralTestUtil::ExpectR1Equal( - {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); + {1.0f, 2.0f, 3.0f}, LiteralSlice(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); @@ -793,12 +776,12 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); - std::unique_ptr result = - executable->Run({x_array.get()}, DefaultExecutableRunOptions()) + ScopedShapedBuffer result = + executable->Run({&x_array}, DefaultExecutableRunOptions()) .ConsumeValueOrDie(); LiteralTestUtil::ExpectR1Near( - {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); + {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(result), error_spec_); } XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { @@ -811,7 +794,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { literal, local_client_->default_device_ordinal(), allocator_)); TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, - local_client_->ShapedBufferToLiteral(*shaped_buffer)); + local_client_->ShapedBufferToLiteral(shaped_buffer)); EXPECT_EQ(literal, *transferred_literal); }; @@ -851,7 +834,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { literal, local_client_->default_device_ordinal(), allocator_)); TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, - local_client_->ShapedBufferToLiteral(*shaped_buffer)); + local_client_->ShapedBufferToLiteral(shaped_buffer)); EXPECT_EQ(literal, *transferred_literal); }; @@ -867,9 +850,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { // TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. // 2017-10-18. -XLA_TEST_F(LocalClientExecuteTest, - DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(InfeedOutfeedTest))) { - ComputationBuilder builder(local_client_, TestName()); +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) { + XlaBuilder builder(TestName()); const Shape shape = ShapeUtil::MakeShape(F32, {3}); auto in = builder.Infeed(shape); auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); @@ -907,7 +889,7 @@ void BM_LocalClientOverhead(int num_iters) { int device_ordinal = client->default_device_ordinal(); // Use a tiny add operation as the computation. - ComputationBuilder builder(client, "Add"); + XlaBuilder builder("Add"); auto shape = ShapeUtil::MakeShape(F32, {2, 3}); auto x = builder.Parameter(0, shape, "x"); builder.Add(x, x); @@ -919,12 +901,12 @@ void BM_LocalClientOverhead(int num_iters) { .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, *buffer)); + executors[device_ordinal], *literal, buffer)); const int kWarmups = 2; auto executable_status = client->Compile( - computation, {&buffer->on_host_shape()}, ExecutableBuildOptions()); + computation, {&buffer.on_host_shape()}, ExecutableBuildOptions()); ASSERT_IS_OK(executable_status); std::unique_ptr executable = executable_status.ConsumeValueOrDie(); @@ -936,13 +918,13 @@ void BM_LocalClientOverhead(int num_iters) { run_options.set_allocator(&allocator).set_stream(&stream); for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({buffer.get()}, run_options); + auto result = executable->Run({&buffer}, run_options); ASSERT_IS_OK(result); } tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({buffer.get()}, run_options); + auto result = executable->Run({&buffer}, run_options); ASSERT_IS_OK(result); } } diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 96b976d25d7..88797a7d0a7 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -35,19 +35,20 @@ namespace xla { /* static */ TestAllocator* LocalClientTestBase::allocator_; -StatusOr TestAllocator::Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) { +StatusOr TestAllocator::Allocate(int device_ordinal, + uint64 size, + bool retry_on_failure) { VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")"; { tensorflow::mutex_lock lock(count_mutex_); allocation_count_++; device_allocation_count_[device_ordinal]++; } - return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size); + return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size, + retry_on_failure); } -tensorflow::Status TestAllocator::Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) { +Status TestAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase mem) { VLOG(2) << "Deallocate(" << device_ordinal << ")"; { tensorflow::mutex_lock lock(count_mutex_); @@ -88,7 +89,7 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { } /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator( - perftools::gputools::Platform* platform) { + se::Platform* platform) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); tensorflow::mutex_lock lock(mu); @@ -115,8 +116,7 @@ struct LocalClientTestBase::EigenThreadPoolWrapper { std::unique_ptr device; }; -LocalClientTestBase::LocalClientTestBase( - perftools::gputools::Platform* platform) +LocalClientTestBase::LocalClientTestBase(se::Platform* platform) : local_client_( ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()), thread_pool_wrapper_(new EigenThreadPoolWrapper()) { @@ -128,7 +128,7 @@ LocalClientTestBase::LocalClientTestBase( LocalClientTestBase::~LocalClientTestBase() {} -std::unique_ptr LocalClientTestBase::LiteralToShapedBuffer( +ScopedShapedBuffer LocalClientTestBase::LiteralToShapedBuffer( const Literal& literal) { return local_client_ ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal()) @@ -148,23 +148,21 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const { ExecutableRunOptions run_options; - run_options.set_inter_op_thread_pool( - local_client_->backend().inter_op_thread_pool()); run_options.set_intra_op_thread_pool(thread_pool_wrapper_->device.get()); run_options.set_allocator(GetOrCreateAllocator(local_client_->platform())); return run_options; } -std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( - const Computation& computation, +ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()) .ConsumeValueOrDie(); } -std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( - const Computation& computation, +ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { @@ -172,17 +170,15 @@ std::unique_ptr LocalClientTestBase::ExecuteLocallyOrDie( .ConsumeValueOrDie(); } -StatusOr> -LocalClientTestBase::ExecuteLocally( - const Computation& computation, +StatusOr LocalClientTestBase::ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments) { return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(), DefaultExecutableRunOptions()); } -StatusOr> -LocalClientTestBase::ExecuteLocally( - const Computation& computation, +StatusOr LocalClientTestBase::ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options) { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index f0c73f04f6e..258226523d8 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -21,8 +21,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -41,15 +41,14 @@ namespace xla { class TestAllocator : public StreamExecutorMemoryAllocator { public: - explicit TestAllocator(perftools::gputools::Platform* platform) + explicit TestAllocator(se::Platform* platform) : StreamExecutorMemoryAllocator( platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) { } - StatusOr Allocate( - int device_ordinal, uint64 size, bool retry_on_failure) override; - tensorflow::Status Deallocate( - int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override; + StatusOr Allocate(int device_ordinal, uint64 size, + bool retry_on_failure) override; + Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override; // Return the number of allocations that have been performed. int64 allocation_count() const; @@ -75,18 +74,15 @@ class TestAllocator : public StreamExecutorMemoryAllocator { class LocalClientTestBase : public ::testing::Test { protected: struct EigenThreadPoolWrapper; - explicit LocalClientTestBase( - perftools::gputools::Platform* platform = nullptr); + explicit LocalClientTestBase(se::Platform* platform = nullptr); virtual ~LocalClientTestBase(); - static TestAllocator* GetOrCreateAllocator( - perftools::gputools::Platform* platform); + static TestAllocator* GetOrCreateAllocator(se::Platform* platform); // Copy the given literal onto the default device and return a // ScopedShapedBuffer. Convenience wrapper around // LocalClient::LiteralToShapedBuffer. - std::unique_ptr LiteralToShapedBuffer( - const Literal& literal); + ScopedShapedBuffer LiteralToShapedBuffer(const Literal& literal); // Construct and return a literal containing the array represented by // shaped_buffer. @@ -95,20 +91,20 @@ class LocalClientTestBase : public ::testing::Test { // Execute the given computation on the local client. With and without // options. - StatusOr> ExecuteLocally( - const Computation& computation, + StatusOr ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments); - StatusOr> ExecuteLocally( - const Computation& computation, + StatusOr ExecuteLocally( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); - std::unique_ptr ExecuteLocallyOrDie( - const Computation& computation, + ScopedShapedBuffer ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments); - std::unique_ptr ExecuteLocallyOrDie( - const Computation& computation, + ScopedShapedBuffer ExecuteLocallyOrDie( + const XlaComputation& computation, tensorflow::gtl::ArraySlice arguments, const ExecutableBuildOptions& build_options, const ExecutableRunOptions& run_options); @@ -128,7 +124,7 @@ class LocalClientTestBase : public ::testing::Test { // of the process. So make the allocator static. static TestAllocator* allocator_; - perftools::gputools::StreamExecutor* stream_executor_; + se::StreamExecutor* stream_executor_; TransferManager* transfer_manager_; LocalClient* local_client_; diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index 174d433a9e1..c0c02e584c2 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -29,7 +29,7 @@ namespace { class LogTest : public ClientLibraryTestBase {}; XLA_TEST_F(LogTest, LogZeroValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR3FromArray3D(Array3D(3, 0, 0)); builder.Log(x); @@ -41,7 +41,7 @@ TEST_F(LogTest, LogTenValues) { std::vector input = {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0}; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1(input); builder.Log(x); diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index efe6cc67872..7df45bebebd 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -41,7 +39,7 @@ namespace { class MapTest : public ClientLibraryTestBase { public: - explicit MapTest(perftools::gputools::Platform* platform = nullptr) + explicit MapTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -341,48 +339,6 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); } -TEST_F(MapTest, VersionedEmbeddedComputation) { - // Build a computation X, use it in a map, then add an additional operation to - // computation X and use it again in a different map. Verify that the proper - // versions of computation X are used in each of the maps. - - // Create a (embedded) computation which adds one to its parameter argument. - ComputationBuilder embedded_builder(client_, "EmbeddedComputation"); - auto param_0 = - embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - auto constant_one = embedded_builder.ConstantR0(1.0); - auto adder_to_one = embedded_builder.Add(param_0, constant_one); - auto computation_status = embedded_builder.Build(); - ASSERT_IS_OK(computation_status.status()); - auto embedded_computation = computation_status.ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto constant_vector = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0}); - - // Add another Add(1) operation to the existing embedded computation. This - // requires using the stub interface because the ComputationBuilder does not - // allow modification to the XlaComputation objects after they have been - // built. - BinaryOpRequest request; - request.set_binop(BINOP_ADD); - *request.mutable_lhs() = adder_to_one; - *request.mutable_rhs() = constant_one; - OpRequest op_request; - *op_request.mutable_computation() = embedded_computation.handle(); - *op_request.mutable_binary_op_request() = request; - OpResponse response; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - ASSERT_TRUE(s.ok()); - - auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0}); - - // The original vector has Add(1) applied to it with a map, followed by - // Add(1+1) resulting in a net Add(3). - ComputeAndCompareR1(&builder, {4.0, 5.0, 6.0, 7.0}, {}, - ErrorSpec(0.01f)); -} - TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index c42f71388ba..27fd36e06ac 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -51,16 +51,11 @@ class MatOpsSimpleTest : public ClientLibraryTestBase {}; template class MatOpsSimpleTest_F16F32 : public MatOpsSimpleTest {}; -// TODO(bixia): This test for F16 failed on GPU 02-25-2018. -#ifdef XLA_TEST_BACKEND_GPU -TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, ::testing::Types); -#else TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32); -#endif XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { using T = TypeParam; - ComputationBuilder builder(this->client_, "exp_2x2"); + XlaBuilder builder("exp_2x2"); auto data = builder.ConstantR2FromArray2D({ {1.0f, 0.0f}, // row 0 {-1.0f, 0.5f}, // row 1 @@ -71,16 +66,15 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) { Literal::CreateR2FromArray2D({{2.71828f, 1.00000f}, // row 0 {0.36788f, 1.64872f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { using T = TypeParam; - Computation add_half; + XlaComputation add_half; { // add_half(x) = x + 0.5 - ComputationBuilder builder(this->client_, "add_half"); + XlaBuilder builder("add_half"); auto x_value = builder.Parameter(0, ShapeUtil::MakeShapeWithType({}), "x_value"); auto half = builder.ConstantR0(static_cast(0.5)); @@ -90,7 +84,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { add_half = computation_status.ConsumeValueOrDie(); } - ComputationBuilder builder(this->client_, "map_2x2"); + XlaBuilder builder("map_2x2"); auto data = builder.ConstantR2FromArray2D({ {1.0f, 0.0f}, // row 0 {-1.0f, 0.5f}, // row 1 @@ -100,13 +94,12 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) { std::unique_ptr expected = Literal::CreateR2FromArray2D({{1.5f, 0.5f}, // row 0 {-0.5f, 1.0f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-5)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { using T = TypeParam; - ComputationBuilder builder(this->client_, "max_2x2"); + XlaBuilder builder("max_2x2"); auto lhs = builder.ConstantR2FromArray2D({ {7.0f, 2.0f}, // row 0 {3.0f, -4.0f}, // row 1 @@ -120,8 +113,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) { std::unique_ptr expected = Literal::CreateR2FromArray2D({{7.0f, 6.0f}, // row 0 {3.0f, -4.0f}}); // row 1 - this->template ComputeAndCompareLiteral(&builder, *expected, {}, - ErrorSpec(1e-6)); + this->ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } struct TestLinspaceMaxParam { @@ -143,8 +135,7 @@ class TestLinspaceMaxParametric MakeLinspaceArray2D(from, to, rows, cols); auto arhs = MakeUnique>(rows, cols, static_cast(1.0f)); - ComputationBuilder builder( - client_, + XlaBuilder builder( tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols)); auto lhs = builder.ConstantR2FromArray2D(*alhs); auto rhs = builder.ConstantR2FromArray2D(*arhs); @@ -171,11 +162,8 @@ string PrintTestLinspaceMaxParam( } #ifndef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 -// TODO(bixia): This test failed on GPU 02-25-2018 -#ifdef XLA_TEST_BACKEND_CPU XLA_TEST_P(TestLinspaceMaxParametric, TestF16) { TestImpl(); } #endif -#endif XLA_TEST_P(TestLinspaceMaxParametric, TestF32) { TestImpl(); } INSTANTIATE_TEST_CASE_P( @@ -219,7 +207,7 @@ class MatOpsDotAddTest client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); auto lhs_mat_arg = lhs_arg; if (transpose) { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 11c0bf7a5a5..0791a71aacf 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -32,7 +32,7 @@ namespace { class SliceTest : public ClientLibraryTestBase {}; XLA_TEST_F(SliceTest, Slice2D) { - ComputationBuilder builder(client_, "slice_2d"); + XlaBuilder builder("slice_2d"); auto original = builder.ConstantR2( {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}}); builder.Slice(original, {2, 1}, {4, 3}, {1, 1}); @@ -42,7 +42,7 @@ XLA_TEST_F(SliceTest, Slice2D) { } XLA_TEST_F(SliceTest, Slice3D) { - ComputationBuilder builder(client_, "slice_3d"); + XlaBuilder builder("slice_3d"); Array3D array_3d( {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}); auto original = builder.ConstantR3FromArray3D(array_3d); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 0a603f4954b..ec7ca20bdf2 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -108,7 +107,7 @@ class MultiOutputFusionTest : public HloTestBase { expect.PopulateWithValue(size * 1.5f * 3.5f); auto actual = ExecuteAndTransfer( std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } void RunTest1D(bool manual_fusion, int size) { @@ -168,7 +167,7 @@ class MultiOutputFusionTest : public HloTestBase { Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); - LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); + EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_)); } }; @@ -211,5 +210,68 @@ XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); } +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFusion) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[4] parameter(0) + multiply = f32[4] multiply(p, p) + less-than = pred[4] less-than(p, multiply) + ROOT tuple = (pred[4], f32[4]) tuple(less-than, multiply) + } + + ENTRY PredFloatMOF { + p0 = f32[4] parameter(0) + fusion = (pred[4], f32[4]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[4] get-tuple-element(fusion), index=0 + gte1 = f32[4] get-tuple-element(fusion), index=1 + const = f32[4] constant({0, 0, 0, 0}) + ROOT select = f32[4] select(gte0, gte1, const) + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0, -1.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0, 1.0}))); +} + +XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { + const char* testcase = R"( + HloModule m + + fused_computation { + p = f32[] parameter(0) + multiply = f32[] multiply(p, p) + less-than = pred[] less-than(p, multiply) + ROOT tuple = (pred[], f32[]) tuple(less-than, multiply) + } + + map_computation { + p0 = f32[] parameter(0) + fusion = (pred[], f32[]) fusion(p0), kind=kLoop, calls=fused_computation + gte0 = pred[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + const = f32[] constant(0) + ROOT select = f32[] select(gte0, gte1, const) + } + + ENTRY MapMOF { + p1 = f32[3] parameter(0) + ROOT map = f32[3] map(p1), to_apply=map_computation + })"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR1({1.0, 2.0, 3.0}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index bb7e800df84..838f1b4e2f0 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -41,7 +41,7 @@ namespace { class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -53,7 +53,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { } XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -65,7 +65,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { } XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = @@ -78,7 +78,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { } XLA_TEST_F(ParamsTest, ConstantR1U8Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); string str("hello world"); std::unique_ptr param0_literal = Literal::CreateR1U8(str); std::unique_ptr param0_data = @@ -91,7 +91,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { } XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = @@ -104,7 +104,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { } XLA_TEST_F(ParamsTest, ConstantR2F32Param) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = @@ -119,7 +119,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { } XLA_TEST_F(ParamsTest, TwoParameters) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = @@ -156,19 +156,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) { std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2"); - auto computation = builder.Build().ConsumeValueOrDie(); + auto computation_status = builder.Build(); - auto execute_status = client_->Execute(computation, {data.get(), data.get()}, - /*execution_options=*/nullptr, - /*execution_profile=*/nullptr); - ASSERT_EQ(execute_status.status().code(), - tensorflow::error::FAILED_PRECONDITION); + ASSERT_NE(computation_status.status(), Status::OK()); } XLA_TEST_F(ParamsTest, UnusedParameter) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = @@ -188,7 +184,7 @@ XLA_TEST_F(ParamsTest, UnusedParameter) { XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // Build a computation with a couple unused parameters which are used in an // unused expression. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = @@ -214,12 +210,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { } XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int size = 8 * 128 * 2; std::vector init_value = {{0, 1}}; init_value.resize(size); - ComputationDataHandle sum_handle = builder.ConstantR1(init_value); + XlaOp sum_handle = builder.ConstantR1(init_value); std::vector sum = {{0, 1}}; sum.resize(size); @@ -237,8 +233,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); sum_handle = builder.Add(sum_handle, param); } @@ -262,10 +257,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { // compilation. XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector> param_data_owner; - ComputationDataHandle sum_handle = builder.ConstantR0(0.0f); + XlaOp sum_handle = builder.ConstantR0(0.0f); float target = 0.0; constexpr int kParamCount = 3000; for (int i = 0; i < kParamCount; ++i) { @@ -273,8 +268,7 @@ XLA_TEST_F(ParamsTest, std::unique_ptr literal = Literal::CreateR0(i); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); sum_handle = builder.Add(sum_handle, param); } @@ -294,25 +288,24 @@ XLA_TEST_F(ParamsTest, // compilation. XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ThreeThousandParametersAndOutputElements))) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector> param_data_owner; - ComputationDataHandle sum_handle = builder.ConstantR1({0, 0}); + XlaOp sum_handle = builder.ConstantR1({0, 0}); int32 target = 0; constexpr int kParamCount = 3000; - std::vector params; + std::vector params; for (int i = 0; i < kParamCount; ++i) { target += i; std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); params.push_back(param); sum_handle = builder.Add(sum_handle, param); } - std::vector outputs; + std::vector outputs; for (int i = 0; i < kParamCount; ++i) { outputs.push_back(builder.Add(params[i], sum_handle)); } @@ -353,18 +346,17 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( // 2017-12-12. XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector> param_data_owner; constexpr int kParamCount = 1900; - std::vector params; + std::vector params; std::vector parameter_shapes; for (int i = 0; i < kParamCount; ++i) { std::unique_ptr literal = Literal::CreateR1({i, i}); param_data_owner.push_back( std::move(client_->TransferToServer(*literal)).ValueOrDie()); - ComputationDataHandle param = - builder.Parameter(i, literal->shape(), "param"); + XlaOp param = builder.Parameter(i, literal->shape(), "param"); params.push_back(param); parameter_shapes.push_back(literal->shape()); } @@ -374,7 +366,7 @@ XLA_TEST_F(ParamsTest, std::unique_ptr bool_literal = Literal::CreateR0(false); param_data_owner.push_back( std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); - ComputationDataHandle bool_param = + XlaOp bool_param = builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); params.push_back(bool_param); parameter_shapes.push_back(bool_literal->shape()); @@ -383,9 +375,9 @@ XLA_TEST_F(ParamsTest, // Create a computation for the condition: while(bool_param). Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client_, "condition"); + XlaBuilder builder("condition"); auto condition_parameter = builder.Parameter(0, while_shape, "condition_parameter"); builder.GetTupleElement(condition_parameter, kParamCount); @@ -394,11 +386,11 @@ XLA_TEST_F(ParamsTest, // Create a computation for the body. // Add {1, 1} to the each tuple element. - Computation body; + XlaComputation body; { - ComputationBuilder builder(client_, "body"); + XlaBuilder builder("body"); auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); - std::vector updates; + std::vector updates; for (int i = 0; i < kParamCount; ++i) { auto add = builder.Add(builder.GetTupleElement(body_parameter, i), builder.ConstantR1({1, 1})); @@ -413,7 +405,7 @@ XLA_TEST_F(ParamsTest, auto loop = builder.While(condition, body, init); - std::vector outputs; + std::vector outputs; for (int i = 0; i < kParamCount; ++i) { outputs.push_back(builder.GetTupleElement(loop, i)); } @@ -437,7 +429,7 @@ XLA_TEST_F(ParamsTest, #endif XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3}); @@ -464,7 +456,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, literal->shape(), "input"); std::unique_ptr data = @@ -476,7 +468,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { std::unique_ptr literal = Literal::CreateR2WithLayout( {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Parameter(0, literal->shape(), "input"); std::unique_ptr data = @@ -501,7 +493,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { ASSERT_EQ(2, literal->Get({0, 1})); } // Use the original shape in building the computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.Parameter(0, original, "input"); // Use the slice operator to get an off-diagonal element. builder.Slice(input, {0, 1}, {1, 2}, {1, 1}); diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 10e44b274a8..77159efb26f 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -29,63 +29,62 @@ namespace { class PredTest : public ClientLibraryTestBase { protected: - void TestCompare(bool lhs, bool rhs, bool expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + void TestCompare( + bool lhs, bool rhs, bool expected, + XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; TEST_F(PredTest, ConstantR0PredTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(true); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, ConstantR0PredFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR0(false); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, ConstantR0PredCompareEq) { - TestCompare(true, false, false, &ComputationBuilder::Eq); + TestCompare(true, false, false, &XlaBuilder::Eq); } TEST_F(PredTest, ConstantR0PredCompareNe) { - TestCompare(true, false, true, &ComputationBuilder::Ne); + TestCompare(true, false, true, &XlaBuilder::Ne); } TEST_F(PredTest, ConstantR0PredCompareLe) { - TestCompare(true, false, false, &ComputationBuilder::Le); + TestCompare(true, false, false, &XlaBuilder::Le); } TEST_F(PredTest, ConstantR0PredCompareLt) { - TestCompare(true, false, false, &ComputationBuilder::Lt); + TestCompare(true, false, false, &XlaBuilder::Lt); } TEST_F(PredTest, ConstantR0PredCompareGe) { - TestCompare(true, false, true, &ComputationBuilder::Ge); + TestCompare(true, false, true, &XlaBuilder::Ge); } TEST_F(PredTest, ConstantR0PredCompareGt) { - TestCompare(true, false, true, &ComputationBuilder::Gt); + TestCompare(true, false, true, &XlaBuilder::Gt); } TEST_F(PredTest, ConstantR1Pred) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false, false, true}); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } TEST_F(PredTest, ConstantR2Pred) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { @@ -96,28 +95,28 @@ TEST_F(PredTest, ConstantR2Pred) { } TEST_F(PredTest, AnyR1True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({true, false}); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, true, {}); } TEST_F(PredTest, AnyR1False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({false, false}); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR1VacuouslyFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR1({}); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, false, {}); } TEST_F(PredTest, AnyR2True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({ {false, false, false}, {false, false, false}, @@ -128,7 +127,7 @@ TEST_F(PredTest, AnyR2True) { } TEST_F(PredTest, AnyR2False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({ {false, false, false}, {false, false, false}, diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 6aafb9fa6cb..1a2de6937c3 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -52,13 +52,14 @@ class PrngTest : public ClientLibraryTestBase { template std::unique_ptr PrngTest::UniformTest( T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngUniform( builder.ConstantR0(a), builder.ConstantR0(b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); SetSeed(seed); - auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + auto actual = + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); @@ -81,8 +82,7 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( - DISABLED_ON_CPU(ScalarBF16Tests)))) { +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) { for (int64 seed = 0; seed < 100; ++seed) { // The largest negative number smaller than zero in bf16 that's not // denormalized. @@ -105,8 +105,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( } // TODO(b/71543667): Fix Rng ops on LLVM backends. -XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU( - DISABLED_ON_CPU_PARALLEL(ScalarBF16CountTests)))) { +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) { // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, // they should get similar counts. bfloat16 low = static_cast(32.25); @@ -141,13 +140,14 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, int64 seed) { int32 sample_size = range_size * expected_count; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(range_size), ShapeUtil::MakeShape(S32, {sample_size})); SetSeed(seed); - auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + auto actual = + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); std::vector counts(range_size, 0); actual->EachCell([&counts](tensorflow::gtl::ArraySlice, int32 value) { ++counts[value]; }); @@ -182,16 +182,15 @@ XLA_TEST_F(PrngTest, Uniformity256) { XLA_TEST_F(PrngTest, MapUsingRng) { // Build a x -> (x + U[0,1)) computation. - auto build_sum_rng = [this](ComputationBuilder& builder) { + auto build_sum_rng = [this](XlaBuilder& builder) { auto b = builder.CreateSubBuilder("sum_with_rng"); auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input"); - b->Add(x, - b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), - ShapeUtil::MakeShape(F32, {}))); + b->Add(x, b->RngUniform(b->ConstantR0(0), b->ConstantR0(1), + ShapeUtil::MakeShape(F32, {}))); return b->BuildAndNoteError(); }; - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr param0_literal = Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr param0_data, @@ -226,7 +225,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { // Build a U[0,1) computation. auto build_computation = [this]() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(1), ShapeUtil::MakeShape(F32, {10})); @@ -274,32 +273,32 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { &execution_options_)); } - LiteralTestUtil::ExpectEqual(*result1, *result2); - LiteralTestUtil::ExpectEqual(*result1, *result3); - LiteralTestUtil::ExpectNotEqual(*result1, *result4); - LiteralTestUtil::ExpectNotEqual(*result4, *result5); - LiteralTestUtil::ExpectNotEqual(*result5, *result6); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2)); + EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5)); + EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6)); } XLA_TEST_F(PrngTest, TenValuesN01) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), ShapeUtil::MakeShape(F32, {10})); SetSeed(42); - ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); // TODO(b/25995601): Test that resultant values are reasonable } XLA_TEST_F(PrngTest, RngUniformCrash) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // This used to crash XLA during LLVM IR generation for CPUs. auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), builder.ConstantR0(1000 * 1000), ShapeUtil::MakeShape(S32, {})); SetSeed(0); - ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); + ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie(); } } // namespace diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 212512207cf..f95e7564834 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,13 +30,13 @@ namespace { class QueryInferredShapeTest : public ClientLibraryTestBase {}; TEST_F(QueryInferredShapeTest, OnePlusOneShape) { - ComputationBuilder builder(client_, "one_plus_one"); + XlaBuilder builder("one_plus_one"); auto one = builder.ConstantR0(1.0); auto result = builder.Add(one, one); - StatusOr> shape_status = builder.GetShape(result); + StatusOr shape_status = builder.GetShape(result); ASSERT_IS_OK(shape_status.status()); auto shape = shape_status.ConsumeValueOrDie(); - ASSERT_TRUE(ShapeUtil::Equal(*shape, ShapeUtil::MakeShape(F32, {}))); + ASSERT_TRUE(ShapeUtil::Equal(shape, ShapeUtil::MakeShape(F32, {}))); } } // namespace diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 768beec15e7..d671d40456a 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -34,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -52,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -59,10 +58,9 @@ limitations under the License. namespace xla { namespace { -using FuncGeneratorForType = Computation (*)(PrimitiveType, - ComputationBuilder*); +using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*); -using FuncGenerator = Computation (*)(ComputationBuilder*); +using FuncGenerator = XlaComputation (*)(XlaBuilder*); class ReduceTest : public ClientLibraryTestBase { protected: @@ -88,8 +86,8 @@ class ReduceTest : public ClientLibraryTestBase { // Runs an R1 => R0 reduction test with the given number of elements. void RunR1ToR0Test(int64 element_count) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -118,13 +116,13 @@ class ReduceTest : public ClientLibraryTestBase { void RunR1ToR0PredTest(bool and_reduce, tensorflow::gtl::ArraySlice input_data) { const int element_count = input_data.size(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count}); auto input_par = builder.Parameter(0, input_shape, "input"); auto pred_values = builder.Eq(input_par, builder.ConstantR1(element_count, 1)); - ComputationDataHandle init_value; - Computation reduce; + XlaOp init_value; + XlaComputation reduce; if (and_reduce) { init_value = builder.ConstantR0(true); reduce = CreateScalarAndComputation(&builder); @@ -156,13 +154,13 @@ class ReduceTest : public ClientLibraryTestBase { template void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1, int64 major = 0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto input_pred = builder.Eq(input, builder.ConstantR0(1)); - ComputationDataHandle init_value; - Computation reduce_op; + XlaOp init_value; + XlaComputation reduce_op; if (and_reduce) { init_value = builder.ConstantR0(true); reduce_op = CreateScalarAndComputation(&builder); @@ -201,8 +199,8 @@ class ReduceTest : public ClientLibraryTestBase { // Runs an R2 => R0 reduction test with the given number of (rows, cols). void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -229,8 +227,8 @@ class ReduceTest : public ClientLibraryTestBase { // Runs an R2 => R1 reduction test with the given number of (rows, cols). void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -260,7 +258,7 @@ class ReduceTest : public ClientLibraryTestBase { template void ComputeAndCompareGeneric( typename std::enable_if::value, - ComputationBuilder>::type* builder, + XlaBuilder>::type* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { ComputeAndCompareR1(builder, expected, arguments, @@ -270,7 +268,7 @@ class ReduceTest : public ClientLibraryTestBase { template void ComputeAndCompareGeneric( typename std::enable_if::value, - ComputationBuilder>::type* builder, + XlaBuilder>::type* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { ComputeAndCompareR1(builder, expected, arguments); @@ -278,15 +276,15 @@ class ReduceTest : public ClientLibraryTestBase { template void RunVectorizedReduceTestForType( - const std::function& + const std::function& reduction_function_generator, const std::function& reference_reduction_function, const NativeT& initial_value) { const int rows = 64, cols = 128; const int minor = 1, major = 0; - ComputationBuilder builder(client_, TestName()); - Computation reduction_function = reduction_function_generator(&builder); + XlaBuilder builder(TestName()); + XlaComputation reduction_function = reduction_function_generator(&builder); const Shape input_shape = ShapeUtil::MakeShape( xla::primitive_util::NativeToPrimitiveType(), {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); @@ -321,7 +319,7 @@ class ReduceTest : public ClientLibraryTestBase { } void RunVectorizedReduceTest( - const std::function& + const std::function& reduction_function_generator_for_type, const std::function& reference_reduction_function_for_floats, @@ -333,21 +331,21 @@ class ReduceTest : public ClientLibraryTestBase { uint32 unsigned_int_identity) { // Float version RunVectorizedReduceTestForType( - [&](ComputationBuilder* builder) { + [&](XlaBuilder* builder) { return reduction_function_generator_for_type(F32, builder); }, reference_reduction_function_for_floats, floating_point_identity); // Signed int version RunVectorizedReduceTestForType( - [&](ComputationBuilder* builder) { + [&](XlaBuilder* builder) { return reduction_function_generator_for_type(S32, builder); }, reference_reduction_function_for_ints, signed_int_identity); // Unsigned int version RunVectorizedReduceTestForType( - [&](ComputationBuilder* builder) { + [&](XlaBuilder* builder) { return reduction_function_generator_for_type(U32, builder); }, reference_reduction_function_for_uints, unsigned_int_identity); @@ -441,8 +439,8 @@ XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { const int64 rows = 111, cols = 50; - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -472,8 +470,8 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { const int64 rows = 111, cols = 50; - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -521,8 +519,8 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const int64 rows = 111, cols = 50; - ComputationBuilder builder(client_, TestName()); - Computation add_f32 = CreateScalarAddComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder); const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); @@ -568,7 +566,7 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) { // Add-reduces a broadcasted scalar matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); auto broadcasted = builder.Broadcast(scalar, {500, 500}); @@ -580,7 +578,7 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { // Max-reduces a broadcasted scalar matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); auto broadcasted = builder.Broadcast(scalar, {500, 500}); @@ -592,7 +590,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { // Max-reduces a matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); @@ -607,7 +605,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { // Min-reduces matrix among dimension 1 and 0. XLA_TEST_F(ReduceTest, MinReduce2DToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); @@ -622,7 +620,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { } XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto min = CreateScalarMinComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); @@ -635,7 +633,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) { } XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array2D input({{1}, {2}}); auto max = CreateScalarMaxComputation(U32, &builder); auto input_literal = Literal::CreateR2FromArray2D(input); @@ -649,7 +647,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) { // Reduces a matrix among dimension 1. XLA_TEST_F(ReduceTest, Reduce2DAmong1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); @@ -660,7 +658,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) { XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar). - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); @@ -670,7 +668,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) { // Tests 2D matrix ReduceToRow operation. XLA_TEST_F(ReduceTest, Reduce2DAmongY) { - ComputationBuilder builder(client_, "reduce_among_y"); + XlaBuilder builder("reduce_among_y"); auto m = builder.ConstantLiteral(*literal_2d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); @@ -680,7 +678,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {1, 2}); @@ -690,7 +688,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1}); @@ -700,7 +698,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) { } XLA_TEST_F(ReduceTest, ReduceR3ToR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0, 1, 2}); @@ -710,7 +708,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {0}); @@ -725,7 +723,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {1}); @@ -742,7 +740,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) { } XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto m = builder.ConstantLiteral(*literal_3d_); auto add = CreateScalarAddComputation(F32, &builder); builder.Reduce(m, builder.ConstantR0(0.0f), add, {2}); @@ -816,7 +814,7 @@ class ReduceR3ToR2Test : public ReduceTest, public ::testing::WithParamInterface {}; XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const auto& bounds = GetParam().bounds; Array3D input_array(bounds[0], bounds[1], bounds[2]); // input_array.FillRandom(3.14f, 0.05); @@ -830,7 +828,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { auto input_activations = builder.Parameter(0, input_literal->shape(), "input"); - Computation add = CreateScalarAddComputation(F32, &builder); + XlaComputation add = CreateScalarAddComputation(F32, &builder); auto sum = builder.Reduce(input_activations, builder.ConstantR0(0.0f), add, GetParam().reduce_dims); @@ -870,8 +868,8 @@ INSTANTIATE_TEST_CASE_P( // IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on // 2017-07-26. XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { - ComputationBuilder builder(client_, TestName()); - Computation max_f32 = CreateScalarMaxComputation(F32, &builder); + XlaBuilder builder(TestName()); + XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder); auto a = builder.ConstantR0(2.0f); auto a2 = builder.Abs(a); @@ -898,8 +896,8 @@ class ReduceInitializerTest : public ReduceTest { protected: template void DoTest(T initializer, int num_elems) { - ComputationBuilder builder(client_, TestName()); - Computation max_fn = CreateScalarMaxComputation( + XlaBuilder builder(TestName()); + XlaComputation max_fn = CreateScalarMaxComputation( primitive_util::NativeToPrimitiveType(), &builder); auto init = builder.ConstantR0(initializer); @@ -934,5 +932,36 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) { DoTest(1234556789123, 1024); } +// Test the operational semantic that the init value is passed on the lhs for +// reduces. Can be tested by performing an "identity" reduce (that simply +// returns one of the parameters). In this case, we return the rhs, which for +// a 1D array with one element, should not be the init value. +XLA_TEST_F(ReduceTest, ReduceIdentity) { + XlaBuilder builder(TestName()); + Shape single_float = ShapeUtil::MakeShape(F32, {}); + builder.Parameter(0, single_float, "lhs-unused"); + builder.Parameter(1, single_float, "rhs-used"); + auto computation_status = builder.Build(); + TF_ASSERT_OK(computation_status.status()); + + Shape operand_shape = ShapeUtil::MakeShape(F32, {1}); + builder.Reduce(builder.Parameter(0, operand_shape, "operand"), + builder.Parameter(1, single_float, "init"), + computation_status.ValueOrDie(), {0}); + + float operand[] = {42.0f}; + float init = 58.5f; + float expected = 42.0f; + std::unique_ptr input_literal = Literal::CreateR1(operand); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + std::unique_ptr input_literal2 = Literal::CreateR0(init); + std::unique_ptr input_global_data2 = + client_->TransferToServer(*input_literal2).ConsumeValueOrDie(); + ComputeAndCompareR0( + &builder, expected, {input_global_data.get(), input_global_data2.get()}, + ErrorSpec(0.0001)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 6a054a5dd39..266760e8202 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -356,12 +356,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); - std::unique_ptr arg_literal = Literal::CreateFromShape(shape); - auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 1.0f; - }; - TF_EXPECT_OK(arg_literal->Populate(generator)); - + auto arg_literal = MakeUnique(shape); + arg_literal->PopulateWithValue(1.0f); const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); Padding padding = Padding::kValid; @@ -371,13 +367,8 @@ XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); - auto out_generator = - [&](tensorflow::gtl::ArraySlice indexes) -> float { - return 27.0f; - }; - TF_EXPECT_OK(expected->Populate(out_generator)); - + auto expected = MakeUnique(result_shape); + expected->PopulateWithValue(27.0f); ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } @@ -861,8 +852,7 @@ INSTANTIATE_TEST_CASE_P( class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {}; // TODO(b/72234705): Fix the test cases failed on CPU and GPU. -XLA_TEST_P(R4ReduceWindowAnyDimsTest, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { +XLA_TEST_P(R4ReduceWindowAnyDimsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { DoIt(); } @@ -1151,7 +1141,7 @@ class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; // TODO(b/72234705): Fix the test cases failed on CPU and GPU. XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, - DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) { DoIt(); } @@ -1349,7 +1339,7 @@ INSTANTIATE_TEST_CASE_P( class ReduceWindowTextTest : public HloTestBase {}; TEST_F(ReduceWindowTextTest, R2General256x384) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1366,7 +1356,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1383,7 +1373,7 @@ ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window= } TEST_F(ReduceWindowTextTest, R2General2x5) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1400,7 +1390,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R2EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R2Window mul { lhs = f32[] parameter(0) @@ -1418,7 +1408,7 @@ ENTRY R2Window { } TEST_F(ReduceWindowTextTest, R3EffectiveScalar) { - const string& hlo_string = R"( + const string hlo_string = R"( HloModule R3Window mul { lhs = f32[] parameter(0) @@ -1435,5 +1425,41 @@ ENTRY R3Window { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } +TEST_F(HloTestBase, ReduceWindowIdentity) { + const string hlo_string = R"( +HloModule ReduceWindowIdentity +identity.pad_to_reduce_window { + param0 = f32[] parameter(0) + ROOT param1 = f32[] parameter(1) +} +ENTRY reduce-window-identity { + operand = f32[1,32,64]{2,1,0} parameter(0) + constant.4466 = f32[] constant(0) + ROOT reduce-window = f32[1,33,64]{2,1,0} reduce-window(operand, constant.4466), window={size=1x1x1 pad=0_0x1_0x0_0}, to_apply=identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + +TEST_F(HloTestBase, ReduceWindowS32) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: s32[], param1: s32[]) -> s32[] { + %param0 = s32[] parameter(0) + ROOT %param1 = s32[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { + %parameter.0 = s32[81,8]{1,0} parameter(0) + %parameter.1 = s32[] parameter(1) + ROOT %reduce-window = s32[82,8]{1,0} reduce-window(s32[81,8]{1,0} %parameter.0, s32[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 6d063ffc363..36d763b0f7f 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -38,17 +38,17 @@ class ReplayTest : public ClientLibraryTestBase {}; TEST_F(ReplayTest, TwoPlusTwoReplay) { // Make 2+2 computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto two = builder.ConstantR0(2); builder.Add(two, two); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. - std::unique_ptr module = + std::unique_ptr module = computation.Snapshot().ConsumeValueOrDie(); // Replay it. - Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); // Check signature is the same. std::unique_ptr original_shape = @@ -69,18 +69,18 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Make computation. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y"); builder.Add(x, y); - Computation computation = builder.Build().ConsumeValueOrDie(); + XlaComputation computation = builder.Build().ConsumeValueOrDie(); // Serialize it out. - std::unique_ptr module = + std::unique_ptr module = computation.Snapshot().ConsumeValueOrDie(); // Replay it. - Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); // Check signature is the same. std::unique_ptr original_shape = @@ -109,24 +109,24 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { TEST_F(ReplayTest, MapPlusTwoOverR1) { // As above, but with map(+2) over some constant array. - ComputationBuilder plus_two_builder(client_, "plus two"); + XlaBuilder plus_two_builder("plus two"); auto input = plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input"); plus_two_builder.Add(input, plus_two_builder.ConstantR0(2)); - Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); + XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie(); - ComputationBuilder mapper_builder(client_, TestName()); + XlaBuilder mapper_builder(TestName()); auto original = mapper_builder.ConstantR1({1, 2, 3}); mapper_builder.Map({original}, plus_two, {0}); - Computation computation = mapper_builder.Build().ConsumeValueOrDie(); + XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie(); // Serialize it out. - std::unique_ptr module = + std::unique_ptr module = computation.Snapshot().ConsumeValueOrDie(); // Replay it. - Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); + XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie(); // Check signature is the same. std::unique_ptr original_shape = @@ -135,10 +135,6 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { client_->GetComputationShape(replayed).ConsumeValueOrDie(); ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); - // Destroy the originals. - computation.Reset(); - plus_two.Reset(); - // Run it. std::unique_ptr literal = client_ diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index e045e164e2e..da1b588ec41 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -20,10 +20,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -34,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,7 +43,7 @@ namespace { using ReshapeMotionTest = ClientLibraryTestBase; TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.ConstantR2({{2, 3, 5}, {7, 11, 13}}); auto b = builder.ConstantR2({{17, 19}, {23, 29}, {31, 37}}); auto c = builder.Reshape(a, {6}); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d7462d581b8..a4580cd71d4 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); if (use_bfloat16()) { - expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + expected = Literal::ConvertF32ToBF16(*expected); } - LiteralTestUtil::ExpectEqual(*expected, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual)); } XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { @@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = - LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); + Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = - LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); + Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, zero_error_spec_); } @@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. if (use_bfloat16()) { - auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + auto expected = Literal::ConvertF32ToBF16(*input_literal); EXPECT_EQ(expected->data(), output_literal->data()); } else { EXPECT_EQ(input_literal->data(), output_literal->data()); @@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal) ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape @@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { /*new_sizes=*/new_bounds); std::unique_ptr expected = - LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal) ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 6959c95502c..e7bd142dc9d 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -114,7 +114,7 @@ class ReverseTest : public ClientLibraryTestBase {}; // Tests the reverse operation on a 4D U8 array on dimension 0 and 3. XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Input shape is U8[1x2x3x4]. // clang-format off Array4D input({{ @@ -144,7 +144,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { // Tests the reverse operation on a 4D float array on dimension 0 and 1. TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); // Input shape is float[4x3x2x1]. // clang-format off Array4D input({ diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 8cbfcc6f5c4..7cfca781acd 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { @@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); - LiteralTestUtil::ExpectEqual(*round_tripped, *actual); + EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index 32db45f8a66..f334a8c1318 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase { client_->TransferToServer(original).ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data).ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqual(original, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(original, *result)); } }; diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 0c88bef69df..308d3fc78a5 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -17,9 +17,10 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -43,83 +44,80 @@ class ScalarComputationsTest : public ClientLibraryTestBase { protected: // A template for building and running a binary comparison test. template - void TestCompare(NativeT lhs, NativeT rhs, bool expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + void TestCompare( + NativeT lhs, NativeT rhs, bool expected, + XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } template void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected, - ComputationDataHandle (ComputationBuilder::*op)( - const ComputationDataHandle&, - const ComputationDataHandle&, - tensorflow::gtl::ArraySlice)) { - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle lhs_op = builder.ConstantR0(lhs); - ComputationDataHandle rhs_op = builder.ConstantR0(rhs); - ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {}); + XlaOp (XlaBuilder::*op)(const XlaOp&, const XlaOp&, + tensorflow::gtl::ArraySlice)) { + XlaBuilder builder(TestName()); + XlaOp lhs_op = builder.ConstantR0(lhs); + XlaOp rhs_op = builder.ConstantR0(rhs); + XlaOp result = (builder.*op)(lhs_op, rhs_op, {}); ComputeAndCompareR0(&builder, expected, {}); } }; XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.ConstantR0(2.1f); ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(2.1f)); ComputeAndCompareR0(&builder, -2.1f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, NegateScalarS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Neg(builder.ConstantR0(2)); ComputeAndCompareR0(&builder, -2, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); ComputeAndCompareR0(&builder, 7.6f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(2), builder.ConstantR0(5)); ComputeAndCompareR0(&builder, 7, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(35), builder.ConstantR0(57)); ComputeAndCompareR0(&builder, 92, {}); } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const uint64 a = static_cast(1) << 63; const uint64 b = a + 1; builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); @@ -128,7 +126,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) { } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const int64 a = static_cast(1) << 62; const int64 b = a - 1; builder.Add(builder.ConstantR0(a), builder.ConstantR0(b)); @@ -137,7 +135,7 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) { } XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Add(builder.ConstantR0(0.25), builder.ConstantR0(3.5)); @@ -145,21 +143,21 @@ XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) { } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Sub(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)); ComputeAndCompareR0(&builder, -3.4f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Sub(builder.ConstantR0(2), builder.ConstantR0(5)); ComputeAndCompareR0(&builder, -3, {}); } XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); builder.ConvertElementType(a, F32); @@ -172,7 +170,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.Mul(builder.ConstantR0(2.1f), builder.ConstantR0(5.5f)), builder.ConstantR0(0.5f)); @@ -191,7 +189,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsS32) { for (int32 x : data) { for (int32 y : data) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); // Signed integer overflow is undefined behavior in C++. Convert the input @@ -210,7 +208,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { for (uint32 x : data) { for (uint32 y : data) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul(builder.ConstantR0(x), builder.ConstantR0(y)); uint32 expected = x * y; @@ -220,7 +218,7 @@ XLA_TEST_F(ScalarComputationsTest, MulTwoScalarsU32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Mul( builder.Mul(builder.ConstantR0(2), builder.ConstantR0(5)), builder.ConstantR0(1)); @@ -229,7 +227,7 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { } XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::unique_ptr a_literal = Literal::CreateR0(2.1f); std::unique_ptr b_literal = Literal::CreateR0(5.5f); std::unique_ptr c_literal = Literal::CreateR0(0.5f); @@ -241,9 +239,9 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { std::unique_ptr c_data = client_->TransferToServer(*c_literal).ConsumeValueOrDie(); - ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a"); - ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b"); - ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c"); + XlaOp a = builder.Parameter(0, a_literal->shape(), "a"); + XlaOp b = builder.Parameter(1, b_literal->shape(), "b"); + XlaOp c = builder.Parameter(2, c_literal->shape(), "c"); builder.Mul(builder.Mul(a, b), c); ComputeAndCompareR0(&builder, 5.775f, @@ -252,14 +250,14 @@ XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Div(builder.ConstantR0(5.0f), builder.ConstantR0(2.5f)); ComputeAndCompareR0(&builder, 2.0f, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(2.5f), builder.ConstantR0(5.0f)); ComputeAndCompareR0(&builder, 2.5f, {}, error_spec_); @@ -282,7 +280,7 @@ class DivS32Test : public ClientLibraryTestBase, XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Div(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); @@ -291,7 +289,7 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(p.dividend), builder.ConstantR0(p.divisor)); @@ -300,9 +298,9 @@ XLA_TEST_P(DivS32Test, RemainderTwoScalarsS32) { XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividendd = CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = @@ -315,9 +313,9 @@ XLA_TEST_P(DivS32Test, DivideTwoScalarsNonConstS32) { XLA_TEST_P(DivS32Test, RemainderTwoScalarsNonConstDivisorS32) { DivS32Params p = GetParam(); - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle dividend; - ComputationDataHandle divisor; + XlaBuilder builder(TestName()); + XlaOp dividend; + XlaOp divisor; auto dividendd = CreateR0Parameter(p.dividend, 0, "dividend", &builder, ÷nd); auto divisord = @@ -364,13 +362,13 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; // clang-format on - Computation div_computation; + XlaComputation div_computation; { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle dividend = + XlaOp dividend = builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); - ComputationDataHandle divisor = + XlaOp divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Div(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(div_computation, builder.Build()); @@ -392,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend / divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } @@ -405,13 +403,13 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { 0, 1, 2, 17, 101, 3333, 0x7FFFFFFF, 0x80000000, UINT32_MAX - 1, UINT32_MAX}; // clang-format on - Computation rem_computation; + XlaComputation rem_computation; { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle dividend = + XlaOp dividend = builder.Parameter(0, ShapeUtil::MakeShape(U32, {}), "dividend"); - ComputationDataHandle divisor = + XlaOp divisor = builder.Parameter(1, ShapeUtil::MakeShape(U32, {}), "divisor"); builder.Rem(dividend, divisor); TF_ASSERT_OK_AND_ASSIGN(rem_computation, builder.Build()); @@ -433,14 +431,14 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) { &execution_options_) .ConsumeValueOrDie(); auto expected_literal = Literal::CreateR0(dividend % divisor); - LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } } } } XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); @@ -450,7 +448,7 @@ XLA_TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { } XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32 // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF). builder.Div(builder.ConstantR0(0xFFFFFFFE), @@ -460,7 +458,7 @@ XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) { } XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Rem(builder.ConstantR0(11), builder.ConstantR0(3)); ComputeAndCompareR0(&builder, 2, {}); @@ -469,7 +467,7 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x && y, {}); @@ -480,7 +478,7 @@ XLA_TEST_F(ScalarComputationsTest, AndBool) { XLA_TEST_F(ScalarComputationsTest, AndS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x & y, {}); @@ -491,7 +489,7 @@ XLA_TEST_F(ScalarComputationsTest, AndS32) { XLA_TEST_F(ScalarComputationsTest, AndU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x & y, {}); @@ -502,7 +500,7 @@ XLA_TEST_F(ScalarComputationsTest, AndU32) { XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x || y, {}); @@ -513,7 +511,7 @@ XLA_TEST_F(ScalarComputationsTest, OrBool) { XLA_TEST_F(ScalarComputationsTest, OrS32) { for (int32 x : {0, 8}) { for (int32 y : {1, -16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x | y, {}); @@ -524,7 +522,7 @@ XLA_TEST_F(ScalarComputationsTest, OrS32) { XLA_TEST_F(ScalarComputationsTest, OrU32) { for (uint32 x : {0, 8}) { for (uint32 y : {1, 16}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x | y, {}); @@ -534,7 +532,7 @@ XLA_TEST_F(ScalarComputationsTest, OrU32) { XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, !x, {}); @@ -543,7 +541,7 @@ XLA_TEST_F(ScalarComputationsTest, NotBool) { XLA_TEST_F(ScalarComputationsTest, NotS32) { for (int32 x : {-1, 0, 1}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, ~x, {}); @@ -552,7 +550,7 @@ XLA_TEST_F(ScalarComputationsTest, NotS32) { XLA_TEST_F(ScalarComputationsTest, NotU32) { for (uint32 x : {0, 1, 2}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, ~x, {}); @@ -560,7 +558,7 @@ XLA_TEST_F(ScalarComputationsTest, NotU32) { } XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Select(builder.ConstantR0(true), // The predicate. builder.ConstantR0(123.0f), // The value on true. builder.ConstantR0(42.0f)); // The value on false. @@ -569,7 +567,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { } XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Select(builder.ConstantR0(false), // The predicate. builder.ConstantR0(123.0f), // The value on true. builder.ConstantR0(42.0f)); // The value on false. @@ -580,7 +578,7 @@ XLA_TEST_F(ScalarComputationsTest, SelectScalarFalse) { // This test is an explicit version of what is happening in the following // templatized comparison tests. XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Gt(builder.ConstantR0(2.0f), builder.ConstantR0(1.0f)); ComputeAndCompareR0(&builder, true, {}); @@ -588,157 +586,156 @@ XLA_TEST_F(ScalarComputationsTest, CompareGtScalar) { // S32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqS32Greater) { - TestCompare(2, 1, false, &ComputationBuilder::Eq); + TestCompare(2, 1, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareEqS32Equal) { - TestCompare(3, 3, true, &ComputationBuilder::Eq); + TestCompare(3, 3, true, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeS32) { - TestCompare(2, 1, true, &ComputationBuilder::Ne); + TestCompare(2, 1, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeS32) { - TestCompare(2, 1, true, &ComputationBuilder::Ge); + TestCompare(2, 1, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtS32) { - TestCompare(1, 5, false, &ComputationBuilder::Gt); + TestCompare(1, 5, false, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeS32) { - TestCompare(2, 1, false, &ComputationBuilder::Le); + TestCompare(2, 1, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtS32) { - TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(9, 7, false, &XlaBuilder::Lt); TestCompare(std::numeric_limits::min(), - std::numeric_limits::max(), true, - &ComputationBuilder::Lt); + std::numeric_limits::max(), true, &XlaBuilder::Lt); } // U32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqU32False) { - TestCompare(2, 1, false, &ComputationBuilder::Eq); + TestCompare(2, 1, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeU32) { - TestCompare(2, 1, true, &ComputationBuilder::Ne); + TestCompare(2, 1, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Greater) { - TestCompare(2, 1, true, &ComputationBuilder::Ge); + TestCompare(2, 1, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeU32Equal) { - TestCompare(3, 3, true, &ComputationBuilder::Ge); + TestCompare(3, 3, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtU32) { - TestCompare(1, 5, false, &ComputationBuilder::Gt); - TestCompare(5, 5, false, &ComputationBuilder::Gt); - TestCompare(5, 1, true, &ComputationBuilder::Gt); + TestCompare(1, 5, false, &XlaBuilder::Gt); + TestCompare(5, 5, false, &XlaBuilder::Gt); + TestCompare(5, 1, true, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeU32) { - TestCompare(2, 1, false, &ComputationBuilder::Le); + TestCompare(2, 1, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtU32) { - TestCompare(9, 7, false, &ComputationBuilder::Lt); + TestCompare(9, 7, false, &XlaBuilder::Lt); TestCompare(0, std::numeric_limits::max(), true, - &ComputationBuilder::Lt); + &XlaBuilder::Lt); } // F32 comparisons. XLA_TEST_F(ScalarComputationsTest, CompareEqF32False) { - TestCompare(2.0, 1.3, false, &ComputationBuilder::Eq); + TestCompare(2.0, 1.3, false, &XlaBuilder::Eq); } XLA_TEST_F(ScalarComputationsTest, CompareNeF32) { - TestCompare(2.0, 1.3, true, &ComputationBuilder::Ne); + TestCompare(2.0, 1.3, true, &XlaBuilder::Ne); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Greater) { - TestCompare(2.0, 1.9, true, &ComputationBuilder::Ge); + TestCompare(2.0, 1.9, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32Equal) { - TestCompare(3.5, 3.5, true, &ComputationBuilder::Ge); + TestCompare(3.5, 3.5, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGtF32) { - TestCompare(1.0, 5.2, false, &ComputationBuilder::Gt); + TestCompare(1.0, 5.2, false, &XlaBuilder::Gt); } XLA_TEST_F(ScalarComputationsTest, CompareLeF32) { - TestCompare(2.0, 1.2, false, &ComputationBuilder::Le); + TestCompare(2.0, 1.2, false, &XlaBuilder::Le); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32) { - TestCompare(9.0, 7.2, false, &ComputationBuilder::Lt); + TestCompare(9.0, 7.2, false, &XlaBuilder::Lt); } // F32 comparisons with exceptional values. The test names encode the // left/right operands at the end, and use Minf and Mzero for -inf and -0.0. XLA_TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) { - TestCompare(-INFINITY, -0.0, true, &ComputationBuilder::Lt); + TestCompare(-INFINITY, -0.0, true, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, false, &ComputationBuilder::Lt); + TestCompare(-0.0, 0.0, false, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) { - TestCompare(0.0, INFINITY, true, &ComputationBuilder::Lt); + TestCompare(0.0, INFINITY, true, &XlaBuilder::Lt); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) { - TestCompare(-INFINITY, -0.0, false, &ComputationBuilder::Ge); + TestCompare(-INFINITY, -0.0, false, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) { // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754. - TestCompare(-0.0, 0.0, true, &ComputationBuilder::Ge); + TestCompare(-0.0, 0.0, true, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) { - TestCompare(0.0, INFINITY, false, &ComputationBuilder::Ge); + TestCompare(0.0, INFINITY, false, &XlaBuilder::Ge); } XLA_TEST_F(ScalarComputationsTest, ExpScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Exp(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 7.3890562, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, LogScalar) { - ComputationBuilder builder(client_, "log"); + XlaBuilder builder("log"); builder.Log(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 0.6931471, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tanh(builder.ConstantR0(2.0f)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Tanh(builder.ConstantR0(2.0)); ComputeAndCompareR0(&builder, 0.96402758, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, PowScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Pow(builder.ConstantR0(2.0f), builder.ConstantR0(3.0f)); ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -747,7 +744,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(2), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -756,7 +753,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(-1), // The lower bound. builder.ConstantR0(-5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -765,7 +762,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(5), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -774,7 +771,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(2), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -783,7 +780,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(1), // The lower bound. builder.ConstantR0(0), // The operand to be clamped. builder.ConstantR0(3)); // The upper bound. @@ -792,7 +789,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(5.0f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -801,7 +798,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(2.5f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -810,7 +807,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { } XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(-5.0f), // The operand to be clamped. builder.ConstantR0(3.0f)); // The upper bound. @@ -819,70 +816,70 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { } XLA_TEST_F(ScalarComputationsTest, MinS32Above) { - TestMinMax(10, 3, 3, &ComputationBuilder::Min); + TestMinMax(10, 3, 3, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinS32Below) { - TestMinMax(-100, 3, -100, &ComputationBuilder::Min); + TestMinMax(-100, 3, -100, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxS32Above) { - TestMinMax(10, 3, 10, &ComputationBuilder::Max); + TestMinMax(10, 3, 10, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxS32Below) { - TestMinMax(-100, 3, 3, &ComputationBuilder::Max); + TestMinMax(-100, 3, 3, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MinU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, 3, &ComputationBuilder::Min); + TestMinMax(large, 3, 3, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinU32Below) { - TestMinMax(0, 5, 0, &ComputationBuilder::Min); + TestMinMax(0, 5, 0, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxU32Above) { const uint32 large = std::numeric_limits::max(); - TestMinMax(large, 3, large, &ComputationBuilder::Max); + TestMinMax(large, 3, large, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxU32Below) { - TestMinMax(0, 5, 5, &ComputationBuilder::Max); + TestMinMax(0, 5, 5, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MinF32Above) { - TestMinMax(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min); + TestMinMax(10.1f, 3.1f, 3.1f, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinF32Below) { - TestMinMax(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min); + TestMinMax(-100.1f, 3.1f, -100.1f, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MinPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Min); - TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Min); + TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Min); + TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Min); } XLA_TEST_F(ScalarComputationsTest, MaxF32Above) { - TestMinMax(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max); + TestMinMax(10.1f, 3.1f, 10.1f, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxF32Below) { - TestMinMax(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max); + TestMinMax(-100.1f, 3.1f, 3.1f, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, MaxPropagatesNan) { SetFastMathDisabled(true); - TestMinMax(NAN, 3.1f, NAN, &ComputationBuilder::Max); - TestMinMax(-3.1f, NAN, NAN, &ComputationBuilder::Max); + TestMinMax(NAN, 3.1f, NAN, &XlaBuilder::Max); + TestMinMax(-3.1f, NAN, NAN, &XlaBuilder::Max); } XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Div( b.Sub(b.Mul(b.ConstantR0(1), b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), @@ -895,7 +892,7 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) { XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { // Compute the expression 1 * (3 - 1) * (7 + 0) - 4. - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); b.Sub(b.Mul(b.ConstantR0(1), b.Mul(b.Sub(b.ConstantR0(3), b.ConstantR0(1)), b.Add(b.ConstantR0(7), b.ConstantR0(0)))), @@ -905,21 +902,20 @@ XLA_TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { } XLA_TEST_F(ScalarComputationsTest, SqrtF320) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Literal zero_literal = Literal::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); - ComputationDataHandle zero = - builder.Parameter(0, zero_literal.shape(), "zero"); + XlaOp zero = builder.Parameter(0, zero_literal.shape(), "zero"); builder.SqrtF32(zero); ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } XLA_TEST_F(ScalarComputationsTest, RoundScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); builder.Round(builder.ConstantR0(1.4f)); ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 009e7d24c5c..72707f22444 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -16,13 +16,12 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -35,7 +34,7 @@ class SelectTest : public ClientLibraryTestBase { }; TEST_F(SelectTest, SelectScalarF32True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR0(123.0f); auto on_false = builder.ConstantR0(42.0f); @@ -45,7 +44,7 @@ TEST_F(SelectTest, SelectScalarF32True) { } TEST_F(SelectTest, SelectScalarS32True) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR0(-42); auto on_false = builder.ConstantR0(42); @@ -55,7 +54,7 @@ TEST_F(SelectTest, SelectScalarS32True) { } TEST_F(SelectTest, SelectScalarF32False) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto on_true = builder.ConstantR0(123.0f); auto on_false = builder.ConstantR0(42.0f); @@ -65,7 +64,7 @@ TEST_F(SelectTest, SelectScalarF32False) { } XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR1({}); auto on_true = builder.ConstantR1({}); auto on_false = builder.ConstantR1({}); @@ -75,7 +74,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) { } TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR1({false, true, false, true, false}); auto on_true = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); auto on_false = builder.ConstantR1({10.0f, 5.0f, 1.0f, 10.0f, -6.0f}); @@ -88,7 +87,7 @@ TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) { XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector // is not a constant, but rather the result of comparing two other vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({}); auto v2 = builder.ConstantR1({}); auto cmp = builder.Eq(v1, v2); @@ -102,7 +101,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is // not a constant, but rather the result of comparing two other vectors. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({1, 2, 3, 4, 5}); auto v2 = builder.ConstantR1({9, 2, 9, 4, 9}); auto cmp = builder.Eq(v1, v2); @@ -116,7 +115,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v1 = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); auto v2 = builder.ConstantR1({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f}); auto cmp = builder.Gt(v1, v2); @@ -131,9 +130,9 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { // Selects among two R1F32s, which come from parameters. v1 and v2 are // compared, and selection between them happens based on a gt-comparison mask. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -151,7 +150,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the // data size passed in and out is large. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Number of floats in the data passed into and out of the computation. constexpr int datalen = 15 * 1000; @@ -174,7 +173,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { expected_vec.push_back(larger); } - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -192,7 +191,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) { TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to // select between two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1, -1, 2, -2}); auto s = builder.ConstantR0(0); auto cmp = builder.Gt(v, s); @@ -209,7 +208,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) { TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to // select between two R1F32s. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto v = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f}); auto s = builder.ConstantR0(2.5f); auto cmp = builder.Gt(v, s); @@ -225,7 +224,7 @@ TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) { XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { for (bool which : {false, true}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(which); auto on_true = builder.ConstantR1({}); auto on_false = builder.ConstantR1({}); @@ -236,7 +235,7 @@ XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) { } TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(true); auto on_true = builder.ConstantR1({-2.5f, 25.5f}); auto on_false = builder.ConstantR1({10.0f, 5.0f}); @@ -246,7 +245,7 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) { } TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto pred = builder.ConstantR0(false); auto on_true = builder.ConstantR1({-2.5f, 25.5f}); auto on_false = builder.ConstantR1({10.0f, 5.0f}); diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc deleted file mode 100644 index 29f79ec28a1..00000000000 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/tests/client_library_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/test.h" - -namespace xla { -namespace { - -class SetReturnValueTest : public ClientLibraryTestBase {}; - -TEST_F(SetReturnValueTest, NoSetValue) { - ComputationBuilder builder(client_, "no_set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - - std::vector expected = {1.0, 3.0, 4.0, 0.0, -1.0, - 5.0, 6.0, -2.0, -3.0, 7.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValue) { - ComputationBuilder builder(client_, "set_value"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueAndModify) { - ComputationBuilder builder(client_, "set_value_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { - ComputationBuilder builder(client_, "set_value_multiple_times_and_modify"); - auto alpha = builder.ConstantR0(1.0); - auto x = builder.ConstantR1( - {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0}); - auto ax = builder.Add(alpha, x); - auto aax = builder.Add(alpha, ax); - auto builder_status = builder.SetReturnValue(aax); - EXPECT_TRUE(builder_status.ok()); - auto aaax = builder.Add(alpha, aax); - builder_status = builder.SetReturnValue(ax); - EXPECT_TRUE(builder_status.ok()); - auto aaaax = builder.Add(alpha, aaax); - - std::vector expected = {0.0, 2.0, 3.0, -1.0, -2.0, - 4.0, 5.0, -3.0, -4.0, 6.0}; - - ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index e2d406f66d9..7ca99a91635 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #define DISABLED_ON_CPU(X) X -#define DISABLED_ON_CPU_PARALLEL(X) X #define DISABLED_ON_GPU(X) X #define DISABLED_ON_INTERPRETER(X) X @@ -51,13 +50,6 @@ limitations under the License. # define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X) #endif // XLA_TEST_BACKEND_CPU -#ifdef XLA_TEST_BACKEND_CPU_PARALLEL -# undef DISABLED_ON_CPU -# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X) -# undef DISABLED_ON_CPU_PARALLEL -# define DISABLED_ON_CPU_PARALLEL(X) XLA_TEST_PASTE(DISABLED_, X) -#endif // XLA_TEST_BACKEND_CPU_PARALLEL - #ifdef XLA_TEST_BACKEND_GPU # undef DISABLED_ON_GPU # define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X) diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index cda1989fad6..de186513880 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -107,7 +107,7 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } - std::unique_ptr literal = Literal::CreateFromShape(shape); + auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: PopulateWithRandomFloatingPointData(literal.get(), engine); @@ -339,8 +339,7 @@ StatusOr>> MakeFakeArguments( return std::move(arguments); } -Status VerifyHloModule(const perftools::gputools::Platform& platform, - HloModule* const module, bool allow_mixed_precision) { +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision) { return HloVerifier(allow_mixed_precision).Run(module).status(); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index b5ab779574f..f483cdebea5 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -68,8 +68,7 @@ StatusOr>> MakeFakeArguments( // Check that a given module satisfies various constraints before trying to // execute it. -Status VerifyHloModule(const perftools::gputools::Platform& platform, - HloModule* const module, +Status VerifyHloModule(HloModule* const module, bool allow_mixed_precision = false); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index e8efc6e2a83..59afd28a80c 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -28,7 +28,7 @@ namespace { class TestUtilsTest : public LocalClientTestBase {}; XLA_TEST_F(TestUtilsTest, UnusedParam) { - ComputationBuilder builder(local_client_, TestName()); + XlaBuilder builder(TestName()); // Make the reduction lambda. Shape single_float = ShapeUtil::MakeShape(F32, {}); builder.Parameter(0, single_float, "unused"); diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index 268ba338f2e..0063e7ad415 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -45,7 +45,7 @@ class TransferManagerTest : public LocalClientTestBase { ~TransferManagerTest() override = default; - std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { + ScopedShapedBuffer AllocateDeviceBuffer(const Shape& shape) { return transfer_manager_ ->AllocateScopedShapedBuffer( shape, GetOrCreateAllocator(local_client_->platform()), @@ -64,10 +64,10 @@ XLA_TEST_F(TransferManagerTest, TransferR0U32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR0Equal(42, *result); } @@ -80,10 +80,10 @@ XLA_TEST_F(TransferManagerTest, TransferR1F32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, *result); @@ -98,10 +98,10 @@ XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR1Equal(test_vector, *result); } @@ -114,10 +114,10 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); EXPECT_EQ(result->GetR1U8AsString(), test_string); } @@ -130,10 +130,10 @@ XLA_TEST_F(TransferManagerTest, TransferR2F32) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); LiteralTestUtil::ExpectR2Equal( {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); @@ -150,10 +150,10 @@ XLA_TEST_F(TransferManagerTest, // Round trip literal through device. Set the on-device layout to something // different than the literal layout. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); EXPECT_FALSE( LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); @@ -170,12 +170,12 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { @@ -184,12 +184,12 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { @@ -204,12 +204,12 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValue) { @@ -219,12 +219,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { @@ -238,12 +238,12 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { // Round trip literal through device. ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, *literal, *device_buffer)); + stream_executor_, *literal, device_buffer)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, transfer_manager_->TransferLiteralFromDevice( - stream_executor_, *device_buffer)); + stream_executor_, device_buffer)); - LiteralTestUtil::ExpectEqual(*literal, *result); + EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index fe5a1778a2c..fe1e3da7eca 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -16,14 +16,13 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -38,7 +37,7 @@ class TransposeTest : public ClientLibraryTestBase { }; XLA_TEST_F(TransposeTest, Transpose0x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 0)); auto result = builder.Transpose(lhs, {1, 0}); @@ -46,7 +45,7 @@ XLA_TEST_F(TransposeTest, Transpose0x0) { } XLA_TEST_F(TransposeTest, Transpose0x42) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(0, 42)); auto result = builder.Transpose(lhs, {1, 0}); @@ -54,7 +53,7 @@ XLA_TEST_F(TransposeTest, Transpose0x42) { } XLA_TEST_F(TransposeTest, Transpose7x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2FromArray2D(Array2D(7, 0)); auto result = builder.Transpose(lhs, {1, 0}); @@ -62,7 +61,7 @@ XLA_TEST_F(TransposeTest, Transpose7x0) { } TEST_F(TransposeTest, Transpose2x2) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto lhs = builder.ConstantR2({ {1.0, 2.0}, {3.0, 4.0}, }); @@ -74,7 +73,7 @@ TEST_F(TransposeTest, Transpose2x2) { } XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D(Array3D(0, 2, 3)); auto result = builder.Transpose(operand, {1, 2, 0}); @@ -82,7 +81,7 @@ XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) { } TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {1, 2, 0}); @@ -92,7 +91,7 @@ TEST_F(TransposeTest, Transpose1x2x3_2x3x1) { } TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {2, 1, 0}); @@ -102,7 +101,7 @@ TEST_F(TransposeTest, Transpose1x2x3_3x2x1) { } TEST_F(TransposeTest, Transpose1x2x3_1x2x3) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto operand = builder.ConstantR3FromArray3D({{{1, 2, 3}, {4, 5, 6}}}); auto result = builder.Transpose(operand, {0, 1, 2}); @@ -116,7 +115,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { Array2D transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}}); for (int transposes = 0; transposes <= 10; ++transposes) { - ComputationBuilder builder(client_, "Transpose"); + XlaBuilder builder("Transpose"); auto computed = builder.ConstantR2FromArray2D(input); for (int i = 0; i < transposes; ++i) { computed = builder.Transpose(computed, {1, 0}); @@ -130,7 +129,7 @@ TEST_F(TransposeTest, MultiTranspose3x2) { TEST_F(TransposeTest, Small_1x1) { auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1); - ComputationBuilder builder(client_, "transpose_1x1"); + XlaBuilder builder("transpose_1x1"); auto operand = builder.ConstantR2FromArray2D(*aoperand); builder.Transpose(operand, {1, 0}); @@ -142,7 +141,7 @@ TEST_F(TransposeTest, Small_1x1) { TEST_F(TransposeTest, Small_2x2) { auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2); - ComputationBuilder builder(client_, "transpose_2x2"); + XlaBuilder builder("transpose_2x2"); auto operand = builder.ConstantR2FromArray2D(*aoperand); builder.Transpose(operand, {1, 0}); @@ -162,7 +161,7 @@ void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) { } } - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto operand = builder.ConstantR3FromArray3D(aoperand); builder.Transpose(operand, {0, 2, 1}); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 098be6d7aab..41189231b90 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -17,8 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" @@ -269,7 +267,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { ComputeAndCompareR2(&builder, expected, {}, error_spec_); } -XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) { // Tests a selection between tuples with "false" path taken. XlaBuilder builder(TestName()); @@ -287,13 +285,13 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { } XLA_TEST_F(TupleTest, TuplesInAMap) { - Computation tuple_computation; + XlaComputation tuple_computation; { // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples. // // Need to put a select in there to prevent HLO-level optimizations from // optimizing out the tuples. - ComputationBuilder b(client_, "sort_square"); + XlaBuilder b("sort_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); auto x2 = b.Mul(x, x); auto x_smaller_tuple = b.Tuple({x, x2}); @@ -307,13 +305,13 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { tuple_computation = computation_status.ConsumeValueOrDie(); } - ComputationBuilder b(client_, TestName()); + XlaBuilder b(TestName()); auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); b.Map({input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } -XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) { // Tests a selection between tuples with "true" path taken. XlaBuilder builder(TestName()); @@ -350,7 +348,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) { } // Cascaded selects between tuple types. -XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) { // // vec1 vec2 vec2 vec1 // | | | | @@ -390,8 +388,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) { ComputeAndCompareR1(&builder, {3.f, 6.f, 9.f}, {}, error_spec_); } -XLA_TEST_F(TupleTest, - DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) { +XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) { // Similar to SelectBetweenTuples, but the constants are shared between the // input tuples. XlaBuilder builder(TestName()); @@ -498,7 +495,7 @@ XLA_TEST_F(TupleTest, ComplexTuples) { auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, {{1011, 2022}, {3031, 4042}}, {{10011, 20022}, {30031, 40042}}}); - auto prod = Literal::CreateFromShape(sum->shape()); + auto prod = MakeUnique(sum->shape()); ASSERT_TRUE(prod->Populate( [&sum](tensorflow::gtl::ArraySlice indexes) { return sum->Get(indexes) * @@ -516,10 +513,8 @@ XLA_TEST_F(TupleTest, ComplexTuples) { class TupleHloTest : public HloTestBase {}; -// Disabled on CPU parallel because that's broken and will be removed soon. // Disabled on the interpreter because bitcast doesn't exist on the interpreter. -TEST_F(TupleHloTest, - DISABLED_ON_INTERPRETER(DISABLED_ON_CPU_PARALLEL(BitcastAfterGTE))) { +XLA_TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) { const char* testcase = R"( HloModule m @@ -535,8 +530,7 @@ TEST_F(TupleHloTest, HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) .ValueOrDie(); auto param = Literal::MakeTupleOwned(Literal::CreateR1({1, 2, 3})); - TF_ASSERT_OK_AND_ASSIGN(auto result, - ExecuteNoHloPasses(std::move(module), {param.get()})); + auto result = ExecuteNoHloPasses(std::move(module), {param.get()}); EXPECT_TRUE(LiteralTestUtil::Equal( *result, *Literal::MakeTupleOwned(Literal::CreateR2({{1, 2, 3}})))); diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 835e2d7e559..c3abe22797f 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -37,7 +37,7 @@ class UnaryOpTest : public ClientLibraryTestBase { } template void AbsSize0TestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({}); auto abs = builder.Abs(arg); @@ -50,7 +50,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void AbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({-2, 25, 0, -123, inf(), -inf()}); auto abs = builder.Abs(arg); @@ -59,7 +59,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {-2, 25, 0, static_cast(-0.0), -123, inf(), -inf()}); auto sign = builder.Sign(arg); @@ -69,7 +69,7 @@ class UnaryOpTest : public ClientLibraryTestBase { template void SignAbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({-2, 25, 0, -123}); auto sign = builder.Sign(arg); auto abs = builder.Abs(arg); @@ -84,9 +84,14 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +int64 UnaryOpTest::inf() { + return 0x7FFFFFFFFFFFFFFFl; +} + template <> void UnaryOpTest::AbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, @@ -102,7 +107,7 @@ void UnaryOpTest::AbsTestHelper() { template <> void UnaryOpTest::SignTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); auto sign = builder.Sign(arg); @@ -114,7 +119,7 @@ void UnaryOpTest::SignTestHelper() { template <> void UnaryOpTest::SignAbsTestHelper() { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); auto sign = builder.Sign(arg); @@ -139,7 +144,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR1) { } XLA_TEST_F(UnaryOpTest, AbsTestR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto argi = builder.ConstantR0(-5); auto absi = builder.Abs(argi); auto argf = builder.ConstantR0(-3.0f); @@ -155,7 +160,7 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) { } XLA_TEST_F(UnaryOpTest, SignTestR0) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto argi = builder.ConstantR0(-5); auto sgni = builder.Sign(argi); // -1 auto argf = builder.ConstantR0(-4.0f); @@ -176,6 +181,7 @@ XLA_TEST_F(UnaryOpTest, SignTestR0) { XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); + SignTestHelper(); SignTestHelper(); SignTestHelper(); } @@ -187,7 +193,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { } XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {2, 25, 0, 123, std::numeric_limits::max()}); auto abs = builder.Abs(arg); @@ -197,7 +203,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { } XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR1( {2, 25, 0, 123, std::numeric_limits::max()}); auto sign = builder.Sign(arg); @@ -206,7 +212,7 @@ XLA_TEST_F(UnaryOpTest, UnsignedSignTestR1) { } XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto arg = builder.ConstantR2({{1.0, -2.0}, {-3.0, 4.0}}); auto sign = builder.Sign(arg); auto abs = builder.Abs(arg); @@ -216,7 +222,7 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 1}); auto rhs = builder.ConstantR1({1, 1}); builder.ConvertElementType(builder.Eq(lhs, rhs), S32); @@ -225,7 +231,7 @@ XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { } XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto lhs = builder.ConstantR1({0, 1}); auto rhs = builder.ConstantR1({1, 1}); builder.ConvertElementType(builder.Eq(lhs, rhs), F32); diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 32ba067a10d..82d301983fc 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -33,9 +33,9 @@ namespace { class VecOpsReduceTest : public ClientLibraryTestBase { public: - VecOpsReduceTest() : builder_(client_, TestName()) {} + VecOpsReduceTest() : builder_(TestName()) {} - ComputationDataHandle BuildSampleConstantCube() { + XlaOp BuildSampleConstantCube() { // clang-format off Array3D x3d({ {{1.0, 2.0, 3.0}, // | dim 1 // } plane 0 in dim 0 @@ -49,7 +49,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase { return builder_.ConstantR3FromArray3D(x3d); } - ComputationBuilder builder_; + XlaBuilder builder_; ErrorSpec errspec_{1e-3, 0}; }; diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index b52c718814d..5cce7a2bf82 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -39,7 +39,7 @@ namespace { class VecOpsSimpleTest : public ClientLibraryTestBase { public: - explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr) + explicit VecOpsSimpleTest(se::Platform* platform = nullptr) : ClientLibraryTestBase(platform) { mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); mutable_debug_options()->add_xla_disable_hlo_passes("inline"); @@ -49,7 +49,7 @@ class VecOpsSimpleTest : public ClientLibraryTestBase { }; XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto exp = builder.Exp(x); @@ -63,7 +63,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) { XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector exponents; exponents.reserve(count); for (int i = 0; i < count; ++i) { @@ -84,7 +84,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) { } XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D exponents(2, 2, 2, 2); std::vector exponents_vector; @@ -106,7 +106,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) { } XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); builder.Neg(x); @@ -117,7 +117,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) { } XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({2, -2, 12, -4, 5, 20, -15, 0, -2, 1}); builder.Neg(x); @@ -126,7 +126,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) { } XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {0, 1, 42, static_cast(-1), static_cast(-12)}); builder.Neg(x); @@ -136,7 +136,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) { } XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); builder.SquareF32(x); @@ -147,7 +147,7 @@ XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) { } XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); builder.ReciprocalF32(x); @@ -159,7 +159,7 @@ XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) { } XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({0.0, -0.0}); auto exp = builder.SqrtF32(x); @@ -167,7 +167,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) { } XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345}); auto exp = builder.SqrtF32(x); @@ -176,7 +176,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) { } XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345}); auto exp = builder.Pow(x, builder.ConstantR0(-.5f)); @@ -188,7 +188,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) { } XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto add = CreateScalarAddComputation(F32, &builder); auto x = builder.ConstantR1( @@ -203,7 +203,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { } XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1( @@ -218,8 +218,8 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) { XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { // Similar to MaxTenValues, except that the inputs come from params rather // than constants. - ComputationBuilder builder(client_, TestName()); - ComputationDataHandle v1, v2; + XlaBuilder builder(TestName()); + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter( {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -236,7 +236,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) { XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { // Similar to MaxTenValuesFromParams, except that the data size passed in and // out is large. - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // Number of floats in the data passed into and out of the computation. constexpr int datalen = 15 * 1000; @@ -259,7 +259,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { expected_vec.push_back(larger); } - ComputationDataHandle v1, v2; + XlaOp v1, v2; std::unique_ptr param0_data = CreateR1Parameter(v1vec, /*parameter_number=*/0, /*name=*/"v1", /*builder=*/&builder, /*data_handle=*/&v1); @@ -274,7 +274,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) { } XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR0(0); @@ -286,7 +286,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) { } XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1( {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1( @@ -299,7 +299,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) { } XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(1); auto x = builder.ConstantR1( @@ -312,7 +312,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) { } XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR0(0); auto one = builder.ConstantR0(1); auto x = builder.ConstantR1( @@ -325,7 +325,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) { } XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto zero = builder.ConstantR1({0.0f, 0.0f}); auto one = builder.ConstantR1({1.0f, 1.0f}); auto x = builder.ConstantR1({2.1, -2.6}); @@ -336,7 +336,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) { } XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto one = builder.ConstantR0(1); auto two = builder.ConstantR0(2); auto x = builder.ConstantR1( @@ -348,11 +348,22 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) { + XlaBuilder builder(TestName()); + auto zero = builder.ConstantR0(0); + auto one = builder.ConstantR0(10); + auto x = builder.ConstantR1({-3, 3, 9, 13}); + auto clamp = builder.Clamp(zero, x, one); + + std::vector expected = {0, 3, 9, 10}; + ComputeAndCompareR1(&builder, expected, {}); +} + XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { - Computation add_half; + XlaComputation add_half; { // add_half(x) = x + 0.5 - ComputationBuilder builder(client_, "add_half"); + XlaBuilder builder("add_half"); auto x_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value"); auto half = builder.ConstantR0(0.5); @@ -362,10 +373,10 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { add_half = computation_status.ConsumeValueOrDie(); } - Computation clamp; + XlaComputation clamp; { // clamp(y) = clamp<0,5>(y) - ComputationBuilder builder(client_, "clamp"); + XlaBuilder builder("clamp"); auto y_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value"); auto zero = builder.ConstantR0(0.0); @@ -375,10 +386,10 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { clamp = computation_status.ConsumeValueOrDie(); } - Computation mult_relu_add; + XlaComputation mult_relu_add; { // mult_relu_add(z) = clamp(add_half(2 * max(z, 0))) - ComputationBuilder builder(client_, "mult_relu_add"); + XlaBuilder builder("mult_relu_add"); auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); auto zero = builder.ConstantR0(0.0); @@ -392,7 +403,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { mult_relu_add = computation_status.ConsumeValueOrDie(); } - ComputationBuilder builder(client_, "map10"); + XlaBuilder builder("map10"); { auto x = builder.ConstantR1( {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); @@ -405,7 +416,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { } XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4}); auto y = builder.ConstantR0(3); builder.Rem(x, y); @@ -415,7 +426,7 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) { } XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({false, true}); auto y = builder.ConstantR1({true, false}); builder.Eq(x, y); @@ -425,7 +436,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) { } XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto x = builder.ConstantR1({false, true}); auto y = builder.ConstantR1({true, false}); builder.Ne(x, y); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 89ce2ce797f..c463f3eac55 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -37,8 +37,6 @@ limitations under the License. #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" -namespace se = ::perftools::gputools; - namespace xla { namespace { @@ -959,22 +957,21 @@ TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileWithPrngScalarResult)) { TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); auto t = outer.Tuple({p, outer.ConstantR1({1, 1})}); - TF_ASSERT_OK_AND_ASSIGN(const std::unique_ptr tuple_shape, - outer.GetShape(t)); + TF_ASSERT_OK_AND_ASSIGN(Shape tuple_shape, outer.GetShape(t)); - ComputationBuilder cond(client_, "cond"); - auto cond_t = cond.Parameter(0, *tuple_shape, "t"); + XlaBuilder cond("cond"); + auto cond_t = cond.Parameter(0, tuple_shape, "t"); TF_ASSERT_OK(Any(cond.Eq(cond.GetTupleElement(cond_t, 0), cond.ConstantR1({42, 42})), &cond) .status()); - ComputationBuilder body(client_, "body"); - auto body_t = body.Parameter(0, *tuple_shape, "t"); + XlaBuilder body("body"); + auto body_t = body.Parameter(0, tuple_shape, "t"); auto e = body.GetTupleElement(body_t, 1); body.Tuple({e, e}); @@ -995,15 +992,15 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); TF_ASSERT_OK( Any(cond.Eq(cond_t, cond.ConstantR1({42, 42})), &cond).status()); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, element_shape, "t"); auto e = body.Broadcast(body.ConstantR0(1.0), {2}); @@ -1021,14 +1018,14 @@ TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Parameter(0, element_shape, "param"); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto cond_t = cond.Parameter(0, element_shape, "t"); cond.Eq(cond_t, cond.ConstantR0(42)); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, element_shape, "t"); auto tuple = body.Tuple({body_t, body.Add(body_t, body.ConstantR0(1))}); @@ -1057,23 +1054,23 @@ TEST_F(WhileTest, WhileWithMixedTupleElements) { auto result_shape = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); - ComputationBuilder outer(client_, "outer"); + XlaBuilder outer("outer"); auto p = outer.Tuple({outer.ConstantR0(0), outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); - ComputationBuilder cond(client_, "cond"); + XlaBuilder cond("cond"); auto params = cond.Parameter(0, result_shape, "prev"); auto cond_t = cond.Add(cond.GetTupleElement(params, 1), cond.GetTupleElement(params, 0)); cond.Lt(cond_t, cond.ConstantR0(30)); - ComputationBuilder body(client_, "body"); + XlaBuilder body("body"); auto body_t = body.Parameter(0, result_shape, "t"); auto tuple = body.Tuple( - {body.Add(body.GetTupleElement(params, 0), body.ConstantR0(1)), - body.Add(body.GetTupleElement(params, 1), body.ConstantR0(1))}); + {body.Add(body.GetTupleElement(body_t, 0), body.ConstantR0(1)), + body.Add(body.GetTupleElement(body_t, 1), body.ConstantR0(1))}); TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); @@ -1323,10 +1320,6 @@ void BM_WhileLoop(int num_iters) { } } -// TODO(b/32470510): Benchmark fails on parallel CPU backend. -#ifndef XLA_TEST_BACKEND_CPU_PARALLEL BENCHMARK(BM_WhileLoop); -#endif - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index ff3418a128e..3c9a01653c6 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,8 +17,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,7 +35,7 @@ limitations under the License. namespace xla { namespace { -namespace se = ::perftools::gputools; + namespace gtl = ::tensorflow::gtl; class HloProfileTest : public ClientLibraryTestBase {}; @@ -83,8 +84,8 @@ Status ParseOneProfileOutputLine( string match_percentage = "\\d+\\.\\d\\d%"; string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; string match_usecs = "([0-9.]+) usec"; - string match_flops = "([^ ]+)"; - string match_trops = "([^ ]+)"; + string match_flops = "([^ ]*)"; + string match_trops = "([^ ]*)"; string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; @@ -119,7 +120,7 @@ Status ParseOneProfileOutputLine( // Returns void so that we can ASSERT. void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, - const Computation& computation, + const XlaComputation& computation, const Shape& lhs_arg_shape, const Shape& rhs_arg_shape) { LocalService* service = ClientLibrary::GetXlaService(client->platform()); @@ -129,18 +130,18 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, auto* transfer_manager = backend->transfer_manager(); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr lhs_arg, + ScopedShapedBuffer lhs_arg, transfer_manager->AllocateScopedShapedBuffer( lhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(lhs_arg_shape), *lhs_arg)); + executor, *Literal::CreateFromShape(lhs_arg_shape), lhs_arg)); TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr rhs_arg, + ScopedShapedBuffer rhs_arg, transfer_manager->AllocateScopedShapedBuffer( rhs_arg_shape, allocator, backend->default_device_ordinal())); TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( - executor, *Literal::CreateFromShape(rhs_arg_shape), *rhs_arg)); + executor, *Literal::CreateFromShape(rhs_arg_shape), rhs_arg)); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr local_executable, @@ -165,7 +166,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, backend->eigen_intra_op_thread_pool()); TF_ASSERT_OK_AND_ASSIGN( auto execution_result, - executable->ExecuteOnStream(&run_options, {lhs_arg.get(), rhs_arg.get()}, + executable->ExecuteOnStream(&run_options, {&lhs_arg, &rhs_arg}, &hlo_execution_profile)); (void)execution_result; @@ -175,8 +176,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, XLA_VLOG_LINES(4, *profile_output); } -// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. -XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { +XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { const int64 m = 256, k = 256, n = 256; Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k}); @@ -186,7 +186,7 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(platform)); - ComputationBuilder builder(client, TestName()); + XlaBuilder builder(TestName()); auto result = builder.Tanh(builder.Add( builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); @@ -239,12 +239,9 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { EXPECT_TRUE(HasTrops(tanh_profile)); } -// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. -// // TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo // instructions "interior" to while nodes. -XLA_TEST_F(HloProfileTest, - DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(ProfileWhileComputation))) { +XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) { const int64 size = 256; Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); Shape while_result_shape = @@ -255,18 +252,18 @@ XLA_TEST_F(HloProfileTest, TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(platform)); - Computation condition; + XlaComputation condition; { - ComputationBuilder builder(client, "condition"); + XlaBuilder builder("condition"); auto state = builder.Parameter(0, while_result_shape, "state"); auto iteration = builder.GetTupleElement(state, 0); builder.Gt(builder.ConstantR0(5), iteration); TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); } - Computation body; + XlaComputation body; { - ComputationBuilder builder(client, "body"); + XlaBuilder builder("body"); auto state = builder.Parameter(0, while_result_shape, "state"); auto matrix = builder.GetTupleElement(state, 1); auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), @@ -275,7 +272,7 @@ XLA_TEST_F(HloProfileTest, TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); } - ComputationBuilder builder(client, TestName()); + XlaBuilder builder(TestName()); auto initial_while_state = builder.Tuple({builder.ConstantR0(0), builder.Parameter(0, matrix_shape, "initial_value")}); diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 44f874cd2ae..56702feab9a 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -42,7 +42,7 @@ StatusOr> TextLiteralReader::ReadPath( << "TextLiteralReader no longer supports reading .gz files"; std::unique_ptr file; Status s = - tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file); + tensorflow::Env::Default()->NewRandomAccessFile(std::string(path), &file); if (!s.ok()) { return s; } @@ -92,7 +92,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { tensorflow::StringPiece sp(shape_string); if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) { - string tmp = sp.ToString(); + string tmp = std::string(sp); shape_string = tmp; } TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string)); @@ -124,10 +124,10 @@ StatusOr> TextLiteralReader::ReadAllLines() { line.c_str()); } float value; - if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(), + if (!tensorflow::strings::safe_strtof(std::string(value_string).c_str(), &value)) { return InvalidArgument("could not parse value as float: \"%s\"", - value_string.ToString().c_str()); + std::string(value_string).c_str()); } SplitByDelimToStringPieces(coordinates_string, ',', &coordinates); coordinate_values.clear(); @@ -136,7 +136,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) { return InvalidArgument( "could not parse coordinate member as int64: \"%s\"", - piece.ToString().c_str()); + std::string(piece).c_str()); } coordinate_values.push_back(coordinate_value); } diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index 3fee467594d..373c0d2d8d8 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -30,10 +30,10 @@ limitations under the License. namespace xla { -/* static */ tensorflow::Status TextLiteralWriter::WriteToPath( +/* static */ Status TextLiteralWriter::WriteToPath( const Literal& literal, tensorflow::StringPiece path) { std::unique_ptr f; - auto s = tensorflow::Env::Default()->NewWritableFile(path.ToString(), &f); + auto s = tensorflow::Env::Default()->NewWritableFile(std::string(path), &f); if (!s.ok()) { return s; } @@ -43,7 +43,7 @@ namespace xla { return s; } - tensorflow::Status status; + Status status; tensorflow::WritableFile* f_ptr = f.get(); literal.EachCellAsString( [f_ptr, &status](tensorflow::gtl::ArraySlice indices, diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h index 7375493f430..0a1235b5e04 100644 --- a/tensorflow/compiler/xla/text_literal_writer.h +++ b/tensorflow/compiler/xla/text_literal_writer.h @@ -37,8 +37,8 @@ namespace xla { // This should be readable by xla::TextLiteralReader. class TextLiteralWriter { public: - static tensorflow::Status WriteToPath(const Literal& literal, - tensorflow::StringPiece path); + static Status WriteToPath(const Literal& literal, + tensorflow::StringPiece path); private: TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 0bc4045a549..415cf9c16a2 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -36,11 +36,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -63,10 +62,9 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -84,11 +82,10 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -164,12 +161,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -183,12 +179,11 @@ tf_cc_binary( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) @@ -201,13 +196,12 @@ tf_cc_binary( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", - "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc index 21ae8583d7c..befb5545377 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc @@ -17,7 +17,7 @@ limitations under the License. // // Dumps a graphviz URL for a snapshot computation to the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The GraphViz URL is placed into the log stderr, whereas computation @@ -30,11 +30,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -49,10 +48,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); ComputationStats stats = diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index b82f1c81c84..cfb8f37487d 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -21,11 +21,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -66,16 +65,16 @@ void RealMain(tensorflow::gtl::ArraySlice args) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); std::unique_ptr program_shape = client->GetComputationShape(computation).ConsumeValueOrDie(); @@ -89,8 +88,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 05c0fdf97d2..b815bbf854b 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -19,11 +19,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/computation_tracker.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -40,16 +39,16 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { LocalService* local_service = ClientLibrary::GetXlaService(client->platform()); for (char* arg : args) { - SessionModule session_module; + HloSnapshot snapshot; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); + &snapshot)); + auto computation_status = client->LoadSnapshot(snapshot); if (!computation_status.ok()) { fprintf(stderr, "could not load snapshot for %s: %s\n", arg, computation_status.status().ToString().c_str()); continue; } - Computation computation = computation_status.ConsumeValueOrDie(); + XlaComputation computation = computation_status.ConsumeValueOrDie(); if (compile) { std::unique_ptr program_shape = @@ -65,8 +64,7 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { build_options.set_device_ordinal(0); build_options.set_result_layout(program_shape->result()); StatusOr> executable = - local_service->CompileExecutable(computation.handle(), layouts, - build_options); + local_service->CompileExecutable(computation, layouts, build_options); const HloModule& module = executable.ValueOrDie()->module(); @@ -74,13 +72,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { local_service->backend().platform()->Name().c_str(), module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { - const ComputationTracker& tracker = local_service->computation_tracker(); - UserComputation* user_computation = - tracker.Resolve(computation.handle()).ConsumeValueOrDie(); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); + auto config = HloModule::CreateModuleConfigFromProto(computation.proto(), + DebugOptions()) + .ConsumeValueOrDie(); std::unique_ptr module = - tracker.BuildHloModule(versioned_handle, HloModuleConfig()) + HloModule::CreateFromProto(computation.proto(), config) .ConsumeValueOrDie(); fprintf(stdout, "%s\n", diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 51f90b07c66..a5dce20456c 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -48,10 +47,11 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + XlaComputation computation = + client->LoadSnapshot(module).ConsumeValueOrDie(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_hlo_dump_as_graphdef(true); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index fc0e4444521..350db126535 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -230,7 +230,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = identifier.ToString(); + str_val_ = std::string(identifier); return TokKind::kIdent; } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index b2f122982ad..d0e7af88442 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -30,6 +30,7 @@ namespace { using tensorflow::StringPiece; using tensorflow::gtl::optional; +using tensorflow::str_util::Join; using tensorflow::str_util::Split; using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; @@ -53,7 +54,7 @@ class HloParser { std::unique_ptr ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + string GetError() const { return Join(error_, "\n"); } private: // ParseXXX returns false if an error occurred. @@ -242,10 +243,10 @@ bool HloParser::Error(LocTy loc, StringPiece msg) { std::vector error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(std::string(lexer_.GetLine(loc))); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + error_.push_back(Join(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } @@ -303,18 +304,20 @@ bool HloParser::ParseComputations() { // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - if (param_shape.has_layout()) { - module_->mutable_entry_computation_layout() - ->mutable_parameter_layout(p) - ->ResetLayout(param_shape.layout()); - } + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_parameter_layout(p) + ->CopyLayoutFromShape(param_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + ->mutable_parameter_layout(p) + ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - if (result_shape.has_layout()) { - module_->mutable_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(result_shape.layout()); - } + TF_CHECK_OK(module_->mutable_host_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); + TF_CHECK_OK(module_->mutable_device_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(result_shape)); } return true; @@ -437,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + optional backend_config; + attrs["backend_config"] = {/*required=*/false, AttrTy::kString, + &backend_config}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -470,13 +477,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kRoundNearestAfz: case HloOpcode::kBitcast: case HloOpcode::kCeil: + case HloOpcode::kClz: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kExpm1: case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: + case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kReal: @@ -724,15 +734,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], *broadcast_dimensions)); break; } - case HloOpcode::kBroadcastDimOne: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction( - HloInstruction::CreateBroadcastDimOne(shape, operands[0])); - break; - } case HloOpcode::kConcatenate: { optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, @@ -1099,8 +1100,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_name(name); - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. + // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { instruction->set_sharding( HloSharding::FromProto(sharding.value()).ValueOrDie()); @@ -1117,6 +1117,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (backend_config) { + instruction->set_backend_config(std::move(*backend_config)); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1494,11 +1497,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - tensorflow::str_util::Join( - elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); - }), + Join(elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1686,7 +1688,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", tensorflow::str_util::Join(index, ", "), "]")); + ": [", Join(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -1854,7 +1856,19 @@ bool HloParser::ParseAttributeHelper( } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return Error(loc, Printf("unexpected attribute %s", name.c_str())); + string allowed_attrs; + if (attrs.empty()) { + allowed_attrs = "No attributes are allowed here."; + } else { + allowed_attrs = StrCat( + "Allowed attributes: ", + Join(attrs, ", ", + [&](string* out, const std::pair& kv) { + StrAppend(out, kv.first); + })); + } + return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), + allowed_attrs.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 57684b58346..131aded95ab 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -57,18 +57,6 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } -)" -}, -// broadcast size-one dimensions -{ -"BroadcastDimOne", -R"(HloModule broadcast_dim_one_module - -ENTRY %broadcast-dim-one () -> f32[2,2] { - %constant = f32[1,2]{1,0} constant(f32[1,2] { { 1.1, 2.2 } }) - ROOT %broadcast-dim-one = f32[2,2]{1,0} broadcast-dim-one(f32[1,2]{1,0} %constant) -} - )" }, // pred constant @@ -77,7 +65,7 @@ ENTRY %broadcast-dim-one () -> f32[2,2] { R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar" } )" @@ -93,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] { )" }, -// f32 constant, but the value is not a decimal +// f32 constant, but the value is not a decimal and there is a backend +// configuration { "ConstantF32", R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { - ROOT %constant = f32[] constant(42) + ROOT %constant = f32[] constant(42), backend_config="this is a configuration" } )" @@ -949,13 +938,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOpcode) { @@ -969,7 +958,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongShape) { @@ -981,7 +970,7 @@ ENTRY %blabla (x: g32[]) -> g32[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, WrongOperandsSize) { @@ -994,7 +983,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, OperandNotFound) { @@ -1005,7 +994,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, MoreConstants) { @@ -1025,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, ConfigurationField) { + const string original = R"(HloModule AModule +ENTRY %configuration_test() -> s32[] { + %constant = s32[] constant(42), backend_config="foo bar" +})"; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); + EXPECT_EQ("foo bar", result.ValueOrDie() + ->entry_computation() + ->root_instruction() + ->backend_config()); +} + TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { const string original = R"(HloModule some_2_module @@ -1034,7 +1036,7 @@ ENTRY %some_2 () -> f32[2] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); } @@ -1048,7 +1050,7 @@ ENTRY %some_2x3 () -> f32[2,3] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); } @@ -1062,7 +1064,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); } @@ -1077,7 +1079,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { )"; auto result = Parse(original); - EXPECT_NE(tensorflow::Status::OK(), result.status()); + EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); } @@ -1104,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} } )"; @@ -1150,7 +1152,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { )"; ExpectHasSubstr(Parse(original).status().error_message(), - "unexpected attribute calls"); + "unexpected attribute \"calls\""); } TEST_F(HloParserTest, MissingAttribute) { @@ -1251,7 +1253,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { auto module = Parse(original); TF_ASSERT_OK(module.status()); - auto program_layout = module.ValueOrDie()->entry_computation_layout(); + auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); auto param_layout = program_layout.parameter_layout(0).layout(); auto result_layout = program_layout.result_layout().layout(); diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 62a353ad09a..df0501386c1 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -17,7 +17,7 @@ limitations under the License. // // Replays computations and shows the results on the command line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // Computations that require arguments can be replayed using fake data by @@ -36,13 +36,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -75,9 +74,10 @@ struct Options { // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. -StatusOr> ReplayComputation( - const SessionModule& module, Client* client, const Options& opts) { - TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); +StatusOr> ReplayComputation(const HloSnapshot& module, + Client* client, + const Options& opts) { + TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); std::vector> arguments; if (opts.use_fake_data) { @@ -153,10 +153,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - SessionModule module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); + HloSnapshot snapshot; + auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); + if (!status.ok()) { + fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg, + status.ToString().c_str()); + continue; + } StatusOr> result_status = - ReplayComputation(module, client, opts); + ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -166,14 +171,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { std::unique_ptr result = result_status.ConsumeValueOrDie(); if (result != nullptr) { - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s:%s\n", arg, + snapshot.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(result->shape()).c_str(), result->ToString().c_str()); - if (module.has_result()) { + if (snapshot.has_result()) { std::unique_ptr literal = - Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); + Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), + ShapeUtil::HumanString(snapshot.result().shape()).c_str(), literal->ToString().c_str()); } } diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc index 1f3340cbc6a..4e53fafcc97 100644 --- a/tensorflow/compiler/xla/tools/show_signature.cc +++ b/tensorflow/compiler/xla/tools/show_signature.cc @@ -18,7 +18,7 @@ limitations under the License. // Shows the signature (ProgramShape) of binary snapshot proto(s) on the command // line. // -// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// some_binary_snapshot_proto is obtained by serializing the HloSnapshot from // ServiceInterface::SnapshotComputation to disk. // // The output format is: @@ -31,9 +31,8 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -49,13 +48,14 @@ namespace tools { void RealMain(tensorflow::gtl::ArraySlice args) { Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule module; + HloSnapshot module; TF_CHECK_OK( tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); - Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + auto computation = client->LoadSnapshot(module).ConsumeValueOrDie(); std::unique_ptr shape = client->GetComputationShape(computation).ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(), + fprintf(stdout, "%s: %s :: %s\n", arg, + module.hlo().hlo_module().name().c_str(), ShapeUtil::HumanString(*shape).c_str()); } } diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 9fa4297523b..b645acb700b 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -46,4 +46,10 @@ using ::Eigen::half; } // namespace xla +// Alias namespace ::stream_executor as ::xla::se. +namespace stream_executor {} +namespace xla { +namespace se = ::stream_executor; +} // namespace xla + #endif // TENSORFLOW_COMPILER_XLA_TYPES_H_ diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 2da9f9ed6f4..be33bd6dd13 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -528,6 +528,16 @@ bool IsInt32(T x) { // value is implementation-defined." return static_cast(x) == x; } + +template +Status EraseElementFromVector(std::vector* container, const T& value) { + // c_find returns a const_iterator which does not seem to work on gcc 4.8.4, + // and this breaks the ubuntu/xla_gpu build bot. + auto it = std::find(container->begin(), container->end(), value); + TF_RET_CHECK(it != container->end()); + container->erase(it); + return Status::OK(); +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 93284b80f9e..f11123ca248 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -199,6 +199,9 @@ bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { int64 DilatedBound(int64 bound, int64 dilation) { CHECK_GE(bound, 0); CHECK_GE(dilation, 1); + if (bound == 0) { + return 0; + } // Suppose the array has three entries 123 and the dilation factor is 4. Then // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except @@ -212,7 +215,7 @@ int64 StridedBound(int64 bound, int64 window_size, int64 stride) { CHECK_GE(bound, 0); CHECK_GE(stride, 1); - if (window_size > bound) { + if (bound == 0 || window_size > bound) { return 0; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index f18d53c6089..b895ac045c3 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -134,6 +134,8 @@ enum Format { // example, Convert) are ignored. // // See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange message Layout { // The method used to store the data in memory. The format determines which of // the other fields are used by the layout. @@ -159,9 +161,12 @@ message Layout { // memory. This field must be unset unless the format is SPARSE. int64 max_sparse_elements = 5; - // Important: if any field is added, be sure to modify ShapeUtil::Equal() - // appropriately to account for the new field. + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. } +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) // A shape describes the number of dimensions in the array, the size of each // dimension, and the primitive component type. @@ -170,6 +175,8 @@ message Layout { // defined. // // See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange message Shape { reserved 1; reserved "rank"; @@ -190,9 +197,12 @@ message Shape { // The layout used to back this shape. Layout layout = 5; - // Important: if any field is added, be sure to modify ShapeUtil::Equal() and - // ShapeUtil::Compatible() appropriately to account for the new field. + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. } +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) // Shape of the parameters and output of a computation (like a traditional // function signature). @@ -801,6 +811,15 @@ enum UnaryOperation { // Elementwise, extract real component of complex x. UNOP_IMAG = 16; + + // Elementwise, computes clz(x). + UNOP_CLZ = 17; + + // Elementwise, computes exp(x)-1. + UNOP_EXPM1 = 18; + + // Elementwise, computes log(x+1). + UNOP_LOG1P = 19; } message UnaryOpRequest { diff --git a/tensorflow/compiler/xla/xlalogo.png b/tensorflow/compiler/xla/xlalogo.png new file mode 100644 index 00000000000..7a0a295953d Binary files /dev/null and b/tensorflow/compiler/xla/xlalogo.png differ diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 9bef0d8b61e..0f9c80404ad 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -25,11 +25,13 @@ py_library( "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", + "//tensorflow/contrib/checkpoint/python:checkpoint", "//tensorflow/contrib/cloud:cloud_py", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", - "//tensorflow/contrib/coder:coder_ops_py", + "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/constrained_optimization", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", @@ -69,6 +71,7 @@ py_library( "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/mixed_precision:mixed_precision", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", @@ -77,6 +80,7 @@ py_library( "//tensorflow/contrib/optimizer_v2:optimizer_v2_py", "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", + "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", "//tensorflow/contrib/autograph", @@ -86,6 +90,7 @@ py_library( "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/contrib/rpc", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/contrib/seq2seq:seq2seq_py", "//tensorflow/contrib/signal:signal_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index aaddb06fa0c..9aad772f0ac 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -24,10 +24,12 @@ import os # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import batching from tensorflow.contrib import bayesflow +from tensorflow.contrib import checkpoint from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler +from tensorflow.contrib import constrained_optimization from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn @@ -58,18 +60,20 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import mixed_precision from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor +from tensorflow.contrib import proto from tensorflow.contrib import quantization from tensorflow.contrib import quantize -from tensorflow.contrib import recurrent from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn +from tensorflow.contrib import rpc from tensorflow.contrib import saved_model from tensorflow.contrib import seq2seq from tensorflow.contrib import signal @@ -92,6 +96,7 @@ if os.name != "nt": from tensorflow.contrib.lite.python import lite from tensorflow.contrib.optimizer_v2 import optimizer_v2_symbols as optimizer_v2 from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field +from tensorflow.contrib.recurrent.python import recurrent_api as recurrent from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs from tensorflow.contrib.summary import summary diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 8add2aacff1..159d985db5c 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -18,10 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import math -import re from tensorflow.contrib import nccl +from tensorflow.python.framework import device as device_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -659,21 +660,20 @@ def _split_by_task(devices, values): num_devices = len(devices) if num_devices != len(values): raise ValueError("len(devices) must equal len(values)") - pattern = re.compile(r"/task:(\d+)/") - per_task_devices = [] - per_task_values = [] + per_task_devices = collections.OrderedDict() + per_task_values = collections.OrderedDict() for d in range(num_devices): - m = pattern.search(devices[d]) - if m: - index = int(m.group(1)) - while index >= len(per_task_devices): - per_task_devices.append([]) - per_task_values.append([]) - per_task_devices[index].append(devices[d]) - per_task_values[index].append(values[d]) - else: + d_spec = device_lib.DeviceSpec.from_string(devices[d]) + if not hasattr(d_spec, "task") or d_spec.task is None: assert False, "failed to parse device %s" % devices[d] - return (per_task_devices, per_task_values) + index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task) + if index not in per_task_devices: + per_task_devices[index] = [] + per_task_values[index] = [] + per_task_devices[index].append(devices[d]) + per_task_values[index].append(values[d]) + + return (list(per_task_devices.values()), list(per_task_values.values())) def build_nccl_all_reduce(input_tensors, red_op, un_op=None): diff --git a/tensorflow/contrib/android/BUILD b/tensorflow/contrib/android/BUILD index 60306ebdc6c..c10179ba8b2 100644 --- a/tensorflow/contrib/android/BUILD +++ b/tensorflow/contrib/android/BUILD @@ -72,7 +72,7 @@ cc_binary( "-s", "-Wl,--gc-sections", "-Wl,--version-script", # This line must be directly followed by LINKER_SCRIPT. - LINKER_SCRIPT, + "$(location {})".format(LINKER_SCRIPT), ]), linkshared = 1, linkstatic = 1, diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md index 7e84f237dc9..0ba99c396fc 100644 --- a/tensorflow/contrib/autograph/README.md +++ b/tensorflow/contrib/autograph/README.md @@ -1,4 +1,122 @@ -# Autograph +# AutoGraph -A compiler for generating TensorFlow numeric and control flow ops from Python -code. +IMPORTANT: AutoGraph is pre-alpha, under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! + +AutoGraph is a Python to TensorFlow compiler. + +With AutoGraph, you can write [Eager style](https://www.tensorflow.org/programmers_guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. + +For example, this Python function: + +``` +def f(x): + if x < 0: + x = -x + return x +``` + +would be converted to this: + +``` +def graph_mode_f(x): + with tf.name_scope('f'): + + def if_true(): + with tf.name_scope('if_true'): + x_1, = x, + x_1 = tf.negative(x_1) + return x_1, + + def if_false(): + with tf.name_scope('if_false'): + x_1, = x, + return x_1, + x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false) + return x +``` + +so you can use it like an op: + +``` +with tf.Graph().as_default(): + x = tf.constant(-1.0) + + converted_f = autograph.to_graph(f) + y = converted_f(x) + + with tf.Session() as sess: + print(sess.run(y)) + # Output: 1 +``` + +# Getting started + +Use AutoGraph in one of the following ways, described below: + + 1. Annotations (simpler) + 2. Functional API (more flexible) + +To get started, install the latest nightly TensorFlow build: + +```shell +pip install -U tf-nightly +``` + +Then import the `autograph` module from `tf.contrib`: + +``` +from tensorflow.contrib import autograph as ag +``` + +### Interactive demo notebooks + +For more extensive examples, check out these interactive notebooks: + + * [RNN trained using Keras and Estimators](https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb) + * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb) + +## Using with annotations + +Annotating a function or class with `@convert` converts it in place: + +``` +@ag.convert() +def f(x): + if x < 0: + x = -x + return x +``` + +... so that it always outputs TensorFlow code: + +``` +with tf.Graph().as_default(): + x = tf.constant(-1) + + y = f(x) + + with tf.Session() as sess: + print(sess.run(y)) + # Output: 1 +``` + +## Using the functional API + +The functional API allows you to convert an existing function, class or object after it was defined: + +``` +converted_f = ag.to_graph(f) + +print(converted_f(tf.constant(-1))) +# Output: Tensor + +print(f(-1)) +# Output: 1 +``` + +You can use the functional API to inspect the generated code as well: + +``` +print(ag.to_code(f)) +# Output: +``` diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index a39f44b21aa..3386c4eca4b 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -21,6 +21,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +# TODO(mdan): Bring only the relevant symbols to the top level. from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call diff --git a/tensorflow/contrib/autograph/converters/asserts.py b/tensorflow/contrib/autograph/converters/asserts.py index f011a97ade9..3b0db677ce5 100644 --- a/tensorflow/contrib/autograph/converters/asserts.py +++ b/tensorflow/contrib/autograph/converters/asserts.py @@ -27,15 +27,13 @@ from tensorflow.contrib.autograph.pyct import transformer class AssertsTransformer(transformer.Base): """Transforms Print nodes to Call so they can be handled as functions.""" - # pylint:disable=invalid-name - def visit_Assert(self, node): self.generic_visit(node) # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ - tf.Assert(test, [msg]) + tf.Assert(test, (msg,)) """ if node.msg is None: @@ -44,9 +42,7 @@ class AssertsTransformer(transformer.Base): elif isinstance(node.msg, gast.Str): return templates.replace(template, test=node.test, msg=node.msg) else: - raise NotImplementedError('Can only convert string messages for now.') - - # pylint:enable=invalid-name + raise NotImplementedError('can only convert string messages for now.') def transform(node, context): diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 62115d4005c..35877224b87 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -26,96 +26,107 @@ from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class BreakCanonicalizationTransformer(transformer.Base): +# Tags for local state. +BREAK_USED = 'break_used' +CONTROL_VAR_NAME = 'control_var_name' + + +class BreakStatementTransformer(transformer.Base): """Canonicalizes break statements into additional conditionals.""" - def __init__(self, context): - super(BreakCanonicalizationTransformer, self).__init__(context) - # This is a stack structure, to correctly process nested loops. - # Each item is a list [break_used, break_variable_name] - self.break_uses = [] - - def _create_break_check(self): - template = """ - (not var_name) - """ - expr, = templates.replace(template, var_name=self.break_uses[-1][1]) - return expr.value - - def _create_break_trigger(self): - template = """ - var_name = True - """ - block = templates.replace(template, var_name=self.break_uses[-1][1]) - block.append(gast.Continue()) - return block - - def _create_break_init(self): - template = """ - var_name = False - """ - assign, = templates.replace(template, var_name=self.break_uses[-1][1]) - return assign - - # TODO(mdan): Surely the transformer supports this better? - def _manual_visit_list(self, block): - new_block = [] - for n in block: - new_n = self.visit(n) - if isinstance(new_n, list): - new_block.extend(new_n) - else: - new_block.append(new_n) - return new_block - - def visit_While(self, node): - self.generic_visit(node.test) - scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - - break_var = self.context.namer.new_symbol('break_requested', - scope.referenced) - self.break_uses.append([False, break_var]) - node.body = self._manual_visit_list(node.body) - if self.break_uses[-1][0]: - node.test = gast.BoolOp(gast.And(), [ - node.test, - gast.UnaryOp(gast.Not(), gast.Name(break_var, gast.Load(), None)) - ]) - final_nodes = [self._create_break_init(), node] - else: - final_nodes = node - self.break_uses.pop() - - for n in node.orelse: - self.generic_visit(n) - return final_nodes - - def visit_For(self, node): - self.generic_visit(node.target) - self.generic_visit(node.iter) - scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - - break_var = self.context.namer.new_symbol('break_requested', - scope.referenced) - self.break_uses.append([False, break_var]) - node.body = self._manual_visit_list(node.body) - if self.break_uses[-1][0]: - extra_cond = templates.replace_as_expression( - 'not var_name', var_name=break_var) - anno.setanno(node, 'extra_cond', extra_cond) - final_nodes = [self._create_break_init(), node] - else: - final_nodes = node - self.break_uses.pop() - - for n in node.orelse: - self.generic_visit(n) - return final_nodes + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used def visit_Break(self, node): - self.break_uses[-1][0] = True - return self._create_break_trigger() + self.set_local(BREAK_USED, True) + var_name = self.get_local(CONTROL_VAR_NAME) + # TODO(mdan): This will fail when expanded inside a top-level else block. + template = """ + var_name = True + continue + """ + return templates.replace(template, var_name=var_name) + + def _guard_if_present(self, block, var_name): + """Prevents the block from executing if var_name is set.""" + + # If we don't have statements that immediately depend on the break + # we still need to make sure that the break variable remains + # used, in case the break becomes useful in later stages of transformation. + # Not having this broke the break_in_inner_loop test. + if not block: + block = [gast.Pass()] + template = """ + if not var_name: + block + """ + node = templates.replace( + template, + var_name=var_name, + block=block) + return node + + def visit_While(self, node): + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + break_var = self.context.namer.new_symbol('break__', scope.referenced) + + node.test = self.visit(node.test) + node.body, break_used = self._track_body(node.body, break_var) + # A break in the else clause applies to the containing scope. + node.orelse = self.visit_block(node.orelse) + + if break_used: + template = """ + var_name = False + while test and not var_name: + body + else: + orelse + """ + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + node = templates.replace( + template, + var_name=break_var, + test=node.test, + body=node.body, + orelse=self._guard_if_present(node.orelse, break_var)) + + return node + + def visit_For(self, node): + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + break_var = self.context.namer.new_symbol('break__', scope.referenced) + + node.target = self.visit(node.target) + node.iter = self.visit(node.iter) + node.body, break_used = self._track_body(node.body, break_var) + # A break in the else clause applies to the containing scope. + node.orelse = self.visit_block(node.orelse) + + if break_used: + node.orelse = self._guard_if_present(node.orelse, break_var) + template = """ + var_name = False + for_stmt + """ + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + node = templates.replace( + template, + var_name=break_var, + for_stmt=node) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + anno.setanno(node[1], 'extra_test', extra_test) + + return node def transform(node, context): - return BreakCanonicalizationTransformer(context).visit(node) + return BreakStatementTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/break_statements_test.py b/tensorflow/contrib/autograph/converters/break_statements_test.py index dd4914a022f..1af59e9b526 100644 --- a/tensorflow/contrib/autograph/converters/break_statements_test.py +++ b/tensorflow/contrib/autograph/converters/break_statements_test.py @@ -25,7 +25,7 @@ from tensorflow.python.platform import test class BreakCanonicalizationTest(converter_test_base.TestCase): - def test_basic_break(self): + def test_basic_while(self): def test_fn(x): v = [] @@ -40,13 +40,11 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): node = break_statements.transform(node, self.ctx) with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + self.assertEqual([], result.test_fn(0)) + self.assertEqual([], result.test_fn(1)) + self.assertEqual([3], result.test_fn(4)) - def test_basic_break_for_loop(self): + def test_basic_for(self): def test_fn(a): v = [] @@ -57,30 +55,18 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): v.append(x) return v - # The break is incompletely canonicalized for for loops. Everything is - # in place except for the condition verification. - def test_equiv_fn(a): - v = [] - for x in a: - x -= 1 - if x % 2 == 0: - continue - v.append(x) - return v - node = self.parse_and_analyze(test_fn, {}) node = break_statements.transform(node, self.ctx) with self.compiled(node) as result: - # The break is incompletely canonicalized. Everything is in place, but - # the loop does not break. - self.assertEqual(test_equiv_fn([]), result.test_fn([])) - self.assertEqual(test_equiv_fn([1]), result.test_fn([1])) - self.assertEqual(test_equiv_fn([2]), result.test_fn([2])) - self.assertEqual( - test_equiv_fn([1, 2, 3, 4]), result.test_fn([1, 2, 3, 4])) + # The break is incompletely canonicalized. The loop will not interrupt, + # but the section following the break will be skipped. + self.assertEqual([], result.test_fn([])) + self.assertEqual([3, 3], result.test_fn([4, 4])) + self.assertEqual([3], result.test_fn([4, 5])) + self.assertEqual([3], result.test_fn([5, 4])) - def test_continue_deeply_nested(self): + def test_deeply_nested(self): def test_fn(x): v = [] @@ -93,7 +79,7 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): u.append(x) else: w.append(x) - continue + break v.append(x) return v, u, w @@ -101,11 +87,60 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): node = break_statements.transform(node, self.ctx) with self.compiled(node) as result: - self.assertEqual(test_fn(0), result.test_fn(0)) - self.assertEqual(test_fn(1), result.test_fn(1)) - self.assertEqual(test_fn(2), result.test_fn(2)) - self.assertEqual(test_fn(3), result.test_fn(3)) - self.assertEqual(test_fn(4), result.test_fn(4)) + self.assertEqual(([], [], []), result.test_fn(0)) + self.assertEqual(([2, 1], [2], [0]), result.test_fn(3)) + self.assertEqual(([10, 9, 8, 7], [10, 8], [6]), result.test_fn(11)) + + def test_nested_loops(self): + + def test_fn(x): + v = [] + u = [] + while x > 0: + x -= 1 + y = x + while y > 0: + y -= 1 + if y % 2 == 0: + break + u.append(y) + if x == 0: + break + v.append(x) + return v, u + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + self.assertEqual(([], []), result.test_fn(0)) + self.assertEqual(([1], []), result.test_fn(2)) + self.assertEqual(([2, 1], [1]), result.test_fn(3)) + self.assertEqual(([4, 3, 2, 1], [3, 1]), result.test_fn(5)) + + def test_loop_else(self): + + def test_fn(x): + v = [] + u = [] + while x > 0: + x -= 1 + y = x + while y > 1: + break + else: + u.append(y) + break + v.append(x) + return v, u + + node = self.parse_and_analyze(test_fn, {}) + node = break_statements.transform(node, self.ctx) + + with self.compiled(node) as result: + self.assertEqual(([], []), result.test_fn(0)) + self.assertEqual(([], [1]), result.test_fn(2)) + self.assertEqual(([2], [1]), result.test_fn(3)) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 0349ce29ceb..317711a866f 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -34,24 +34,24 @@ class BuiltinFunctionTransformer(transformer.Base): def __init__(self, context): super(BuiltinFunctionTransformer, self).__init__(context) - # pylint:disable=invalid-name - def _convert_builtin(self, node): template = """ - autograph_utils.dynamic_builtin(func, args) + ag__.utils.dynamic_builtin(func, args) """ return templates.replace(template, func=node.func, args=node.args)[0].value def _convert_print(self, node): template = """ - autograph_utils.dynamic_print(args) + ag__.utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. - if isinstance(node.func, gast.Name) and node.func.id in ('len', 'range'): + # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. + if (isinstance(node.func, gast.Name) and + node.func.id in ('len', 'range', 'xrange')): return self._convert_builtin(node) # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': @@ -70,8 +70,6 @@ class BuiltinFunctionTransformer(transformer.Base): function_call = templates.replace(template, fname='print', args=args)[0] return self.visit(function_call) - # pylint:enable=invalid-name - def transform(node, context): return BuiltinFunctionTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index ac7e756c47c..30272409df3 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -26,8 +26,6 @@ from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -49,7 +47,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): self.assertEqual(3, result.test_fn([0, 0, 0])) - def test_print_with_op(self): + def test_print(self): def test_fn(a): print(a) @@ -57,14 +55,12 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a') + result.test_fn(constant_op.constant('a')) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a\n') finally: @@ -72,41 +68,19 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): def test_print_with_op_multiple_values(self): - def test_fn(a, b): - print(a, b) - - node = self.parse_and_analyze(test_fn, {'print': print}) - node = builtin_functions.transform(node, self.ctx) - - # Note: it's relevant not to include script_ops.py_func here, to verify - # that tf.Print is used. - with self.compiled(node, logging_ops.Print) as result: - with self.test_session() as sess: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn('a', 1) - sess.run(sess.graph.get_operations()) - self.assertEqual(out_capturer.getvalue(), 'a 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_print_with_py_func(self): - def test_fn(a, b, c): print(a, b, c) node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - # Note: it's relevant not to include logging_ops.Print here, to verify - # that py_func is used. - with self.compiled(node, script_ops.py_func) as result: + with self.compiled(node) as result: with self.test_session() as sess: try: out_capturer = six.StringIO() sys.stdout = out_capturer - result.test_fn('a', 1, [2, 3]) + result.test_fn( + constant_op.constant('a'), constant_op.constant(1), [2, 3]) sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') finally: diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index b9088026c1e..554f0471d44 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -198,7 +198,7 @@ class CallTreeTransformer(transformer.Base): def _wrap_to_py_func_no_return(self, node): # TODO(mdan): Properly handle varargs, etc. template = """ - autograph_utils.wrap_py_func(func, None, (args,), kwargs, True) + ag__.utils.wrap_py_func(func, None, (args,), kwargs, True) """ return templates.replace( template, @@ -209,7 +209,7 @@ class CallTreeTransformer(transformer.Base): def _wrap_to_py_func_single_return(self, node, dtype): # TODO(mdan): Properly handle varargs, etc. template = """ - autograph_utils.wrap_py_func(func, dtype, (args,), kwargs, False) + ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False) """ return templates.replace_as_expression( template, @@ -237,7 +237,7 @@ class CallTreeTransformer(transformer.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - autograph_api.converted_call(func, True, False, {}, args) + ag__.converted_call(func, True, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value @@ -245,8 +245,6 @@ class CallTreeTransformer(transformer.Base): new_call.keywords = node.keywords return new_call - # pylint:disable=invalid-name - def visit_Expr(self, node): if isinstance(node.value, gast.Call): if anno.hasanno(node.value.func, 'live_val'): @@ -294,15 +292,17 @@ class CallTreeTransformer(transformer.Base): raise NotImplementedError( 'py_func with return values (unknown function)') else: - if self.context.recursive: + if ast_util.matches(node, 'super(_)'): + # super() calls are preserved. The class conversion mechanism will + # ensure that they return the correct value. + pass + elif self.context.recursive: node = self._insert_dynamic_conversion(node) else: # Unresolved functions are allowed in non-recursive mode. pass return node - # pylint:enable=invalid-name - def transform(node, context, uncompiled_modules, nocompile_decorators): """Transform function call to the compiled counterparts. diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 55a28e8ac30..d7ddbe8a04f 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handles control flow statements: while, if.""" +"""Handles control flow statements: while, for, if.""" from __future__ import absolute_import from __future__ import division @@ -25,6 +25,7 @@ from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer +from tensorflow.contrib.autograph.pyct.static_analysis import cfg from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno @@ -47,9 +48,6 @@ class SymbolNamer(object): class ControlFlowTransformer(transformer.Base): """Transforms control flow structures like loops an conditionals.""" - def __init__(self, context): - super(ControlFlowTransformer, self).__init__(context) - def _create_cond_branch(self, body_name, aliased_orig_names, aliased_new_names, body, returns): if aliased_orig_names: @@ -78,7 +76,7 @@ class ControlFlowTransformer(transformer.Base): def _create_cond_expr(self, results, test, body_name, orelse_name): if results is not None: template = """ - results = autograph_utils.run_cond(test, body_name, orelse_name) + results = ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, @@ -88,7 +86,7 @@ class ControlFlowTransformer(transformer.Base): orelse_name=orelse_name) else: template = """ - autograph_utils.run_cond(test, body_name, orelse_name) + ag__.utils.run_cond(test, body_name, orelse_name) """ return templates.replace( template, test=test, body_name=body_name, orelse_name=orelse_name) @@ -98,30 +96,63 @@ class ControlFlowTransformer(transformer.Base): body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, NodeAnno.ORELSE_SCOPE) + body_defs = body_scope.created | body_scope.modified + orelse_defs = orelse_scope.created | orelse_scope.modified + live = anno.getanno(node, 'live_out') - if body_scope.created - orelse_scope.created: - raise ValueError( - 'The if branch creates new symbols that the else branch does not.') - if orelse_scope.created - body_scope.created: - raise ValueError( - 'The else branch creates new symbols that the if branch does not.') + # We'll need to check if we're closing over variables that are defined + # elsewhere in the function + # NOTE: we can only detect syntactic closure in the scope + # of the code passed in. If the AutoGraph'd function itself closes + # over other variables, this analysis won't take that into account. + defined = anno.getanno(node, 'defined_in') - modified = tuple(body_scope.modified | orelse_scope.modified) - all_referenced = body_scope.referenced | orelse_scope.referenced + # We only need to return variables that are + # - modified by one or both branches + # - live (or has a live parent) at the end of the conditional + modified = [] + for def_ in body_defs | orelse_defs: + def_with_parents = set((def_,)) | def_.support_set + if live & def_with_parents: + modified.append(def_) + + # We need to check if live created variables are balanced + # in both branches + created = live & (body_scope.created | orelse_scope.created) + + # The if statement is illegal if there are variables that are created, + # that are also live, but both branches don't create them. + if created: + if created != (body_scope.created & live): + raise ValueError( + 'The main branch does not create all live symbols that the else ' + 'branch does.') + if created != (orelse_scope.created & live): + raise ValueError( + 'The else branch does not create all live symbols that the main ' + 'branch does.') # Alias the closure variables inside the conditional functions # to avoid errors caused by the local variables created in the branch # functions. - need_alias = ( - (body_scope.modified | orelse_scope.modified) - - (body_scope.created | orelse_scope.created)) - aliased_orig_names = tuple(need_alias) - aliased_new_names = tuple( - self.context.namer.new_symbol(s.ssf(), all_referenced) - for s in aliased_orig_names) - alias_map = dict(zip(aliased_orig_names, aliased_new_names)) - node_body = ast_util.rename_symbols(node.body, alias_map) - node_orelse = ast_util.rename_symbols(node.orelse, alias_map) + # We will alias variables independently for body and orelse scope, + # because different branches might write different variables. + aliased_body_orig_names = tuple(body_scope.modified - body_scope.created) + aliased_orelse_orig_names = tuple(orelse_scope.modified - + orelse_scope.created) + aliased_body_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), body_scope.referenced) + for s in aliased_body_orig_names) + aliased_orelse_new_names = tuple( + self.context.namer.new_symbol(s.ssf(), orelse_scope.referenced) + for s in aliased_orelse_orig_names) + + alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) + alias_orelse_map = dict( + zip(aliased_orelse_orig_names, aliased_orelse_new_names)) + + node_body = ast_util.rename_symbols(node.body, alias_body_map) + node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) if not modified: # When the cond would return no value, we leave the cond called without @@ -134,26 +165,47 @@ class ControlFlowTransformer(transformer.Base): else: results = gast.Tuple([s.ast() for s in modified], None) - body_name = self.context.namer.new_symbol('if_true', all_referenced) - orelse_name = self.context.namer.new_symbol('if_false', all_referenced) + body_name = self.context.namer.new_symbol('if_true', body_scope.referenced) + orelse_name = self.context.namer.new_symbol('if_false', + orelse_scope.referenced) if modified: - body_returns = tuple( - alias_map[s] if s in aliased_orig_names else s for s in modified) + + def build_returns(aliased_names, alias_map, scope): + """Builds list of return variables for a branch of a conditional.""" + returns = [] + for s in modified: + if s in aliased_names: + returns.append(alias_map[s]) + else: + if s not in scope.created | defined: + raise ValueError( + 'Attempting to return variable "%s" from the true branch of ' + 'a conditional, but it was not closed over, or created in ' + 'this branch.' % str(s)) + else: + returns.append(s) + return tuple(returns) + + body_returns = build_returns(aliased_body_orig_names, alias_body_map, + body_scope) + orelse_returns = build_returns(aliased_orelse_orig_names, + alias_orelse_map, orelse_scope) + else: - body_returns = templates.replace('tf.ones(())')[0].value + body_returns = orelse_returns = templates.replace('tf.ones(())')[0].value body_def = self._create_cond_branch( body_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_body_orig_names), + aliased_new_names=tuple(aliased_body_new_names), body=node_body, returns=body_returns) orelse_def = self._create_cond_branch( orelse_name, - aliased_orig_names=tuple(aliased_orig_names), - aliased_new_names=tuple(aliased_new_names), + aliased_orig_names=tuple(aliased_orelse_orig_names), + aliased_new_names=tuple(aliased_orelse_new_names), body=node_orelse, - returns=body_returns) + returns=orelse_returns) cond_expr = self._create_cond_expr(results, node.test, body_name, orelse_name) @@ -207,7 +259,7 @@ class ControlFlowTransformer(transformer.Base): def body_name(state_ssf): body return state_ssf, - state_ast_tuple = __ops.while_loop( + state_ast_tuple = ag__.while_stmt( test_name, body_name, (state,), (extra_deps,)) """ node = templates.replace( @@ -252,31 +304,31 @@ class ControlFlowTransformer(transformer.Base): state_ast_tuple = gast.Tuple([n.ast() for n in state], None) node_body = ast_util.rename_symbols(node.body, ssf_map) - if anno.hasanno(node, 'extra_cond'): - extra_cond = anno.getanno(node, 'extra_cond') - extra_cond = ast_util.rename_symbols(extra_cond, ssf_map) + if anno.hasanno(node, 'extra_test'): + extra_test = anno.getanno(node, 'extra_test') + extra_test = ast_util.rename_symbols(extra_test, ssf_map) else: - extra_cond = parser.parse_expression('True') + extra_test = parser.parse_expression('True') template = """ - def extra_cond_name(state_ssf): - return extra_cond_expr + def extra_test_name(state_ssf): + return extra_test_expr def body_name(iterate, state_ssf): body return state_ssf, - state_ast_tuple = __ops.for_loop( - iterated, extra_cond_name, body_name, (state,)) + state_ast_tuple = ag__.for_stmt( + iter_, extra_test_name, body_name, (state,)) """ node = templates.replace( template, state=state, state_ssf=state_ssf, state_ast_tuple=state_ast_tuple, - iterated=node.iter, + iter_=node.iter, iterate=node.target, - extra_cond_name=self.context.namer.new_symbol('extra_cond', + extra_test_name=self.context.namer.new_symbol('extra_test', all_referenced), - extra_cond_expr=extra_cond, + extra_test_expr=extra_test, body_name=self.context.namer.new_symbol('loop_body', all_referenced), body=node_body) @@ -284,6 +336,7 @@ class ControlFlowTransformer(transformer.Base): def transform(node, context): - t = ControlFlowTransformer(context) - node = t.visit(node) + cfg.run_analyses(node, cfg.Liveness(context)) + cfg.run_analyses(node, cfg.Defined(context)) + node = ControlFlowTransformer(context).visit(node) return node diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index c5610b16b4e..1a863590f97 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -22,6 +22,7 @@ from tensorflow.contrib.autograph.converters import control_flow from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test @@ -95,6 +96,91 @@ class ControlFlowTest(converter_test_base.TestCase): with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) + def test_imbalanced_aliasing(self): + + def test_fn(n): + if n > 0: + n = 3 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_ignore_unread_variable(self): + + def test_fn(n): + b = 3 # pylint: disable=unused-variable + if n > 0: + b = 4 + return n + + node = self.parse_and_analyze(test_fn, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3)))) + self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3)))) + + def test_handle_temp_variable(self): + + def test_fn_using_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z, w + + node = self.parse_and_analyze(test_fn_using_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + self.assertEqual(3, w) + z, w = sess.run( + result.test_fn_using_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + self.assertEqual(2, w) + + def test_fn_ignoring_temp(x, y, w): + if x < y: + z = x + y + else: + w = 2 + tmp = w + z = x - tmp + return z + + node = self.parse_and_analyze(test_fn_ignoring_temp, {}) + node = control_flow.transform(node, self.ctx) + + with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result: + with self.test_session() as sess: + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(-3), constant_op.constant(3), + constant_op.constant(3))) + self.assertEqual(0, z) + z = sess.run( + result.test_fn_ignoring_temp( + constant_op.constant(3), constant_op.constant(-3), + constant_op.constant(3))) + self.assertEqual(1, z) + def test_simple_for(self): def test_fn(l): diff --git a/tensorflow/contrib/autograph/converters/converter_test_base.py b/tensorflow/contrib/autograph/converters/converter_test_base.py index 6f75e9a529b..41c2e71702e 100644 --- a/tensorflow/contrib/autograph/converters/converter_test_base.py +++ b/tensorflow/contrib/autograph/converters/converter_test_base.py @@ -35,14 +35,17 @@ from tensorflow.python.platform import test class FakeNamer(object): + """A fake namer that uses a global counter to generate unique names.""" + + def __init__(self): + self.i = 0 def new_symbol(self, name_root, used): - i = 0 while True: - name = '%s%d' % (name_root, i) + self.i += 1 + name = '%s%d' % (name_root, self.i) if name not in used: return name - i += 1 def compiled_function_name(self, original_fqn, @@ -76,9 +79,10 @@ class TestCase(test.TestCase): try: result, source = compiler.ast_to_object(node) result.tf = self.make_fake_mod('fake_tf', *symbols) - result.autograph_utils = utils - result.autograph_api = self.make_fake_mod('fake_api', converted_call) - result.__dict__['__ops'] = operators + fake_ag = self.make_fake_mod('fake_ag', converted_call) + fake_ag.__dict__.update(operators.__dict__) + fake_ag.__dict__['utils'] = utils + result.__dict__['ag__'] = fake_ag yield result except Exception: # pylint:disable=broad-except if source is None: diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py index bb0c0a36a78..616d222762e 100644 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -27,7 +27,7 @@ class IfExp(transformer.Base): def visit_IfExp(self, node): template = """ - autograph_utils.run_cond(test, lambda: (body,), lambda: (orelse,)) + ag__.utils.run_cond(test, lambda: (body,), lambda: (orelse,)) """ desugared_ifexp = templates.replace_as_expression( template, test=node.test, body=node.body, orelse=node.orelse) diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index 234a0a7487d..b49521b2c32 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -45,7 +45,7 @@ class ListTransformer(transformer.Base): if not anno.hasanno(node, 'element_type'): raise NotImplementedError( 'type inference for empty lists is not yet supported; ' - 'use utils.set_element_type(, ) to continue') + 'use set_element_type(, ) to continue') dtype = anno.getanno(node, 'element_type') if not isinstance(dtype, dtypes.DType): # TODO(mdan): Allow non-TF dtypes? @@ -74,7 +74,7 @@ class ListTransformer(transformer.Base): if qn.qn[-1] == 'append' and (len(call_node.args) == 1): template = """ - target = autograph_utils.dynamic_list_append(target, element) + target = ag__.utils.dynamic_list_append(target, element) """ node = templates.replace( template, @@ -82,23 +82,33 @@ class ListTransformer(transformer.Base): element=call_node.args[0]) return node + def _replace_list_constructors(self, targets, values): + for target in targets: + if (isinstance(target, (gast.Tuple, gast.List)) and + isinstance(values, (gast.Tuple, gast.List))): + n_targets = len(target.elts) + for i in range(n_targets): + target_el, value_el = target.elts[i], values.elts[i] + values.elts[i] = self._replace_list_constructors( + (target_el,), value_el) + return values + if isinstance(values, gast.List): + if values.elts: + return self._pre_populated_list(values) + else: + return self._empty_list(values) + return values + def visit_Assign(self, node): node = self.generic_visit(node) # Only convert lists when they are assigned to a variable, e.g.: # l = [] - # TODO(mdan): This rule should be improved. - if len(node.targets) != 1: - return node - if not isinstance(node.value, gast.List): - return node - if not isinstance(node.value.ctx, gast.Load): - return node - - if node.value.elts: - node.value = self._pre_populated_list(node.value) - else: - node.value = self._empty_list(node.value) + # TODO(mdan): A similar pattern exists in type_info.py + # We should add a generic "unpack_assignment" function to the base + # transformer, that has the same effect as applying some logic to the SSA + # form. + node.value = self._replace_list_constructors(node.targets, node.value) return node diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 749ba143473..74c6dc64f19 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -45,7 +45,51 @@ class ListTest(converter_test_base.TestCase): result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - self.assertEqual(test_fn(), sess.run(result.test_fn().stack())) + self.assertAllEqual([1], sess.run(result.test_fn().stack())) + + def test_empty_annotated_lists_unpacked(self): + + def test_fn(): + l, m = [], [] + utils.set_element_type(l, dtypes.int32) + utils.set_element_type(m, dtypes.int32) + l.append(1) + m.append(2) + return l, m + + node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = lists.transform(node, self.ctx) + + with self.compiled(node, tensor_array_ops.TensorArray, + dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + res_l, res_m = result.test_fn() + self.assertEqual([1], sess.run(res_l.stack())) + self.assertEqual([2], sess.run(res_m.stack())) + + def test_empty_annotated_lists_list_unpacked(self): + + def test_fn(): + [l, m] = [], [] + utils.set_element_type(l, dtypes.int32) + utils.set_element_type(m, dtypes.int32) + l.append(1) + m.append(2) + return l, m + + node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = lists.transform(node, self.ctx) + + with self.compiled(node, tensor_array_ops.TensorArray, + dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + res_l, res_m = result.test_fn() + self.assertEqual([1], sess.run(res_l.stack())) + self.assertEqual([2], sess.run(res_m.stack())) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards.py b/tensorflow/contrib/autograph/converters/side_effect_guards.py index 1c1293d2c41..3bcb2d3c42c 100644 --- a/tensorflow/contrib/autograph/converters/side_effect_guards.py +++ b/tensorflow/contrib/autograph/converters/side_effect_guards.py @@ -160,8 +160,8 @@ class SideEffectGuardTransformer(transformer.Base): [alias_map.get(s, s).ast() for s in guarded_args], None) template = """ - with autograph_utils.control_dependency_on_returns(call): - aliased_guarded_args = autograph_utils.alias_tensors(guarded_args) + with ag__.utils.control_dependency_on_returns(call): + aliased_guarded_args = ag__.utils.alias_tensors(guarded_args) """ control_deps_guard = templates.replace( template, @@ -172,7 +172,7 @@ class SideEffectGuardTransformer(transformer.Base): alias_map = {} template = """ - with autograph_utils.control_dependency_on_returns(call): + with ag__.utils.control_dependency_on_returns(call): pass """ control_deps_guard = templates.replace(template, call=node.value)[-1] diff --git a/tensorflow/contrib/autograph/examples/notebooks/rnn_colorbot_estimator.ipynb b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb similarity index 50% rename from tensorflow/contrib/autograph/examples/notebooks/rnn_colorbot_estimator.ipynb rename to tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb index 7f5e4d4ac12..324b23c24b5 100644 --- a/tensorflow/contrib/autograph/examples/notebooks/rnn_colorbot_estimator.ipynb +++ b/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb @@ -62,7 +62,7 @@ } }, "source": [ - "# Case study: building an RNN\n" + "# Case study: training a custom RNN, using Keras and Estimators\n" ] }, { @@ -118,6 +118,16 @@ " length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n", " return rgb, chars, length\n", "\n", + "\n", + "def set_static_batch_shape(batch_size):\n", + " def apply(rgb, chars, length):\n", + " rgb.set_shape((batch_size, None))\n", + " chars.set_shape((batch_size, None, 256))\n", + " length.set_shape((batch_size,))\n", + " return rgb, chars, length\n", + " return apply\n", + "\n", + "\n", "def load_dataset(data_dir, url, batch_size, training=True):\n", " \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n", " path = tf.keras.utils.get_file(os.path.basename(url), url, cache_dir=data_dir)\n", @@ -129,7 +139,10 @@ " if training:\n", " dataset = dataset.shuffle(buffer_size=3000)\n", " dataset = dataset.padded_batch(\n", - " batch_size, padded_shapes=([None], [None, None], []))\n", + " batch_size, padded_shapes=((None,), (None, 256), ()))\n", + " # To simplify the model code, we statically set as many of the shapes that we\n", + " # know.\n", + " dataset = dataset.map(set_static_batch_shape(batch_size))\n", " return dataset" ] }, @@ -145,7 +158,8 @@ "source": [ "To show the use of control flow, we write the RNN loop by hand, rather than using a pre-built RNN model.\n", "\n", - "Note how we write the model code in Eager style, with regular `if` and `while` statements. Then, we annotate the functions with `@autograph.convert` to have them automatically compiled to run in graph mode." + "Note how we write the model code in Eager style, with regular `if` and `while` statements. Then, we annotate the functions with `@autograph.convert` to have them automatically compiled to run in graph mode.\n", + "We use Keras to define the model, and we will train it using Estimators." ] }, { @@ -166,70 +180,72 @@ }, "outputs": [], "source": [ - "class RnnColorbot(object):\n", - " \"\"\"Holds the parameters of the colorbot model.\"\"\"\n", + "@autograph.convert()\n", + "class RnnColorbot(tf.keras.Model):\n", + " \"\"\"RNN Colorbot model.\"\"\"\n", "\n", " def __init__(self):\n", + " super(RnnColorbot, self).__init__()\n", " self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256)\n", " self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128)\n", " self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n", "\n", + "\n", + " def _rnn_layer(self, chars, cell, batch_size, training):\n", + " \"\"\"A single RNN layer.\n", + "\n", + " Args:\n", + " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", + " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", + " batch_size: Int, the batch size to use\n", + " training: Boolean, whether the layer is used for training\n", + "\n", + " Returns:\n", + " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", + " \"\"\"\n", + " hidden_outputs = []\n", + " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", + " state, output = cell.zero_state(batch_size, tf.float32)\n", + " for ch in chars:\n", + " cell_output, (state, output) = cell.call(ch, (state, output))\n", + " hidden_outputs.append(cell_output)\n", + " hidden_outputs = hidden_outputs.stack()\n", + " if training:\n", + " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", + " return hidden_outputs\n", + "\n", + " def build(self, _):\n", + " \"\"\"Creates the model variables. See keras.Model.build().\"\"\"\n", " self.lower_cell.build(tf.TensorShape((None, 256)))\n", " self.upper_cell.build(tf.TensorShape((None, 256)))\n", - " self.relu_layer.build(tf.TensorShape((None, 128)))\n", + " self.relu_layer.build(tf.TensorShape((None, 128))) \n", + " self.built = True\n", "\n", "\n", - "def rnn_layer(chars, cell, batch_size, training):\n", - " \"\"\"A simple RNN layer.\n", - " \n", - " Args:\n", - " chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n", - " cell: An object of type tf.contrib.rnn.LSTMBlockCell\n", - " batch_size: Int, the batch size to use\n", - " training: Boolean, whether the layer is used for training\n", + " def call(self, inputs, training=False):\n", + " \"\"\"The RNN model code. Uses Eager and \n", "\n", - " Returns:\n", - " A Tensor of shape (max_sequence_length, batch_size, output_size).\n", - " \"\"\"\n", - " hidden_outputs = []\n", - " autograph.utils.set_element_type(hidden_outputs, tf.float32)\n", - " state, output = cell.zero_state(batch_size, tf.float32)\n", - " for ch in chars:\n", - " cell_output, (state, output) = cell.call(ch, (state, output))\n", - " hidden_outputs.append(cell_output)\n", - " hidden_outputs = hidden_outputs.stack()\n", - " if training:\n", - " hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n", - " return hidden_outputs\n", + " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", + " followed by a fully connected layer with ReLU activation.\n", "\n", + " Args:\n", + " inputs: A tuple (chars, length)\n", + " training: Boolean, whether the layer is used for training\n", "\n", - "@autograph.convert(recursive=True)\n", - "def model(inputs, colorbot, batch_size, training):\n", - " \"\"\"RNNColorbot model.\n", - " \n", - " The model consists of two RNN layers (made by lower_cell and upper_cell),\n", - " followed by a fully connected layer with ReLU activation.\n", - " \n", - " Args:\n", - " inputs: A tuple (chars, length)\n", - " colorbot: An object of type RnnColorbot\n", - " batch_size: Int, the batch size to use\n", - " training: Boolean, whether the layer is used for training\n", - " \n", - " Returns:\n", - " A Tensor of shape (batch_size, 3) - the model predictions.\n", - " \"\"\"\n", - " (chars, length) = inputs\n", - " seq = tf.transpose(chars, [1, 0, 2])\n", - " seq.set_shape((None, batch_size, 256))\n", + " Returns:\n", + " A Tensor of shape (batch_size, 3) - the model predictions.\n", + " \"\"\"\n", + " chars, length = inputs\n", + " batch_size = chars.shape[0]\n", + " seq = tf.transpose(chars, (1, 0, 2))\n", "\n", - " seq = rnn_layer(seq, colorbot.lower_cell, batch_size, training)\n", - " seq = rnn_layer(seq, colorbot.upper_cell, batch_size, training)\n", + " seq = self._rnn_layer(seq, self.lower_cell, batch_size, training)\n", + " seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n", "\n", - " # Grab just the end-of-sequence from each output.\n", - " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", - " sequence_ends = tf.gather_nd(seq, indices)\n", - " return colorbot.relu_layer(sequence_ends)\n", + " # Grab just the end-of-sequence from each output.\n", + " indices = tf.stack([length - 1, range(batch_size)], axis=1)\n", + " sequence_ends = tf.gather_nd(seq, indices)\n", + " return self.relu_layer(sequence_ends)\n", "\n", "@autograph.convert()\n", "def loss_fn(labels, predictions):\n", @@ -246,9 +262,9 @@ } }, "source": [ - "We will now create the model function for the estimator.\n", + "We will now create the model function for the custom Estimator.\n", "\n", - "In the model function, we simply call the converted functions that we defined above - that's it!" + "In the model function, we simply use the model class we defined above - that's it!" ] }, { @@ -275,14 +291,12 @@ " sequence_length = features['sequence_length']\n", " inputs = (chars, sequence_length)\n", "\n", - " # Create the model components.\n", - " # Simply calling the AutoGraph-ed functions and objects just works!\n", + " # Create the model. Simply using the AutoGraph-ed class just works!\n", " colorbot = RnnColorbot()\n", - " \n", - " batch_size = params['batch_size']\n", + " colorbot.build(None)\n", "\n", " if mode == tf.estimator.ModeKeys.TRAIN:\n", - " predictions = model(inputs, colorbot, batch_size, training=True)\n", + " predictions = colorbot(inputs, training=True)\n", " loss = loss_fn(labels, predictions)\n", "\n", " learning_rate = params['learning_rate']\n", @@ -292,14 +306,13 @@ " return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)\n", "\n", " elif mode == tf.estimator.ModeKeys.EVAL:\n", - " predictions = model(inputs, colorbot, batch_size, training=False)\n", + " predictions = colorbot(inputs)\n", " loss = loss_fn(labels, predictions)\n", "\n", " return tf.estimator.EstimatorSpec(mode, loss=loss)\n", - " \n", + "\n", " elif mode == tf.estimator.ModeKeys.PREDICT:\n", - " # For prediction, we expect single tensors.\n", - " predictions = model(inputs, colorbot, 1, training=False)\n", + " predictions = colorbot(inputs)\n", "\n", " predictions = tf.minimum(predictions, 1.0)\n", " return tf.estimator.EstimatorSpec(mode, predictions=predictions)" @@ -368,7 +381,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 7, "metadata": { "colab": { "autoexec": { @@ -379,9 +392,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 10064, + "elapsed": 10604, "status": "ok", - "timestamp": 1523580419240, + "timestamp": 1524095272039, "user": { "displayName": "", "photoUrl": "", @@ -390,7 +403,7 @@ "user_tz": 240 }, "id": "2pg1AfbxBJQq", - "outputId": "41894b16-3d3a-4e30-f6e4-5a9c837a2210", + "outputId": "9c924b4f-06e1-4538-976c-a3e1ddac5660", "slideshow": { "slide_type": "-" } @@ -400,7 +413,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Eval loss at step 100: 0.0665446\n" + "Eval loss at step 100: 0.0674834\n" ] } ], @@ -444,7 +457,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": 8, "metadata": { "colab": { "autoexec": { @@ -455,9 +468,9 @@ }, "colab_type": "code", "executionInfo": { - "elapsed": 31286, + "elapsed": 7990, "status": "ok", - "timestamp": 1523580450579, + "timestamp": 1524095280105, "user": { "displayName": "", "photoUrl": "", @@ -466,7 +479,7 @@ "user_tz": 240 }, "id": "dxHex2tUN_10", - "outputId": "b3dc558d-b800-4e9b-e60e-3441124e80d8", + "outputId": "2b889e5a-b9ed-4645-bf03-d98f26c72101", "slideshow": { "slide_type": "slide" } @@ -478,7 +491,7 @@ "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f4112527e90\u003e" + "\u003cIPython.core.display.HTML at 0x7f3f36aa6cd0\u003e" ] }, "metadata": { @@ -494,7 +507,7 @@ "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f4112527f10\u003e" + "\u003cIPython.core.display.HTML at 0x7f3eca67f7d0\u003e" ] }, "metadata": { @@ -510,7 +523,7 @@ "\u003cdiv id=\"id1\"\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f4112527f50\u003e" + "\u003cIPython.core.display.HTML at 0x7f3eca67f8d0\u003e" ] }, "metadata": { @@ -523,11 +536,11 @@ { "data": { "application/javascript": [ - "window[\"2c60f474-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"initialSelection\": 0, \"location\": \"top\", \"contentHeight\": [\"initial\"], \"borderColor\": [\"#a7a7a7\"], \"contentBorder\": [\"0px\"], \"tabNames\": [\"RNN Colorbot\"], \"elementId\": \"id1\"});\n", - "//# sourceURL=js_a0db480422" + "window[\"e8ddfa22-4362-11e8-91ec-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id1\", \"borderColor\": [\"#a7a7a7\"], \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0});\n", + "//# sourceURL=js_71b9087b6d" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd1d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67f950\u003e" ] }, "metadata": { @@ -540,11 +553,11 @@ { "data": { "application/javascript": [ - "window[\"2c60f475-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_d2a46ea291" + "window[\"e8ddfa23-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_e390445f33" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd0d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" ] }, "metadata": { @@ -557,11 +570,11 @@ { "data": { "application/javascript": [ - "window[\"2c60f476-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_0a8262c6e9" + "window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_241dd76d85" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd390\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" ] }, "metadata": { @@ -575,11 +588,11 @@ { "data": { "application/javascript": [ - "window[\"2c60f477-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_e32f85ccd2" + "window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_60c64e3d50" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd490\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" ] }, "metadata": { @@ -593,11 +606,11 @@ { "data": { "application/javascript": [ - "window[\"2c60f478-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"2c60f477-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_eaee748b21" + "window[\"e8ddfa26-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa25-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_14ea437cbd" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd550\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" ] }, "metadata": { @@ -611,11 +624,11 @@ { "data": { "application/javascript": [ - "window[\"2c60f479-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_2befe06587" + "window[\"e8ddfa27-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_09294c2226" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f4112527f10\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fcd0\u003e" ] }, "metadata": { @@ -629,11 +642,11 @@ { "data": { "application/javascript": [ - "window[\"354d7b1a-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"2c60f476-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_8ec4aeeb25" + "window[\"ec965514-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"e8ddfa24-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_e5e8266997" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd690\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fe10\u003e" ] }, "metadata": { @@ -647,11 +660,11 @@ { "data": { "application/javascript": [ - "window[\"354d7b1b-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_9f9f4574f1" + "window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_07a097f0ee" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd350\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fc90\u003e" ] }, "metadata": { @@ -665,11 +678,11 @@ { "data": { "application/javascript": [ - "window[\"354d7b1c-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_bcccd8f300" + "window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_790d669ca8" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd6d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67f8d0\u003e" ] }, "metadata": { @@ -683,11 +696,11 @@ { "data": { "application/javascript": [ - "window[\"354d7b1d-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b1c-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_2c056cee72" + "window[\"ec965517-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965516-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_d30df771f0" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd490\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fd90\u003e" ] }, "metadata": { @@ -701,11 +714,11 @@ { "data": { "application/javascript": [ - "window[\"354d7b1e-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_c853c3f58b" + "window[\"ec965518-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_8a43a2da4b" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd610\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fc50\u003e" ] }, "metadata": { @@ -718,369 +731,9 @@ }, { "data": { - "application/javascript": [ - "window[\"354d7b1f-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b1b-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_e5730ab00d" - ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMBJREFUeJzt3F+I1XX+x/G32zjiFERUpgaFd2JBzOg5joX4h0SiMgmM\n/uhVGIlgFBlERGB3hUEkhkRdtDfRP1ACL6KpLBqcguxCjEAkmGamQcSohFHzsxe7O6zssvsydtff\n+ns8rs758j3f8z7fiyef7/k3o7XWCiDwh4s9APC/QzCAmGAAMcEAYoIBxAQDiAkGF8XTTz9d3W63\n7rvvvhoZGakVK1Zc7JEICMYlbvXq1TU8PHyxxzjPV199VcPDw/XZZ5/V22+/XVVVM2bMuMhTkRAM\n/qt+++23+uGHH+r666+vWbNmXexxuECCcQl76qmnanx8vLZs2VIDAwP1+uuv1zfffFP3339/dTqd\nWr9+fY2MjEzvv2nTpnr55ZfrgQceqIGBgXr44Yfr5MmTVVV1+vTp2r59ey1durQ6nU5t2LChTpw4\nUVVVk5OTtWXLllq6dGmtXbu23nnnnelj7tq1q7Zt21bbt2+vJUuW1HvvvVfPPvtsHTp0qAYGBmrX\nrl1/N/fRo0dr06ZN1el06u67766hoaGqqhodHa1OpzO93zPPPFO33nrr9P3t27fXm2+++e89iZyv\ncUlbtWpVGx4ebq21NjEx0brdbjtw4EBrrbUvvviidbvdduLEidZaaxs3bmxr1qxp33//fZuammob\nN25sO3fubK219tZbb7VHH320TU1NtXPnzrXDhw+3X375pbXW2kMPPdR27NjRTp8+3Y4cOdIGBwen\nn/OVV15pN910U/voo49aa61NTU21999/vz344IPTMx48eLCtWLGitdbamTNn2po1a9qePXvamTNn\n2vDwcOvv72/Hjh2bfj2HDx9urbW2du3advvtt7ejR4+21lpbuXJlO3LkyH/qVNJas8L4f6D95edC\n+/btq5UrV9by5curqmrZsmV1880316effjq977333ls33HBD9fb21h133FFHjhypqqqenp46efJk\nHTt2rGbMmFGLFi2qyy+/vCYmJurrr7+uJ598smbOnFkLFy6sDRs21N69e6eP2d/fX6tXr66qqt7e\n3n8666FDh+rUqVP1yCOPVE9PTw0ODtaqVavqgw8+qKqqJUuW1MjISB0/fryqqtauXVtffvlljY6O\n1q+//loLFy78N501/pGeiz0A/z1jY2O1f//++vjjj6vqzyE5e/ZsLVu2bHqfa665Zvr27Nmz69Sp\nU1VVdc8999TExEQ98cQT9fPPP9e6devq8ccfr8nJybryyitr9uzZ04+bP39+HT58ePr+3Llz4xkn\nJydr3rx5522bP39+TU5OVlVVp9OpoaGhuu6666rb7Va32629e/dWb29vLV68+ALOBr+HYFzi/vbT\nh3nz5tX69etrx44dF3ycnp6e2rp1a23durXGxsZq8+bNtWDBgrrtttvqp59+qlOnTlVfX19VVY2P\nj9ecOXP+4Qz/ypw5c2p8fPy8bWNjY7VgwYKqqup2u/Xiiy/WvHnzqtPp1MDAQD333HPV29tb3W73\ngl8XF8YlySXu2muvrdHR0aqqWrduXQ0NDdXnn39e586dq6mpqRoZGakff/zxXx7n4MGD9d1339W5\nc+eqr6+venp66rLLLqu5c+dWf39/vfTSS3X69On69ttv6913361169b9rnlvueWW6uvrq9dee63O\nnj1bBw8erE8++aTuvPPOqqq68cYba9asWbVv377qdDp1xRVX1NVXX10ffvjheW+I8p8hGJe4zZs3\n1+7du6vb7db+/ftr9+7dtWfPnlq2bFmtWrWq3njjjen3OP7ZSuD48eO1bdu2Wrx4cd111121dOnS\n6Sjs3LmzRkdHa/ny5bVt27Z67LHHzrvMuRAzZ86sV199tQ4cOFCDg4P1/PPP1wsvvDC9wqj68yrj\nqquumr7U+WsoFi1a9Luek9yM1vyBDpCxwgBiggHEBAOICQYQ+z/7PYzjf/QRGVxM12z68u+2WWEA\nMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE\nBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhAT\nDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMIDajtdYu9hDA/wYrDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4j9CY2LTAbbRbWuAAAAAElFTkSuQmCC\n", "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2050\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"354d7b20-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_a897ef7e24" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2250\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"354d7b21-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_565fa3d154" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f4113124d90\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"354d7b22-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b21-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_222e0dc6af" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f4113124c10\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"354d7b23-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_831db7458f" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f4113124310\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fab4-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"354d7b20-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_adb576c6eb" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f990850\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fab5-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_9418f2d32f" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f990850\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fab6-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_3fad25f306" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f4112527ed0\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fab7-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803fab6-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_45b9340e7b" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f990c90\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fab8-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_bec9896d44" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f990a10\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fab9-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803fab5-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_460b91ad4a" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3a10\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803faba-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_7dedd0b037" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3890\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fabb-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_4b1c977dc7" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3bd0\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fabc-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803fabb-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_d64fedfcf9" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3410\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3803fabd-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_3e8c929c3f" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3c50\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3b9b986c-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3803faba-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_9f9cf2b76f" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd590\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3b9b986d-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_b402e6b587" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3d90\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3b9b986e-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_9b7d66db72" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3b10\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3b9b986f-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b986e-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_11ec213a3f" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3950\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "application/javascript": [ - "window[\"3b9b9870-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_9c055e4bc0" - ], - "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41b21d3850\u003e" - ] - }, - "metadata": { - "tags": [ - "id1_content_0", - "outputarea_id1" - ] - }, - "output_type": "display_data" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAACMRJREFUeJzt3F+IlfW+x/Gvp3FECyIqU4PCO7EgZnQtnUJ0JJGoTDoY\n/dGrMBJhosggIgK7KwwiMdxRF11F/0AJvIisLBqcguxCjEAkmNQGcRvVwIzm71zsc4Yje7P3x9h7\nz97u1+tqrYdnPeu7nos3v2f9m9FaawUQ+K/pHgD49yEYQEwwgJhgADHBAGKCAcQEg2nx9NNPV7fb\nrfvuu69GRkZq5cqV0z0SAcG4xK1evbqGh4ene4wLfPXVVzU8PFyfffZZvf3221VVNWPGjGmeioRg\n8E/122+/1Q8//FDXX399zZo1a7rH4SIJxiXsqaeeqhMnTtSWLVuqv7+/Xn/99frmm2/q/vvvr06n\nU+vXr6+RkZGp/Tdt2lQvv/xyPfDAA9Xf318PP/xwnTlzpqqqJicna9u2bbVs2bLqdDq1YcOGOn36\ndFVVjY2N1ZYtW2rZsmW1du3aeuedd6aOuXPnzhoaGqpt27bV0qVL67333qtnn322Dh06VP39/bVz\n584/m/vo0aO1adOm6nQ6dffdd9f+/furqmp0dLQ6nc7Ufs8880zdeuutU/e3bdtWb7755t/3JHKh\nxiVtcHCwDQ8Pt9ZaO3nyZOt2u+3AgQOttda++OKL1u122+nTp1trrW3cuLGtWbOmff/9921iYqJt\n3Lix7dixo7XW2ltvvdUeffTRNjEx0c6fP98OHz7cfvnll9Zaaw899FDbvn17m5ycbEeOHGnLly+f\nes5XXnml3XTTTe2jjz5qrbU2MTHR3n///fbggw9OzXjw4MG2cuXK1lprZ8+ebWvWrGm7d+9uZ8+e\nbcPDw62vr68dO3Zs6vUcPny4tdba2rVr2+23396OHj3aWmtt1apV7ciRI/+oU0lrzQrjP0D7358L\n7d27t1atWlUrVqyoqqqBgYG6+eab69NPP53a9957760bbrihent764477qgjR45UVVVPT0+dOXOm\njh07VjNmzKjFixfX5ZdfXidPnqyvv/66nnzyyZo5c2YtWrSoNmzYUHv27Jk6Zl9fX61evbqqqnp7\ne//qrIcOHarx8fF65JFHqqenp5YvX16Dg4P1wQcfVFXV0qVLa2RkpE6dOlVVVWvXrq0vv/yyRkdH\n69dff61Fixb9nc4af0nPdA/AP8/x48dr37599fHHH1fVn0Jy7ty5GhgYmNrnmmuumbo9e/bsGh8f\nr6qqe+65p06ePFlPPPFE/fzzz7Vu3bp6/PHHa2xsrK688sqaPXv21OMWLFhQhw8fnro/b968eMax\nsbGaP3/+BdsWLFhQY2NjVVXV6XRq//79dd1111W3261ut1t79uyp3t7eWrJkyUWcDX4PwbjE/f9P\nH+bPn1/r16+v7du3X/Rxenp6auvWrbV169Y6fvx4bd68uRYuXFi33XZb/fTTTzU+Pl5z5sypqqoT\nJ07U3Llz/+IMf8vcuXPrxIkTF2w7fvx4LVy4sKqqut1uvfjiizV//vzqdDrV399fzz33XPX29la3\n273o18XFcUlyibv22mtrdHS0qqrWrVtX+/fvr88//7zOnz9fExMTNTIyUj/++OPfPM7Bgwfru+++\nq/Pnz9ecOXOqp6enLrvsspo3b1719fXVSy+9VJOTk/Xtt9/Wu+++W+vWrftd895yyy01Z86ceu21\n1+rcuXN18ODB+uSTT+rOO++sqqobb7yxZs2aVXv37q1Op1NXXHFFXX311fXhhx9e8IYo/xiCcYnb\nvHlz7dq1q7rdbu3bt6927dpVu3fvroGBgRocHKw33nhj6j2Ov7YSOHXqVA0NDdWSJUvqrrvuqmXL\nlk1FYceOHTU6OlorVqyooaGheuyxxy64zLkYM2fOrFdffbUOHDhQy5cvr+eff75eeOGFqRVG1Z9W\nGVddddXUpc7/hWLx4sW/6znJzWjNH+gAGSsMICYYQEwwgJhgALF/2e9h/PEP/z3dI8B/tKseee/P\ntllhADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEA\nYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOI\nCQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAm\nGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhg\nADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIB\nxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQ\nEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBM\nMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHB\nAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQD\niAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHEBAOICQYQEwwg\nJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICY\nYAAxwQBiggHEBAOICQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEwwgJhgADHBAGKC\nAcQEA4gJBhATDCA2o7XWpnsI4N+DFQYQEwwgJhhATDCAmGAAMcEAYoIBxAQDiAkGEBMMICYYQEww\ngJhgADHBAGKCAcQEA4gJBhATDCAmGEBMMICYYAAxwQBiggHE/gfh60wGjfc7LQAAAABJRU5ErkJg\ngg==\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f4113124310\u003e" + "\u003cmatplotlib.figure.Figure at 0x7f3ecc00bf10\u003e" ] }, "metadata": { @@ -1095,11 +748,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9871-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b986d-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_ba6a061307" + "window[\"ec965519-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec965515-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_893ad561f4" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd890\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b55c90\u003e" ] }, "metadata": { @@ -1113,11 +766,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9872-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", - "//# sourceURL=js_83e3496927" + "window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n", + "//# sourceURL=js_2d99e0ac17" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd590\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67fe50\u003e" ] }, "metadata": { @@ -1131,11 +784,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9873-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", - "//# sourceURL=js_f437bab20d" + "window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"] = document.querySelector(\"#id1_content_0\");\n", + "//# sourceURL=js_5c19462e32" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a22d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b55dd0\u003e" ] }, "metadata": { @@ -1149,11 +802,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9874-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b9873-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_93aa63450e" + "window[\"ec96551c-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551b-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_b9c8b7567b" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2b90\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b55a50\u003e" ] }, "metadata": { @@ -1167,11 +820,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9875-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", - "//# sourceURL=js_aca189bea5" + "window[\"ec96551d-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"id1\"].setSelectedTabIndex(0);\n", + "//# sourceURL=js_fd05186348" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd4d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b55810\u003e" ] }, "metadata": { @@ -1185,10 +838,10 @@ { "data": { "text/html": [ - "\u003cdiv class=id_100313201 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" + "\u003cdiv class=id_888646481 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e" ], "text/plain": [ - "\u003cIPython.core.display.HTML at 0x7f410f990a90\u003e" + "\u003cIPython.core.display.HTML at 0x7f3f32414810\u003e" ] }, "metadata": { @@ -1203,11 +856,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9876-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_100313201 span\");\n", - "//# sourceURL=js_5df1fe383e" + "window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", + "//# sourceURL=js_efef96e882" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f410f8fd490\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" ] }, "metadata": { @@ -1222,11 +875,11 @@ { "data": { "application/javascript": [ - "window[\"3b9b9877-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"3b9b9876-3eb4-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_c62c7174ad" + "window[\"ec96551f-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ec96551e-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_6eca889864" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2390\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3eca67f990\u003e" ] }, "metadata": { @@ -1241,11 +894,11 @@ { "data": { "application/javascript": [ - "window[\"3ed76584-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_100313201 input\");\n", - "//# sourceURL=js_2e2201ddc4" + "window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 input\");\n", + "//# sourceURL=js_f02070cc60" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2810\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b553d0\u003e" ] }, "metadata": { @@ -1260,11 +913,11 @@ { "data": { "application/javascript": [ - "window[\"3ed76585-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"3ed76584-3eb4-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", - "//# sourceURL=js_288e5283d6" + "window[\"ed8ea973-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea972-4362-11e8-91ec-c8d3ffb5fbe0\"].remove();\n", + "//# sourceURL=js_ed9faba660" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a26d0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31a95450\u003e" ] }, "metadata": { @@ -1279,11 +932,11 @@ { "data": { "application/javascript": [ - "window[\"3ed76586-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_100313201 span\");\n", - "//# sourceURL=js_2f31d19cde" + "window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"] = jQuery(\".id_888646481 span\");\n", + "//# sourceURL=js_f3458d7074" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2fd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31a95250\u003e" ] }, "metadata": { @@ -1298,11 +951,11 @@ { "data": { "application/javascript": [ - "window[\"3ed76587-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = window[\"3ed76586-3eb4-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", - "//# sourceURL=js_2fbbcda050" + "window[\"ed8ea975-4362-11e8-91ec-c8d3ffb5fbe0\"] = window[\"ed8ea974-4362-11e8-91ec-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n", + "//# sourceURL=js_3ffd97bd6f" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f4112527e90\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31a953d0\u003e" ] }, "metadata": { @@ -1317,11 +970,11 @@ { "data": { "application/javascript": [ - "window[\"3ed76588-3eb4-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"3b9b9872-3eb4-11e8-91ec-c8d3ffb5fbe0\"]);\n", - "//# sourceURL=js_f94d975cf3" + "window[\"ed8ea976-4362-11e8-91ec-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"ec96551a-4362-11e8-91ec-c8d3ffb5fbe0\"]);\n", + "//# sourceURL=js_7f73e8bcca" ], "text/plain": [ - "\u003cIPython.core.display.Javascript at 0x7f41127a2fd0\u003e" + "\u003cIPython.core.display.Javascript at 0x7f3f31b55710\u003e" ] }, "metadata": { @@ -1337,7 +990,7 @@ "def predict_input_fn(color_name):\n", " \"\"\"An input function for prediction.\"\"\"\n", " _, chars, sequence_length = parse(color_name)\n", - " \n", + "\n", " # We create a batch of a single element.\n", " features = {\n", " 'chars': tf.expand_dims(chars, 0),\n", @@ -1385,7 +1038,11 @@ "colab": { "collapsed_sections": [], "default_view": {}, - "name": "RNN Colorbot using Estimators", + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "RNN Colorbot using Keras and Estimators", "provenance": [ { "file_id": "1CtzefX39ffFibX_BqE6cRbT0UW_DdVKl", diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 3c3130c7702..24f87b2c14d 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -241,7 +241,7 @@ def to_graph(e, module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) - for dep in conversion_map.dependency_cache.values(): + for dep in reversed(conversion_map.dependency_cache.values()): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) diff --git a/tensorflow/contrib/autograph/impl/config.py b/tensorflow/contrib/autograph/impl/config.py index 2600088595a..878bb7e12f2 100644 --- a/tensorflow/contrib/autograph/impl/config.py +++ b/tensorflow/contrib/autograph/impl/config.py @@ -33,7 +33,7 @@ DEFAULT_UNCOMPILED_MODULES = set(( (utils.__name__,), # All of tensorflow's subpackages. Unlike the root tf module, they don't - # have well-known names. Not refering to the module directly to avoid + # have well-known names. Not referring to the module directly to avoid # circular imports. ( utils.__name__[:-len('.contrib.autograph.utils')],), diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index bcf31b8961e..55a30dc1279 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import imp + import gast from tensorflow.contrib.autograph import operators @@ -37,6 +40,7 @@ from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.impl import naming +from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import parser @@ -79,7 +83,9 @@ class ConversionMap(object): self.recursive = recursive self.nocompile_decorators = nocompile_decorators self.partial_types = partial_types if partial_types else () - self.dependency_cache = {} + # Required to output dependencies in discovery order, which should match + # the reverse dependency order. + self.dependency_cache = collections.OrderedDict() self.additional_imports = set() self.name_map = {} self.api_module = api_module @@ -152,7 +158,16 @@ def entity_to_graph(o, conversion_map, arg_values, arg_types): if tf_inspect.isclass(o): node, name, ns = class_to_graph(o, conversion_map) elif tf_inspect.isfunction(o): - node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) + # TODO(mdan): This is not a reliable mechanism. + # The most reliable way is to check the source code, the AST will contain + # a Lambda node instead of a FunctionDef + if o.__name__ == '': + raise NotImplementedError( + 'lambda functions are not yet supported; declare the function' + ' using def instead: %s' % o) + else: + node, name, ns = function_to_graph(o, conversion_map, arg_values, + arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types) else: @@ -190,6 +205,9 @@ def class_to_graph(c, conversion_map): class_namespace = {} for _, m in members: + # Only convert the members that are directly defined by the class. + if inspect_utils.getdefiningclass(m, c) is not c: + continue node, _, namespace = function_to_graph( m, conversion_map=conversion_map, @@ -203,12 +221,49 @@ def class_to_graph(c, conversion_map): converted_members[m] = node namer = conversion_map.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) - node = gast.ClassDef( - class_name, - bases=[], - keywords=[], - body=list(converted_members.values()), - decorator_list=[]) + + # TODO(mdan): This needs to be explained more thoroughly. + # Process any base classes: if the sueprclass if of a whitelisted type, an + # absolute import line is generated. Otherwise, it is marked for conversion + # (as a side effect of the call to namer.compiled_class_name() followed by + # conversion_map.update_name_map(namer)). + output_nodes = [] + renames = {} + bases = [] + for base in c.__bases__: + if isinstance(object, base): + bases.append('object') + continue + if is_whitelisted_for_graph(base): + alias = namer.new_symbol(base.__name__, ()) + output_nodes.append( + gast.ImportFrom( + module=base.__module__, + names=[gast.alias(name=base.__name__, asname=alias)], + level=0)) + else: + # This will trigger a conversion into a class with this name. + alias = namer.compiled_class_name(base.__name__, base) + bases.append(alias) + renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) + conversion_map.update_name_map(namer) + + # Generate the definition of the converted class. + output_nodes.append( + gast.ClassDef( + class_name, + bases=bases, + keywords=[], + body=list(converted_members.values()), + decorator_list=[])) + node = gast.Module(output_nodes) + + # Make a final pass to replace references to the class or its base classes. + # Most commonly, this occurs when making super().__init__() calls. + # TODO(mdan): Making direct references to superclass' superclass will fail. + node = qual_names.resolve(node) + renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) + node = ast_util.rename_symbols(node, renames) return node, class_name, class_namespace @@ -220,13 +275,24 @@ def _add_reserved_symbol(namespace, name, entity): raise ValueError('The name "%s" is reserved and may not be used.' % name) +ag_internal = None + + def _add_self_references(namespace, api_module): - # Manually add the utils namespace which may be used from generated code. - _add_reserved_symbol(namespace, 'autograph_utils', utils) - _add_reserved_symbol(namespace, '__ops', operators) - # We also make reference to the api module for dynamic conversion, but - # to avoid circular references we don't import it here. - _add_reserved_symbol(namespace, 'autograph_api', api_module) + """Adds namespace references to the module that exposes the api itself.""" + global ag_internal + if ag_internal is None: + # Craft a module that exposes parts of the external API as well as certain + # internal modules. + ag_internal = imp.new_module('autograph') + ag_internal.converted_call = api_module.converted_call + ag_internal.utils = utils + # TODO(mdan): Add safeguards against name clashes. + # We don't want to create a submodule because we want the operators to be + # accessible as ag__. + ag_internal.__dict__.update(operators.__dict__) + + _add_reserved_symbol(namespace, 'ag__', ag_internal) def function_to_graph(f, conversion_map, arg_values, arg_types, @@ -312,6 +378,8 @@ def node_to_graph(node, ctx, nocompile_decorators): node = ifexp.transform(node, ctx) node, deps = decorators.transform(node, nocompile_decorators) node = break_statements.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 962009c71f5..bc61498b542 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -21,13 +21,18 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.impl import api from tensorflow.contrib.autograph.impl import conversion from tensorflow.python.framework import constant_op +from tensorflow.python.keras.engine import training from tensorflow.python.platform import test class ConversionTest(test.TestCase): + def _simple_conversion_map(self): + return conversion.ConversionMap(True, (), (), api) + def test_is_whitelisted_for_graph(self): def test_fn(): @@ -39,7 +44,7 @@ class ConversionTest(test.TestCase): def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = conversion.ConversionMap(True, (), (), None) + conversion_map = self._simple_conversion_map() conversion.entity_to_graph('dummy', conversion_map, None, None) def test_entity_to_graph_callable(self): @@ -47,7 +52,7 @@ class ConversionTest(test.TestCase): def f(a): return a + b - conversion_map = conversion.ConversionMap(True, (), (), None) + conversion_map = self._simple_conversion_map() ast, name, ns = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', name) @@ -61,7 +66,7 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = conversion.ConversionMap(True, (), (), None) + conversion_map = self._simple_conversion_map() conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) @@ -74,6 +79,87 @@ class ConversionTest(test.TestCase): conversion_map.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) + def test_entity_to_graph_class_hierarchy(self): + + class TestBase(object): + + def __init__(self, x='base'): + self.x = x + + def foo(self): + return self.x + + def bar(self): + return self.x + + class TestSubclass(TestBase): + + def __init__(self, y): + super(TestSubclass, self).__init__('sub') + self.y = y + + def foo(self): + return self.y + + def baz(self): + return self.y + + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + + self.assertTrue(TestBase in conversion_map.dependency_cache) + self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertEqual('TfTestBase', + conversion_map.dependency_cache[TestBase].body[-1].name) + self.assertEqual( + 'TfTestSubclass', + conversion_map.dependency_cache[TestSubclass].body[-1].name) + + def test_entity_to_graph_class_hierarchy_whitelisted(self): + + class TestSubclass(training.Model): + + def __init__(self, y): + super(TestSubclass, self).__init__() + self.built = False + + def call(self, x): + return 3 * x + + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(TestSubclass, conversion_map, None, None) + + self.assertTrue(TestSubclass in conversion_map.dependency_cache) + self.assertFalse(training.Model in conversion_map.dependency_cache) + self.assertEqual( + 'Model', + conversion_map.dependency_cache[TestSubclass].body[0].names[0].name) + self.assertEqual( + 'TfTestSubclass', + conversion_map.dependency_cache[TestSubclass].body[-1].name) + + def test_entity_to_graph_lambda(self): + f = lambda a: a + + with self.assertRaises(NotImplementedError): + conversion_map = self._simple_conversion_map() + conversion.entity_to_graph(f, conversion_map, None, None) + + def test_ag_module_cached(self): + def callee(): + return range(3) + + def caller(a): + return a() + + conversion_map = self._simple_conversion_map() + _, _, callee_ns = conversion.entity_to_graph( + callee, conversion_map, None, None) + _, _, caller_ns = conversion.entity_to_graph( + caller, conversion_map, None, None) + + self.assertTrue(callee_ns['ag__'] is caller_ns['ag__']) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/naming.py b/tensorflow/contrib/autograph/impl/naming.py index 1facaa0ca0e..b1d3f76be77 100644 --- a/tensorflow/contrib/autograph/impl/naming.py +++ b/tensorflow/contrib/autograph/impl/naming.py @@ -62,8 +62,6 @@ class Namer(object): n += 1 new_name = '%s_%d' % (new_name_root, n) - if live_entity is not None: - self.renamed_calls[live_entity] = new_name self.generated_names.add(new_name) if live_entity is not None: self.renamed_calls[live_entity] = new_name diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 4c624685751..18bfec5d9c6 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -21,11 +21,25 @@ py_library( srcs = [ "__init__.py", "control_flow.py", + "data_structures.py", + "dispatch_context.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ "//tensorflow/contrib/autograph/utils", + "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +py_test( + name = "data_structures_test", + srcs = ["data_structures_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 04b4734551d..38b761d97d5 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -19,11 +19,19 @@ conditionals and loops, implemented in functional form, using for example closures for the body. """ +# Naming conventions: +# * operator names match the name usually used for the respective Python +# idiom; examples: for_stmt, list_append +# * operator arguments match either of: +# - the corresponding Python AST attribute (e.g. the condition of an if +# statement is called test) if the operator represents an AST construct +# - the names used in the Python docs, if the operator is a function (e.g. +# list_ and x for append, see +# https://docs.python.org/3.7/tutorial/datastructures.html) + from __future__ import absolute_import from __future__ import division from __future__ import print_function -# TODO(mdan): Add a container for implementation-specific toggles (throughout). - -from tensorflow.contrib.autograph.operators.control_flow import for_loop -from tensorflow.contrib.autograph.operators.control_flow import while_loop +from tensorflow.contrib.autograph.operators.control_flow import for_stmt +from tensorflow.contrib.autograph.operators.control_flow import while_stmt diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 81ae64f1109..671c9ccc13e 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -26,40 +26,54 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops -def for_loop(iterated, extra_cond, loop_body, init_state): +def for_stmt(iter_, extra_test, body, init_state): """Functional form of a for statement. - The loop operates on a so-called state, which includes all symbols that are - variant across loop iterations, excluding the iterate. In what follows we - refer to state as either a tuple of entities that represent an actual state, - or a list of arguments of the corresponding types. + The loop operates on a state, which includes all symbols that are + variant across loop iterations, excluding the iterate as well as the + variables local to the loop. + + For example, given the loop below that calculates the geometric and + arithmetic means or some numbers: + + geo_mean = 1 + arith_mean = 0 + for i in range(n): + a = numbers[i] + geo_mean *= a + arith_mean += a + + The state is represented by the variables geo_mean and arith_mean. The + argument for initial_state may contain the tuple (1, 0), the body will + include the arguments geo_mean and arith_mean and will return a tuple + representing the new values for geo_mean and respectively arith_mean. Args: - iterated: The entity being iterated over. - extra_cond: Callable with the state as arguments, and boolean return type. + iter_: The entity being iterated over. + extra_test: Callable with the state as arguments, and boolean return type. An additionnal loop condition. - loop_body: Callable with the iterate and the state as arguments, and + body: Callable with the iterate and the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. Returns: Tuple containing the final state. """ - if tensor_util.is_tensor(iterated): - return _known_len_for_loop(iterated, extra_cond, loop_body, init_state) - elif isinstance(iterated, dataset_ops.Dataset): - return _dataset_for_loop(iterated, extra_cond, loop_body, init_state) + if tensor_util.is_tensor(iter_): + return _known_len_for_stmt(iter_, extra_test, body, init_state) + elif isinstance(iter_, dataset_ops.Dataset): + return _dataset_for_stmt(iter_, extra_test, body, init_state) else: - return _py_for_loop(iterated, extra_cond, loop_body, init_state) + return _py_for_stmt(iter_, extra_test, body, init_state) -def _py_for_loop(iterated, extra_cond, loop_body, init_state): - """Overload of for_loop that executes a Python for loop.""" +def _py_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that executes a Python for loop.""" state = init_state - for iterate in iterated: - if not extra_cond(*state): + for target in iter_: + if not extra_test(*state): break - state = loop_body(iterate, *state) + state = body(target, *state) # TODO(mdan): Remove this special case. if len(state) == 1: @@ -67,23 +81,23 @@ def _py_for_loop(iterated, extra_cond, loop_body, init_state): return state -def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): - """Overload of for_loop that iterates over objects that define a length.""" - n = builtins.dynamic_len(iterated) +def _known_len_for_stmt(iter_, extra_test, body, init_state): + """Overload of for_stmt that iterates over objects that define a length.""" + n = builtins.dynamic_len(iter_) def while_body(iterate_index, *state): - iterate = iterated[iterate_index] - new_state = loop_body(iterate, *state) + iterate = iter_[iterate_index] + new_state = body(iterate, *state) return (iterate_index + 1,) + new_state def while_cond(iterate_index, *state): - return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state)) + return gen_math_ops.logical_and(iterate_index < n, extra_test(*state)) - results = while_loop( + results = while_stmt( while_cond, while_body, init_state=(0,) + init_state, - extra_deps=(iterated,), + extra_deps=(iter_,), opts=dict(maximum_iterations=n)) # Dropping the iteration index because it's not syntactically visible. results = results[1:] @@ -94,8 +108,8 @@ def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): return results -def _dataset_for_loop(ds, extra_cond, loop_body, init_state): - """Overload of for_loop that iterates over TF Datasets.""" +def _dataset_for_stmt(ds, extra_test, body, init_state): + """Overload of for_stmt that iterates over TF Datasets.""" # Because Datsets only expose get_next, in the style of Python iterators, # we are forced to unpack the loop as: # @@ -114,15 +128,15 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state): epoch_number, iterate = iterator.get_next() def while_body(epoch_number, iterate, *state): - new_state = loop_body(iterate, *state) + new_state = body(iterate, *state) epoch_number, iterate = iterator.get_next() return (epoch_number, iterate) + new_state def while_cond(epoch_number, iterate, *state): del iterate - return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state)) + return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state)) - results = while_loop( + results = while_stmt( while_cond, while_body, init_state=(epoch_number, iterate) + init_state, @@ -137,7 +151,7 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state): return results -def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): +def while_stmt(test, body, init_state, extra_deps, opts=None): """Functional form of a while statement. The loop operates on a so-called state, which includes all symbols that are @@ -146,13 +160,13 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): of the corresponding types. Args: - loop_cond: Callable with the state as arguments, and boolean return type. + test: Callable with the state as arguments, and boolean return type. The loop condition. - loop_body: Callable with the state as arguments, and state as return type. + body: Callable with the state as arguments, and state as return type. The actual loop body. init_state: Tuple containing the initial state. extra_deps: Tuple containing additional entities on which the loop may - depend, such as loop invariants referenced by loop_cond. Used + depend, such as loop invariants referenced by test. Used exclusively for dispatch control. opts: Optional dict of extra loop parameters. @@ -160,25 +174,54 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None): Tuple containing the final state. """ # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch. - # That could be somethins as simple as a collection of dispatch rules, with + # That could be something as simple as a collection of dispatch rules, with # some prioritization. if any(tensor_util.is_tensor(v) for v in init_state + extra_deps): - return _tf_while_loop(loop_cond, loop_body, init_state, opts) + return _tf_while_stmt(test, body, init_state, opts) else: - return _py_while_loop(loop_cond, loop_body, init_state, opts) + return _py_while_stmt(test, body, init_state, opts) -def _tf_while_loop(loop_cond, loop_body, init_state, opts): - """Overload of while_loop that stages a TF while_loop.""" +def _tf_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that stages a TF while_stmt.""" if opts is None: opts = {} - return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts) + return control_flow_ops.while_loop(test, body, init_state, **opts) -def _py_while_loop(loop_cond, loop_body, init_state, opts): - """Overload of while_loop that executes a Python while loop.""" +def _py_while_stmt(test, body, init_state, opts): + """Overload of while_stmt that executes a Python while loop.""" del opts state = init_state - while loop_cond(*state): - state = loop_body(*state) + while test(*state): + state = body(*state) return state + + +def if_stmt(cond, body, orelse): + """Functional form of an if statement. + + Args: + cond: Boolean. + body: Callable with no arguments, and outputs of the positive (if) branch + as return type. + orelse: Callable with no arguments, and outputs of the negative (else) + branch as return type. + + Returns: + Tuple containing the statement outputs. + """ + if tensor_util.is_tensor(cond): + return _tf_if_stmt(cond, body, orelse) + else: + return _py_if_stmt(cond, body, orelse) + + +def _tf_if_stmt(cond, body, orelse): + """Overload of if_stmt that stages a TF cond.""" + return control_flow_ops.cond(cond, body, orelse) + + +def _py_if_stmt(cond, body, orelse): + """Overload of if_stmt that executes a Python if statement.""" + return body() if cond else orelse() diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py index 9112b1627fc..b14d7edba38 100644 --- a/tensorflow/contrib/autograph/operators/control_flow_test.py +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph.operators import control_flow from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,28 +29,28 @@ from tensorflow.python.platform import test class ForLoopTest(test.TestCase): def test_tensor(self): - s = operators.for_loop( + s = control_flow.for_stmt( constant_op.constant([1, 2, 3, 4]), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) with self.test_session() as sess: self.assertEqual((10,), sess.run(s)) def test_python(self): - s = operators.for_loop( + s = control_flow.for_stmt( range(5), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) self.assertEqual(10, s) def test_dataset(self): to_int32 = lambda i: math_ops.cast(i, dtypes.int32) - s = operators.for_loop( + s = control_flow.for_stmt( dataset_ops.Dataset.range(5).map(to_int32), - extra_cond=lambda s: True, - loop_body=lambda i, s: (s + i,), + extra_test=lambda s: True, + body=lambda i, s: (s + i,), init_state=(0,)) with self.test_session() as sess: self.assertEqual((10,), sess.run(s)) @@ -60,9 +60,9 @@ class WhileLoopTest(test.TestCase): def test_tensor(self): n = constant_op.constant(5) - results = operators.while_loop( - loop_cond=lambda i, s: i < n, - loop_body=lambda i, s: (i + 1, s + i,), + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i,), init_state=(0, 0), extra_deps=(n,)) with self.test_session() as sess: @@ -70,13 +70,30 @@ class WhileLoopTest(test.TestCase): def test_python(self): n = 5 - results = operators.while_loop( - loop_cond=lambda i, s: i < n, - loop_body=lambda i, s: (i + 1, s + i), + results = control_flow.while_stmt( + test=lambda i, s: i < n, + body=lambda i, s: (i + 1, s + i), init_state=(0, 0), extra_deps=(n,)) self.assertEqual((5, 10), results) +class IfStmtTest(test.TestCase): + + def test_tensor(self): + def test_if_stmt(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: 1, + orelse=lambda: -1) + with self.test_session() as sess: + self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True)))) + self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False)))) + + def test_python(self): + self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1)) + self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py new file mode 100644 index 00000000000..c862306baa9 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -0,0 +1,56 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to data structures: list append, subscripts, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import tensor_array_ops + +# TODO(mdan): Add support for TensorList once functional. +# TODO(mdan): Add primitives for empty list, list with elements. + + +def append(target, element): + """The list append function. + + Note: it is unspecified where target will be mutated or not. If target is + a TensorFlow entity, it will not be typically mutated. If target is a plain + list, it will be. In general, if the target is mutated then the return value + should point to the original entity. + + Args: + target: An entity that supports append semantics. + element: The element to append. + + Returns: + Same as target, after the append was performed. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(target, element) + else: + return _py_append(target, element) + + +def _tf_tensorarray_append(target, element): + """Overload of append that stages a TensorArray write at the last position.""" + return target.write(target.size(), element) + + +def _py_append(target, element): + """Overload of append that executes a Python list append.""" + target.append(element) + return target diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py new file mode 100644 index 00000000000..577d28c34da --- /dev/null +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for data_structures module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class AppendTest(test.TestCase): + + def test_tf_tensorarray(self): + l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) + l1 = data_structures.append(l, 1) + l2 = data_structures.append(l1, 2) + with self.test_session() as sess: + self.assertAllEqual(sess.run(l1.stack()), [1]) + self.assertAllEqual(sess.run(l2.stack()), [1, 2]) + + def test_python(self): + l = [] + self.assertAllEqual(data_structures.append(l, 1), [1]) + self.assertAllEqual(data_structures.append(l, 2), [1, 2]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/applications/densenet/__init__.py b/tensorflow/contrib/autograph/operators/dispatch_context.py similarity index 60% rename from tensorflow/python/keras/applications/densenet/__init__.py rename to tensorflow/contrib/autograph/operators/dispatch_context.py index 6b8ea839207..097002465bd 100644 --- a/tensorflow/python/keras/applications/densenet/__init__.py +++ b/tensorflow/contrib/autograph/operators/dispatch_context.py @@ -12,18 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""DenseNet Keras applications.""" +"""Structures that allow uniform control over the dispatch process.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.densenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet121 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet169 -from tensorflow.python.keras._impl.keras.applications.densenet import DenseNet201 -from tensorflow.python.keras._impl.keras.applications.densenet import preprocess_input +import collections -del absolute_import -del division -del print_function + +# TODO(mdan): This is where macro override controls fit. + + +class DispatchContext(collections.namedtuple( + 'DispatchContext', + ('options',))): + """Allows passing additional parameters to the specific implementations. + + Attributes: + options: Optional dict of extra arguments that may be required by specific + implementations. + """ + + def option(self, name): + return self.options[name] + + +NO_CTX = DispatchContext(options={}) diff --git a/tensorflow/contrib/autograph/pyct/ast_util.py b/tensorflow/contrib/autograph/pyct/ast_util.py index 4a70bab4402..c4f82d11708 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util.py +++ b/tensorflow/contrib/autograph/pyct/ast_util.py @@ -23,10 +23,11 @@ import ast import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser class CleanCopier(gast.NodeVisitor): - """Copy AST nodes. + """Copies AST nodes. The copied nodes will ignore almost all fields that are prefixed by '__'. Exceptions make some annotations. @@ -106,3 +107,79 @@ def keywords_to_dict(keywords): keys.append(gast.Str(kw.arg)) values.append(kw.value) return gast.Dict(keys=keys, values=values) + + +class PatternMatcher(gast.NodeVisitor): + """Matches a node against a pattern represented by a node. + + The pattern may contain wildcards represented by the symbol '_'. + """ + + def __init__(self, pattern): + self.pattern = pattern + self.pattern_stack = [] + self.matches = True + + def compare_and_visit(self, node, pattern): + self.pattern_stack.append(self.pattern) + self.pattern = pattern + self.generic_visit(node) + self.pattern = self.pattern_stack.pop() + + def no_match(self): + self.matches = False + return False + + def is_wildcard(self, p): + if isinstance(p, (list, tuple)) and len(p) == 1: + p, = p + if isinstance(p, gast.Name) and p.id == '_': + return True + if p == '_': + return True + return False + + def generic_visit(self, node): + if not self.matches: + return + + pattern = self.pattern + for f in node._fields: + if f.startswith('__'): + continue + + if not hasattr(node, f): + if hasattr(pattern, f) and getattr(pattern, f): + return self.no_match() + else: + continue + if not hasattr(pattern, f): + return self.no_match() + + v = getattr(node, f) + p = getattr(pattern, f) + + if self.is_wildcard(p): + continue + if isinstance(v, (list, tuple)): + if not isinstance(p, (list, tuple)) or len(v) != len(p): + return self.no_match() + for v_item, p_item in zip(v, p): + self.compare_and_visit(v_item, p_item) + elif isinstance(v, (gast.AST, ast.AST)): + if not isinstance(v, type(p)) and not isinstance(p, type(v)): + return self.no_match() + self.compare_and_visit(v, p) + else: + # Assume everything else is a value type. + if v != p: + return self.no_match() + + +def matches(node, pattern): + if isinstance(pattern, str): + pattern = parser.parse_expression(pattern) + matcher = PatternMatcher(pattern) + matcher.visit(node) + return matcher.matches + diff --git a/tensorflow/contrib/autograph/pyct/ast_util_test.py b/tensorflow/contrib/autograph/pyct/ast_util_test.py index 8faf92c705d..3afa04a5068 100644 --- a/tensorflow/contrib/autograph/pyct/ast_util_test.py +++ b/tensorflow/contrib/autograph/pyct/ast_util_test.py @@ -85,7 +85,33 @@ class AstUtilTest(test.TestCase): output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),) result, _ = compiler.ast_to_object(output) self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'}) - print(d) + + def assertMatch(self, target_str, pattern_str): + node = parser.parse_expression(target_str) + pattern = parser.parse_expression(pattern_str) + self.assertTrue(ast_util.matches(node, pattern)) + + def assertNoMatch(self, target_str, pattern_str): + node = parser.parse_expression(target_str) + pattern = parser.parse_expression(pattern_str) + self.assertFalse(ast_util.matches(node, pattern)) + + def test_matches_symbols(self): + self.assertMatch('foo', '_') + self.assertNoMatch('foo()', '_') + self.assertMatch('foo + bar', 'foo + _') + self.assertNoMatch('bar + bar', 'foo + _') + self.assertNoMatch('foo - bar', 'foo + _') + + def test_matches_function_args(self): + self.assertMatch('super(Foo, self).__init__(arg1, arg2)', + 'super(_).__init__(_)') + self.assertMatch('super().__init__()', 'super(_).__init__(_)') + self.assertNoMatch('super(Foo, self).bar(arg1, arg2)', + 'super(_).__init__(_)') + self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)') + self.assertNoMatch('super(Foo, self).__init__()', + 'super(Bar, _).__init__(_)') if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py index 63361cc4f25..eef74599a7d 100644 --- a/tensorflow/contrib/autograph/pyct/inspect_utils.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py @@ -63,16 +63,27 @@ def getnamespace(f): return namespace +def _get_unbound_function(m): + # TODO(mdan): Figure out why six.get_unbound_function fails in some cases. + # The failure case is for tf.keras.Model. + if hasattr(m, 'im_func'): + return m.im_func + return m + + def getdefiningclass(m, owner_class): """Resolves the class (e.g. one of the superclasses) that defined a method.""" - m = six.get_unbound_function(m) - last_defining = owner_class - for superclass in tf_inspect.getmro(owner_class): + # Normalize bound functions to their respective unbound versions. + m = _get_unbound_function(m) + for superclass in owner_class.__bases__: if hasattr(superclass, m.__name__): superclass_m = getattr(superclass, m.__name__) - if six.get_unbound_function(superclass_m) == m: - last_defining = superclass - return last_defining + if _get_unbound_function(superclass_m) is m: + return superclass + elif hasattr(m, '__self__') and m.__self__ == owner_class: + # Python 3 class methods only work this way it seems :S + return superclass + return owner_class def getmethodclass(m): diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py index cf841dae814..1a212f676a6 100644 --- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -243,6 +243,10 @@ class InspectUtilsTest(test.TestCase): def bar(self): pass + @classmethod + def class_method(cls): + pass + class Subclass(Superclass): def foo(self): @@ -257,6 +261,9 @@ class InspectUtilsTest(test.TestCase): inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass) self.assertTrue( inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is + Superclass) def test_isbuiltin(self): self.assertTrue(inspect_utils.isbuiltin(range)) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD index 83f3bafc421..8064a967cd3 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/BUILD +++ b/tensorflow/contrib/autograph/pyct/static_analysis/BUILD @@ -19,6 +19,7 @@ py_library( srcs = [ "activity.py", "annos.py", + "cfg.py", "live_values.py", "type_info.py", ], @@ -43,6 +44,19 @@ py_test( ], ) +py_test( + name = "cfg_test", + srcs = ["cfg_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":static_analysis", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "live_values_test", srcs = ["live_values_test.py"], diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py index b6817e9d75b..4d7b0cbb7b8 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py @@ -23,11 +23,12 @@ import copy import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import qual_names from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.contrib.autograph.pyct.qual_names import QN from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). +# TODO(alexbw): Ignore named literals (e.g. None) class Scope(object): @@ -43,16 +44,20 @@ class Scope(object): used: identifiers referenced in this scope """ - def __init__(self, parent, isolated=True): + def __init__(self, parent, isolated=True, add_unknown_symbols=False): """Create a new scope. Args: parent: A Scope or None. isolated: Whether the scope is isolated, that is, whether variables created in this scope should be visible to the parent scope. + add_unknown_symbols: Whether to handle attributed and subscripts + without having first seen the base name. + E.g., analyzing the statement 'x.y = z' without first having seen 'x'. """ self.isolated = isolated self.parent = parent + self.add_unknown_symbols = add_unknown_symbols self.modified = set() self.created = set() self.used = set() @@ -133,18 +138,22 @@ class Scope(object): def mark_param(self, name): self.params.add(name) - def mark_creation(self, name): + def mark_creation(self, name, writes_create_symbol=False): + """Mark a qualified name as created.""" if name.is_composite(): parent = name.parent - if self.has(parent): - # This is considered mutation of the parent, not creation. - # TODO(mdan): Is that really so? + if not writes_create_symbol: return else: - raise ValueError('Unknown symbol "%s".' % parent) + if not self.has(parent): + if self.add_unknown_symbols: + self.mark_read(parent) + else: + raise ValueError('Unknown symbol "%s".' % parent) self.created.add(name) def mark_write(self, name): + """Marks the given symbol as modified in the current scope.""" self.modified.add(name) if self.isolated: self.mark_creation(name) @@ -163,28 +172,61 @@ class Scope(object): class ActivityAnalyzer(transformer.Base): - """Annotates nodes with local scope information. See Scope.""" + """Annotates nodes with local scope information. - def __init__(self, context, parent_scope): + See Scope. + + The use of this class requires that qual_names.resolve() has been called on + the node. This class will ignore nodes have not been + annotated with their qualified names. + """ + + def __init__(self, context, parent_scope=None, add_unknown_symbols=False): super(ActivityAnalyzer, self).__init__(context) - self.scope = Scope(parent_scope) + self.scope = Scope(parent_scope, None, add_unknown_symbols) self._in_return_statement = False + self._in_aug_assign = False - def _track_symbol(self, node): - # This can happen when we have an attribute (or subscript) on a function - # call. Example: a().b + @property + def _in_constructor(self): + if len(self.enclosing_entities) > 1: + innermost = self.enclosing_entities[-1] + parent = self.enclosing_entities[-2] + return isinstance(parent, gast.ClassDef) and innermost.name == '__init__' + return False + + def _node_sets_self_attribute(self, node): + if anno.hasanno(node, anno.Basic.QN): + qn = anno.getanno(node, anno.Basic.QN) + # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'. + if qn.has_attr and qn.parent.qn == ('self',): + return True + return False + + def _track_symbol(self, + node, + composite_writes_alter_parent=False, + writes_create_symbol=False): + # A QN may be missing when we have an attribute (or subscript) on a function + # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): return qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): self.scope.mark_write(qn) + if qn.is_composite and composite_writes_alter_parent: + self.scope.mark_write(qn.parent) + if writes_create_symbol: + self.scope.mark_creation(qn, writes_create_symbol=True) + if self._in_aug_assign: + self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - # TODO(mdan): This bay be incorrect with nested functions. + # TODO(mdan): This may be incorrect with nested functions. # For nested functions, we'll have to add the notion of hiding args from # the parent scope, not writing to them. self.scope.mark_creation(qn) @@ -200,6 +242,14 @@ class ActivityAnalyzer(transformer.Base): if self._in_return_statement: self.scope.mark_returned(qn) + def visit_AugAssign(self, node): + # Special rules for AugAssign. In Assign, the target is only written, + # but in AugAssig (e.g. a += b), the target is both read and written. + self._in_aug_assign = True + self.generic_visit(node) + self._in_aug_assign = False + return node + def visit_Name(self, node): self.generic_visit(node) self._track_symbol(node) @@ -207,7 +257,18 @@ class ActivityAnalyzer(transformer.Base): def visit_Attribute(self, node): self.generic_visit(node) - self._track_symbol(node) + if self._in_constructor and self._node_sets_self_attribute(node): + self._track_symbol( + node, composite_writes_alter_parent=True, writes_create_symbol=True) + else: + self._track_symbol(node) + return node + + def visit_Subscript(self, node): + self.generic_visit(node) + # Subscript writes (e.g. a[b] = "value") are considered to modify + # both the element itself (a[b]) and its parent (a). + self._track_symbol(node, composite_writes_alter_parent=True) return node def visit_Print(self, node): @@ -262,7 +323,7 @@ class ActivityAnalyzer(transformer.Base): def visit_FunctionDef(self, node): if self.scope: - qn = QN(node.name) + qn = qual_names.QN(node.name) self.scope.mark_write(qn) current_scope = self.scope body_scope = Scope(current_scope, isolated=True) @@ -322,5 +383,32 @@ class ActivityAnalyzer(transformer.Base): return node +def get_read(node, context): + """Return the variable names as QNs (qual_names.py) read by this statement.""" + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.used + + +def get_updated(node, context): + """Return the variable names created or mutated by this statement. + + This function considers assign statements, augmented assign statements, and + the targets of for loops, as well as function arguments. + For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and + `y`, `for i in range(x)` will return `i`, etc. + Args: + node: An AST node + context: An EntityContext instance + + Returns: + A set of variable names (QNs, see qual_names.py) of all the variables + created or mutated. + """ + analyzer = ActivityAnalyzer(context, None, True) + analyzer.visit(node) + return analyzer.scope.created | analyzer.scope.modified + + def resolve(node, context, parent_scope=None): return ActivityAnalyzer(context, parent_scope).visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py index 65e1a8f0ea2..fdbd349af9d 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py @@ -123,7 +123,7 @@ class ActivityAnalyzerTest(test.TestCase): recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) - return node + return node, ctx def test_local_markers(self): @@ -133,7 +133,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= 1 return b - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) self.assertFalse( anno.getanno(node.body[0].body[0].value, NodeAnno.IS_LOCAL)) # c in b = c @@ -144,10 +144,22 @@ class ActivityAnalyzerTest(test.TestCase): anno.getanno(node.body[0].body[2].value, NodeAnno.IS_LOCAL)) # b in return b + def assertSymbolSetsAre(self, expected, actual, name): + expected = set(expected) + actual = set(str(s) for s in actual) + self.assertSetEqual( + expected, actual, 'for symbol set: %s\n' + ' Expected: %s\n' + ' Got: %s\n' + ' Missing: %s\n' + ' Extra: %s\n' % (name.upper(), expected, actual, + expected - actual, actual - expected)) + def assertScopeIsRmc(self, scope, used, modified, created): - self.assertItemsEqual(used, tuple(str(s) for s in scope.used)) - self.assertItemsEqual(modified, tuple(str(s) for s in scope.modified)) - self.assertItemsEqual(created, tuple(str(s) for s in scope.created)) + """Assert the scope contains specific used, modified & created variables.""" + self.assertSymbolSetsAre(used, scope.used, 'read') + self.assertSymbolSetsAre(modified, scope.modified, 'modified') + self.assertSymbolSetsAre(created, scope.created, 'created') def test_print_statement(self): @@ -157,7 +169,7 @@ class ActivityAnalyzerTest(test.TestCase): print(a, b) return c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) print_node = node.body[0].body[2] if isinstance(print_node, gast.Print): # Python 2 @@ -172,7 +184,7 @@ class ActivityAnalyzerTest(test.TestCase): # arguments. self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ()) - def test_call(self): + def test_call_args(self): def test_fn(a): b = 0 @@ -180,13 +192,64 @@ class ActivityAnalyzerTest(test.TestCase): foo(a, b) # pylint:disable=undefined-variable return c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) call_node = node.body[0].body[2].value # We basically need to detect which variables are captured by the call # arguments. self.assertScopeIsRmc( anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ()) + def test_call_args_attributes(self): + + def foo(*_): + pass + + def test_fn(a): + a.c = 0 + foo(a.b, a.c) + return a.d + + node, _ = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[1].value + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), + ('a', 'a.b', 'a.c'), + (), + (), + ) + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, + ('a', 'a.b', 'a.c', 'a.d', 'foo'), + ('a.c',), + ('a',), + ) + + def test_call_args_subscripts(self): + + def foo(*_): + pass + + def test_fn(a): + b = 1 + c = 2 + foo(a[0], a[b]) + return a[c] + + node, _ = self._parse_and_analyze(test_fn) + call_node = node.body[0].body[2].value + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE), + ('a', 'a[0]', 'a[b]', 'b'), + (), + (), + ) + self.assertScopeIsRmc( + anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent, + ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'), + ('b', 'c'), + ('a', 'b', 'c'), + ) + def test_while(self): def test_fn(a): @@ -196,7 +259,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= 1 return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) while_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), @@ -216,7 +279,7 @@ class ActivityAnalyzerTest(test.TestCase): b -= 1 return b, c - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) for_node = node.body[0].body[1] self.assertScopeIsRmc( anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',)) @@ -237,7 +300,7 @@ class ActivityAnalyzerTest(test.TestCase): u = -y return z, u - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) if_node = node.body[0].body[0] self.assertScopeIsRmc( anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'), @@ -253,7 +316,72 @@ class ActivityAnalyzerTest(test.TestCase): anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) - def test_nested_if_else_creation(self): + def test_if_attributes(self): + + def test_fn(a): + if a > 0: + a.b = -a.c + d = 2 * a + else: + a.b = a.c + d = 1 + return d + + node, _ = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), + ('a', 'a.c'), + ('a.b', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'a.c'), + ('a.b', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, + ('a', 'a.c', 'd'), + ('a.b', 'd'), + ('a', 'd'), + ) + + def test_if_subscripts(self): + + def test_fn(a, b, c, e): + if a > 0: + a[b] = -a[c] + d = 2 * a + else: + a[0] = e + d = 1 + return d + + node, _ = self._parse_and_analyze(test_fn) + if_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.BODY_SCOPE), + ('a', 'b', 'c', 'a[c]'), + ('a', 'a[b]', 'd'), + ('d',), + ) + # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"? + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), + ('a', 'e'), + ('a', 'a[0]', 'd'), + ('d',), + ) + self.assertScopeIsRmc( + anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, + ('a', 'b', 'c', 'd', 'e', 'a[c]'), + ('a', 'd', 'a[b]', 'a[0]'), + ('a', 'b', 'c', 'd', 'e'), + ) + + def test_nested_if(self): def test_fn(b): if b > 0: @@ -263,7 +391,7 @@ class ActivityAnalyzerTest(test.TestCase): a = b * b return a - node = self._parse_and_analyze(test_fn) + node, _ = self._parse_and_analyze(test_fn) inner_if_node = node.body[0].body[0].body[0] self.assertScopeIsRmc( anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',), @@ -272,7 +400,7 @@ class ActivityAnalyzerTest(test.TestCase): anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',), ('a',)) - def test_function_def(self): + def test_nested_function(self): def test_fn(a): @@ -286,45 +414,152 @@ class ActivityAnalyzerTest(test.TestCase): b -= f(i) return b, c - node = self._parse_and_analyze(test_fn) - fndef_node = node.body[0].body[0] + node, _ = self._parse_and_analyze(test_fn) + fn_def_node = node.body[0].body[0] self.assertScopeIsRmc( - anno.getanno(fndef_node, + anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) self.assertScopeIsRmc( - anno.getanno(fndef_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( + anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( 'x', 'y', )) - def test_call_with_composite_names(self): + def test_constructor_attributes(self): - def foo(*_): - pass + class TestClass(object): + + def __init__(self, a): + self.b = a + self.b.c = 1 + + node, _ = self._parse_and_analyze(TestClass) + init_node = node.body[0].body[0] + self.assertScopeIsRmc( + anno.getanno(init_node, NodeAnno.BODY_SCOPE), + ('self', 'a', 'self.b'), + ('self', 'self.b', 'self.b.c'), + ('self', 'a', 'self.b'), + ) + + def test_aug_assign_subscripts(self): def test_fn(a): - foo(a.b, a.c) - if a > 0: - a.b = 2 - else: - d = 2 - d.e = a.c - f = d.e + 1 - a.c = f + a[0] += 1 - node = self._parse_and_analyze(test_fn) - call_node = node.body[0].body[0].value + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] self.assertScopeIsRmc( - anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (), - ()) - if_node = node.body[0].body[1] + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a', 'a[0]'), + ('a', 'a[0]'), + ('a',), + ) + + def test_return_vars_are_read(self): + + def test_fn(a, b, c): # pylint: disable=unused-argument + return c + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ()) + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('c',), + (), + ( + 'a', + 'b', + 'c', + ), + ) + + def test_aug_assign(self): + + def test_fn(a, b): + a += b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] self.assertScopeIsRmc( - anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), - ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f')) + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('a', 'b'), + ('a'), + ('a', 'b'), + ) + + def test_aug_assign_rvalues(self): + + a = dict(bar=3) + + def foo(): + return a + + def test_fn(x): + foo()['bar'] += x + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), + ('foo', 'x'), + (), + ('x',), + ) + + def test_params_created(self): + + def test_fn(a, b): # pylint: disable=unused-argument + return b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + self.assertScopeIsRmc( + anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')), + (('a', 'b'))) + + def test_get_read(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, set(map(qual_names.QN, ('test_fn', 'x', 'y')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + read_vars = activity.get_read(node, ctx) + self.assertEqual(read_vars, + set(map(qual_names.QN, ('test_fn2', 'x', 'y', 'z')))) + + def test_get_updated(self): + + def test_fn(x, y): + z = test_fn(x, y) + return z + + node, ctx = self._parse_and_analyze(test_fn) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) + + def test_fn2(x, y, z): + z += test_fn2(x, y, z) + return z + + node, ctx = self._parse_and_analyze(test_fn2) + node = node.body[0].body[0] + updated_vars = activity.get_updated(node, ctx) + self.assertEqual(updated_vars, set(map(qual_names.QN, ('z')))) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py new file mode 100644 index 00000000000..ad97fdfa8e7 --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg.py @@ -0,0 +1,445 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow graph analysis. + +Given a Python AST we construct a control flow graph, with edges both to the +next and previous statements (so it can easily walk the graph both ways). Its +nodes contain the AST of the statements. It can then perform forward or backward +analysis on this CFG. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import namedtuple +import functools +import operator + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct.static_analysis import activity + + +class CfgNode(object): + """A node in the CFG.""" + __slots__ = ['next', 'value', 'prev'] + + def __init__(self, value): + self.next = set() + self.prev = set() + self.value = value + + +class Cfg(namedtuple('Cfg', ['entry', 'exit'])): + """A Control Flow Graph. + + Each statement is represented as a node. For control flow statements such + as conditionals and loops the conditional itself is a node which either + branches or cycles, respectively. + Attributes: + entry: The entry node, which contains the `gast.arguments` node of the + function definition. + exit: The exit node. This node is special because it has no value (i.e. no + corresponding AST node). This is because Python functions can have + multiple return statements. + """ + pass + + +class CfgBuilder(gast.NodeVisitor): + """Construct a control flow graph. + + Construct a CFG starting from a FunctionDef node. + Usage: + cfg_obj = CfgBuilder().build_cfg(fndef_node) + """ + + def __init__(self): + # The current leaves of the CFG + self.current_leaves = [] + # TODO(alexbw): generalize to break, return, continue, yield, etc. + # A stack of lists, tracking continue statements + self.continue_ = [] + # A stack of lists tracking break nodes + self.break_ = [] + + def set_current_leaves(self, cfg_node): + """Link this cfg_node to the current leaves. + + This is the central function for building the CFG. It links the current + head cfg_nodes to the passed cfg_node. It then resets the head to the + passed cfg_node. + + Args: + cfg_node: A CfgNode instance. + """ + for head in self.current_leaves: + head.next.add(cfg_node) + # While we're linking the CFG forward, add backlinks + cfg_node.prev.add(head) + self.current_leaves = [cfg_node] + + def build_cfg(self, node): + """Build a CFG for a function. + + Implementation of building a CFG for dataflow analysis. See, e.g.: + https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf + + Args: + node: A function definition the body of which to analyze. + Returns: + A CFG object. + Raises: + TypeError: If the input is not a function definition. + """ + if not isinstance(node, gast.FunctionDef): + raise TypeError('input must be a function definition') + entry_cfg_node = CfgNode(node.args) + self.current_leaves = [entry_cfg_node] + self.visit_statements(node.body) + exit_cfg_node = CfgNode(None) + self.set_current_leaves(exit_cfg_node) + return Cfg(entry_cfg_node, exit_cfg_node) + + def visit_statements(self, nodes): + for node in nodes: + # Check for control flow + if isinstance(node, (gast.For, gast.While, gast.If, gast.Try, gast.Break, + gast.Continue, gast.With)): + self.visit(node) + else: + expr = CfgNode(node) + self.set_current_leaves(expr) + + def generic_visit(self, node): + raise ValueError('unknown control flow') + + def visit_If(self, node): + # TODO(alexbw): change this to use immutable tuples instead of lists + # The current head will hold the conditional + test = CfgNode(node.test) + self.set_current_leaves(test) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves = [test] + # Handle the orelse + self.visit_statements(node.orelse) + self.current_leaves.extend(body_exit) + + def visit_While(self, node): + test = CfgNode(node.test) + self.set_current_leaves(test) + # Start a new level of nesting + self.break_.append([]) + self.continue_.append([]) + # Handle the body + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(test) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_For(self, node): + iter_ = CfgNode(node.iter) + self.set_current_leaves(iter_) + self.break_.append([]) + self.continue_.append([]) + self.visit_statements(node.body) + body_exit = self.current_leaves + self.current_leaves.extend(self.continue_.pop()) + self.set_current_leaves(iter_) + # Handle the orelse + self.visit_statements(node.orelse) + # The break statements and the test go to the next node + self.current_leaves.extend(self.break_.pop()) + # Body and orelse statements can reach out of the loop + self.current_leaves.extend(body_exit) + + def visit_Break(self, node): + self.break_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Continue(self, node): + self.continue_[-1].extend(self.current_leaves) + self.current_leaves[:] = [] + + def visit_Try(self, node): + self.visit_statements(node.body) + body = self.current_leaves + handlers = [] + for handler in node.handlers: + self.current_leaves = body[:] + self.visit_statements(handler.body) + handlers.extend(self.current_leaves) + self.current_leaves = body + self.visit_statements(node.orelse) + self.current_leaves = handlers + self.current_leaves + self.visit_statements(node.finalbody) + + def visit_With(self, node): + for item in node.items: + self.set_current_leaves(CfgNode(item)) + self.visit_statements(node.body) + + +# TODO(alexbw): once CFG analysis occurs at a block level, +# this extra class will not be necessary +class PropagateAnalysis(gast.NodeVisitor): + """Port analysis annotations from statements to their enclosing blocks.""" + + def __init__(self, analysis): + self.transfer_fn = analysis.transfer_fn + self.in_label = analysis.in_label + self.out_label = analysis.out_label + super(PropagateAnalysis, self).__init__() + + def visit_If(self, node): + # Depth-first. + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + outgoing |= anno.getanno(node.test, self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_For(self, node): + self.generic_visit(node) + incoming = set(anno.getanno(node.body[0], self.in_label)) + incoming -= set((anno.getanno(node.target, anno.Basic.QN),)) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, frozenset(incoming)) + anno.setanno(node, self.out_label, outgoing) + + def visit_While(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + incoming |= anno.getanno(node.test, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + if node.orelse: + orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label) + outgoing = self.transfer_fn(outgoing, orelse_outgoing) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + def visit_With(self, node): + self.generic_visit(node) + incoming = anno.getanno(node.body[0], self.in_label) + for item in node.items: + incoming |= anno.getanno(item, self.in_label) + outgoing = anno.getanno(node.body[-1], self.out_label) + anno.setanno(node, self.in_label, incoming) + anno.setanno(node, self.out_label, outgoing) + + +# TODO(alexbw): Abstract the CFG walking machinery into a superclass +# which is parameterized on which fields it selects when walking. +# TODO(alexbw): Abstract the application of dataflow analysis +class Forward(object): + """Forward analysis on CFG. + + Args: + label: A name for this analysis e.g. 'active' for activity analysis. The AST + nodes in the CFG will be given annotations 'name_in', 'name_out', + 'name_gen' and 'name_kill' which contain the incoming values, outgoing + values, values generated by the statement, and values deleted by the + statement respectively. + transfer_fn: Either the AND or OR operator. If the AND operator is used it + turns into forward must analysis (i.e. a value will only be carried + forward if it appears on all incoming paths). The OR operator means that + forward may analysis is done (i.e. the union of incoming values will be + taken). + """ + + def __init__(self, label, context, transfer_fn=operator.or_): + self.transfer_fn = transfer_fn + self.context = context + self.out_label = label + '_out' + self.in_label = label + '_in' + self.gen_label = label + '_gen' + self.kill_label = label + '_kill' + + # TODO(alexbw): see if we can simplify by visiting breadth-first + def visit(self, node): + """Depth-first walking the CFG, applying dataflow information propagtion.""" + # node.value is None only for the exit CfgNode. + if not node.value: + return + + if anno.hasanno(node.value, self.out_label): + before = hash(anno.getanno(node.value, self.out_label)) + else: + before = None + preds = [ + anno.getanno(pred.value, self.out_label) + for pred in node.prev + if anno.hasanno(pred.value, self.out_label) + ] + if preds: + incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0]) + else: + incoming = frozenset() + anno.setanno(node.value, self.in_label, incoming) + gen, kill = self.get_gen_kill(node, incoming) + anno.setanno(node.value, self.gen_label, gen) + anno.setanno(node.value, self.kill_label, kill) + anno.setanno(node.value, self.out_label, (incoming - kill) | gen) + + if hash(anno.getanno(node.value, self.out_label)) != before: + for succ in node.next: + self.visit(succ) + + def get_gen_kill(self, cfg_node, incoming): + """Calculate Gen and Kill properties of a CFG node in dataflow analysis. + + A function which takes the CFG node as well as a set of incoming + values. It must return a set of newly generated values by the statement as + well as a set of deleted (killed) values. + + Args: + cfg_node: A CfgNode instance. + incoming: + """ + raise NotImplementedError() + + +class Backward(Forward): + """Backward analysis on CFG.""" + + def visit(self, cfg_node): + # cfg_node.value is None for the exit node, which will be visited only once + if not cfg_node.value: + for pred in cfg_node.prev: + self.visit(pred) + return + + if anno.hasanno(cfg_node.value, self.in_label): + before = hash(anno.getanno(cfg_node.value, self.in_label)) + else: + before = None + succs = [ + anno.getanno(succ.value, self.in_label) + for succ in cfg_node.next + if anno.hasanno(succ.value, self.in_label) + ] + if succs: + incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0]) + else: + incoming = frozenset() + anno.setanno(cfg_node.value, self.out_label, incoming) + gen, kill = self.get_gen_kill(cfg_node, incoming) + anno.setanno(cfg_node.value, self.gen_label, gen) + anno.setanno(cfg_node.value, self.kill_label, kill) + anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen) + if hash(anno.getanno(cfg_node.value, self.in_label)) != before: + for pred in cfg_node.prev: + self.visit(pred) + + +def run_analyses(node, analyses): + """Perform dataflow analysis on all functions within an AST. + + Args: + node: An AST node on which to run dataflow analysis. + analyses: Either an instance of the Forward or Backward dataflow analysis + class, or a list or tuple of them. + + Returns: + node: The node, but now with annotations on the AST nodes containing the + results of the dataflow analyses. + """ + if not isinstance(analyses, (tuple, list)): + analyses = (analyses,) + for analysis in analyses: + if not isinstance(analysis, (Forward, Backward)): + raise TypeError('not a valid forward analysis object') + + for child_node in gast.walk(node): + if isinstance(child_node, gast.FunctionDef): + cfg_obj = CfgBuilder().build_cfg(child_node) + for analysis in analyses: + if isinstance(analysis, Backward): + analysis.visit(cfg_obj.exit) + elif isinstance(analysis, Forward): + analysis.visit(cfg_obj.entry) + for analysis in analyses: + PropagateAnalysis(analysis).visit(node) + return node + + +class Liveness(Backward): + """Perform a liveness analysis. + + Each statement is annotated with a set of variables that may be used + later in the program. + """ + + def __init__(self, context): + super(Liveness, self).__init__('live', context) + + def get_gen_kill(self, node, _): + # A variable's parents are live if it is live + # e.g. x is live if x.y is live. This means gen needs to return + # all parents of a variable (if it's an Attribute or Subscript). + # This doesn't apply to kill (e.g. del x.y doesn't affect liveness of x) + gen = activity.get_read(node.value, self.context) + gen = functools.reduce(lambda left, right: left | right.support_set, gen, + gen) + kill = activity.get_updated(node.value, self.context) + return gen, kill + + +class ReachingDefinitions(Forward): + """Perform reaching definition analysis. + + Each statement is annotated with a set of (variable, definition) pairs. + """ + + def __init__(self, context): + super(ReachingDefinitions, self).__init__('definitions', context) + + def get_gen_kill(self, node, incoming): + definitions = activity.get_updated(node.value, self.context) + gen = frozenset((id_, node.value) for id_ in definitions) + kill = frozenset(def_ for def_ in incoming if def_[0] in definitions) + return gen, kill + + +class Defined(Forward): + """Perform defined variable analysis. + + Each statement is annotated with a set of variables which are guaranteed to + be defined at that point. + """ + + def __init__(self, context): + super(Defined, self).__init__('defined', context, transfer_fn=operator.and_) + + def get_gen_kill(self, node, _): + gen = activity.get_updated(node.value, self.context) + return gen, frozenset() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py new file mode 100644 index 00000000000..fc07fa3447b --- /dev/null +++ b/tensorflow/contrib/autograph/pyct/static_analysis/cfg_test.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cfg module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import context +from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct.static_analysis import cfg +from tensorflow.python.platform import test + + +class CFGTest(test.TestCase): + + def _parse_and_analyze(self, test_fn, namespace, arg_types=None): + arg_types = arg_types or {} + node, source = parser.parse_entity(test_fn) + ctx = context.EntityContext( + namer=None, + source_code=source, + source_file=None, + namespace=namespace, + arg_values=None, + arg_types=arg_types, + owner_type=None, + recursive=True) + node = qual_names.resolve(node) + return node, ctx + + def _check_anno_matches(self, node, anno_name, var_names): + if isinstance(var_names, str): + var_names = (var_names,) + qual_vars = set() + for var_name in var_names: + if isinstance(var_name, str): + if '[' in var_name or ']' in var_name: + raise ValueError('Annotation matching not supported with subscript.') + if '.' not in var_name: + qual_vars.add(qual_names.QN(var_name)) + else: + attrs = var_name.split('.') + this_qn = functools.reduce(qual_names.QN, attrs[1:], + qual_names.QN(attrs[0])) + qual_vars.add(this_qn) + self.assertEqual(anno.getanno(node, anno_name), qual_vars) + + def test_reaching(self): + + def f(x): + print(x) + while True: + x = x + x = x + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + # Only the argument reaches the expression + def_in = anno.getanno(body[0], 'definitions_in') + # One element, x, from arguments + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,))) + + while_body = body[1].body + def_in = anno.getanno(while_body[0], 'definitions_in') + # One definition, two possible sources. + # - One from an assignment (if the loop is entered) + # - The other from the arguments (if loop is not entered) + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def_in = anno.getanno(while_body[1], 'definitions_in') + # If we've reached this line, the only reaching definition of x is the + # Assign node in previous line + self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,))) + + def_in = anno.getanno(body[2], 'definitions_in') + # Same situation as while_body[0] + self.assertEqual( + set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign))) + + def test_defined(self): + + def f(x): + if x: + y = 2 # pylint: disable=unused-variable + return x + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + # only x is for sure defined at the end + self._check_anno_matches(body[1], 'defined_in', 'x') + # at the end of the if body both x and y are defined + if_body = body[0].body + self._check_anno_matches(if_body[0], 'defined_out', ('x', 'y')) + + def _get_live_annotated_fnbody(self, f): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Liveness(ctx)) + body = node.body[0].body + return body + + def test_live_straightline(self): + + def f1(x): + a = g(x) # pylint: disable=undefined-variable + b = h(a) # pylint: disable=undefined-variable, unused-variable + return x + + body = self._get_live_annotated_fnbody(f1) + self._check_anno_matches(body[1], 'live_in', ('a', 'h', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + self._check_anno_matches(body[0], 'live_in', ('g', 'h', 'x')) + self._check_anno_matches(body[2], 'live_out', ()) + + def test_live_stacked_conds_with_else(self): + + def f2(x, a): # pylint: disable=unused-argument + if a > 0: # x should not be live + x = 0 + if a > 1: + x = 1 + else: + x = 2 + + body = self._get_live_annotated_fnbody(f2) + self._check_anno_matches(body[0], 'live_in', ('a')) + self._check_anno_matches(body[1], 'live_in', ('a')) + + def test_live_stacked_conds(self): + + def f3(x, a): + if a > 0: # x and a should be live + x = 0 + if a > 1: # x and a should be live_in + x = 1 + return x # x should be live + + body = self._get_live_annotated_fnbody(f3) + self._check_anno_matches(body[0], 'live_in', ('a', 'x')) + self._check_anno_matches(body[1], 'live_in', ('a', 'x')) + self._check_anno_matches(body[2], 'live_in', ('x')) + + def test_live_possibly_unused_cond(self): + + def f4(x, a): + if a > 0: # x should be live + x = 0 + x += 1 + + body = self._get_live_annotated_fnbody(f4) + self._check_anno_matches(body[0], 'live_in', ('x', 'a')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_attribute_in_cond(self): + + def f5(x, a): + if a > 0: # x.y should be live + x.y = 0 + return x.y + + body = self._get_live_annotated_fnbody(f5) + self._check_anno_matches(body[0], 'live_in', ('x', 'x.y', 'a')) + + def test_live_noop(self): + + def f6(x): + return x # should this cause x.* to be live? + + body = self._get_live_annotated_fnbody(f6) + self._check_anno_matches(body[0], 'live_in', ('x')) + + def test_live_loop(self): + + def f7(x, n): + for i in range(n): + x += i + return x + + body = self._get_live_annotated_fnbody(f7) + self._check_anno_matches(body[0], 'live_in', ('x', 'n', 'range')) + self._check_anno_matches(body[1], 'live_in', ('x')) + + def test_live_context_manager(self): + + def f8(x, f): + with f: + x += 1 + + body = self._get_live_annotated_fnbody(f8) + self._check_anno_matches(body[0], 'live_in', ('f', 'x')) + + def test_node_equality(self): + node_a = gast.parse('y = x').body[0] + node_b = gast.parse('y = x').body[0] + self.assertNotEqual(node_a, node_b) + + def test_nested_functions_defined(self): + + def f(x): + y = x * 2 + + def g(z): + return z + y + + return g(x) + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('g', 'x', 'y')))) + + # TODO(alexbw): CFG analysis doesn't currently cross FunctionDef boundaries. + # NOTE: 'z' is easy to find, but 'y' is not identified as + # defined, because CFG analysis is applied with each function separately. + # fndef_body = body[1].body + # self.assertEqual( + # anno.getanno(fndef_body[0], 'defined_in'), + # frozenset(map(qual_names.QN, ('z', 'y')))) + + def test_nested_functions_dont_leak_definitions(self): + + def f(x): + print(x) + + def g(): + y = 2 + return y + + return g() # y is not defined here + + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.Defined(ctx)) + body = node.body[0].body + self.assertEqual( + anno.getanno(body[2], 'defined_in'), + frozenset(map(qual_names.QN, ('x', 'g')))) + + def test_loop_else(self): + + # Disabling useless-else-on-loop error, because 'break' and 'continue' + # canonicalization are a separate analysis pass, and here we test + # the CFG analysis in isolation. + def for_orelse(x): + y = 0 + for i in range(len(x)): + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + def while_orelse(x, i): + y = 0 + while x < 10: + x += i + else: # pylint: disable=useless-else-on-loop + y = 1 + return x, y + + for f in (for_orelse, while_orelse): + node, ctx = self._parse_and_analyze(f, {}) + cfg.run_analyses(node, cfg.ReachingDefinitions(ctx)) + body = node.body[0].body + return_node = body[-1] + reaching_defs = anno.getanno(return_node, 'definitions_in') + + # Y could be defined by Assign(Num(0)) or Assign(Num(1)) + # X could be defined as an argument or an AugAssign. + y_defs = [node for var, node in reaching_defs if str(var) == 'y'] + x_defs = [node for var, node in reaching_defs if str(var) == 'x'] + + self.assertEqual(set((gast.Assign,)), set(type(def_) for def_ in y_defs)) + self.assertEqual(set((0, 1)), set(def_.value.n for def_ in y_defs)) + self.assertEqual(len(y_defs), 2) + self.assertEqual( + set((gast.arguments, gast.AugAssign)), + set(type(def_) for def_ in x_defs)) + self.assertEqual(len(x_defs), 2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index 203aa3c3d18..c00946f9c41 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -48,6 +48,9 @@ from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect +# TODO(mdan): Remove the duplication between this and activity.py. +# In particular, the symbol definitions we track here could as well be tracked +# there because they follow the same rules for visibility. class Scope(object): """Tracks symbol value references. @@ -99,20 +102,16 @@ class TypeInfoResolver(transformer.Base): def __init__(self, context): super(TypeInfoResolver, self).__init__(context) self.scope = Scope(None) - self.function_level = 0 def visit_FunctionDef(self, node): self.scope = Scope(self.scope) - self.function_level += 1 - self.generic_visit(node) - self.function_level -= 1 + node = self.generic_visit(node) self.scope = self.scope.parent return node def _visit_block(self, block): self.scope = Scope(self.scope) - for i, n in enumerate(block): - block[i] = self.generic_visit(n) + block = self.visit_block(block) self.scope = self.scope.parent return block @@ -137,7 +136,7 @@ class TypeInfoResolver(transformer.Base): def _process_function_arg(self, arg_name): str_name = str(arg_name) - if self.function_level == 1 and str_name in self.context.arg_types: + if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. type_holder = arg_name.ast() @@ -168,16 +167,8 @@ class TypeInfoResolver(transformer.Base): anno.getanno(definition, 'element_type')) return node - def _process_tuple_assignment(self, source, t): - for i, e in enumerate(t.elts): - if isinstance(e, gast.Tuple): - self._process_tuple_assignment(source, e) - else: - self.scope.setval( - anno.getanno(e, anno.Basic.QN), - gast.Subscript(source, gast.Index(i), ctx=gast.Store())) - def _process_variable_assignment(self, source, targets): + # Special case: constructors. if isinstance(source, gast.Call): func = source.func if anno.hasanno(func, 'live_val'): @@ -190,15 +181,25 @@ class TypeInfoResolver(transformer.Base): # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - for t in targets: - if isinstance(t, gast.Tuple): - # need to recurse on the case of assigning nested tuples, - # ex. a, (b, c) = f() - self._process_tuple_assignment(source, t) - elif isinstance(t, (gast.Name, gast.Attribute)): - self.scope.setval(anno.getanno(t, anno.Basic.QN), source) + # Multiple targets mean multiple assignment. + for target in targets: + # Tuple target means unpacking. + if isinstance(target, (gast.Tuple, gast.List)): + for i, target_item in enumerate(target.elts): + # Two cases here: + # 1. Static unpacking, e.g. a, b = c, d + # 2. Dynamic unpacking, e.g. a, b = c + # The former case is optimized away. + if isinstance(source, (gast.Tuple, gast.List)): + source_item = source.elts[i] + else: + source_item = gast.Subscript(source, gast.Index(i), ctx=None) + self._process_variable_assignment(source_item, (target_item,)) + elif isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, source) else: - raise ValueError('Dont know how to handle assignment to %s' % t) + raise ValueError('assignment target has unknown type: %s' % target) def visit_With(self, node): for wi in node.items: @@ -218,19 +219,26 @@ class TypeInfoResolver(transformer.Base): # type that it specified. if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): - # Expecting the actual type to be the second argument. + if len(node.args) != 2: raise ValueError('"%s" must have exactly two parameters' % self.context.type_annotation_func) - if not anno.hasanno(node.args[0], anno.Basic.QN): + target_arg, type_arg = node.args + if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) - if not anno.hasanno(node.args[1], 'live_val'): - raise ValueError( - 'the second argument of "%s" must be statically resolvable' % - self.context.type_annotation_func) - target_symbol = anno.getanno(node.args[0], anno.Basic.QN) - element_type = anno.getanno(node.args[1], 'live_val') + if isinstance(type_arg, gast.Str): + element_type = type_arg.s + elif isinstance(type_arg, gast.Num): + element_type = type_arg.n + else: + if not anno.hasanno(type_arg, 'live_val'): + raise ValueError( + 'the second argument of "%s" must be statically resolvable' % + self.context.type_annotation_func) + element_type = anno.getanno(type_arg, 'live_val') + + target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given # data type. That in turn will cause future uses of the symbol # to receive the same type annotation. diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index c0de4a60430..46b7701624a 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -196,23 +196,46 @@ class TypeInfoResolverTest(test.TestCase): f_ref = node.body[0].body[1].value self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) - def test_nested_assignment(self): + def test_nested_unpacking(self): - def test_fn(foo): - a, (b, c) = foo + class Foo(object): + pass + + class Bar(object): + pass + + def test_fn(): + a, (b, c) = (Foo(), (Bar(), Foo())) return a, b, c - node = self._parse_and_analyze(test_fn, {'foo': (1, 2, 3)}) - lhs = node.body[0].body[1].value.elts - a = lhs[0] - b = lhs[1] - c = lhs[2] - # TODO(mdan): change these once we have the live values propagating - # correctly + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) + a, b, c = node.body[0].body[1].value.elts + self.assertEquals(Foo, anno.getanno(a, 'type')) + self.assertEquals(Bar, anno.getanno(b, 'type')) + self.assertEquals(Foo, anno.getanno(c, 'type')) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) + def test_inner_scope(self): + + def test_fn(): + a = [] + utils.set_element_type(a, 1) + for _ in a: + b = [] + utils.set_element_type(b, 2) + return a, b + + node = self._parse_and_analyze(test_fn, {'utils': utils}) + a, b = node.body[0].body[2].body[2].value.elts + self.assertEquals(1, anno.getanno(a, 'element_type')) + self.assertEquals(2, anno.getanno(b, 'element_type')) + self.assertFalse(anno.hasanno(a, 'type')) + self.assertFalse(anno.hasanno(b, 'type')) + self.assertFalse(anno.hasanno(a, 'live_val')) + self.assertFalse(anno.hasanno(b, 'live_val')) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index b38d52c5b2c..4db6cc0adfa 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -40,7 +40,13 @@ def try_ast_to_source(node): class Base(gast.NodeTransformer): - """Base class for specialized transformers.""" + """Base class for specialized transformers. + + Scope-local state tracking: to keep state across nodes, at the level of + (possibly nested) scopes, use enter/exit_local_scope and set/get_local. + You must call enter/exit_local_scope manually, but the transformer detects + when they are not properly paired. + """ def __init__(self, context): """Initialize the transformer. Subclasses should call this. @@ -53,20 +59,55 @@ class Base(gast.NodeTransformer): self.context = context self._enclosing_entities = [] + # A stack that allows keeping mutable, scope-local state where scopes may be + # nested. For example, it can be used to track the usage of break + # statements in each loop, where loops may be nested. + self._local_scope_state = [] + self.enter_local_scope() + @property def enclosing_entities(self): return tuple(self._enclosing_entities) + @property + def locel_scope_level(self): + return len(self._local_scope_state) + + def enter_local_scope(self): + self._local_scope_state.append({}) + + def exit_local_scope(self): + return self._local_scope_state.pop() + + def set_local(self, name, value): + self._local_scope_state[-1][name] = value + + def get_local(self, name, default=None): + return self._local_scope_state[-1].get(name, default) + def debug_print(self, node): """Helper method useful for debugging.""" if __debug__: print(pretty_printer.fmt(node)) return node + def visit_block(self, nodes): + """Helper equivalent to generic_visit, but for node lists.""" + results = [] + for node in nodes: + replacement = self.visit(node) + if replacement: + if isinstance(replacement, (list, tuple)): + results.extend(replacement) + else: + results.append(replacement) + return results + def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file did_enter_function = False + local_scope_state_size = len(self._local_scope_state) try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): @@ -97,3 +138,10 @@ class Base(gast.NodeTransformer): finally: if did_enter_function: self._enclosing_entities.pop() + + if local_scope_state_size != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index 57f1c31ef65..f96b0dc3775 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -27,6 +27,17 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): + def _context_for_nodetesting(self): + return context.EntityContext( + namer=None, + source_code=None, + source_file=None, + namespace=None, + arg_values=None, + arg_types=None, + owner_type=None, + recursive=False) + def test_entity_scope_tracking(self): class TestTransformer(transformer.Base): @@ -42,16 +53,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer( - context.EntityContext( - namer=None, - source_code=None, - source_file=None, - namespace=None, - arg_values=None, - arg_types=None, - owner_type=None, - recursive=False)) + tr = TestTransformer(self._context_for_nodetesting()) def test_function(): a = 0 @@ -92,6 +94,86 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) + def test_statement_info_stack(self): + + class TestTransformer(transformer.Base): + + # Extract all string constants from the block. + def visit_Str(self, node): + self.set_local('string', self.get_local('string', default='') + node.s) + return self.generic_visit(node) + + def _annotate_result(self, node): + self.enter_local_scope() + node = self.generic_visit(node) + anno.setanno(node, 'test', self.get_local('string')) + self.exit_local_scope() + return node + + def visit_While(self, node): + return self._annotate_result(node) + + def visit_For(self, node): + return self._annotate_result(node) + + tr = TestTransformer(self._context_for_nodetesting()) + + def test_function(a): + """Docstring.""" + assert a == 'This should not be counted' + for i in range(3): + _ = 'a' + if i > 2: + return 'b' + else: + _ = 'c' + while True: + raise '1' + return 'nor this' + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + + for_node = node.body[0].body[2] + while_node = for_node.body[1].orelse[1] + + self.assertFalse(anno.hasanno(for_node, 'string')) + self.assertEqual('abc', anno.getanno(for_node, 'test')) + self.assertFalse(anno.hasanno(while_node, 'string')) + self.assertEqual('1', anno.getanno(while_node, 'test')) + + def test_statement_info_stack_checks_integrity(self): + + class TestTransformer(transformer.Base): + + def visit_If(self, node): + self.enter_local_scope() + return self.generic_visit(node) + + def visit_For(self, node): + node = self.generic_visit(node) + self.exit_local_scope() + return node + + tr = TestTransformer(self._context_for_nodetesting()) + + def no_exit(a): + if a > 0: + print(a) + return None + + node, _ = parser.parse_entity(no_exit) + with self.assertRaises(AssertionError): + tr.visit(node) + + def no_entry(a): + for _ in a: + print(a) + + node, _ = parser.parse_entity(no_entry) + with self.assertRaises(AssertionError): + tr.visit(node) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 0a0e72d70e9..211e8eaee90 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -28,24 +28,17 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops -from tensorflow.python.util import tf_inspect def dynamic_builtin(f, *args, **kwargs): """Converts a builtin function call inline.""" - # Some built-ins may be objects. - if not tf_inspect.isbuiltin(f) and f not in (range,): - return f(*args, **kwargs) - if f is len: return dynamic_len(*args, **kwargs) if six.PY2 and f is xrange: return dynamic_range(*args, **kwargs) if f is range: return dynamic_range(*args, **kwargs) - - raise NotImplementedError( - 'The "%s" builtin is not yet supported.' % f.__name__) + raise ValueError('%s is not supported' % f) def dynamic_len(list_or_tensor): @@ -98,9 +91,15 @@ def dynamic_print(*values): if all(map(is_tf_print_compatible, values)): return logging_ops.Print(1, values) - def flushed_print(*vals): + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) print(*vals) + # The flush helps avoid garbled output in IPython. sys.stdout.flush() return py_func.wrap_py_func( - flushed_print, None, values, use_dummy_return=True) + print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index d9f7913d89a..163e6984079 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -76,8 +76,9 @@ class BuiltinsTest(test.TestCase): def range(x): # pylint:disable=redefined-builtin return x - # Functions that just have the names of builtins are ignored. - self.assertEqual(builtins.dynamic_builtin(range, 1), 1) + # Functions that just have the names of builtins are rejected. + with self.assertRaises(ValueError): + self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index fac7aff29f7..e22f978dde6 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -250,7 +250,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" with self.test_session() as sess: - inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, batch_timeout_micros=36000000, grad_timeout_micros=1000000, diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py index d193a8459d0..032b859d469 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo_impl.py @@ -44,15 +44,13 @@ def expectation_importance_sampler(f, n=None, seed=None, name='expectation_importance_sampler'): - r"""Monte Carlo estimate of `\\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\)`. + r"""Monte Carlo estimate of \\(E_p[f(Z)] = E_q[f(Z) p(Z) / q(Z)]\\). - With `\\(p(z) := exp^{log_p(z)}\\)`, this `Op` returns + With \\(p(z) := exp^{log_p(z)}\\), this `Op` returns - ``` \\(n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ], z_i ~ q,\\) \\(\approx E_q[ f(Z) p(Z) / q(Z) ]\\) \\(= E_p[f(Z)]\\) - ``` This integral is done in log-space with max-subtraction to better handle the often extreme values that `f(z) p(z) / q(z)` can take on. @@ -121,14 +119,12 @@ def expectation_importance_sampler_logspace( name='expectation_importance_sampler_logspace'): r"""Importance sampling with a positive function, in log-space. - With `\\(p(z) := exp^{log_p(z)}\\)`, and `\\(f(z) = exp{log_f(z)}\\)`, + With \\(p(z) := exp^{log_p(z)}\\), and \\(f(z) = exp{log_f(z)}\\), this `Op` returns - ``` \\(Log[ n^{-1} sum_{i=1}^n [ f(z_i) p(z_i) / q(z_i) ] ], z_i ~ q,\\) \\(\approx Log[ E_q[ f(Z) p(Z) / q(Z) ] ]\\) \\(= Log[E_p[f(Z)]]\\) - ``` This integral is done in log-space with max-subtraction to better handle the often extreme values that `f(z) p(z) / q(z)` can take on. @@ -196,13 +192,11 @@ def _logspace_mean(log_values): def expectation(f, samples, log_prob=None, use_reparametrization=True, axis=0, keep_dims=False, name=None): - """Computes the Monte-Carlo approximation of `\\(E_p[f(X)]\\)`. + """Computes the Monte-Carlo approximation of \\(E_p[f(X)]\\). This function computes the Monte-Carlo approximation of an expectation, i.e., - ```none \\(E_p[f(X)] \approx= m^{-1} sum_i^m f(x_j), x_j\ ~iid\ p(X)\\) - ``` where: @@ -216,8 +210,8 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, parameterless distribution (e.g., `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and expectation, i.e., - `grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n }` where - `S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\)`. + grad[ Avg{ \\(s_i : i=1...n\\) } ] = Avg{ grad[\\(s_i\\)] : i=1...n } where + S_n = Avg{\\(s_i\\)}` and `\\(s_i = f(x_i), x_i ~ p\\). However, if p is not reparameterized, TensorFlow's gradient will be incorrect since the chain-rule stops at samples of non-reparameterized distributions. @@ -296,7 +290,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Args: f: Python callable which can return `f(samples)`. samples: `Tensor` of samples used to form the Monte-Carlo approximation of - `\\(E_p[f(X)]\\)`. A batch of samples should be indexed by `axis` + \\(E_p[f(X)]\\). A batch of samples should be indexed by `axis` dimensions. log_prob: Python callable which can return `log_prob(samples)`. Must correspond to the natural-logarithm of the pdf/pmf of each sample. Only @@ -317,7 +311,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, Returns: approx_expectation: `Tensor` corresponding to the Monte-Carlo approximation - of `\\(E_p[f(X)]\\)`. + of \\(E_p[f(X)]\\). Raises: ValueError: if `f` is not a Python `callable`. @@ -329,7 +323,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, if not callable(f): raise ValueError('`f` must be a callable function.') if use_reparametrization: - return math_ops.reduce_mean(f(samples), axis=axis, keep_dims=keep_dims) + return math_ops.reduce_mean(f(samples), axis=axis, keepdims=keep_dims) else: if not callable(log_prob): raise ValueError('`log_prob` must be a callable function.') @@ -349,7 +343,7 @@ def expectation(f, samples, log_prob=None, use_reparametrization=True, # "Is there a floating point value of x, for which x-x == 0 is false?" # http://stackoverflow.com/q/2686644 fx += stop(fx) * (logpx - stop(logpx)) # Add zeros_like(logpx). - return math_ops.reduce_mean(fx, axis=axis, keep_dims=keep_dims) + return math_ops.reduce_mean(fx, axis=axis, keepdims=keep_dims) def _sample_mean(values): diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 44a8ffaf4b2..04e32267cc4 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -422,6 +422,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } present_gradient_stats *= normalizer_ratio; + GradientStats not_present = + root_gradient_stats - present_gradient_stats; + // If there was (almost) no sparsity, fix the default direction to LEFT. + bool fixed_default_direction = not_present.IsAlmostZero(); GradientStats left_gradient_stats; for (int64 element_idx = start_index; element_idx < end_index; @@ -441,6 +445,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // backward pass gradients. GradientStats right_gradient_stats = present_gradient_stats - left_gradient_stats; + { NodeStats left_stats_default_left = ComputeNodeStats(root_gradient_stats - right_gradient_stats); @@ -457,7 +462,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { best_dimension_idx = dimension_id; } } - { + // Consider calculating the default direction only when there were + // enough missing examples. + if (!fixed_default_direction) { NodeStats left_stats_default_right = ComputeNodeStats(left_gradient_stats); NodeStats right_stats_default_right = diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 7df514cd207..f06b73c00d0 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -417,9 +417,18 @@ class SparseSplitHandler(InequalitySplitHandler): return (are_splits_ready, partition_ids, gains, split_infos) -@function.Defun(dtypes.bool, dtypes.bool, dtypes.float32, dtypes.float32, - dtypes.int32, dtypes.float32, dtypes.float32, dtypes.float32, - dtypes.float32, dtypes.float32) +@function.Defun( + dtypes.bool, + dtypes.bool, + dtypes.float32, + dtypes.float32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantile_buckets, example_partition_ids, gradients, hessians, weights, empty_gradients, empty_hessians): @@ -452,9 +461,20 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, gradients, hessians) -@function.Defun(dtypes.bool, dtypes.bool, dtypes.int64, dtypes.float32, - dtypes.int64, dtypes.float32, dtypes.int32, dtypes.float32, - dtypes.float32, dtypes.float32, dtypes.float32, dtypes.float32) +@function.Defun( + dtypes.bool, + dtypes.bool, + dtypes.int64, + dtypes.float32, + dtypes.int64, + dtypes.float32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) def sparse_make_stats_update( is_active, are_buckets_ready, sparse_column_indices, sparse_column_values, sparse_column_shape, quantile_buckets, example_partition_ids, gradients, @@ -481,11 +501,18 @@ def sparse_make_stats_update( example_partition_ids) # Compute aggregate stats for each partition. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. bias_feature_ids = array_ops.fill( diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc index 4481c0d0e44..67ac9bf387a 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream_test.cc @@ -138,6 +138,12 @@ void GenerateOneValue(int32 worker_id, int64 max_elements, double *total_weight, stream->Finalize(); } +void GenerateOneZeroWeightedValue(int32 worker_id, int64 max_elements, + double *total_weight, Stream *stream) { + stream->PushEntry(10, 0); + stream->Finalize(); +} + TEST(WeightedQuantilesStreamTest, OneValue) { const double eps = 0.01; const int64 max_elements = 1 << 16; @@ -145,6 +151,13 @@ TEST(WeightedQuantilesStreamTest, OneValue) { {10.0, 10.0, 10.0, 10.0, 10.0}, 1e-2); } +TEST(WeightedQuantilesStreamTest, OneZeroWeightValue) { + const double eps = 0.01; + const int64 max_elements = 1 << 16; + TestSingleWorkerStreams(eps, max_elements, GenerateOneZeroWeightedValue, {}, + 1e-2); +} + TEST(WeightedQuantilesStreamTest, FixedUniform) { const double eps = 0.01; const int64 max_elements = 1 << 16; diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index aec232f3cbb..7576856dc3a 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -235,6 +235,11 @@ class WeightedQuantilesSummary { // The resulting boundaries are guaranteed to both contain at least // num_boundaries unique elements and maintain approximation bounds. std::vector GenerateBoundaries(int64 num_boundaries) const { + std::vector output; + if (entries_.empty()) { + return output; + } + // Generate soft compressed summary. WeightedQuantilesSummary compressed_summary; @@ -246,7 +251,6 @@ class WeightedQuantilesSummary { compressed_summary.Compress(num_boundaries, compression_eps); // Return boundaries. - std::vector output; output.reserve(compressed_summary.entries_.size()); for (const auto& entry : compressed_summary.entries_) { output.push_back(entry.value); @@ -260,6 +264,9 @@ class WeightedQuantilesSummary { // full rank queries O(nlogn). std::vector GenerateQuantiles(int64 num_quantiles) const { std::vector output; + if (entries_.empty()) { + return output; + } num_quantiles = std::max(num_quantiles, 2LL); output.reserve(num_quantiles + 1); diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index 28834ef55bf..5cd37ec67ec 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops @@ -399,6 +401,65 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose(0.6, split_node.split.threshold) + def testMakeSparseSplitDefaultDirectionIsStable(self): + """Tests default direction is stable when no sparsity.""" + random.seed(1123) + for _ in range(50): + with self.test_session() as sess: + grad = random.random() + hessian = random.random() + # The data looks like the following (divide by the num of steps 2). + # Gradients | Partition | bucket ID | + # (grad, hessian) | 0 | -1 | + # And then 100 buckets of + # (grad/100, hessian/100), so there is no sparsity. + n_buckets = 100 + + # 1 for the overall sum, and 100 buckets. + partition_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + + bucket_ids = [-1] + [n for n in range(100)] + bucket_ids = array_ops.constant(bucket_ids, dtype=dtypes.int64) + dimension_ids = array_ops.constant( + [0] * (n_buckets + 1), dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = [grad] + [grad / n_buckets] * n_buckets + gradients = array_ops.constant(gradients) + hessians = [hessian] + [hessian / n_buckets] * n_buckets + hessians = array_ops.constant(hessians) + + boundaries = [x * 1 for x in range(n_buckets + 1)] + bucket_boundaries = array_ops.constant(boundaries, dtype=dtypes.float32) + + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0], partitions) + self.assertEqual(1, len(splits)) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + self.assertTrue( + split_info.split_node.HasField( + 'sparse_float_binary_split_default_left')) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 1b184d296b3..50cc00afdcc 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -187,7 +187,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token: Expected current token. next_stamp_token: Next value for the token. Returns: - A list of quantiles or approximate boundaries. + The flush operation. """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 4bde7f3e33d..e53d86ec612 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -369,7 +369,7 @@ class GradientBoostedDecisionTreeModel(object): Returns: a dictionary of prediction results - ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, - NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. + NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPTED. """ ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) @@ -970,10 +970,8 @@ class GradientBoostedDecisionTreeModel(object): # Stack all the inputs to one tensor per type. # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). - split_sizes = array_ops.stack([ - array_ops.shape(partition_id)[0] - for partition_id in partition_ids_list - ]) + split_sizes = array_ops.reshape( + array_ops.shape_n(partition_ids_list), [-1]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 17dcb49f476..f9c22283b7f 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -45,7 +45,7 @@ from tensorflow.python.platform import googletest def _squared_loss(label, unused_weights, predictions): """Unweighted loss implementation.""" loss = math_ops.reduce_sum( - math_ops.square(predictions - label), 1, keep_dims=True) + math_ops.square(predictions - label), 1, keepdims=True) return loss diff --git a/tensorflow/contrib/checkpoint/README.md b/tensorflow/contrib/checkpoint/README.md new file mode 100644 index 00000000000..d35c5bae3b7 --- /dev/null +++ b/tensorflow/contrib/checkpoint/README.md @@ -0,0 +1,2 @@ +Tools for working with object-based checkpoints produced by +`tf.train.Checkpoint`. diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py new file mode 100644 index 00000000000..af8df72618b --- /dev/null +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tools for working with object-based checkpoints. + +Visualization and inspection: +@@dot_graph_from_checkpoint +@@object_metadata + +Creating and managing dependencies: +@@Checkpointable +@@CheckpointableObjectGraph +@@NoDependency +@@split_dependency +@@UniqueNameTracker +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker +from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency +from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint +from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.util import object_metadata + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(module_name=__name__) diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD new file mode 100644 index 00000000000..53f4e97f993 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -0,0 +1,84 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "checkpoint", + srcs_version = "PY2AND3", + deps = [ + ":containers", + ":split_dependency", + ":visualize", + ], +) + +py_library( + name = "containers", + srcs = ["containers.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = ["//tensorflow/python/training/checkpointable:base"], +) + +py_test( + name = "containers_test", + srcs = ["containers_test.py"], + deps = [ + ":containers", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", + "@six_archive//:six", + ], +) + +py_library( + name = "split_dependency", + srcs = ["split_dependency.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "split_dependency_test", + srcs = ["split_dependency_test.py"], + deps = [ + ":split_dependency", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "visualize", + srcs = ["visualize.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:pywrap_tensorflow", + ], +) + +py_test( + name = "visualize_test", + srcs = ["visualize_test.py"], + deps = [ + ":visualize", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + ], +) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py new file mode 100644 index 00000000000..9807abae1f5 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -0,0 +1,77 @@ +"""Checkpointable data structures.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.training.checkpointable import base as checkpointable_lib + + +class UniqueNameTracker(checkpointable_lib.CheckpointableBase): + """Adds dependencies on checkpointable objects with name hints. + + Useful for creating dependencies with locally unique names. + + Example usage: + ```python + class SlotManager(tf.contrib.checkpoint.Checkpointable): + + def __init__(self): + # Create a dependency named "slotdeps" on the container. + self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tfe.Variable(4.), "y")) + slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + ``` + """ + + def __init__(self): + self._maybe_initialize_checkpointable() + self._name_counts = {} + + def track(self, checkpointable, base_name): + """Add a dependency on `checkpointable`. + + Args: + checkpointable: An object to add a checkpoint dependency on. + base_name: A name hint, which is uniquified to determine the dependency + name. + Returns: + `checkpointable`, for chaining. + Raises: + ValueError: If `checkpointable` is not a checkpointable object. + """ + + if not isinstance(checkpointable, checkpointable_lib.CheckpointableBase): + raise ValueError( + ("Expected a checkpointable value, got %s which does not inherit " + "from CheckpointableBase.") % (checkpointable,)) + + def _format_name(prefix, number): + if number > 0: + return "%s_%d" % (prefix, number) + else: + return prefix + + count = self._name_counts.get(base_name, 0) + candidate = _format_name(base_name, count) + while self._lookup_dependency(candidate) is not None: + count += 1 + candidate = _format_name(base_name, count) + self._name_counts[base_name] = count + 1 + return self._track_checkpointable(checkpointable, name=candidate) diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py new file mode 100644 index 00000000000..851a8005885 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import six + +from tensorflow.contrib.checkpoint.python import containers +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +class UniqueNameTrackerTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testNames(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + + x1 = resource_variable_ops.ResourceVariable(2.) + x2 = resource_variable_ops.ResourceVariable(3.) + x3 = resource_variable_ops.ResourceVariable(4.) + y = resource_variable_ops.ResourceVariable(5.) + slots = containers.UniqueNameTracker() + slots.track(x1, "x") + slots.track(x2, "x") + slots.track(x3, "x_1") + slots.track(y, "y") + self.evaluate((x1.initializer, x2.initializer, x3.initializer, + y.initializer)) + save_root = checkpointable_utils.Checkpoint(slots=slots) + save_path = save_root.save(checkpoint_prefix) + + restore_slots = checkpointable.Checkpointable() + restore_root = checkpointable_utils.Checkpoint( + slots=restore_slots) + status = restore_root.restore(save_path) + restore_slots.x = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.x_1_1 = resource_variable_ops.ResourceVariable(0.) + restore_slots.y = resource_variable_ops.ResourceVariable(0.) + status.assert_consumed().run_restore_ops() + self.assertEqual(2., self.evaluate(restore_slots.x)) + self.assertEqual(3., self.evaluate(restore_slots.x_1)) + self.assertEqual(4., self.evaluate(restore_slots.x_1_1)) + self.assertEqual(5., self.evaluate(restore_slots.y)) + + @test_util.run_in_graph_and_eager_modes() + def testExample(self): + class SlotManager(checkpointable.Checkpointable): + + def __init__(self): + self.slotdeps = containers.UniqueNameTracker() + slotdeps = self.slotdeps + slots = [] + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(3.), "x")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(4.), "y")) + slots.append(slotdeps.track( + resource_variable_ops.ResourceVariable(5.), "x")) + self.slots = slots + + manager = SlotManager() + self.evaluate([v.initializer for v in manager.slots]) + checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpoint.save(checkpoint_prefix) + metadata = checkpointable_utils.object_metadata(save_path) + dependency_names = [] + for node in metadata.nodes: + for child in node.children: + dependency_names.append(child.local_name) + six.assertCountEqual( + self, + dependency_names, + ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/checkpoint/python/split_dependency.py similarity index 95% rename from tensorflow/contrib/eager/python/checkpointable_utils.py rename to tensorflow/contrib/checkpoint/python/split_dependency.py index 30c4103c5aa..7e77453f3d8 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency.py @@ -1,4 +1,4 @@ -"""Utilities for working with Checkpointable objects.""" +"""Utility for creating multiple dependencies with synchronized save/restore.""" # Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,8 +20,8 @@ from __future__ import print_function import functools from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable as core_checkpointable from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import base as checkpointable class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): @@ -43,7 +43,7 @@ class _CallbackSaveable(saver_lib.BaseSaverBuilder.SaveableObject): return self._restore_callback(tensor) -class _SplitDependency(core_checkpointable.CheckpointableBase): +class _SplitDependency(checkpointable.CheckpointableBase): """Looks like a regular variable while synchronizing save/restores.""" def __init__(self, save_buffer, restore_buffer, name, dtype, num_components, @@ -83,7 +83,7 @@ class _SplitDependency(core_checkpointable.CheckpointableBase): def _gather_saveables_for_checkpoint(self): """Looks to Checkpointable like a regular variable.""" return { - core_checkpointable.VARIABLE_VALUE_KEY: + checkpointable.VARIABLE_VALUE_KEY: functools.partial(_CallbackSaveable, dtype=self._dtype, save_callback=self._save, diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py similarity index 91% rename from tensorflow/contrib/eager/python/checkpointable_utils_test.py rename to tensorflow/contrib/checkpoint/python/split_dependency_test.py index da04199aaad..69dc0b9be2d 100644 --- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py +++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py @@ -18,13 +18,13 @@ from __future__ import print_function import os -from tensorflow.contrib.eager.python import checkpointable_utils as contrib_checkpointable_utils +from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.python.eager import test from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import checkpointable -from tensorflow.python.training import checkpointable_utils +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils def _split_variable_closure(variable): @@ -47,7 +47,7 @@ class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase): def __init__(self): self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.]) - split_dependencies = contrib_checkpointable_utils.split_dependency( + split_dependencies = split_dependency.split_dependency( component_names=("first_half", "second_half"), component_dtypes=(self.combined.dtype,) * 2, fill_save_buffer_fn=_split_variable_closure( @@ -73,7 +73,7 @@ class OnlyOneDep(checkpointable.Checkpointable): class SplitTests(test.TestCase): - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + @test_util.run_in_graph_and_eager_modes() def testSaveRestoreSplitDep(self): save_checkpoint = checkpointable_utils.Checkpoint( dep=SaveTensorSlicesAsDeps()) diff --git a/tensorflow/contrib/checkpoint/python/visualize.py b/tensorflow/contrib/checkpoint/python/visualize.py new file mode 100644 index 00000000000..bac071c4cff --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/visualize.py @@ -0,0 +1,99 @@ +"""Utilities for visualizing dependency graphs.""" +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import util as checkpointable_utils + + +def dot_graph_from_checkpoint(save_path): + r"""Visualizes an object-based checkpoint (from `tf.train.Checkpoint`). + + Useful for inspecting checkpoints and debugging loading issues. + + Example usage from Python (requires pydot): + ```python + import tensorflow as tf + import pydot + + dot_string = tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt') + parsed, = pydot.graph_from_dot_data(dot_string) + parsed.write_svg('/tmp/tensorflow/visualized_checkpoint.svg') + ``` + + Example command line usage: + ```sh + python -c "import tensorflow as tf;\ + print(tf.contrib.checkpoint.dot_graph_from_checkpoint('/path/to/ckpt'))"\ + | dot -Tsvg > /tmp/tensorflow/checkpoint_viz.svg + ``` + + Args: + save_path: The checkpoint prefix, as returned by `tf.train.Checkpoint.save` + or `tf.train.latest_checkpoint`. + Returns: + A graph in DOT format as a string. + """ + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + object_graph = checkpointable_utils.object_metadata(save_path) + shape_map = reader.get_variable_to_shape_map() + dtype_map = reader.get_variable_to_dtype_map() + graph = 'digraph {\n' + def _escape(name): + return name.replace('"', '\\"') + slot_ids = set() + for node in object_graph.nodes: + for slot_reference in node.slot_variables: + slot_ids.add(slot_reference.slot_variable_node_id) + for node_id, node in enumerate(object_graph.nodes): + if (len(node.attributes) == 1 + and node.attributes[0].name == checkpointable.VARIABLE_VALUE_KEY): + if node_id in slot_ids: + color = 'orange' + tooltip_prefix = 'Slot variable' + else: + color = 'blue' + tooltip_prefix = 'Variable' + attribute = node.attributes[0] + graph += ('N_%d [shape=point label="" color=%s width=.25' + ' tooltip="%s %s shape=%s %s"]\n') % ( + node_id, + color, + tooltip_prefix, + _escape(attribute.full_name), + shape_map[attribute.checkpoint_key], + dtype_map[attribute.checkpoint_key].name) + elif node.slot_variables: + graph += ('N_%d [shape=point label="" width=.25 color=red,' + 'tooltip="Optimizer"]\n') % node_id + else: + graph += 'N_%d [shape=point label="" width=.25]\n' % node_id + for reference in node.children: + graph += 'N_%d -> N_%d [label="%s"]\n' % ( + node_id, reference.node_id, _escape(reference.local_name)) + for slot_reference in node.slot_variables: + graph += 'N_%d -> N_%d [label="%s" style=dotted]\n' % ( + node_id, + slot_reference.slot_variable_node_id, + _escape(slot_reference.slot_name)) + graph += 'N_%d -> N_%d [style=dotted]\n' % ( + slot_reference.original_variable_node_id, + slot_reference.slot_variable_node_id) + graph += '}\n' + return graph diff --git a/tensorflow/contrib/checkpoint/python/visualize_test.py b/tensorflow/contrib/checkpoint/python/visualize_test.py new file mode 100644 index 00000000000..583e3bc4428 --- /dev/null +++ b/tensorflow/contrib/checkpoint/python/visualize_test.py @@ -0,0 +1,97 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +from tensorflow.contrib.checkpoint.python import visualize + +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.keras.engine import training +from tensorflow.python.keras.layers import core +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.training import adam +from tensorflow.python.training.checkpointable import util as checkpointable_utils + +try: + import pydot # pylint: disable=g-import-not-at-top +except ImportError: + pydot = None + + +class MyModel(training.Model): + """A concrete Model for testing.""" + + def __init__(self): + super(MyModel, self).__init__() + self._named_dense = core.Dense(1, use_bias=True) + self._second = core.Dense(1, use_bias=False) + + def call(self, values): + ret = self._second(self._named_dense(values)) + return ret + + +class DotGraphTests(test.TestCase): + + def testMakeDotGraph(self): + with context.eager_mode(): + input_value = constant_op.constant([[3.]]) + model = MyModel() + optimizer = adam.AdamOptimizer(0.001) + optimizer_step = resource_variable_ops.ResourceVariable(12) + save_checkpoint = checkpointable_utils.Checkpoint( + optimizer=optimizer, model=model, optimizer_step=optimizer_step) + optimizer.minimize(functools.partial(model, input_value)) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') + save_path = save_checkpoint.save(checkpoint_prefix) + prefix = save_checkpoint.save(save_path) + + dot_graph_string = visualize.dot_graph_from_checkpoint(prefix) + + # The remainder of this test is more-or-less optional since it's so + # dependent on pydot/platform/Python versions. + if pydot is None: + self.skipTest('pydot is required for the remainder of this test.') + try: + parsed, = pydot.graph_from_dot_data(dot_graph_string) + except NameError as e: + if "name 'dot_parser' is not defined" in str(e): + self.skipTest("pydot isn't working") + else: + raise + # Check that the graph isn't completely trivial + self.assertEqual( + '"model"', + parsed.obj_dict['edges'][('N_0', 'N_1')][0]['attributes']['label']) + image_path = os.path.join(self.get_temp_dir(), 'saved.svg') + try: + parsed.write_svg(image_path) + except Exception as e: # pylint: disable=broad-except + # For some reason PyDot's "dot not available" error is an Exception, not + # something more specific. + if '"dot" not found in path' in str(e): + self.skipTest("pydot won't save SVGs (dot not available)") + else: + raise + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 5a2771229d9..8ede28602fd 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,7 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_DEFAULT_ENV_VARIABLE = 'TPU_NAME' class TPUClusterResolver(ClusterResolver): @@ -70,6 +71,12 @@ class TPUClusterResolver(ClusterResolver): def _gkeMaster(): return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + @staticmethod + def _envVarFallback(): + if _DEFAULT_ENV_VARIABLE in os.environ: + return os.environ[_DEFAULT_ENV_VARIABLE] + return None + def __init__(self, tpu=None, zone=None, @@ -123,8 +130,11 @@ class TPUClusterResolver(ClusterResolver): in_gke = self._inGke() # When using GKE with Cloud TPUs, the env variable will be set. - if tpu is None and in_gke: - tpu = self._gkeMaster() + if tpu is None: + if in_gke: + tpu = self._gkeMaster() + else: + tpu = self._envVarFallback() self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name @@ -245,7 +255,7 @@ class TPUClusterResolver(ClusterResolver): else: if not self._tpu.startswith(compat.as_bytes('grpc://')): # Case 3. - return server_lib.ClusterSpec({}) + return None # Case 2. cluster_spec = {self._job_name: [self._tpu[len( compat.as_bytes('grpc://')):]]} diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index dff7a03b684..5b3f9be5a11 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -356,8 +356,7 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') self.assertEqual( compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) - self.assertEqual( - server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec()) + self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) def testGkeEnvironment(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index a7944ea74ae..0708d6b7b9f 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -31,10 +31,14 @@ option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF) option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF) option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) -option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) +# SIMD, MKL and MKLDNN options +option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions" OFF) +option(tensorflow_ENABLE_MKL_SUPPORT "Enable Intel MKL support" OFF) +option(tensorflow_ENABLE_MKLDNN_SUPPORT "Enable Intel MKLDNN support, requires MKL enabled" OFF) + # GPU, CUDA and cuDNN options option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") @@ -80,7 +84,7 @@ if (NOT WIN32) option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) if (systemlib_ALL) - set (systmelib_ZLIB ON) + set (systemlib_ZLIB ON) endif (systemlib_ALL) endif() @@ -124,8 +128,16 @@ endif() add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) + if(CMAKE_SIZEOF_VOID_P EQUAL 8) + # 64 bits + add_definitions(-DWIN64) + elseif(CMAKE_SIZEOF_VOID_P EQUAL 4) + # 32 bits + # temporary fix for #18241 + add_definitions(-DEIGEN_DEFAULT_DENSE_INDEX_TYPE=std::int64_t) + endif() add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11) - add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN64 -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) + add_definitions(-DWIN32 -DOS_WIN -D_MBCS -DWIN32_LEAN_AND_MEAN -DNOGDI -DPLATFORM_WINDOWS) add_definitions(-DTENSORFLOW_USE_EIGEN_THREADPOOL -DEIGEN_HAS_C99_MATH) add_definitions(-DTF_COMPILE_LIBRARY) add_definitions(/bigobj /nologo /EHsc /GF /MP /Gm-) @@ -160,14 +172,24 @@ if (tensorflow_OPTIMIZE_FOR_NATIVE_ARCH) endif() endif() +include(CheckCXXCompilerFlag) + +# OpenMP Support +CHECK_CXX_COMPILER_FLAG("-fopenmp" GCC_OPENMP_SUPPORT) +if (GCC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +endif() +CHECK_CXX_COMPILER_FLAG("/openmp" MSVC_OPENMP_SUPPORT) +if (MSVC_OPENMP_SUPPORT) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /openmp") +endif() + # MSVC SIMD instructions if (tensorflow_WIN_CPU_SIMD_OPTIONS) if (WIN32) - CHECK_CXX_COMPILER_FLAG("${tensorflow_WIN_CPU_SIMD_OPTIONS}" COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) + CHECK_CXX_COMPILER_FLAG(${tensorflow_WIN_CPU_SIMD_OPTIONS} COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) if(COMPILER_OPT_WIN_CPU_SIMD_SUPPORTED) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${tensorflow_WIN_CPU_SIMD_OPTIONS}") - else() - message(FATAL_ERROR "${tensorflow_WIN_CPU_SIMD_OPTIONS} not supported") endif() endif() endif() @@ -193,6 +215,7 @@ include(protobuf) include(re2) include(cub) include(sqlite) +include(double_conversion) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() @@ -213,6 +236,7 @@ set(tensorflow_EXTERNAL_LIBRARIES ${protobuf_STATIC_LIBRARIES} ${re2_STATIC_LIBRARIES} ${sqlite_STATIC_LIBRARIES} + ${double_conversion_STATIC_LIBRARIES} ) if (systemlib_ZLIB) @@ -240,6 +264,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES fft2d re2 sqlite_copy_headers_to_destination + double_conversion ) include_directories( @@ -262,6 +287,7 @@ include_directories( ${PROTOBUF_INCLUDE_DIRS} ${re2_INCLUDE_DIR} ${sqlite_INCLUDE_DIR} + ${double_conversion_INCLUDE_DIR} ) if(tensorflow_ENABLE_SSL_SUPPORT) @@ -298,6 +324,49 @@ if(HAIKU) list(APPEND tensorflow_EXTERNAL_LIBRARIES network) endif() +# MKL Support +if (tensorflow_ENABLE_MKL_SUPPORT) + add_definitions(-DINTEL_MKL -DEIGEN_USE_VML) + if (WIN32) + find_path(MKL_HOME_PLATFORM mkl + PATHS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ + PATH_SUFFIXES windows) + set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) + set(MKL_LINK_DIRS + ${MKL_HOME_PLATFORM}/mkl/lib/intel64 + ${MKL_HOME_PLATFORM}/tbb/lib/intel64/vc_mt + ${MKL_HOME_PLATFORM}/compiler/lib/intel64 + ${MKL_HOME_PLATFORM}/mkl/tools/builder/lib) + set(MKL_REDIST_DLL_DIRS + ${MKL_HOME_PLATFORM}/redist/intel64/mkl + ${MKL_HOME_PLATFORM}/redist/intel64/tbb/vc_mt + ${MKL_HOME_PLATFORM}/redist/intel64/compiler) + list(APPEND tensorflow_EXTERNAL_LIBRARIES + mkl_intel_lp64_dll mkl_sequential_dll mkl_core_dll mkl_rt mkl_cdll_intel64) + endif() + if (UNIX) + # Fix me: complete the path on linux + find_path(MKL_HOME_PLATFORM mkl + HINTS ${MKL_HOME} ${MKL_HOME}/../ ${MKL_HOME}/../../ + $ENV{MKLROOT} $ENV{MKLROOT}/../ $ENV{MKLROOT}/../../ + PATH_SUFFIXES linux) + set(MKL_INCLUDE_DIRS ${MKL_HOME_PLATFORM}/mkl/include) + set(MKL_LINK_DIRS) # incompleted + set(MKL_REDIST_SO_DIRS) # incompleted + endif() + include_directories(${MKL_INCLUDE_DIRS}) + link_directories(${MKL_LINK_DIRS}) + if (tensorflow_ENABLE_MKLDNN_SUPPORT) + include(mkldnn) + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES}) + list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn) + include_directories(${mkldnn_INCLUDE_DIRS}) + else (tensorflow_ENABLE_MKLDNN_SUPPORT) + add_definitions(-DINTEL_MKL_ML) + endif() +endif (tensorflow_ENABLE_MKL_SUPPORT) + if (tensorflow_ENABLE_GPU) if (NOT WIN32) # Default install paths for cuda libraries in Linux @@ -409,6 +478,10 @@ if (tensorflow_ENABLE_GPU) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CUDA_LIBRARIES}) + if(NOT WIN32) + # add gomp to tensorflow_EXTERNAL_LIBRARIES, needed by libcusolver.so + list(APPEND tensorflow_EXTERNAL_LIBRARIES gomp) + endif() # NOTE(mrry): Update these flags when the version of CUDA or cuDNN used # in the default build is upgraded. diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index fe83bb32046..0b79f718d48 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -128,6 +128,18 @@ Step-by-step Windows build D:\local\cuda\bin ``` + * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable. + + In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable. + It should contain the directory of the MKL dlls. For example: + + ``` + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler + D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt + ``` + + * We assume that `cmake` and `git` are installed and in your `%PATH%`. If for example `cmake` is not in your path and it is installed in `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory @@ -166,7 +178,15 @@ Step-by-step Windows build More? -Dtensorflow_ENABLE_GPU=ON ^ More? -DCUDNN_HOME="D:\...\cudnn" ``` + To build with MKL support add "^" at the end of the last line above following with: + + ``` + More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^ + More? -DMKL_HOME="D:\...\compilers_and_libraries" + ``` + To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows: + ``` More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX ``` @@ -226,6 +246,7 @@ Step-by-step Windows build ``` ctest -C RelWithDebInfo ``` + * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`. After building the python wheel, you need to install the new wheel before running the tests. @@ -234,6 +255,12 @@ Step-by-step Windows build ctest -C RelWithDebInfo ``` + * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). + CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl. + + * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support. + + 4. Invoke MSBuild to build TensorFlow. To build the C++ example program, which will be created as a `.exe` @@ -251,6 +278,7 @@ Step-by-step Windows build D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj ``` + Linux Continuous Integration build ================================== diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake new file mode 100644 index 00000000000..527ccdc8d88 --- /dev/null +++ b/tensorflow/contrib/cmake/external/double_conversion.cmake @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) +set(double_conversion_URL https://github.com/google/double-conversion.git) +set(double_conversion_TAG 5664746) +set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) +set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) +set(double_conversion_INCLUDES ${double_conversion_BUILD}) + +if(WIN32) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) +else() + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) +endif() + +set(double_conversion_HEADERS + "${double_conversion_INCLUDE_DIR}/double-conversion/bignum-dtoa.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/cached-powers.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/double-conversion.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/fixed-dtoa.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/strtod.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/bignum.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/diy-fp.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/fast-dtoa.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/ieee.h" + "${double_conversion_INCLUDE_DIR}/double-conversion/utils.h" +) + +ExternalProject_Add(double_conversion + PREFIX double_conversion + GIT_REPOSITORY ${double_conversion_URL} + GIT_TAG ${double_conversion_TAG} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON +) diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index a235442dc5c..cdaa6b73b93 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) -set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip) +set(gemmlowp_HASH SHA256=b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 35c2a294ecf..693dc7cd673 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 09386db3939cae1ac12e5f09b735adfa8958c68e) +set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/external/mkldnn.cmake b/tensorflow/contrib/cmake/external/mkldnn.cmake new file mode 100644 index 00000000000..a639fdee367 --- /dev/null +++ b/tensorflow/contrib/cmake/external/mkldnn.cmake @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +include (ExternalProject) + +set(mkldnn_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/include) +set(mkldnn_URL https://github.com/01org/mkl-dnn.git) +set(mkldnn_BUILD ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src) +set(mkldnn_TAG 3063b2e4c943983f6bf5f2fb9a490d4a998cd291) + +if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/Release/mkldnn.lib) + else() + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/mkldnn.lib) + endif() +else() + set(mkldnn_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/mkldnn/src/mkldnn/src/libmkldnn.a) +endif() + +ExternalProject_Add(mkldnn + PREFIX mkldnn + GIT_REPOSITORY ${mkldnn_URL} + GIT_TAG ${mkldnn_TAG} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${mkldnn_STATIC_LIBRARIES} + INSTALL_COMMAND "" + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DMKLINC:STRING=${MKL_INCLUDE_DIRS} +) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 6cd66a65990..ad2af01bc00 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -15,32 +15,33 @@ include (ExternalProject) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) -set(png_URL https://storage.googleapis.com/libpng-public-archive/libpng-1.2.53.tar.gz) -set(png_HASH SHA256=e05c9056d7f323088fd7824d8c6acc03a4a758c4b4916715924edc5dd3223a72) +set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) +set(png_HASH SHA256=e45ce5f68b1d80e2cb9a2b601605b374bdf51e1798ef1c2c2bd62131dfcf9eef) set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(png_STATIC_LIBRARIES - debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib - optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_staticd.lib + optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_static.lib) else() if(CMAKE_BUILD_TYPE EQUAL Debug) set(png_STATIC_LIBRARIES - ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib) + ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_staticd.lib) else() set(png_STATIC_LIBRARIES - ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + ${CMAKE_BINARY_DIR}/png/install/lib/libpng16_static.lib) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) endif() set(png_HEADERS - "${png_INSTALL}/include/libpng12/png.h" - "${png_INSTALL}/include/libpng12/pngconf.h" + "${png_INSTALL}/include/libpng16/png.h" + "${png_INSTALL}/include/libpng16/pngconf.h" + "${png_INSTALL}/include/libpng16/pnglibconf.h" ) ExternalProject_Add(png diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 57c4ae76517..7f835d2d519 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) -set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) -set(sqlite_HASH SHA256=208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4) +set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2018/sqlite-amalgamation-3230100.zip) +set(sqlite_HASH SHA256=4239a1f69e5721d07d9a374eb84d594225229e54be4ee628da2995f4315d8dfc) set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index 116d4230939..8942f3eecf0 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -31,7 +31,8 @@ else (systemlib_ZLIB) set(ZLIB_URL https://github.com/madler/zlib) set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) - set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) + # Match zlib version in tensorflow/workspace.bzl + set(ZLIB_TAG v1.2.11) if(WIN32) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 91839194c7c..fece56c4127 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -32,52 +32,13 @@ tensorflow/python/feature_column tensorflow/python/framework tensorflow/python/grappler tensorflow/python/keras -tensorflow/python/keras/activations tensorflow/python/keras/applications -tensorflow/python/keras/applications/densenet -tensorflow/python/keras/applications/inception_resnet_v2 -tensorflow/python/keras/applications/inception_v3 -tensorflow/python/keras/applications/mobilenet -tensorflow/python/keras/applications/nasnet -tensorflow/python/keras/applications/resnet50 -tensorflow/python/keras/applications/vgg16 -tensorflow/python/keras/applications/vgg19 -tensorflow/python/keras/applications/xception -tensorflow/python/keras/backend -tensorflow/python/keras/callbacks -tensorflow/python/keras/constraints tensorflow/python/keras/datasets -tensorflow/python/keras/datasets/boston_housing -tensorflow/python/keras/datasets/cifar10 -tensorflow/python/keras/datasets/cifar100 -tensorflow/python/keras/datasets/fashion_mnist -tensorflow/python/keras/datasets/imdb -tensorflow/python/keras/datasets/mnist -tensorflow/python/keras/datasets/reuters -tensorflow/python/keras/estimator -tensorflow/python/keras/initializers +tensorflow/python/keras/engine tensorflow/python/keras/layers -tensorflow/python/keras/losses -tensorflow/python/keras/metrics -tensorflow/python/keras/models -tensorflow/python/keras/optimizers tensorflow/python/keras/preprocessing -tensorflow/python/keras/preprocessing/image -tensorflow/python/keras/preprocessing/sequence -tensorflow/python/keras/preprocessing/text -tensorflow/python/keras/regularizers tensorflow/python/keras/utils tensorflow/python/keras/wrappers -tensorflow/python/keras/wrappers/scikit_learn -tensorflow/python/keras/_impl -tensorflow/python/keras/_impl/keras -tensorflow/python/keras/_impl/keras/applications -tensorflow/python/keras/_impl/keras/datasets -tensorflow/python/keras/_impl/keras/engine -tensorflow/python/keras/_impl/keras/layers -tensorflow/python/keras/_impl/keras/preprocessing -tensorflow/python/keras/_impl/keras/utils -tensorflow/python/keras/_impl/keras/wrappers tensorflow/python/kernel_tests tensorflow/python/kernel_tests/boosted_trees tensorflow/python/kernel_tests/distributions @@ -100,6 +61,7 @@ tensorflow/python/summary tensorflow/python/summary/writer tensorflow/python/tools tensorflow/python/training +tensorflow/python/training/checkpointable tensorflow/python/user_ops tensorflow/python/util tensorflow/python/util/protobuf @@ -129,7 +91,13 @@ tensorflow/contrib/boosted_trees/kernels tensorflow/contrib/boosted_trees/ops tensorflow/contrib/boosted_trees/proto tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/kernel_tests tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/boosted_trees/python/training +tensorflow/contrib/boosted_trees/python/training/functions +tensorflow/contrib/boosted_trees/python/utils +tensorflow/contrib/checkpoint +tensorflow/contrib/checkpoint/python tensorflow/contrib/cloud tensorflow/contrib/cloud/kernels tensorflow/contrib/cloud/ops @@ -142,8 +110,11 @@ tensorflow/contrib/coder tensorflow/contrib/coder/kernels tensorflow/contrib/coder/ops tensorflow/contrib/coder/python +tensorflow/contrib/coder/python/layers tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler +tensorflow/contrib/constrained_optimization +tensorflow/contrib/constrained_optimization/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util @@ -324,6 +295,8 @@ tensorflow/contrib/metrics tensorflow/contrib/metrics/python tensorflow/contrib/metrics/python/metrics tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mixed_precision +tensorflow/contrib/mixed_precision/python tensorflow/contrib/mpi_collectives/python tensorflow/contrib/mpi_collectives/python/ops tensorflow/contrib/model_pruning diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt index d63c41db844..cf1ee2ad76f 100644 --- a/tensorflow/contrib/cmake/python_protos.txt +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -11,7 +11,6 @@ tensorflow/contrib/mpi tensorflow/contrib/mpi_collectives tensorflow/contrib/session_bundle tensorflow/contrib/tensor_forest/proto -tensorflow/contrib/tensorboard/graph_explorer/proto tensorflow/contrib/tensorboard/plugins/projector tensorflow/contrib/tensorboard/plugins/trace tensorflow/contrib/tpu/proto diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c6a15f2ca07..310fe58e055 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -22,8 +22,6 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" - "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" "${tensorflow_source_dir}/tensorflow/c/tf_status_helper.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index f7cb186c7ca..b47c32f1c48 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -276,7 +276,7 @@ add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo) add_custom_command(OUTPUT ${VERSION_INFO_CC} COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py - ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} + ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE} DEPENDS __force_rebuild) set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index ed018b4fed8..2d76bf530a2 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -63,10 +63,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/csv_dataset_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" @@ -176,6 +179,16 @@ if(WIN32) "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_windows_exclude_srcs}) +else(WIN32) + if(tensorflow_ENABLE_GPU) + file(GLOB_RECURSE tf_core_kernels_gpu_exclude_srcs + # temporarily disable nccl as it needs to be ported with gpu + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" + ) + list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_gpu_exclude_srcs}) + endif(tensorflow_ENABLE_GPU) endif(WIN32) file(GLOB_RECURSE tf_core_gpu_kernels_srcs diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 1c3206f1a26..8d24a7ae38f 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -244,13 +244,11 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD # tf_python_op_gen_main library ######################################################## set(tf_python_op_gen_main_srcs - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" - "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_main.cc" ) add_library(tf_python_op_gen_main OBJECT ${tf_python_op_gen_main_srcs}) @@ -330,8 +328,10 @@ GENERATE_PYTHON_OP_LIB("ctc_ops") GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") -GENERATE_PYTHON_OP_LIB("decode_proto_ops") -GENERATE_PYTHON_OP_LIB("encode_proto_ops") +GENERATE_PYTHON_OP_LIB("decode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_decode_proto_op.py) +GENERATE_PYTHON_OP_LIB("encode_proto_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/proto/python/ops/gen_encode_proto_op.py) GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") @@ -345,7 +345,8 @@ GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py) GENERATE_PYTHON_OP_LIB("resource_variable_ops") -GENERATE_PYTHON_OP_LIB("rpc_ops") +GENERATE_PYTHON_OP_LIB("rpc_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rpc/python/ops/gen_rpc_op.py) GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") GENERATE_PYTHON_OP_LIB("set_ops") @@ -461,12 +462,12 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.h" - "${tensorflow_source_dir}/tensorflow/python/eager/python_eager_op_gen.cc" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.h" "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.h" + "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen_internal.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" @@ -551,12 +552,13 @@ if(WIN32) set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow.def") endif() set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE) - + math(EXPR tensorflow_target_bitness "${CMAKE_SIZEOF_VOID_P}*8") add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py --input "${pywrap_tensorflow_internal_static_dependencies}" --output "${pywrap_tensorflow_deffile}" --target _pywrap_tensorflow_internal.pyd + --bitness "${tensorflow_target_bitness}" BYPRODUCTS ${pywrap_tensorflow_deffile} # Required for Ninja ) endif(WIN32) diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 9738bbeb9ae..38f40452b53 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -52,12 +52,13 @@ if(WIN32) set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/tensorflow.def") endif() set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) - + math(EXPR tensorflow_target_bitness "${CMAKE_SIZEOF_VOID_P}*8") add_custom_command(TARGET tensorflow_static POST_BUILD COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/create_def_file.py --input "${tensorflow_static_dependencies}" --output "${tensorflow_deffile}" --target tensorflow.dll + --bitness "${tensorflow_target_bitness}" ) endif(WIN32) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 91ca33f4c4d..9a37b681194 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -64,7 +64,15 @@ file(GLOB tf_stream_executor_srcs if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" ) + if (NOT tensorflow_BUILD_CC_TESTS) + file(GLOB tf_stream_executor_gpu_tests + "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*_test.cc" + ) + list(REMOVE_ITEM tf_stream_executor_gpu_srcs ${tf_stream_executor_gpu_tests}) + endif() list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 92f2ab6dea8..5942ff3363a 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -267,6 +267,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" + # Flaky on Windows cpu with py36 (b/73556968) + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/sparse_reshape_op_test.py" # Windows file management related issues. "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # training tests diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index 53c2285699a..cffe069aa35 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -63,7 +63,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" r"^(TFE_\w*)$|" r"tensorflow::|" r"functor::|" - r"nsync_|" + r"\?nsync_|" r"perftools::gputools") # We want to identify data members explicitly in the DEF file, so that no one @@ -87,6 +87,7 @@ def get_args(): required=True) parser.add_argument("--output", help="output deffile", required=True) parser.add_argument("--target", help="name of the target", required=True) + parser.add_argument("--bitness", help="build target bitness", required=True) args = parser.parse_args() return args @@ -125,7 +126,10 @@ def main(): # Header for the def file. def_fp.write("LIBRARY " + args.target + "\n") def_fp.write("EXPORTS\n") - def_fp.write("\t ??1OpDef@tensorflow@@UEAA@XZ\n") + if args.bitness == "64": + def_fp.write("\t??1OpDef@tensorflow@@UEAA@XZ\n") + else: + def_fp.write("\t??1OpDef@tensorflow@@UAE@XZ\n") # Each symbols returned by undname matches the same position in candidates. # We compare on undname but use the decorated name from candidates. diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD index 9ca4ce8a9c7..a2c6e413039 100644 --- a/tensorflow/contrib/coder/BUILD +++ b/tensorflow/contrib/coder/BUILD @@ -1,5 +1,5 @@ # Description: -# Contains entropy coding related modules. +# Contains tools related to data compression. package(default_visibility = [ "//learning/brain:__subpackages__", @@ -54,19 +54,27 @@ tf_gen_op_libs( ], ) +cc_library( + name = "range_coder_ops_util", + srcs = ["kernels/range_coder_ops_util.cc"], + hdrs = ["kernels/range_coder_ops_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + tf_kernel_library( name = "range_coder_ops", srcs = [ "kernels/range_coder_ops.cc", - "kernels/range_coder_ops_util.cc", - ], - hdrs = [ - "kernels/range_coder_ops_util.h", ], visibility = ["//visibility:public"], deps = [ ":coder_ops_op_lib", ":range_coder", + ":range_coder_ops_util", "//tensorflow/core:framework", "//tensorflow/core:lib", ], @@ -152,10 +160,21 @@ tf_gen_op_wrapper_py( deps = [":coder_ops_op_lib"], ) +py_library( + name = "coder_py", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":coder_ops_py", + ":entropybottleneck_py", + ], +) + tf_custom_op_py_library( name = "coder_ops_py", srcs = [ - "__init__.py", "python/ops/coder_ops.py", ], dso = [ @@ -186,3 +205,44 @@ tf_py_test( ], main = "python/ops/coder_ops_test.py", ) + +py_library( + name = "entropybottleneck_py", + srcs = [ + "python/layers/entropybottleneck.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":coder_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:functional_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/keras:engine", + "//third_party/py/numpy", + ], +) + +tf_py_test( + name = "entropybottleneck_py_test", + srcs = [ + "python/layers/entropybottleneck_test.py", + ], + additional_deps = [ + ":entropybottleneck_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:variables", + "//tensorflow/python:training", + ], + main = "python/layers/entropybottleneck_test.py", +) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py index b7e663e6f13..99b8ac7595e 100644 --- a/tensorflow/contrib/coder/__init__.py +++ b/tensorflow/contrib/coder/__init__.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Entropy code operations.""" +"""Data compression tools.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import +from tensorflow.contrib.coder.python.layers.entropybottleneck import * from tensorflow.contrib.coder.python.ops.coder_ops import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc index c787e8edede..bd5272ee6f2 100644 --- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include +#include #include #include #include @@ -79,8 +80,8 @@ class PmfToCdfOp : public OpKernel { } private: - struct Item { - Item(int32* p, double mass) : pointer(p), mass(mass) { + struct PenaltyItem { + PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) { penalty = ComputeNextPenalty(); } @@ -90,7 +91,7 @@ class PmfToCdfOp : public OpKernel { penalty = ComputeNextPenalty(); } - friend bool operator<(const Item& lhs, const Item& rhs) { + friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) { return lhs.penalty < rhs.penalty; } @@ -106,6 +107,34 @@ class PmfToCdfOp : public OpKernel { double penalty; }; + struct GainItem { + GainItem(int32* p, double mass) : pointer(p), mass(mass) { + gain = ComputeNextGain(); + } + + void Increase() { + CHECK_GT(*pointer, 0); + ++*pointer; + gain = ComputeNextGain(); + } + + friend bool operator>(const GainItem& lhs, const GainItem& rhs) { + return lhs.gain > rhs.gain; + } + + double ComputeNextGain() { + // Never increment zero value to non-zero value. + if (*pointer < 1) { + return -std::numeric_limits::infinity(); + } + return mass * (std::log2(*pointer + 1) - std::log2(*pointer)); + } + + int32* pointer; + double mass; + double gain; + }; + void PerShard(gtl::ArraySlice pmf, gtl::MutableArraySlice cdf) const { CHECK_EQ(pmf.size(), cdf.size()); @@ -121,7 +150,7 @@ class PmfToCdfOp : public OpKernel { int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0); if (sum > normalizer) { - std::vector queue; + std::vector queue; queue.reserve(cdf.size()); for (int i = 0; i < cdf.size(); ++i) { queue.emplace_back(&cdf[i], pmf[i]); @@ -132,9 +161,26 @@ class PmfToCdfOp : public OpKernel { queue[0].Decrease(); // Performs a linear search because this find_if is likely to return // iterator very close to the begin. - auto iter = - std::find_if(std::next(queue.begin()), queue.end(), - [&queue](const Item& rhs) { return queue[0] < rhs; }); + auto iter = std::find_if( + std::next(queue.begin()), queue.end(), + [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; }); + std::rotate(queue.begin(), std::next(queue.begin()), iter); + } + } else if (sum < normalizer) { + std::vector queue; + queue.reserve(cdf.size()); + for (int i = 0; i < cdf.size(); ++i) { + queue.emplace_back(&cdf[i], pmf[i]); + } + + std::sort(queue.begin(), queue.end(), std::greater()); + while (sum++ < normalizer) { + queue[0].Increase(); + // Performs a linear search because this find_if is likely to return + // iterator very close to the begin. + auto iter = std::find_if( + std::next(queue.begin()), queue.end(), + [&queue](const GainItem& rhs) { return queue[0] > rhs; }); std::rotate(queue.begin(), std::next(queue.begin()), iter); } } diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc index c70e38faab7..3408f6b519a 100644 --- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc @@ -82,7 +82,7 @@ class PmfToQuantizedCdfOpTest : public OpsTestBase { EXPECT_GT(diff, 0); } - EXPECT_LE(cdf_slice(cdf_slice.size() - 1), normalizer); + EXPECT_EQ(cdf_slice(cdf_slice.size() - 1), normalizer); } } }; @@ -98,6 +98,8 @@ TEST_F(PmfToQuantizedCdfOpTest, UnderSum) { GenerateData(&rand, {&matrix(i, 0), n}); } + pmf.flat() = pmf.flat() * 0.85f; + constexpr int kPrecision = 10; SetupOp(kPrecision, &pmf); TF_ASSERT_OK(RunOpKernel()); @@ -115,7 +117,7 @@ TEST_F(PmfToQuantizedCdfOpTest, OverSum) { matrix.setZero(); const std::size_t n = matrix.dimension(1) / 2; - random::PhiloxRandom gen; + random::PhiloxRandom gen(random::New64(), random::New64()); random::SimplePhilox rand(&gen); for (int64 i = 0; i < matrix.dimension(0); ++i) { GenerateData(&rand, {&matrix(i, 0), n}); diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc index ae4d9d2836a..81b36ca902b 100644 --- a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc index 9bb171298f8..a185e07913f 100644 --- a/tensorflow/contrib/coder/ops/coder_ops.cc +++ b/tensorflow/contrib/coder/ops/coder_ops.cc @@ -77,7 +77,7 @@ are incorrect. For this reason, the range coder uses integer arithmetics and avoids using any floating point operations internally, and `cdf` should contain integers representing quantized probability mass rather than floating points. -data: An int32 tensor. +data: An int16 tensor. cdf: An int32 tensor representing the CDF's of `data`. Each integer is divided by `2^precision` to represent a fraction. encoded: A range-coded scalar string. @@ -112,7 +112,7 @@ potential performance issues, the decoder does not return error status. encoded: A scalar string tensor from RangeEncode. shape: An int32 1-D tensor representing the shape of the data encoded by RangeEncode. -decoded: An int32 tensor with shape equal to `shape`. +decoded: An int16 tensor with shape equal to `shape`. precision: The number of bits for probability quantization. Must be <= 16, and must match the precision used by RangeEncode that produced `encoded`. )doc"); @@ -138,14 +138,12 @@ platforms. For entropy encoders and decoders to have the same quantized CDF on different platforms, the quantized CDF should be produced once and saved, then the saved quantized CDF should be used everywhere. -After quantization, if PMF sums to less than or equal to 2^precision, then this -is equivalent to cumsum over the last dimension. This op makes no effort to make -the sum close to 2^precision when the sum is already <= 2^precision. +After quantization, if PMF does not sum to 2^precision, then some values of PMF +are increased or decreased to adjust the sum to equal to 2^precision. -After quantization, if PMF sums to greater than 2^precision, then some values of -PMF is decreased to keep the sum no more than 2^precision. - -Note that the input PMF is pre-quantization. +Note that the input PMF is pre-quantization. The input PMF is not normalized +by this op prior to quantization. Therefore the user is responsible for +normalizing PMF if necessary. )doc"); // clang-format on } // namespace tensorflow diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py new file mode 100644 index 00000000000..0fbe3081af0 --- /dev/null +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py @@ -0,0 +1,697 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Entropy bottleneck layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.coder.python.ops import coder_ops + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import engine +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary + + +class EntropyBottleneck(engine.Layer): + """Entropy bottleneck layer. + + This layer can be used to model the entropy (the amount of information + conveyed) of the tensor passing through it. During training, this can be used + to impose a (soft) entropy constraint on its activations, limiting the amount + of information flowing through the layer. Note that this is distinct from + other types of bottlenecks, which reduce the dimensionality of the space, for + example. Dimensionality reduction does not limit the amount of information, + and does not enable efficient data compression per se. + + After training, this layer can be used to compress any input tensor to a + string, which may be written to a file, and to decompress a file which it + previously generated back to a reconstructed tensor (possibly on a different + machine having access to the same model checkpoint). The entropies estimated + during training or evaluation are approximately equal to the average length of + the strings in bits. + + The layer implements a flexible probability density model to estimate entropy, + which is described in the appendix of the paper (please cite the paper if you + use this code for scientific work): + + "Variational image compression with a scale hyperprior" + + Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston + + https://arxiv.org/abs/1802.01436 + + The layer assumes that the input tensor is at least 2D, with a batch dimension + at the beginning and a channel dimension as specified by `data_format`. The + layer trains an independent probability density model for each channel, but + assumes that across all other dimensions, the inputs are i.i.d. (independent + and identically distributed). Because the entropy (and hence, average + codelength) is a function of the densities, this assumption may have a direct + effect on the compression performance. + + Because data compression always involves discretization, the outputs of the + layer are generally only approximations of its inputs. During training, + discretization is modeled using additive uniform noise to ensure + differentiability. The entropies computed during training are differential + entropies. During evaluation, the data is actually quantized, and the + entropies are discrete (Shannon entropies). To make sure the approximated + tensor values are good enough for practical purposes, the training phase must + be used to balance the quality of the approximation with the entropy, by + adding an entropy term to the training loss, as in the following example. + + Here, we use the entropy bottleneck to compress the latent representation of + an autoencoder. The data vectors `x` in this case are 4D tensors in + `'channels_last'` format (for example, 16x16 pixel grayscale images). + + The layer always produces exactly one auxiliary loss and one update op which + are only significant for compression and decompression. To use the compression + feature, the auxiliary loss must be minimized during or after training. After + that, the update op must be executed at least once. Here, we simply attach + them to the main training step. + + Training: + ``` + # Build autoencoder. + x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) + y = forward_transform(x) + entropy_bottleneck = EntropyBottleneck() + y_, likelihoods = entropy_bottleneck(y, training=True) + x_ = backward_transform(y_) + + # Information content (= predicted codelength) in bits of each batch element + # (note that taking the natural logarithm and dividing by `log(2)` is + # equivalent to taking base-2 logarithms): + bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) + + # Squared difference of each batch element: + squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) + + # The loss is a weighted sum of mean squared error and entropy (average + # information content), where the weight controls the trade-off between + # approximation error and entropy. + main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) + + # Minimize loss and auxiliary loss, and execute update op. + main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) + main_step = optimizer.minimize(main_loss) + # 1e-2 is a good starting point for the learning rate of the auxiliary loss, + # assuming Adam is used. + aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2) + aux_step = optimizer.minimize(entropy_bottleneck.losses[0]) + step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0]) + ``` + + Evaluation: + ``` + # Build autoencoder. + x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) + y = forward_transform(x) + y_, likelihoods = EntropyBottleneck()(y, training=False) + x_ = backward_transform(y_) + + # Information content (= predicted codelength) in bits of each batch element: + bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2) + + # Squared difference of each batch element: + squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3)) + + # The loss is a weighted sum of mean squared error and entropy (average + # information content), where the weight controls the trade-off between + # approximation error and entropy. + loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits) + ``` + + To be able to compress the bottleneck tensor and decompress it in a different + session, or on a different machine, you need three items: + - The compressed representations stored as strings. + - The shape of the bottleneck for these string representations as a `Tensor`, + as well as the number of channels of the bottleneck at graph construction + time. + - The checkpoint of the trained model that was used for compression. Note: + It is crucial that the auxiliary loss produced by this layer is minimized + during or after training, and that the update op is run after training and + minimization of the auxiliary loss, but *before* the checkpoint is saved. + + Compression: + ``` + x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1]) + y = forward_transform(x) + strings = EntropyBottleneck().compress(y) + shape = tf.shape(y)[1:] + ``` + + Decompression: + ``` + strings = tf.placeholder(tf.string, shape=[None]) + shape = tf.placeholder(tf.int32, shape=[3]) + entropy_bottleneck = EntropyBottleneck(dtype=tf.float32) + y_ = entropy_bottleneck.decompress(strings, shape, channels=5) + x_ = backward_transform(y_) + ``` + Here, we assumed that the tensor produced by the forward transform has 5 + channels. + + The above four use cases can also be implemented within the same session (i.e. + on the same `EntropyBottleneck` instance), for testing purposes, etc., by + calling the object more than once. + + Arguments: + init_scale: Float. A scaling factor determining the initial width of the + probability densities. This should be chosen big enough so that the + range of values of the layer inputs roughly falls within the interval + [`-init_scale`, `init_scale`] at the beginning of training. + filters: An iterable of ints, giving the number of filters at each layer of + the density model. Generally, the more filters and layers, the more + expressive is the density model in terms of modeling more complicated + distributions of the layer inputs. For details, refer to the paper + referenced above. The default is `[3, 3, 3]`, which should be sufficient + for most practical purposes. + tail_mass: Float, between 0 and 1. The bottleneck layer automatically + determines the range of input values that should be represented based on + their frequency of occurrence. Values occurring in the tails of the + distributions will be clipped to that range during compression. + `tail_mass` determines the amount of probability mass in the tails which + is cut off in the worst case. For example, the default value of `1e-9` + means that at most 1 in a billion input samples will be clipped to the + range. + optimize_integer_offset: Boolean. Typically, the input values of this layer + are floats, which means that quantization during evaluation can be + performed with an arbitrary offset. By default, the layer determines that + offset automatically. In special situations, such as when it is known that + the layer will receive only full integer values during evaluation, it can + be desirable to set this argument to `False` instead, in order to always + quantize to full integer values. + likelihood_bound: Float. If positive, the returned likelihood values are + ensured to be greater than or equal to this value. This prevents very + large gradients with a typical entropy loss (defaults to 1e-9). + range_coder_precision: Integer, between 1 and 16. The precision of the range + coder used for compression and decompression. This trades off computation + speed with compression efficiency, where 16 is the slowest but most + efficient setting. Choosing lower values may increase the average + codelength slightly compared to the estimated entropies. + data_format: Either `'channels_first'` or `'channels_last'` (default). + trainable: Boolean. Whether the layer should be trained. + name: String. The name of the layer. + dtype: Default dtype of the layer's parameters (default of `None` means use + the type of the first input). + + Read-only properties: + init_scale: See above. + filters: See above. + tail_mass: See above. + optimize_integer_offset: See above. + likelihood_bound: See above. + range_coder_precision: See above. + data_format: See above. + name: String. See above. + dtype: See above. + trainable_variables: List of trainable variables. + non_trainable_variables: List of non-trainable variables. + variables: List of all variables of this layer, trainable and non-trainable. + updates: List of update ops of this layer. Always contains exactly one + update op, which must be run once after the last training step, before + `compress` or `decompress` is used. + losses: List of losses added by this layer. Always contains exactly one + auxiliary loss, which must be added to the training loss. + + Mutable properties: + trainable: Boolean. Whether the layer should be trained. + input_spec: Optional `InputSpec` object specifying the constraints on inputs + that can be accepted by the layer. + """ + + def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, + optimize_integer_offset=True, likelihood_bound=1e-9, + range_coder_precision=16, data_format="channels_last", **kwargs): + super(EntropyBottleneck, self).__init__(**kwargs) + self._init_scale = float(init_scale) + self._filters = tuple(int(f) for f in filters) + self._tail_mass = float(tail_mass) + if not 0 < self.tail_mass < 1: + raise ValueError( + "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) + self._optimize_integer_offset = bool(optimize_integer_offset) + self._likelihood_bound = float(likelihood_bound) + self._range_coder_precision = int(range_coder_precision) + self._data_format = data_format + self._channel_axis(2) # trigger ValueError early + self.input_spec = engine.InputSpec(min_ndim=2) + + @property + def init_scale(self): + return self._init_scale + + @property + def filters(self): + return self._filters + + @property + def tail_mass(self): + return self._tail_mass + + @property + def optimize_integer_offset(self): + return self._optimize_integer_offset + + @property + def likelihood_bound(self): + return self._likelihood_bound + + @property + def range_coder_precision(self): + return self._range_coder_precision + + @property + def data_format(self): + return self._data_format + + def _channel_axis(self, ndim): + try: + return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] + except KeyError: + raise ValueError("Unsupported `data_format` for {} layer: {}.".format( + self.__class__.__name__, self.data_format)) + + def _logits_cumulative(self, inputs, stop_gradient): + """Evaluate logits of the cumulative densities. + + Args: + inputs: The values at which to evaluate the cumulative densities, expected + to be a `Tensor` of shape `(channels, 1, batch)`. + stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so + that the gradient of the output with respect to the density model + parameters is disconnected (the gradient with respect to `inputs` is + left untouched). + + Returns: + A `Tensor` of the same shape as `inputs`, containing the logits of the + cumulative densities evaluated at the given inputs. + """ + logits = inputs + + for i in range(len(self.filters) + 1): + matrix = self._matrices[i] + if stop_gradient: + matrix = array_ops.stop_gradient(matrix) + logits = math_ops.matmul(matrix, logits) + + bias = self._biases[i] + if stop_gradient: + bias = array_ops.stop_gradient(bias) + logits += bias + + if i < len(self._factors): + factor = self._factors[i] + if stop_gradient: + factor = array_ops.stop_gradient(factor) + logits += factor * math_ops.tanh(logits) + + return logits + + def build(self, input_shape): + """Builds the layer. + + Creates the variables for the network modeling the densities, creates the + auxiliary loss estimating the median and tail quantiles of the densities, + and then uses that to create the probability mass functions and the update + op that produces the discrete cumulative density functions used by the range + coder. + + Args: + input_shape: Shape of the input tensor, used to get the number of + channels. + + Raises: + ValueError: if `input_shape` doesn't specify the length of the channel + dimension. + """ + input_shape = tensor_shape.TensorShape(input_shape) + channel_axis = self._channel_axis(input_shape.ndims) + channels = input_shape[channel_axis].value + if channels is None: + raise ValueError("The channel dimension of the inputs must be defined.") + self.input_spec = engine.InputSpec( + ndim=input_shape.ndims, axes={channel_axis: channels}) + filters = (1,) + self.filters + (1,) + scale = self.init_scale ** (1 / (len(self.filters) + 1)) + + # Create variables. + self._matrices = [] + self._biases = [] + self._factors = [] + for i in range(len(self.filters) + 1): + init = np.log(np.expm1(1 / scale / filters[i + 1])) + matrix = self.add_variable( + "matrix_{}".format(i), dtype=self.dtype, + shape=(channels, filters[i + 1], filters[i]), + initializer=init_ops.Constant(init)) + matrix = nn.softplus(matrix) + self._matrices.append(matrix) + + bias = self.add_variable( + "bias_{}".format(i), dtype=self.dtype, + shape=(channels, filters[i + 1], 1), + initializer=init_ops.RandomUniform(-.5, .5)) + self._biases.append(bias) + + if i < len(self.filters): + factor = self.add_variable( + "factor_{}".format(i), dtype=self.dtype, + shape=(channels, filters[i + 1], 1), + initializer=init_ops.Zeros()) + factor = math_ops.tanh(factor) + self._factors.append(factor) + + # To figure out what range of the densities to sample, we need to compute + # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we + # can't take inverses of the cumulative directly, we make it an optimization + # problem: + # `quantiles = argmin(|logit(cumulative) - target|)` + # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. + # Taking the logit (inverse of sigmoid) of the cumulative makes the + # representation of the right target more numerically stable. + + # Numerically stable way of computing logits of `tail_mass / 2` + # and `1 - tail_mass / 2`. + target = np.log(2 / self.tail_mass - 1) + # Compute lower and upper tail quantile as well as median. + target = constant_op.constant([-target, 0, target], dtype=self.dtype) + + def quantiles_initializer(shape, dtype=None, partition_info=None): + del partition_info # unused + assert tuple(shape[1:]) == (1, 3) + init = constant_op.constant( + [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype) + return array_ops.tile(init, (shape[0], 1, 1)) + + quantiles = self.add_variable( + "quantiles", shape=(channels, 1, 3), dtype=self.dtype, + initializer=quantiles_initializer) + logits = self._logits_cumulative(quantiles, stop_gradient=True) + loss = math_ops.reduce_sum(abs(logits - target)) + self.add_loss(loss, inputs=None) + + # Save medians for `call`, `compress`, and `decompress`. + self._medians = quantiles[:, :, 1:2] + if not self.optimize_integer_offset: + self._medians = math_ops.round(self._medians) + + # Largest distance observed between lower tail quantile and median, + # or between median and upper tail quantile. + minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1]) + maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians) + minmax = math_ops.maximum(minima, maxima) + minmax = math_ops.ceil(minmax) + minmax = math_ops.maximum(minmax, 1) + + # Sample the density up to `minmax` around the median. + samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) + samples += self._medians + + half = constant_op.constant(.5, dtype=self.dtype) + # We strip the sigmoid from the end here, so we can use the special rule + # below to only compute differences in the left tail of the sigmoid. + # This increases numerical stability (see explanation in `call`). + lower = self._logits_cumulative(samples - half, stop_gradient=True) + upper = self._logits_cumulative(samples + half, stop_gradient=True) + # Flip signs if we can move more towards the left tail of the sigmoid. + sign = -math_ops.sign(math_ops.add_n([lower, upper])) + pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) + # Add tail masses to first and last bin of pmf, as we clip values for + # compression, meaning that out-of-range values get mapped to these bins. + pmf = array_ops.concat([ + math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]), + pmf[:, 0, 1:-1], + math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]), + ], axis=-1) + self._pmf = pmf + + cdf = coder_ops.pmf_to_quantized_cdf( + pmf, precision=self.range_coder_precision) + def cdf_getter(*args, **kwargs): + del args, kwargs # ignored + return variable_scope.get_variable( + "quantized_cdf", dtype=dtypes.int32, initializer=cdf, + trainable=False, validate_shape=False, collections=()) + # Need to provide a fake shape here since add_variable insists on it. + self._quantized_cdf = self.add_variable( + "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32, + getter=cdf_getter, trainable=False) + + update_op = state_ops.assign( + self._quantized_cdf, cdf, validate_shape=False) + self.add_update(update_op, inputs=None) + + super(EntropyBottleneck, self).build(input_shape) + + def call(self, inputs, training): + """Pass a tensor through the bottleneck. + + Args: + inputs: The tensor to be passed through the bottleneck. + training: Boolean. If `True`, returns a differentiable approximation of + the inputs, and their likelihoods under the modeled probability + densities. If `False`, returns the quantized inputs and their + likelihoods under the corresponding probability mass function. These + quantities can't be used for training, as they are not differentiable, + but represent actual compression more closely. + + Returns: + values: `Tensor` with the same shape as `inputs` containing the perturbed + or quantized input values. + likelihood: `Tensor` with the same shape as `inputs` containing the + likelihood of `values` under the modeled probability distributions. + + Raises: + ValueError: if `inputs` has different `dtype` or number of channels than + a previous set of inputs the model was invoked with earlier. + """ + inputs = ops.convert_to_tensor(inputs) + ndim = self.input_spec.ndim + channel_axis = self._channel_axis(ndim) + half = constant_op.constant(.5, dtype=self.dtype) + + # Convert to (channels, 1, batch) format by commuting channels to front + # and then collapsing. + order = list(range(ndim)) + order.pop(channel_axis) + order.insert(0, channel_axis) + values = array_ops.transpose(inputs, order) + shape = array_ops.shape(values) + values = array_ops.reshape(values, (shape[0], 1, -1)) + + # Add noise or quantize. + if training: + noise = random_ops.random_uniform(array_ops.shape(values), -half, half) + values = math_ops.add_n([values, noise]) + elif self.optimize_integer_offset: + values = math_ops.round(values - self._medians) + self._medians + else: + values = math_ops.round(values) + + # Evaluate densities. + # We can use the special rule below to only compute differences in the left + # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1 + # for large x, 0 for small x. Subtracting two numbers close to 0 can be done + # with much higher precision than subtracting two numbers close to 1. + lower = self._logits_cumulative(values - half, stop_gradient=False) + upper = self._logits_cumulative(values + half, stop_gradient=False) + # Flip signs if we can move more towards the left tail of the sigmoid. + sign = -math_ops.sign(math_ops.add_n([lower, upper])) + sign = array_ops.stop_gradient(sign) + likelihood = abs( + math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower)) + if self.likelihood_bound > 0: + likelihood_bound = constant_op.constant( + self.likelihood_bound, dtype=self.dtype) + # TODO(jballe): Override gradients. + likelihood = math_ops.maximum(likelihood, likelihood_bound) + + # Convert back to input tensor shape. + order = list(range(1, ndim)) + order.insert(channel_axis, 0) + values = array_ops.reshape(values, shape) + values = array_ops.transpose(values, order) + likelihood = array_ops.reshape(likelihood, shape) + likelihood = array_ops.transpose(likelihood, order) + + if not context.executing_eagerly(): + values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) + values.set_shape(values_shape) + likelihood.set_shape(likelihood_shape) + + return values, likelihood + + def compress(self, inputs): + """Compress inputs and store their binary representations into strings. + + Args: + inputs: `Tensor` with values to be compressed. + + Returns: + String `Tensor` vector containing the compressed representation of each + batch element of `inputs`. + """ + with ops.name_scope(self._name_scope()): + inputs = ops.convert_to_tensor(inputs) + if not self.built: + # Check input assumptions set before layer building, e.g. input rank. + self._assert_input_compatibility(inputs) + if self.dtype is None: + self._dtype = inputs.dtype.base_dtype.name + self.build(inputs.shape) + + # Check input assumptions set after layer building, e.g. input shape. + if not context.executing_eagerly(): + self._assert_input_compatibility(inputs) + + ndim = self.input_spec.ndim + channel_axis = self._channel_axis(ndim) + # Tuple of slices for expanding dimensions of tensors below. + slices = ndim * [None] + [slice(None)] + slices[channel_axis] = slice(None) + slices = tuple(slices) + + # Expand dimensions of CDF to input dimensions, keeping the channels along + # the right dimension. + cdf = self._quantized_cdf[slices[1:]] + num_levels = array_ops.shape(cdf)[-1] - 1 + + # Bring inputs to the right range by centering the range on the medians. + half = constant_op.constant(.5, dtype=self.dtype) + medians = array_ops.squeeze(self._medians, [1, 2]) + offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians + # Expand offsets to input dimensions and add to inputs. + values = inputs + offsets[slices[:-1]] + + # Clip to range and cast to integers. Because we have added .5 above, and + # all values are positive, the cast effectively implements rounding. + values = math_ops.maximum(values, half) + values = math_ops.minimum( + values, math_ops.cast(num_levels, self.dtype) - half) + values = math_ops.cast(values, dtypes.int16) + + def loop_body(tensor): + return coder_ops.range_encode( + tensor, cdf, precision=self.range_coder_precision) + strings = functional_ops.map_fn( + loop_body, values, dtype=dtypes.string, back_prop=False) + + if not context.executing_eagerly(): + strings.set_shape(inputs.shape[:1]) + + return strings + + def decompress(self, strings, shape, channels=None): + """Decompress values from their compressed string representations. + + Args: + strings: A string `Tensor` vector containing the compressed data. + shape: A `Tensor` vector of int32 type. Contains the shape of the tensor + to be decompressed, excluding the batch dimension. + channels: Integer. Specifies the number of channels statically. Needs only + be set if the layer hasn't been built yet (i.e., this is the first input + it receives). + + Returns: + The decompressed `Tensor`. Its shape will be equal to `shape` prepended + with the batch dimension from `strings`. + + Raises: + ValueError: If the length of `shape` isn't available at graph construction + time. + """ + with ops.name_scope(self._name_scope()): + strings = ops.convert_to_tensor(strings) + shape = ops.convert_to_tensor(shape) + if self.built: + ndim = self.input_spec.ndim + channel_axis = self._channel_axis(ndim) + if channels is None: + channels = self.input_spec.axes[channel_axis] + else: + if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): + raise ValueError("`shape` must be a vector with known length.") + ndim = shape.shape[0].value + 1 + channel_axis = self._channel_axis(ndim) + input_shape = ndim * [None] + input_shape[channel_axis] = channels + self.build(input_shape) + + # Tuple of slices for expanding dimensions of tensors below. + slices = ndim * [None] + [slice(None)] + slices[channel_axis] = slice(None) + slices = tuple(slices) + + # Expand dimensions of CDF to input dimensions, keeping the channels along + # the right dimension. + cdf = self._quantized_cdf[slices[1:]] + num_levels = array_ops.shape(cdf)[-1] - 1 + + def loop_body(string): + return coder_ops.range_decode( + string, shape, cdf, precision=self.range_coder_precision) + outputs = functional_ops.map_fn( + loop_body, strings, dtype=dtypes.int16, back_prop=False) + outputs = math_ops.cast(outputs, self.dtype) + + medians = array_ops.squeeze(self._medians, [1, 2]) + offsets = math_ops.cast(num_levels // 2, self.dtype) - medians + outputs -= offsets[slices[:-1]] + + if not context.executing_eagerly(): + outputs_shape = ndim * [None] + outputs_shape[0] = strings.shape[0] + outputs_shape[channel_axis] = channels + outputs.set_shape(outputs_shape) + + return outputs + + def visualize(self): + """Multi-channel visualization of densities as images. + + Creates and returns an image summary visualizing the current probabilty + density estimates. The image contains one row for each channel. Within each + row, the pixel intensities are proportional to probability values, and each + row is centered on the median of the corresponding distribution. + + Returns: + The created image summary. + """ + with ops.name_scope(self._name_scope()): + image = self._pmf + image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) + image = math_ops.cast(image + .5, dtypes.uint8) + image = image[None, :, :, None] + return summary.image("pmf", image, max_outputs=1) + + def compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + return input_shape, input_shape diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py new file mode 100644 index 00000000000..798b0234ebc --- /dev/null +++ b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests of EntropyBottleneck class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.coder.python.layers import entropybottleneck + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class EntropyBottleneckTest(test.TestCase): + + def test_noise(self): + # Tests that the noise added is uniform noise between -0.5 and 0.5. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck() + noisy, _ = layer(inputs, training=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + values = np.linspace(-50, 50, 100)[:, None] + noisy, = sess.run([noisy], {inputs: values}) + self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) + self.assertAllClose(values, noisy, rtol=0, atol=.5) + + def test_quantization(self): + # Tests that inputs are quantized to full integer values, even after + # quantiles have been updated. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False) + quantized, _ = layer(inputs, training=False) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + values = np.linspace(-50, 50, 100)[:, None] + quantized, = sess.run([quantized], {inputs: values}) + self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) + + def test_quantization_optimized_offset(self): + # Tests that inputs are not quantized to full integer values after quantiles + # have been updated. However, the difference between input and output should + # be between -0.5 and 0.5, and the offset must be consistent. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True) + quantized, _ = layer(inputs, training=False) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + values = np.linspace(-50, 50, 100)[:, None] + quantized, = sess.run([quantized], {inputs: values}) + self.assertAllClose(values, quantized, rtol=0, atol=.5) + diff = np.ravel(np.around(values) - quantized) % 1 + self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) + self.assertNotEqual(diff[0], 0) + + def test_codec(self): + # Tests that inputs are compressed and decompressed correctly, and quantized + # to full integer values, even after quantiles have been updated. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=60, + optimize_integer_offset=False) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = np.linspace(-50, 50, 100)[None, :, None] + decoded, = sess.run([decoded], {inputs: values}) + self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) + + def test_codec_optimized_offset(self): + # Tests that inputs are compressed and decompressed correctly, and not + # quantized to full integer values after quantiles have been updated. + # However, the difference between input and output should be between -0.5 + # and 0.5, and the offset must be consistent. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=60, + optimize_integer_offset=True) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + opt = gradient_descent.GradientDescentOptimizer(learning_rate=1) + self.assertTrue(len(layer.losses) == 1) + step = opt.minimize(layer.losses[0]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run(step) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = np.linspace(-50, 50, 100)[None, :, None] + decoded, = sess.run([decoded], {inputs: values}) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + diff = np.ravel(np.around(values) - decoded) % 1 + self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) + self.assertNotEqual(diff[0], 0) + + def test_codec_clipping(self): + # Tests that inputs are compressed and decompressed correctly, and clipped + # to the expected range. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=40) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = np.linspace(-50, 50, 100)[None, :, None] + decoded, = sess.run([decoded], {inputs: values}) + expected = np.clip(np.around(values), -40, 40) + self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) + + def test_channels_last(self): + # Test the layer with more than one channel and multiple input dimensions, + # with the channels in the last dimension. + inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_last", init_scale=50) + noisy, _ = layer(inputs, training=True) + quantized, _ = layer(inputs, training=False) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = 5 * np.random.normal(size=(7, 5, 3, 2)) + noisy, quantized, decoded = sess.run( + [noisy, quantized, decoded], {inputs: values}) + self.assertAllClose(values, noisy, rtol=0, atol=.5) + self.assertAllClose(values, quantized, rtol=0, atol=.5) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + + def test_channels_first(self): + # Test the layer with more than one channel and multiple input dimensions, + # with the channel dimension right after the batch dimension. + inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_first", init_scale=50) + noisy, _ = layer(inputs, training=True) + quantized, _ = layer(inputs, training=False) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = 5 * np.random.normal(size=(2, 3, 5, 7)) + noisy, quantized, decoded = sess.run( + [noisy, quantized, decoded], {inputs: values}) + self.assertAllClose(values, noisy, rtol=0, atol=.5) + self.assertAllClose(values, quantized, rtol=0, atol=.5) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + + def test_compress(self): + # Test compression and decompression, and produce test data for + # `test_decompress`. If you set the constant at the end to `True`, this test + # will fail and the log will contain the new test data. + inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10)) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_first", filters=(), init_scale=2) + bitstrings = layer.compress(inputs) + decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 + bitstrings, quantized_cdf, decoded = sess.run( + [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) + self.assertAllClose(values, decoded, rtol=0, atol=.5) + # Set this constant to `True` to log new test data for `test_decompress`. + if False: # pylint:disable=using-constant-test + assert False, (bitstrings, quantized_cdf, decoded) + + # Data generated by `test_compress`. + # pylint:disable=g-inconsistent-quotes,bad-whitespace + bitstrings = np.array([ + b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', + b'\x8dF\xf0%\x1cv\xccllW' + ], dtype=object) + + quantized_cdf = np.array([ + [ 0, 15636, 22324, 30145, 38278, 65536], + [ 0, 19482, 26927, 35052, 42904, 65535], + [ 0, 21093, 28769, 36919, 44578, 65536] + ], dtype=np.int32) + + expected = np.array([ + [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], + [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], + [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], + [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], + [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], + [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] + ], dtype=np.float32) + # pylint:enable=g-inconsistent-quotes,bad-whitespace + + def test_decompress(self): + # Test that decompression of values compressed with a previous version + # works, i.e. that the file format doesn't change across revisions. + bitstrings = array_ops.placeholder(dtypes.string) + input_shape = array_ops.placeholder(dtypes.int32) + quantized_cdf = array_ops.placeholder(dtypes.int32) + layer = entropybottleneck.EntropyBottleneck( + data_format="channels_first", filters=(), dtype=dtypes.float32) + layer.build(self.expected.shape) + layer._quantized_cdf = quantized_cdf + decoded = layer.decompress(bitstrings, input_shape[1:]) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + decoded, = sess.run([decoded], { + bitstrings: self.bitstrings, input_shape: self.expected.shape, + quantized_cdf: self.quantized_cdf}) + self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) + + def test_build_decompress(self): + # Test that layer can be built when `decompress` is the first call to it. + bitstrings = array_ops.placeholder(dtypes.string) + input_shape = array_ops.placeholder(dtypes.int32, shape=[3]) + layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) + layer.decompress(bitstrings, input_shape[1:], channels=5) + self.assertTrue(layer.built) + + def test_pmf_normalization(self): + # Test that probability mass functions are normalized correctly. + layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) + layer.build((None, 10)) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + pmf, = sess.run([layer._pmf]) + self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) + + def test_visualize(self): + # Test that summary op can be constructed. + layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32) + layer.build((None, 10)) + summary = layer.visualize() + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + sess.run([summary]) + + def test_normalization(self): + # Test that densities are normalized correctly. + inputs = array_ops.placeholder(dtypes.float32, (None, 1)) + layer = entropybottleneck.EntropyBottleneck(filters=(2,)) + _, likelihood = layer(inputs, training=True) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + x = np.repeat(np.arange(-200, 201), 1000)[:, None] + likelihood, = sess.run([likelihood], {inputs: x}) + self.assertEqual(x.shape, likelihood.shape) + integral = np.sum(likelihood) * .001 + self.assertAllClose(1, integral, rtol=0, atol=1e-4) + + def test_entropy_estimates(self): + # Test that entropy estimates match actual range coding. + inputs = array_ops.placeholder(dtypes.float32, (1, None, 1)) + layer = entropybottleneck.EntropyBottleneck( + filters=(2, 3), data_format="channels_last") + _, likelihood = layer(inputs, training=True) + diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) + _, likelihood = layer(inputs, training=False) + disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2) + bitstrings = layer.compress(inputs) + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + self.assertTrue(len(layer.updates) == 1) + sess.run(layer.updates[0]) + diff_entropy, disc_entropy, bitstrings = sess.run( + [diff_entropy, disc_entropy, bitstrings], + {inputs: np.random.normal(size=(1, 10000, 1))}) + codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) + self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) + self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) + self.assertGreater(codelength, disc_entropy) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 29a593f6bcf..a56a01b1635 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,7 +24,6 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -170,12 +169,11 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) -@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant([[3]]) + x = constant_op.constant([[3.]]) y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): y_c = math_ops.matmul(y_nc, y_nc, name="compiled") @@ -200,11 +198,11 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) @@ -222,11 +220,11 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant([[1]]) + a1 = constant_op.constant([[1.]]) a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant([[1]]) + a2 = constant_op.constant([[1.]]) a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD new file mode 100644 index 00000000000..619153df67c --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/BUILD @@ -0,0 +1,91 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "constrained_optimization_pip", + deps = [ + ":constrained_optimization", + ":test_util", + ], +) + +py_library( + name = "constrained_optimization", + srcs = [ + "__init__.py", + "python/candidates.py", + "python/constrained_minimization_problem.py", + "python/constrained_optimizer.py", + "python/external_regret_optimizer.py", + "python/swap_regret_optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:standard_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_test( + name = "candidates_test", + srcs = ["python/candidates_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +# NOTE: This library can't be "testonly" since it needs to be included in the +# pip package. +py_library( + name = "test_util", + srcs = ["python/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + "//tensorflow/python:dtypes", + "//tensorflow/python:standard_ops", + ], +) + +py_test( + name = "external_regret_optimizer_test", + srcs = ["python/external_regret_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + ":test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:standard_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + +py_test( + name = "swap_regret_optimizer_test", + srcs = ["python/swap_regret_optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":constrained_optimization", + ":test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python:standard_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md new file mode 100644 index 00000000000..c65a150464e --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/README.md @@ -0,0 +1,345 @@ + + +# ConstrainedOptimization (TFCO) + +TFCO is a library for optimizing inequality-constrained problems in TensorFlow. +Both the objective function and the constraints are represented as Tensors, +giving users the maximum amount of flexibility in specifying their optimization +problems. + +This flexibility makes optimization considerably more difficult: on a non-convex +problem, if one uses the "standard" approach of introducing a Lagrange +multiplier for each constraint, and then jointly maximizing over the Lagrange +multipliers and minimizing over the model parameters, then a stable stationary +point might not even *exist*. Hence, in some cases, oscillation, instead of +convergence, is inevitable. + +Thankfully, it turns out that even if, over the course of optimization, no +*particular* iterate does a good job of minimizing the objective while +satisfying the constraints, the *sequence* of iterates, on average, usually +will. This observation suggests the following approach: at training time, we'll +periodically snapshot the model state during optimization; then, at evaluation +time, each time we're given a new example to evaluate, we'll sample one of the +saved snapshots uniformly at random, and apply it to the example. This +*stochastic model* will generally perform well, both with respect to the +objective function, and the constraints. + +In fact, we can do better: it's possible to post-process the set of snapshots to +find a distribution over at most $$m+1$$ snapshots, where $$m$$ is the number of +constraints, that will be at least as good (and will usually be much better) +than the (much larger) uniform distribution described above. If you're unable or +unwilling to use a stochastic model at all, then you can instead use a heuristic +to choose the single best snapshot. + +For full details, motivation, and theoretical results on the approach taken by +this library, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +which will be referred to as [CoJiSr18] throughout the remainder of this +document. + +### Proxy Constraints + +Imagine that we want to constrain the recall of a binary classifier to be at +least 90%. Since the recall is proportional to the number of true positive +classifications, which itself is a sum of indicator functions, this constraint +is non-differentible, and therefore cannot be used in a problem that will be +optimized using a (stochastic) gradient-based algorithm. + +For this and similar problems, TFCO supports so-called *proxy constraints*, +which are (at least semi-differentiable) approximations of the original +constraints. For example, one could create a proxy recall function by replacing +the indicator functions with sigmoids. During optimization, each proxy +constraint function will be penalized, with the magnitude of the penalty being +chosen to satisfy the corresponding *original* (non-proxy) constraint. + +On a problem including proxy constraints—even a convex problem—the +Lagrangian approach discussed above isn't guaranteed to work. However, a +different algorithm, based on minimizing *swap regret*, does work. Aside from +this difference, the recommended procedure for optimizing a proxy-constrained +problem remains the same: periodically snapshot the model during optimization, +and then either find the best $$m+1$$-sized distribution, or heuristically +choose the single best snapshot. + +## Components + +* [constrained_minimization_problem](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py): + contains the `ConstrainedMinimizationProblem` interface. Your own + constrained optimization problems should be represented using + implementations of this interface. + +* [constrained_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py): + contains the `ConstrainedOptimizer` interface, which is similar to (but + different from) `tf.train.Optimizer`, with the main difference being that + `ConstrainedOptimizer`s are given `ConstrainedMinimizationProblem`s to + optimize, and perform constrained optimization. + + * [external_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py): + contains the `AdditiveExternalRegretOptimizer` implementation, which is + a `ConstrainedOptimizer` implementing the Lagrangian approach discussed + above (with additive updates to the Lagrange multipliers). You should + use this optimizer for problems *without* proxy constraints. It may also + work for problems with proxy constraints, but we recommend using a swap + regret optimizer, instead. + + This optimizer is most similar to Algorithm 3 in Appendix C.3 of + [CoJiSr18], and is discussed in Section 3. The two differences are that + it uses proxy constraints (if they're provided) in the update of the + model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates. + + * [swap_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py): + contains the `AdditiveSwapRegretOptimizer` and + `MultiplicativeSwapRegretOptimizer` implementations, which are + `ConstrainedOptimizer`s implementing the swap-regret minimization + approach mentioned above (with additive or multiplicative updates, + respectively, to the parameters associated with the + constraints—these parameters are not Lagrange multipliers, but + play a similar role). You should use one of these optimizers (we suggest + `MultiplicativeSwapRegretOptimizer`) for problems *with* proxy + constraints. + + The `MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 + in Section 4 of [CoJiSr18], with the difference being that it uses + `tf.train.Optimizer`s, instead of SGD, for the "inner" updates. The + `AdditiveSwapRegretOptimizer` differs further in that it performs + additive (instead of multiplicative) updates of the stochastic matrix. + +* [candidates](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/candidates.py): + contains two functions, `find_best_candidate_distribution` and + `find_best_candidate_index`. Both of these functions are given a set of + candidate solutions to a constrained optimization problem, from which the + former finds the best distribution over at most $$m+1$$ candidates, and the + latter heuristically finds the single best candidate. As discussed above, + the set of candidates will typically be model snapshots saved periodically + during optimization. Both of these functions require that scipy be + installed. + + The `find_best_candidate_distribution` function implements the approach + described in Lemma 3 of [CoJiSr18], while `find_best_candidate_index` + implements the heuristic used for hyperparameter search in the experiments + of Section 5.2. + +## Convex Example with Proxy Constraints + +This is a simple example of recall-constrained optimization on simulated data: +we will try to find a classifier that minimizes the average hinge loss while +constraining recall to be at least 90%. + +We'll start with the required imports—notice the definition of `tfco`: + +```python +import math +import numpy as np +import tensorflow as tf + +tfco = tf.contrib.constrained_optimization +``` + +We'll now create an implementation of the `ConstrainedMinimizationProblem` class +for this problem. The constructor takes three parameters: a Tensor containing +the classification labels (0 or 1) for every training example, another Tensor +containing the model's predictions on every training example (sometimes called +the "logits"), and the lower bound on recall that will be enforced using a +constraint. + +This implementation will contain both constraints *and* proxy constraints: the +former represents the constraint that the true recall (defined in terms of the +*number* of true positives) be at least `recall_lower_bound`, while the latter +represents the same constraint, but on a hinge approximation of the recall. + +```python +class ExampleProblem(tfco.ConstrainedMinimizationProblem): + + def __init__(self, labels, predictions, recall_lower_bound): + self._labels = labels + self._predictions = predictions + self._recall_lower_bound = recall_lower_bound + # The number of positively-labeled examples. + self._positive_count = tf.reduce_sum(self._labels) + + @property + def objective(self): + return tf.losses.hinge_loss(labels=self._labels, logits=self._predictions) + + @property + def constraints(self): + true_positives = self._labels * tf.to_float(self._predictions > 0) + true_positive_count = tf.reduce_sum(true_positives) + recall = true_positive_count / self._positive_count + # The constraint is (recall >= self._recall_lower_bound), which we convert + # to (self._recall_lower_bound - recall <= 0) because + # ConstrainedMinimizationProblems must always provide their constraints in + # the form (tensor <= 0). + # + # The result of this function should be a tensor, with each element being + # a quantity that is constrained to be nonpositive. We only have one + # constraint, so we return a one-element tensor. + return self._recall_lower_bound - recall + + @property + def proxy_constraints(self): + # Use 1 - hinge since we're SUBTRACTING recall in the constraint function, + # and we want the proxy constraint function to be convex. + true_positives = self._labels * tf.minimum(1.0, self._predictions) + true_positive_count = tf.reduce_sum(true_positives) + recall = true_positive_count / self._positive_count + # Please see the corresponding comment in the constraints property. + return self._recall_lower_bound - recall +``` + +We'll now create a simple simulated dataset by sampling 1000 random +10-dimensional feature vectors from a Gaussian, finding their labels using a +random "ground truth" linear model, and then adding noise by randomly flipping +200 labels. + +```python +# Create a simulated 10-dimensional training dataset consisting of 1000 labeled +# examples, of which 800 are labeled correctly and 200 are mislabeled. +num_examples = 1000 +num_mislabeled_examples = 200 +dimension = 10 +# We will constrain the recall to be at least 90%. +recall_lower_bound = 0.9 + +# Create random "ground truth" parameters to a linear model. +ground_truth_weights = np.random.normal(size=dimension) / math.sqrt(dimension) +ground_truth_threshold = 0 + +# Generate a random set of features for each example. +features = np.random.normal(size=(num_examples, dimension)).astype( + np.float32) / math.sqrt(dimension) +# Compute the labels from these features given the ground truth linear model. +labels = (np.matmul(features, ground_truth_weights) > + ground_truth_threshold).astype(np.float32) +# Add noise by randomly flipping num_mislabeled_examples labels. +mislabeled_indices = np.random.choice( + num_examples, num_mislabeled_examples, replace=False) +labels[mislabeled_indices] = 1 - labels[mislabeled_indices] +``` + +We're now ready to construct our model, and the corresponding optimization +problem. We'll use a linear model of the form $$f(x) = w^T x - t$$, where $$w$$ +is the `weights`, and $$t$$ is the `threshold`. The `problem` variable will hold +an instance of the `ExampleProblem` class we created earlier. + +```python +# Create variables containing the model parameters. +weights = tf.Variable(tf.zeros(dimension), dtype=tf.float32, name="weights") +threshold = tf.Variable(0.0, dtype=tf.float32, name="threshold") + +# Create the optimization problem. +constant_labels = tf.constant(labels, dtype=tf.float32) +constant_features = tf.constant(features, dtype=tf.float32) +predictions = tf.tensordot(constant_features, weights, axes=(1, 0)) - threshold +problem = ExampleProblem( + labels=constant_labels, + predictions=predictions, + recall_lower_bound=recall_lower_bound, +) +``` + +We're almost ready to train our model, but first we'll create a couple of +functions to measure its performance. We're interested in two quantities: the +average hinge loss (which we seek to minimize), and the recall (which we +constrain). + +```python +def average_hinge_loss(labels, predictions): + num_examples, = np.shape(labels) + signed_labels = (labels * 2) - 1 + total_hinge_loss = np.sum(np.maximum(0.0, 1.0 - signed_labels * predictions)) + return total_hinge_loss / num_examples + +def recall(labels, predictions): + positive_count = np.sum(labels) + true_positives = labels * (predictions > 0) + true_positive_count = np.sum(true_positives) + return true_positive_count / positive_count +``` + +As was mentioned earlier, external regret optimizers suffice for problems +without proxy constraints, but swap regret optimizers are recommended for +problems *with* proxy constraints. Since this problem contains proxy +constraints, we use the `MultiplicativeSwapRegretOptimizer`. + +For this problem, the constraint is fairly easy to satisfy, so we can use the +same "inner" optimizer (an `AdagradOptimizer` with a learning rate of 1) for +optimization of both the model parameters (`weights` and `threshold`), and the +internal parameters associated with the constraints (these are the analogues of +the Lagrange multipliers used by the `MultiplicativeSwapRegretOptimizer`). For +more difficult problems, it will often be necessary to use different optimizers, +with different learning rates (presumably found via a hyperparameter search): to +accomplish this, pass *both* the `optimizer` and `constraint_optimizer` +parameters to `MultiplicativeSwapRegretOptimizer`'s constructor. + +Since this is a convex problem (both the objective and proxy constraint +functions are convex), we can just take the last iterate. Periodic snapshotting, +and the use of the `find_best_candidate_distribution` or +`find_best_candidate_index` functions, is generally only necessary for +non-convex problems (and even then, it isn't *always* necessary). + +```python +with tf.Session() as session: + optimizer = tfco.MultiplicativeSwapRegretOptimizer( + optimizer=tf.train.AdagradOptimizer(learning_rate=1.0)) + train_op = optimizer.minimize(problem) + + session.run(tf.global_variables_initializer()) + for ii in xrange(1000): + session.run(train_op) + + trained_weights, trained_threshold = session.run((weights, threshold)) + +trained_predictions = np.matmul(features, trained_weights) - trained_threshold +print("Constrained average hinge loss = %f" % average_hinge_loss( + labels, trained_predictions)) +print("Constrained recall = %f" % recall(labels, trained_predictions)) +``` + +Running the above code gives the following output (due to the randomness of the +dataset, you'll get a different result when you run it): + +```none +Constrained average hinge loss = 0.710019 +Constrained recall = 0.899811 +``` + +As we hoped, the recall is extremely close to 90%—and, thanks to the use +of proxy constraints, this is the *true* recall, not a hinge approximation. + +For comparison, let's try optimizing the same problem *without* the recall +constraint: + +```python +with tf.Session() as session: + optimizer = tf.train.AdagradOptimizer(learning_rate=1.0) + # For optimizing the unconstrained problem, we just minimize the "objective" + # portion of the minimization problem. + train_op = optimizer.minimize(problem.objective) + + session.run(tf.global_variables_initializer()) + for ii in xrange(1000): + session.run(train_op) + + trained_weights, trained_threshold = session.run((weights, threshold)) + +trained_predictions = np.matmul(features, trained_weights) - trained_threshold +print("Unconstrained average hinge loss = %f" % average_hinge_loss( + labels, trained_predictions)) +print("Unconstrained recall = %f" % recall(labels, trained_predictions)) +``` + +This code gives the following output (again, you'll get a different answer, +since the dataset is random): + +```none +Unconstrained average hinge loss = 0.627271 +Unconstrained recall = 0.793951 +``` + +Because there is no constraint, the unconstrained problem does a better job of +minimizing the average hinge loss, but naturally doesn't approach 90% recall. diff --git a/tensorflow/contrib/constrained_optimization/__init__.py b/tensorflow/contrib/constrained_optimization/__init__.py new file mode 100644 index 00000000000..1e49ba9f179 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A library for performing constrained optimization in TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.constrained_optimization.python.candidates import * +from tensorflow.contrib.constrained_optimization.python.constrained_minimization_problem import * +from tensorflow.contrib.constrained_optimization.python.constrained_optimizer import * +from tensorflow.contrib.constrained_optimization.python.external_regret_optimizer import * +from tensorflow.contrib.constrained_optimization.python.swap_regret_optimizer import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "AdditiveExternalRegretOptimizer", + "AdditiveSwapRegretOptimizer", + "ConstrainedMinimizationProblem", + "ConstrainedOptimizer", + "find_best_candidate_distribution", + "find_best_candidate_index", + "MultiplicativeSwapRegretOptimizer", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py new file mode 100644 index 00000000000..ac86a6741be --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/candidates.py @@ -0,0 +1,319 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Code for optimizing over a set of candidate solutions. + +The functions in this file deal with the constrained problem: + +> minimize f(w) +> s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + +Here, f(w) is the "objective function", and g_i(w) is the ith (of m) "constraint +function". Given the values of the objective and constraint functions for a set +of n "candidate solutions" {w_0,w_1,...,w_{n-1}} (for a total of n objective +function values, and n*m constraint function values), the +`find_best_candidate_distribution` function finds the best DISTRIBUTION over +these candidates, while `find_best_candidate_index' heuristically finds the +single best candidate. + +Both of these functions have dependencies on `scipy`, so if you want to call +them, then you must make sure that `scipy` is available. The imports are +performed inside the functions themselves, so if they're not actually called, +then `scipy` is not needed. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The `find_best_candidate_distribution` function implements the approach +described in Lemma 3, while `find_best_candidate_index` implements the heuristic +used for hyperparameter search in the experiments of Section 5.2. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + + +def _find_best_candidate_distribution_helper(objective_vector, + constraints_matrix, + maximum_violation=0.0): + """Finds a distribution minimizing an objective subject to constraints. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n + candidates that, in expectation, minimizes the objective while violating + the constraints by no more than `maximum_violation`. If no such distribution + exists, it returns an error (using Go-style error reporting). + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + This function will return a distribution for which at most m+1 probabilities, + and often fewer, are nonzero. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + maximum_violation: nonnegative float, the maximum amount by which any + constraint may be violated, in expectation. + + Returns: + A pair (`result`, `message`), exactly one of which is None. If `message` is + None, then the `result` contains the optimal distribution as a numpy array + of shape (n,). If `result` is None, then `message` contains an error + message. + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes, or if `maximum_violation` is negative. + ImportError: If we're unable to import `scipy.optimize`. + """ + if maximum_violation < 0.0: + raise ValueError("maximum_violation must be nonnegative") + + mm, nn = np.shape(constraints_matrix) + if (nn,) != np.shape(objective_vector): + raise ValueError( + "objective_vector must have shape (n,), and constraints_matrix (m, n)," + " where n is the number of candidates, and m is the number of " + "constraints") + + # We import scipy inline, instead of at the top of the file, so that a scipy + # dependency is only introduced if either find_best_candidate_distribution() + # or find_best_candidate_index() are actually called. + import scipy.optimize # pylint: disable=g-import-not-at-top + + # Feasibility (within maximum_violation) constraints. + a_ub = constraints_matrix + b_ub = np.full((mm, 1), maximum_violation) + # Sum-to-one constraint. + a_eq = np.ones((1, nn)) + b_eq = np.ones((1, 1)) + # Nonnegativity constraints. + bounds = (0, None) + + result = scipy.optimize.linprog( + objective_vector, + A_ub=a_ub, + b_ub=b_ub, + A_eq=a_eq, + b_eq=b_eq, + bounds=bounds) + # Go-style error reporting. We don't raise on error, since + # find_best_candidate_distribution() needs to handle the failure case, and we + # shouldn't use exceptions as flow-control. + if not result.success: + return (None, result.message) + else: + return (result.x, None) + + +def find_best_candidate_distribution(objective_vector, + constraints_matrix, + epsilon=0.0): + """Finds a distribution minimizing an objective subject to constraints. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n + candidates that, in expectation, minimizes the objective while violating + the constraints by the smallest possible amount (with the amount being found + via bisection search). + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + This function will return a distribution for which at most m+1 probabilities, + and often fewer, are nonzero. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + This function implements the approach described in Lemma 3. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + epsilon: nonnegative float, the threshold at which to terminate the binary + search while searching for the minimal expected constraint violation + magnitude. + + Returns: + The optimal distribution, as a numpy array of shape (n,). + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes, or if `epsilon` is negative. + ImportError: If we're unable to import `scipy.optimize`. + """ + if epsilon < 0.0: + raise ValueError("epsilon must be nonnegative") + + # If there is a feasible solution (i.e. with maximum_violation=0), then that's + # what we'll return. + pp, _ = _find_best_candidate_distribution_helper(objective_vector, + constraints_matrix) + if pp is not None: + return pp + + # The bound is the minimum over all candidates, of the maximum per-candidate + # constraint violation. + lower = 0.0 + upper = np.min(np.amax(constraints_matrix, axis=0)) + best_pp, _ = _find_best_candidate_distribution_helper( + objective_vector, constraints_matrix, maximum_violation=upper) + assert best_pp is not None + + # Throughout this loop, a maximum_violation of "lower" is not achievable, + # but a maximum_violation of "upper" is achiveable. + while True: + middle = 0.5 * (lower + upper) + if (middle - lower <= epsilon) or (upper - middle <= epsilon): + break + else: + pp, _ = _find_best_candidate_distribution_helper( + objective_vector, constraints_matrix, maximum_violation=middle) + if pp is None: + lower = middle + else: + best_pp = pp + upper = middle + + return best_pp + + +def find_best_candidate_index(objective_vector, + constraints_matrix, + rank_objectives=False): + """Heuristically finds the best candidate solution to a constrained problem. + + This function deals with the constrained problem: + + > minimize f(w) + > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1} + + Here, f(w) is the "objective function", and g_i(w) is the ith (of m) + "constraint function". Given a set of n "candidate solutions" + {w_0,w_1,...,w_{n-1}}, this function finds the "best" solution according + to the following heuristic: + + 1. Across all models, the ith constraint violations (i.e. max{0, g_i(0)}) + are ranked, as are the objectives (if rank_objectives=True). + 2. Each model is then associated its MAXIMUM rank across all m constraints + (and the objective, if rank_objectives=True). + 3. The model with the minimal maximum rank is then identified. Ties are + broken using the objective function value. + 4. The index of this "best" model is returned. + + The `objective_vector` parameter should be a numpy array with shape (n,), for + which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a + numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j). + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + This function implements the heuristic used for hyperparameter search in the + experiments of Section 5.2. + + Args: + objective_vector: numpy array of shape (n,), where n is the number of + "candidate solutions". Contains the objective function values. + constraints_matrix: numpy array of shape (m,n), where m is the number of + constraints and n is the number of "candidate solutions". Contains the + constraint violation magnitudes. + rank_objectives: bool, whether the objective function values should be + included in the initial ranking step. If True, both the objective and + constraints will be ranked. If False, only the constraints will be ranked. + In either case, the objective function values will be used for + tiebreaking. + + Returns: + The index (in {0,1,...,n-1}) of the "best" model according to the above + heuristic. + + Raises: + ValueError: If `objective_vector` and `constraints_matrix` have inconsistent + shapes. + ImportError: If we're unable to import `scipy.stats`. + """ + mm, nn = np.shape(constraints_matrix) + if (nn,) != np.shape(objective_vector): + raise ValueError( + "objective_vector must have shape (n,), and constraints_matrix (m, n)," + " where n is the number of candidates, and m is the number of " + "constraints") + + # We import scipy inline, instead of at the top of the file, so that a scipy + # dependency is only introduced if either find_best_candidate_distribution() + # or find_best_candidate_index() are actually called. + import scipy.stats # pylint: disable=g-import-not-at-top + + if rank_objectives: + maximum_ranks = scipy.stats.rankdata(objective_vector, method="min") + else: + maximum_ranks = np.zeros(nn, dtype=np.int64) + for ii in xrange(mm): + # Take the maximum of the constraint functions with zero, since we want to + # rank the magnitude of constraint *violations*. If the constraint is + # satisfied, then we don't care how much it's satisfied by (as a result, we + # we expect all models satisfying a constraint to be tied at rank 1). + ranks = scipy.stats.rankdata( + np.maximum(0.0, constraints_matrix[ii, :]), method="min") + maximum_ranks = np.maximum(maximum_ranks, ranks) + + best_index = None + best_rank = float("Inf") + best_objective = float("Inf") + for ii in xrange(nn): + if maximum_ranks[ii] < best_rank: + best_index = ii + best_rank = maximum_ranks[ii] + best_objective = objective_vector[ii] + elif (maximum_ranks[ii] == best_rank) and (objective_vector[ii] <= + best_objective): + best_index = ii + best_objective = objective_vector[ii] + + return best_index diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py new file mode 100644 index 00000000000..a4c49d48bc5 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.candidates.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import candidates +from tensorflow.python.platform import test + + +class CandidatesTest(test.TestCase): + + def test_inconsistent_shapes_for_best_distribution(self): + """An error is raised when parameters have inconsistent shapes.""" + objective_vector = np.array([1, 2, 3]) + constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + with self.assertRaises(ValueError): + _ = candidates.find_best_candidate_distribution(objective_vector, + constraints_matrix) + + def test_inconsistent_shapes_for_best_index(self): + """An error is raised when parameters have inconsistent shapes.""" + objective_vector = np.array([1, 2, 3]) + constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]]) + with self.assertRaises(ValueError): + _ = candidates.find_best_candidate_index(objective_vector, + constraints_matrix) + + def test_best_distribution(self): + """Distribution should match known solution.""" + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + distribution = candidates.find_best_candidate_distribution( + objective_vector, constraints_matrix) + # Verify that the solution is a probability distribution. + self.assertTrue(np.all(distribution >= 0)) + self.assertAlmostEqual(np.sum(distribution), 1.0) + # Verify that the solution satisfies the constraints. + maximum_constraint_violation = np.amax( + np.dot(constraints_matrix, distribution)) + self.assertLessEqual(maximum_constraint_violation, 0) + # Verify that the solution matches that which we expect. + expected_distribution = np.array([0.37872711, 0.62127289, 0, 0]) + self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6) + + def test_best_index_rank_objectives_true(self): + """Index should match known solution.""" + # Objective ranks = [2, 1, 4, 3]. + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]]. + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + # Maximum ranks = [4, 3, 4, 3]. + index = candidates.find_best_candidate_index( + objective_vector, constraints_matrix, rank_objectives=True) + self.assertEqual(1, index) + + def test_best_index_rank_objectives_false(self): + """Index should match known solution.""" + # Objective ranks = [2, 1, 4, 3]. + objective_vector = np.array( + [0.03053309, -0.06667082, 0.88355145, 0.46529806]) + # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]]. + constraints_matrix = np.array( + [[-0.60164551, 0.36676229, 0.7856454, -0.8441711], + [0.00371592, -0.16392108, -0.59778071, -0.56908492]]) + # Maximum ranks = [4, 3, 4, 1]. + index = candidates.find_best_candidate_index( + objective_vector, constraints_matrix, rank_objectives=False) + self.assertEqual(3, index) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py new file mode 100644 index 00000000000..70813fb2179 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines abstract class for `ConstrainedMinimizationProblem`s. + +A ConstrainedMinimizationProblem consists of an objective function to minimize, +and a set of constraint functions that are constrained to be nonpositive. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class ConstrainedMinimizationProblem(object): + """Abstract class representing a `ConstrainedMinimizationProblem`. + + A ConstrainedMinimizationProblem consists of an objective function to + minimize, and a set of constraint functions that are constrained to be + nonpositive. + + In addition to the constraint functions, there may (optionally) be proxy + constraint functions: a ConstrainedOptimizer will attempt to penalize these + proxy constraint functions so as to satisfy the (non-proxy) constraints. Proxy + constraints could be used if the constraints functions are difficult or + impossible to optimize (e.g. if they're piecewise constant), in which case the + proxy constraints should be some approximation of the original constraints + that is well-enough behaved to permit successful optimization. + """ + + @abc.abstractproperty + def objective(self): + """Returns the objective function. + + Returns: + A 0d tensor that should be minimized. + """ + pass + + @property + def num_constraints(self): + """Returns the number of constraints. + + Returns: + An int containing the number of constraints. + + Raises: + ValueError: If the constraints (or proxy_constraints, if present) do not + have fully-known shapes, OR if proxy_constraints are present, and the + shapes of constraints and proxy_constraints are fully-known, but they're + different. + """ + constraints_shape = self.constraints.get_shape() + if self.proxy_constraints is None: + proxy_constraints_shape = constraints_shape + else: + proxy_constraints_shape = self.proxy_constraints.get_shape() + + if (constraints_shape is None or proxy_constraints_shape is None or + any([ii is None for ii in constraints_shape.as_list()]) or + any([ii is None for ii in proxy_constraints_shape.as_list()])): + raise ValueError( + "constraints and proxy_constraints must have fully-known shapes") + if constraints_shape != proxy_constraints_shape: + raise ValueError( + "constraints and proxy_constraints must have the same shape") + + size = 1 + for ii in constraints_shape.as_list(): + size *= ii + return int(size) + + @abc.abstractproperty + def constraints(self): + """Returns the vector of constraint functions. + + Letting g_i be the ith element of the constraints vector, the ith constraint + will be g_i <= 0. + + Returns: + A tensor of constraint functions. + """ + pass + + # This is a property, instead of an abstract property, since it doesn't need + # to be overridden: if proxy_constraints returns None, then there are no + # proxy constraints. + @property + def proxy_constraints(self): + """Returns the optional vector of proxy constraint functions. + + The difference between `constraints` and `proxy_constraints` is that, when + proxy constraints are present, the `constraints` are merely EVALUATED during + optimization, whereas the `proxy_constraints` are DIFFERENTIATED. If there + are no proxy constraints, then the `constraints` are both evaluated and + differentiated. + + For example, if we want to impose constraints on step functions, then we + could use these functions for `constraints`. However, because a step + function has zero gradient almost everywhere, we can't differentiate these + functions, so we would take `proxy_constraints` to be some differentiable + approximation of `constraints`. + + Returns: + A tensor of proxy constraint functions. + """ + return None diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py new file mode 100644 index 00000000000..80555453661 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py @@ -0,0 +1,208 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines base class for `ConstrainedOptimizer`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.training import optimizer as train_optimizer + + +@six.add_metaclass(abc.ABCMeta) +class ConstrainedOptimizer(object): + """Base class representing a constrained optimizer. + + A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and + applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer, + which takes a tensor to minimize as a parameter to its minimize() method, a + constrained optimizer instead takes a ConstrainedMinimizationProblem. + """ + + def __init__(self, optimizer): + """Constructs a new `ConstrainedOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the + ConstraintedMinimizationProblem. + + Returns: + A new `ConstrainedOptimizer`. + """ + self._optimizer = optimizer + + @property + def optimizer(self): + """Returns the `tf.train.Optimizer` used for optimization.""" + return self._optimizer + + def minimize_unconstrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the unconstrained problem. + + Unlike `minimize_constrained`, this function ignores the `constraints` (and + `proxy_constraints`) portion of the minimization problem entirely, and only + minimizes `objective`. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + return self.optimizer.minimize( + minimization_problem.objective, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + @abc.abstractmethod + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + Unlike `minimize_unconstrained`, this function attempts to find a solution + that minimizes the `objective` portion of the minimization problem while + satisfying the `constraints` portion. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + pass + + def minimize(self, + minimization_problem, + unconstrained_steps=None, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + This method combines the functionality of `minimize_unconstrained` and + `minimize_constrained`. If global_step < unconstrained_steps, it will + perform an unconstrained update, and if global_step >= unconstrained_steps, + it will perform a constrained update. + + The reason for this functionality is that it may be best to initialize the + constrained optimizer with an approximate optimum of the unconstrained + problem. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + unconstrained_steps: int, number of steps for which we should perform + unconstrained updates, before transitioning to constrained updates. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + + Raises: + ValueError: If unconstrained_steps is provided, but global_step is not. + """ + + def unconstrained_fn(): + """Returns an `Op` for minimizing the unconstrained problem.""" + return self.minimize_unconstrained( + minimization_problem=minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + def constrained_fn(): + """Returns an `Op` for minimizing the constrained problem.""" + return self.minimize_constrained( + minimization_problem=minimization_problem, + global_step=global_step, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + name=name, + grad_loss=grad_loss) + + if unconstrained_steps is not None: + if global_step is None: + raise ValueError( + "global_step cannot be None if unconstrained_steps is provided") + unconstrained_steps_tensor = ops.convert_to_tensor(unconstrained_steps) + dtype = unconstrained_steps_tensor.dtype + return control_flow_ops.cond( + standard_ops.cast(global_step, dtype) < unconstrained_steps_tensor, + true_fn=unconstrained_fn, + false_fn=constrained_fn) + else: + return constrained_fn() diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py new file mode 100644 index 00000000000..01c6e4f08af --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py @@ -0,0 +1,375 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines `AdditiveExternalRegretOptimizer`. + +This optimizer minimizes a `ConstrainedMinimizationProblem` by introducing +Lagrange multipliers, and using `tf.train.Optimizer`s to jointly optimize over +the model parameters and Lagrange multipliers. + +For the purposes of constrained optimization, at least in theory, +external-regret minimization suffices if the `ConstrainedMinimizationProblem` +we're optimizing doesn't have any `proxy_constraints`, while swap-regret +minimization should be used if `proxy_constraints` are present. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The formulation used by the AdditiveExternalRegretOptimizer--which is simply the +usual Lagrangian formulation--can be found in Definition 1, and is discussed in +Section 3. This optimizer is most similar to Algorithm 3 in Appendix C.3, with +the two differences being that it uses proxy constraints (if they're provided) +in the update of the model parameters, and uses `tf.train.Optimizer`s, instead +of SGD, for the "inner" updates. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.constrained_optimization.python import constrained_optimizer + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer as train_optimizer + + +def _project_multipliers_wrt_euclidean_norm(multipliers, radius): + """Projects its argument onto the feasible region. + + The feasible region is the set of all vectors with nonnegative elements that + sum to at most `radius`. + + Args: + multipliers: 1d tensor, the Lagrange multipliers to project. + radius: float, the radius of the feasible region. + + Returns: + The 1d tensor that results from projecting `multipliers` onto the feasible + region w.r.t. the Euclidean norm. + + Raises: + ValueError: if the `multipliers` tensor does not have a fully-known shape, + or is not one-dimensional. + """ + multipliers_shape = multipliers.get_shape() + if multipliers_shape is None: + raise ValueError("multipliers must have known shape") + if multipliers_shape.ndims != 1: + raise ValueError( + "multipliers must be one dimensional (instead is %d-dimensional)" % + multipliers_shape.ndims) + dimension = multipliers_shape[0].value + if dimension is None: + raise ValueError("multipliers must have fully-known shape") + + def while_loop_condition(iteration, multipliers, inactive, old_inactive): + """Returns false if the while loop should terminate.""" + del multipliers # Needed by the body, but not the condition. + not_done = (iteration < dimension) + not_converged = standard_ops.reduce_any( + standard_ops.not_equal(inactive, old_inactive)) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, multipliers, inactive, old_inactive): + """Performs one iteration of the projection.""" + del old_inactive # Needed by the condition, but not the body. + iteration += 1 + scale = standard_ops.minimum( + 0.0, + (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive))) + multipliers += scale * inactive + new_inactive = standard_ops.to_float(multipliers > 0) + multipliers *= new_inactive + return (iteration, multipliers, new_inactive, inactive) + + iteration = standard_ops.constant(0) + inactive = standard_ops.ones_like(multipliers) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, multipliers, inactive, old_inactive = while_loop_body( + iteration, multipliers, inactive, inactive) + iteration, multipliers, inactive, old_inactive = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, multipliers, inactive, old_inactive), + name="euclidean_projection") + + return multipliers + + +@six.add_metaclass(abc.ABCMeta) +class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): + """Base class representing an `_ExternalRegretOptimizer`. + + This class contains most of the logic for performing constrained + optimization, minimizing external regret for the constraints player. What it + *doesn't* do is keep track of the internal state (the Lagrange multipliers). + Instead, the state is accessed via the _initial_state(), + _lagrange_multipliers(), _constraint_grad_and_var() and _projection_op() + methods. + + The reason for this is that we want to make it easy to implement different + representations of the internal state. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by `_ExternalRegretOptimizer`s--which is simply the usual + Lagrangian formulation--can be found in Definition 1, and is discussed in + Section 3. Such optimizers are most similar to Algorithm 3 in Appendix C.3. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `_ExternalRegretOptimizer`. + + The difference between `optimizer` and `constraint_optimizer` (if the latter + is provided) is that the former is used for learning the model parameters, + while the latter us used for the Lagrange multipliers. If no + `constraint_optimizer` is provided, then `optimizer` is used for both. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of the ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multipliers. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multipliers. + + Returns: + A new `_ExternalRegretOptimizer`. + """ + super(_ExternalRegretOptimizer, self).__init__(optimizer=optimizer) + self._constraint_optimizer = constraint_optimizer + + @property + def constraint_optimizer(self): + """Returns the `tf.train.Optimizer` used for the Lagrange multipliers.""" + return self._constraint_optimizer + + @abc.abstractmethod + def _initial_state(self, num_constraints): + pass + + @abc.abstractmethod + def _lagrange_multipliers(self, state): + pass + + @abc.abstractmethod + def _constraint_grad_and_var(self, state, gradient): + pass + + @abc.abstractmethod + def _projection_op(self, state, name=None): + pass + + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + The `optimizer` constructor parameter will be used to update the model + parameters, while the Lagrange multipliers will be updated using + `constrained_optimizer` (if provided) or `optimizer` (if not). + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + objective = minimization_problem.objective + + constraints = minimization_problem.constraints + proxy_constraints = minimization_problem.proxy_constraints + if proxy_constraints is None: + proxy_constraints = constraints + # Flatten both constraints tensors to 1d. + num_constraints = minimization_problem.num_constraints + constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) + proxy_constraints = standard_ops.reshape( + proxy_constraints, shape=(num_constraints,)) + + # We use a lambda to initialize the state so that, if this function call is + # inside the scope of a tf.control_dependencies() block, the dependencies + # will not be applied to the initializer. + state = standard_ops.Variable( + lambda: self._initial_state(num_constraints), + trainable=False, + name="external_regret_optimizer_state") + + multipliers = self._lagrange_multipliers(state) + loss = ( + objective + standard_ops.tensordot(multipliers, proxy_constraints, 1)) + multipliers_gradient = constraints + + update_ops = [] + if self.constraint_optimizer is None: + # If we don't have a separate constraint_optimizer, then we use + # self._optimizer for both the update of the model parameters, and that of + # the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + grads_and_vars.append( + self._constraint_grad_and_var(state, multipliers_gradient)) + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + else: + # If we have a separate constraint_optimizer, then we use self._optimizer + # for the update of the model parameters, and self._constraint_optimizer + # for that of the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + multiplier_grads_and_vars = [ + self._constraint_grad_and_var(state, multipliers_gradient) + ] + + gradients = [ + gradient for gradient, _ in grads_and_vars + multiplier_grads_and_vars + if gradient is not None + ] + with ops.control_dependencies(gradients): + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + update_ops.append( + self.constraint_optimizer.apply_gradients( + multiplier_grads_and_vars, name="optimizer_state_update")) + + with ops.control_dependencies(update_ops): + if global_step is None: + # If we don't have a global step, just project, and we're done. + return self._projection_op(state, name=name) + else: + # If we have a global step, then we need to increment it in addition to + # projecting. + projection_op = self._projection_op(state, name="project") + with ops.colocate_with(global_step): + global_step_op = state_ops.assign_add( + global_step, 1, name="global_step_increment") + return control_flow_ops.group(projection_op, global_step_op, name=name) + + +class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer): + """A `ConstrainedOptimizer` based on external-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over Lagrange multipliers, + with the latter maximization using additive updates and an algorithm that + minimizes external regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer--which is simply the usual Lagrangian + formulation--can be found in Definition 1, and is discussed in Section 3. It + is most similar to Algorithm 3 in Appendix C.3, with the two differences being + that it uses proxy constraints (if they're provided) in the update of the + model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for the + "inner" updates. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + maximum_multiplier_radius=None): + """Constructs a new `AdditiveExternalRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multipliers. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multipliers. + maximum_multiplier_radius: float, an optional upper bound to impose on the + sum of the Lagrange multipliers. + + Returns: + A new `AdditiveExternalRegretOptimizer`. + + Raises: + ValueError: If the maximum_multiplier_radius parameter is nonpositive. + """ + super(AdditiveExternalRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + if maximum_multiplier_radius and (maximum_multiplier_radius <= 0.0): + raise ValueError("maximum_multiplier_radius must be strictly positive") + + self._maximum_multiplier_radius = maximum_multiplier_radius + + def _initial_state(self, num_constraints): + # For an AdditiveExternalRegretOptimizer, the internal state is simply a + # tensor of Lagrange multipliers with shape (m,), where m is the number of + # constraints. + return standard_ops.zeros((num_constraints,), dtype=dtypes.float32) + + def _lagrange_multipliers(self, state): + return state + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + if self._maximum_multiplier_radius: + projected_multipliers = _project_multipliers_wrt_euclidean_norm( + state, self._maximum_multiplier_radius) + else: + projected_multipliers = standard_ops.maximum(state, 0.0) + return state_ops.assign(state, projected_multipliers, name=name) diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py new file mode 100644 index 00000000000..9b4bf627100 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py @@ -0,0 +1,136 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.external_regret_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import external_regret_optimizer +from tensorflow.contrib.constrained_optimization.python import test_util + +from tensorflow.python.ops import standard_ops +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class AdditiveExternalRegretOptimizerWrapper( + external_regret_optimizer.AdditiveExternalRegretOptimizer): + """Testing wrapper class around AdditiveExternalRegretOptimizer. + + This class is identical to AdditiveExternalRegretOptimizer, except that it + caches the internal optimization state when _lagrange_multipliers() is called, + so that we can test that the Lagrange multipliers take on their expected + values. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + maximum_multiplier_radius=None): + """Same as AdditiveExternalRegretOptimizer.__init__.""" + super(AdditiveExternalRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, + constraint_optimizer=constraint_optimizer, + maximum_multiplier_radius=maximum_multiplier_radius) + self._cached_lagrange_multipliers = None + + @property + def lagrange_multipliers(self): + """Returns the cached Lagrange multipliers.""" + return self._cached_lagrange_multipliers + + def _lagrange_multipliers(self, state): + """Caches the internal state for testing.""" + self._cached_lagrange_multipliers = super( + AdditiveExternalRegretOptimizerWrapper, + self)._lagrange_multipliers(state) + return self._cached_lagrange_multipliers + + +class ExternalRegretOptimizerTest(test.TestCase): + + def test_project_multipliers_wrt_euclidean_norm(self): + """Tests Euclidean projection routine on some known values.""" + multipliers1 = standard_ops.constant([-0.1, -0.6, -0.3]) + expected_projected_multipliers1 = np.array([0.0, 0.0, 0.0]) + + multipliers2 = standard_ops.constant([-0.1, 0.6, 0.3]) + expected_projected_multipliers2 = np.array([0.0, 0.6, 0.3]) + + multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1]) + expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0]) + + with self.test_session() as session: + projected_multipliers1 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers1, 1.0)) + projected_multipliers2 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers2, 1.0)) + projected_multipliers3 = session.run( + external_regret_optimizer._project_multipliers_wrt_euclidean_norm( + multipliers3, 1.0)) + + self.assertAllClose( + expected_projected_multipliers1, + projected_multipliers1, + rtol=0, + atol=1e-6) + self.assertAllClose( + expected_projected_multipliers2, + projected_multipliers2, + rtol=0, + atol=1e-6) + self.assertAllClose( + expected_projected_multipliers3, + projected_multipliers3, + rtol=0, + atol=1e-6) + + def test_additive_external_regret_optimizer(self): + """Tests that the Lagrange multipliers update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = AdditiveExternalRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0), + maximum_multiplier_radius=1.0) + train_op = optimizer.minimize_constrained(minimization_problem) + + expected_multipliers = [ + np.array([0.0, 0.0, 0.0]), + np.array([0.6, 0.0, 0.4]), + np.array([0.7, 0.0, 0.3]), + np.array([0.8, 0.0, 0.2]), + np.array([0.9, 0.0, 0.1]), + np.array([1.0, 0.0, 0.0]), + np.array([1.0, 0.0, 0.0]), + ] + + multipliers = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(multipliers) < len(expected_multipliers): + multipliers.append(session.run(optimizer.lagrange_multipliers)) + session.run(train_op) + + for expected, actual in zip(expected_multipliers, multipliers): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py new file mode 100644 index 00000000000..04014ab4aeb --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py @@ -0,0 +1,595 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines `{Additive,Multiplicative}SwapRegretOptimizer`s. + +These optimizers minimize a `ConstrainedMinimizationProblem` by using a +swap-regret minimizing algorithm (either SGD or multiplicative weights) to learn +what weights should be associated with the objective function and constraints. +These algorithms do *not* use Lagrange multipliers, but the idea is similar. +The main differences between the formulation used here, and the standard +Lagrangian formulation, are that (i) the objective function is weighted, in +addition to the constraints, and (ii) we learn a matrix of weights, instead of a +vector. + +For the purposes of constrained optimization, at least in theory, +external-regret minimization suffices if the `ConstrainedMinimizationProblem` +we're optimizing doesn't have any `proxy_constraints`, while swap-regret +minimization should be used if `proxy_constraints` are present. + +For more specifics, please refer to: + +> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex +> Constrained Optimization". +> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + +The formulation used by both of the SwapRegretOptimizers can be found in +Definition 2, and is discussed in Section 4. The +`MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4, +with the difference being that it uses `tf.train.Optimizer`s, instead of SGD, +for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in +that it performs additive (instead of multiplicative) updates of the stochastic +matrix. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import math + +import six + +from tensorflow.contrib.constrained_optimization.python import constrained_optimizer + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import optimizer as train_optimizer + + +def _maximal_eigenvector_power_method(matrix, + epsilon=1e-6, + maximum_iterations=100): + """Returns the maximal right-eigenvector of `matrix` using the power method. + + Args: + matrix: 2D Tensor, the matrix of which we will find the maximal + right-eigenvector. + epsilon: nonnegative float, if two iterations of the power method differ (in + L2 norm) by no more than epsilon, we will terminate. + maximum_iterations: nonnegative int, if we perform this many iterations, we + will terminate. + + Result: + The maximal right-eigenvector of `matrix`. + + Raises: + ValueError: If the epsilon or maximum_iterations parameters violate their + bounds. + """ + if epsilon <= 0.0: + raise ValueError("epsilon must be strictly positive") + if maximum_iterations <= 0: + raise ValueError("maximum_iterations must be strictly positive") + + def while_loop_condition(iteration, eigenvector, old_eigenvector): + """Returns false if the while loop should terminate.""" + not_done = (iteration < maximum_iterations) + not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, eigenvector, old_eigenvector): + """Performs one iteration of the power method.""" + del old_eigenvector # Needed by the condition, but not the body. + iteration += 1 + # We need to use tf.matmul() and tf.expand_dims(), instead of + # tf.tensordot(), since the former will infer the shape of the result, while + # the latter will not (tf.while_loop() needs the shapes). + new_eigenvector = standard_ops.matmul( + matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0] + new_eigenvector /= standard_ops.norm(new_eigenvector) + return (iteration, new_eigenvector, eigenvector) + + iteration = standard_ops.constant(0) + eigenvector = standard_ops.ones_like(matrix[:, 0]) + eigenvector /= standard_ops.norm(eigenvector) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, eigenvector, old_eigenvector = while_loop_body( + iteration, eigenvector, eigenvector) + iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, eigenvector, old_eigenvector), + name="power_method") + + return eigenvector + + +def _project_stochastic_matrix_wrt_euclidean_norm(matrix): + """Projects its argument onto the set of left-stochastic matrices. + + This algorithm is O(n^3) at worst, where `matrix` is n*n. It can be done in + O(n^2 * log(n)) time by sorting each column (and maybe better with a different + algorithm), but the algorithm implemented here is easier to implement in + TensorFlow. + + Args: + matrix: 2d square tensor, the matrix to project. + + Returns: + The 2d square tensor that results from projecting `matrix` onto the set of + left-stochastic matrices w.r.t. the Euclidean norm applied column-wise + (i.e. the Frobenius norm). + + Raises: + ValueError: if the `matrix` tensor does not have a fully-known shape, or is + not two-dimensional and square. + """ + matrix_shape = matrix.get_shape() + if matrix_shape is None: + raise ValueError("matrix must have known shape") + if matrix_shape.ndims != 2: + raise ValueError( + "matrix must be two dimensional (instead is %d-dimensional)" % + matrix_shape.ndims) + if matrix_shape[0] != matrix_shape[1]: + raise ValueError("matrix must be be square (instead has shape (%d,%d))" % + (matrix_shape[0], matrix_shape[1])) + dimension = matrix_shape[0].value + if dimension is None: + raise ValueError("matrix must have fully-known shape") + + def while_loop_condition(iteration, matrix, inactive, old_inactive): + """Returns false if the while loop should terminate.""" + del matrix # Needed by the body, but not the condition. + not_done = (iteration < dimension) + not_converged = standard_ops.reduce_any( + standard_ops.not_equal(inactive, old_inactive)) + return standard_ops.logical_and(not_done, not_converged) + + def while_loop_body(iteration, matrix, inactive, old_inactive): + """Performs one iteration of the projection.""" + del old_inactive # Needed by the condition, but not the body. + iteration += 1 + scale = (1.0 - standard_ops.reduce_sum( + matrix, axis=0, keep_dims=True)) / standard_ops.maximum( + 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True)) + matrix += scale * inactive + new_inactive = standard_ops.to_float(matrix > 0) + matrix *= new_inactive + return (iteration, matrix, new_inactive, inactive) + + iteration = standard_ops.constant(0) + inactive = standard_ops.ones_like(matrix) + + # We actually want a do-while loop, so we explicitly call while_loop_body() + # once before tf.while_loop(). + iteration, matrix, inactive, old_inactive = while_loop_body( + iteration, matrix, inactive, inactive) + iteration, matrix, inactive, old_inactive = control_flow_ops.while_loop( + while_loop_condition, + while_loop_body, + loop_vars=(iteration, matrix, inactive, old_inactive), + name="euclidean_projection") + + return matrix + + +def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix): + """Projects its argument onto the set of log-left-stochastic matrices. + + Args: + log_matrix: 2d square tensor, the element-wise logarithm of the matrix to + project. + + Returns: + The 2d square tensor that results from projecting exp(`matrix`) onto the set + of left-stochastic matrices w.r.t. the KL-divergence applied column-wise. + """ + + # For numerical reasons, make sure that the largest matrix element is zero + # before exponentiating. + log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True) + log_matrix -= standard_ops.log( + standard_ops.reduce_sum( + standard_ops.exp(log_matrix), axis=0, keep_dims=True)) + return log_matrix + + +@six.add_metaclass(abc.ABCMeta) +class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer): + """Base class representing a `_SwapRegretOptimizer`. + + This class contains most of the logic for performing constrained optimization, + minimizing external regret for the constraints player. What it *doesn't* do is + keep track of the internal state (the stochastic matrix). Instead, the state + is accessed via the _initial_state(), _stochastic_matrix(), + _constraint_grad_and_var() and _projection_op() methods. + + The reason for this is that we want to make it easy to implement different + representations of the internal state. For example, for additive updates, it's + most natural to store the stochastic matrix directly, whereas for + multiplicative updates, it's most natural to store its element-wise logarithm. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by `_SwapRegretOptimizer`s can be found in Definition 2, + and is discussed in Section 4. Such optimizers are most similar to Algorithm + 2 in Section 4. Most notably, the internal state is a left-stochastic matrix + of shape (m+1,m+1), where m is the number of constraints. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `_SwapRegretOptimizer`. + + The difference between `optimizer` and `constraint_optimizer` (if the latter + is provided) is that the former is used for learning the model parameters, + while the latter us used for the update to the constraint/objective weight + matrix (the analogue of Lagrange multipliers). If no `constraint_optimizer` + is provided, then `optimizer` is used for both. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + + Returns: + A new `_SwapRegretOptimizer`. + """ + super(_SwapRegretOptimizer, self).__init__(optimizer=optimizer) + self._constraint_optimizer = constraint_optimizer + + @property + def constraint_optimizer(self): + """Returns the `tf.train.Optimizer` used for the matrix.""" + return self._constraint_optimizer + + @abc.abstractmethod + def _initial_state(self, num_constraints): + pass + + @abc.abstractmethod + def _stochastic_matrix(self, state): + pass + + def _distribution(self, state): + distribution = _maximal_eigenvector_power_method( + self._stochastic_matrix(state)) + distribution = standard_ops.abs(distribution) + distribution /= standard_ops.reduce_sum(distribution) + return distribution + + @abc.abstractmethod + def _constraint_grad_and_var(self, state, gradient): + pass + + @abc.abstractmethod + def _projection_op(self, state, name=None): + pass + + def minimize_constrained(self, + minimization_problem, + global_step=None, + var_list=None, + gate_gradients=train_optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + name=None, + grad_loss=None): + """Returns an `Op` for minimizing the constrained problem. + + The `optimizer` constructor parameter will be used to update the model + parameters, while the constraint/objective weight matrix (the analogue of + Lagrange multipliers) will be updated using `constrained_optimizer` (if + provided) or `optimizer` (if not). Whether the matrix updates are additive + or multiplicative depends on the derived class. + + Args: + minimization_problem: ConstrainedMinimizationProblem, the problem to + optimize. + global_step: as in `tf.train.Optimizer`'s `minimize` method. + var_list: as in `tf.train.Optimizer`'s `minimize` method. + gate_gradients: as in `tf.train.Optimizer`'s `minimize` method. + aggregation_method: as in `tf.train.Optimizer`'s `minimize` method. + colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize` + method. + name: as in `tf.train.Optimizer`'s `minimize` method. + grad_loss: as in `tf.train.Optimizer`'s `minimize` method. + + Returns: + TensorFlow Op. + """ + objective = minimization_problem.objective + + constraints = minimization_problem.constraints + proxy_constraints = minimization_problem.proxy_constraints + if proxy_constraints is None: + proxy_constraints = constraints + # Flatten both constraints tensors to 1d. + num_constraints = minimization_problem.num_constraints + constraints = standard_ops.reshape(constraints, shape=(num_constraints,)) + proxy_constraints = standard_ops.reshape( + proxy_constraints, shape=(num_constraints,)) + + # We use a lambda to initialize the state so that, if this function call is + # inside the scope of a tf.control_dependencies() block, the dependencies + # will not be applied to the initializer. + state = standard_ops.Variable( + lambda: self._initial_state(num_constraints), + trainable=False, + name="swap_regret_optimizer_state") + + zero_and_constraints = standard_ops.concat( + (standard_ops.zeros((1,)), constraints), axis=0) + objective_and_proxy_constraints = standard_ops.concat( + (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0) + + distribution = self._distribution(state) + loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints, + 1) + matrix_gradient = standard_ops.matmul( + standard_ops.expand_dims(zero_and_constraints, 1), + standard_ops.expand_dims(distribution, 0)) + + update_ops = [] + if self.constraint_optimizer is None: + # If we don't have a separate constraint_optimizer, then we use + # self._optimizer for both the update of the model parameters, and that of + # the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + grads_and_vars.append( + self._constraint_grad_and_var(state, matrix_gradient)) + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + else: + # If we have a separate constraint_optimizer, then we use self._optimizer + # for the update of the model parameters, and self._constraint_optimizer + # for that of the internal state. + grads_and_vars = self.optimizer.compute_gradients( + loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + matrix_grads_and_vars = [ + self._constraint_grad_and_var(state, matrix_gradient) + ] + + gradients = [ + gradient for gradient, _ in grads_and_vars + matrix_grads_and_vars + if gradient is not None + ] + with ops.control_dependencies(gradients): + update_ops.append( + self.optimizer.apply_gradients(grads_and_vars, name="update")) + update_ops.append( + self.constraint_optimizer.apply_gradients( + matrix_grads_and_vars, name="optimizer_state_update")) + + with ops.control_dependencies(update_ops): + if global_step is None: + # If we don't have a global step, just project, and we're done. + return self._projection_op(state, name=name) + else: + # If we have a global step, then we need to increment it in addition to + # projecting. + projection_op = self._projection_op(state, name="project") + with ops.colocate_with(global_step): + global_step_op = state_ops.assign_add( + global_step, 1, name="global_step_increment") + return control_flow_ops.group(projection_op, global_step_op, name=name) + + +class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer): + """A `ConstrainedOptimizer` based on swap-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over constraint/objective + weight matrix (the analogue of Lagrange multipliers), with the latter + maximization using additive updates and an algorithm that minimizes swap + regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer can be found in Definition 2, and is + discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with + the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates, and performs additive (instead of multiplicative) updates + of the stochastic matrix. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Constructs a new `AdditiveSwapRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + + Returns: + A new `AdditiveSwapRegretOptimizer`. + """ + # TODO(acotter): add a parameter determining the initial values of the + # matrix elements (like initial_multiplier_radius in + # MultiplicativeSwapRegretOptimizer). + super(AdditiveSwapRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + def _initial_state(self, num_constraints): + # For an AdditiveSwapRegretOptimizer, the internal state is a tensor of + # shape (m+1,m+1), where m is the number of constraints, representing a + # left-stochastic matrix. + dimension = num_constraints + 1 + # Initialize by putting all weight on the objective, and none on the + # constraints. + return standard_ops.concat( + (standard_ops.ones( + (1, dimension)), standard_ops.zeros((dimension - 1, dimension))), + axis=0) + + def _stochastic_matrix(self, state): + return state + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + return state_ops.assign( + state, + _project_stochastic_matrix_wrt_euclidean_norm(state), + name=name) + + +class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer): + """A `ConstrainedOptimizer` based on swap-regret minimization. + + This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly + minimize over the model parameters, and maximize over constraint/objective + weight matrix (the analogue of Lagrange multipliers), with the latter + maximization using multiplicative updates and an algorithm that minimizes swap + regret. + + For more specifics, please refer to: + + > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex + > Constrained Optimization". + > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500) + + The formulation used by this optimizer can be found in Definition 2, and is + discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with + the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for + the "inner" updates. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + minimum_multiplier_radius=1e-3, + initial_multiplier_radius=None): + """Constructs a new `MultiplicativeSwapRegretOptimizer`. + + Args: + optimizer: tf.train.Optimizer, used to optimize the objective and + proxy_constraints portion of ConstrainedMinimizationProblem. If + constraint_optimizer is not provided, this will also be used to optimize + the Lagrange multiplier analogues. + constraint_optimizer: optional tf.train.Optimizer, used to optimize the + Lagrange multiplier analogues. + minimum_multiplier_radius: float, each element of the matrix will be lower + bounded by `minimum_multiplier_radius` divided by one plus the number of + constraints. + initial_multiplier_radius: float, the initial value of each element of the + matrix associated with a constraint (i.e. excluding those elements + associated with the objective) will be `initial_multiplier_radius` + divided by one plus the number of constraints. Defaults to the value of + `minimum_multiplier_radius`. + + Returns: + A new `MultiplicativeSwapRegretOptimizer`. + + Raises: + ValueError: If the two radius parameters are inconsistent. + """ + super(MultiplicativeSwapRegretOptimizer, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + + if (minimum_multiplier_radius <= 0.0) or (minimum_multiplier_radius >= 1.0): + raise ValueError("minimum_multiplier_radius must be in the range (0,1)") + if initial_multiplier_radius is None: + initial_multiplier_radius = minimum_multiplier_radius + elif (initial_multiplier_radius < + minimum_multiplier_radius) or (minimum_multiplier_radius > 1.0): + raise ValueError("initial_multiplier_radius must be in the range " + "[minimum_multiplier_radius,1]") + + self._minimum_multiplier_radius = minimum_multiplier_radius + self._initial_multiplier_radius = initial_multiplier_radius + + def _initial_state(self, num_constraints): + # For a MultiplicativeSwapRegretOptimizer, the internal state is a tensor of + # shape (m+1,m+1), where m is the number of constraints, representing the + # element-wise logarithm of a left-stochastic matrix. + dimension = num_constraints + 1 + # Initialize by putting as much weight as possible on the objective, and as + # little as possible on the constraints. + log_initial_one = math.log(1.0 - (self._initial_multiplier_radius * + (dimension - 1) / (dimension))) + log_initial_zero = math.log(self._initial_multiplier_radius / dimension) + return standard_ops.concat( + (standard_ops.constant( + log_initial_one, dtype=dtypes.float32, shape=(1, dimension)), + standard_ops.constant( + log_initial_zero, + dtype=dtypes.float32, + shape=(dimension - 1, dimension))), + axis=0) + + def _stochastic_matrix(self, state): + return standard_ops.exp(state) + + def _constraint_grad_and_var(self, state, gradient): + # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True? + return (-gradient, state) + + def _projection_op(self, state, name=None): + with ops.colocate_with(state): + # Gets the dimension of the state (num_constraints + 1)--all of these + # assertions are of things that should be impossible, since the state + # passed into this method will have the same shape as that returned by + # _initial_state(). + state_shape = state.get_shape() + assert state_shape is not None + assert state_shape.ndims == 2 + assert state_shape[0] == state_shape[1] + dimension = state_shape[0].value + assert dimension is not None + + minimum_log_multiplier = standard_ops.log( + self._minimum_multiplier_radius / standard_ops.to_float(dimension)) + + return state_ops.assign( + state, + standard_ops.maximum( + _project_log_stochastic_matrix_wrt_kl_divergence(state), + minimum_log_multiplier), + name=name) diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py new file mode 100644 index 00000000000..34c4543dca9 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py @@ -0,0 +1,212 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for constrained_optimization.python.swap_regret_optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.constrained_optimization.python import swap_regret_optimizer +from tensorflow.contrib.constrained_optimization.python import test_util + +from tensorflow.python.ops import standard_ops +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent + + +class AdditiveSwapRegretOptimizerWrapper( + swap_regret_optimizer.AdditiveSwapRegretOptimizer): + """Testing wrapper class around AdditiveSwapRegretOptimizer. + + This class is identical to AdditiveSwapRegretOptimizer, except that it caches + the internal optimization state when _stochastic_matrix() is called, so that + we can test that the stochastic matrices take on their expected values. + """ + + def __init__(self, optimizer, constraint_optimizer=None): + """Same as AdditiveSwapRegretOptimizer.__init__().""" + super(AdditiveSwapRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, constraint_optimizer=constraint_optimizer) + self._cached_stochastic_matrix = None + + @property + def stochastic_matrix(self): + """Returns the cached stochastic matrix.""" + return self._cached_stochastic_matrix + + def _stochastic_matrix(self, state): + """Caches the internal state for testing.""" + self._cached_stochastic_matrix = super(AdditiveSwapRegretOptimizerWrapper, + self)._stochastic_matrix(state) + return self._cached_stochastic_matrix + + +class MultiplicativeSwapRegretOptimizerWrapper( + swap_regret_optimizer.MultiplicativeSwapRegretOptimizer): + """Testing wrapper class around MultiplicativeSwapRegretOptimizer. + + This class is identical to MultiplicativeSwapRegretOptimizer, except that it + caches the internal optimization state when _stochastic_matrix() is called, so + that we can test that the stochastic matrices take on their expected values. + """ + + def __init__(self, + optimizer, + constraint_optimizer=None, + minimum_multiplier_radius=None, + initial_multiplier_radius=None): + """Same as MultiplicativeSwapRegretOptimizer.__init__().""" + super(MultiplicativeSwapRegretOptimizerWrapper, self).__init__( + optimizer=optimizer, + constraint_optimizer=constraint_optimizer, + minimum_multiplier_radius=1e-3, + initial_multiplier_radius=initial_multiplier_radius) + self._cached_stochastic_matrix = None + + @property + def stochastic_matrix(self): + """Returns the cached stochastic matrix.""" + return self._cached_stochastic_matrix + + def _stochastic_matrix(self, state): + """Caches the internal state for testing.""" + self._cached_stochastic_matrix = super( + MultiplicativeSwapRegretOptimizerWrapper, + self)._stochastic_matrix(state) + return self._cached_stochastic_matrix + + +class SwapRegretOptimizerTest(test.TestCase): + + def test_maximum_eigenvector_power_method(self): + """Tests power method routine on some known left-stochastic matrices.""" + matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]]) + matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]]) + + with self.test_session() as session: + eigenvector1 = session.run( + swap_regret_optimizer._maximal_eigenvector_power_method( + standard_ops.constant(matrix1))) + eigenvector2 = session.run( + swap_regret_optimizer._maximal_eigenvector_power_method( + standard_ops.constant(matrix2))) + + # Check that eigenvector1 and eigenvector2 are eigenvectors of matrix1 and + # matrix2 (respectively) with associated eigenvalue 1. + matrix_eigenvector1 = np.tensordot(matrix1, eigenvector1, axes=1) + matrix_eigenvector2 = np.tensordot(matrix2, eigenvector2, axes=1) + self.assertAllClose(eigenvector1, matrix_eigenvector1, rtol=0, atol=1e-6) + self.assertAllClose(eigenvector2, matrix_eigenvector2, rtol=0, atol=1e-6) + + def test_project_stochastic_matrix_wrt_euclidean_norm(self): + """Tests Euclidean projection routine on some known values.""" + matrix = standard_ops.constant([[-0.1, -0.1, 0.4], [-0.8, 0.4, 1.2], + [-0.3, 0.1, 0.2]]) + expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], + [0.4, 0.3, 0.0]]) + + with self.test_session() as session: + projected_matrix = session.run( + swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm( + matrix)) + + self.assertAllClose( + expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6) + + def test_project_log_stochastic_matrix_wrt_kl_divergence(self): + """Tests KL-divergence projection routine on some known values.""" + matrix = standard_ops.constant([[0.2, 0.8, 0.6], [0.1, 0.2, 1.5], + [0.2, 1.0, 0.9]]) + expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], + [0.4, 0.5, 0.3]]) + + with self.test_session() as session: + projected_matrix = session.run( + standard_ops.exp( + swap_regret_optimizer. + _project_log_stochastic_matrix_wrt_kl_divergence( + standard_ops.log(matrix)))) + + self.assertAllClose( + expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6) + + def test_additive_swap_regret_optimizer(self): + """Tests that the stochastic matrices update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = AdditiveSwapRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0)) + train_op = optimizer.minimize_constrained(minimization_problem) + + # Calculated using a numpy+python implementation of the algorithm. + expected_matrices = [ + np.array([[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), + np.array([[0.66666667, 1.0, 1.0, 1.0], [0.26666667, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.06666667, 0.0, 0.0, 0.0]]), + np.array([[0.41666667, 0.93333333, 1.0, + 0.98333333], [0.46666667, 0.05333333, 0.0, + 0.01333333], [0.0, 0.0, 0.0, 0.0], + [0.11666667, 0.01333333, 0.0, 0.00333333]]), + ] + + matrices = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(matrices) < len(expected_matrices): + matrices.append(session.run(optimizer.stochastic_matrix)) + session.run(train_op) + + for expected, actual in zip(expected_matrices, matrices): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + def test_multiplicative_swap_regret_optimizer(self): + """Tests that the stochastic matrices update as expected.""" + minimization_problem = test_util.ConstantMinimizationProblem( + np.array([0.6, -0.1, 0.4])) + optimizer = MultiplicativeSwapRegretOptimizerWrapper( + gradient_descent.GradientDescentOptimizer(1.0), + initial_multiplier_radius=0.8) + train_op = optimizer.minimize_constrained(minimization_problem) + + # Calculated using a numpy+python implementation of the algorithm. + expected_matrices = [ + np.array([[0.4, 0.4, 0.4, 0.4], [0.2, 0.2, 0.2, 0.2], + [0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]]), + np.array([[0.36999014, 0.38528351, 0.38528351, 0.38528351], [ + 0.23517483, 0.21720297, 0.21720297, 0.21720297 + ], [0.17774131, 0.18882719, 0.18882719, 0.18882719], + [0.21709373, 0.20868632, 0.20868632, 0.20868632]]), + np.array([[0.33972109, 0.36811863, 0.37118462, 0.36906575], [ + 0.27114826, 0.23738228, 0.23376693, 0.23626491 + ], [0.15712313, 0.17641793, 0.17858959, 0.17708679], + [0.23200752, 0.21808115, 0.21645886, 0.21758255]]), + ] + + matrices = [] + with self.test_session() as session: + session.run(standard_ops.global_variables_initializer()) + while len(matrices) < len(expected_matrices): + matrices.append(session.run(optimizer.stochastic_matrix)) + session.run(train_op) + + for expected, actual in zip(expected_matrices, matrices): + self.assertAllClose(expected, actual, rtol=0, atol=1e-6) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/constrained_optimization/python/test_util.py b/tensorflow/contrib/constrained_optimization/python/test_util.py new file mode 100644 index 00000000000..704b36ca4c9 --- /dev/null +++ b/tensorflow/contrib/constrained_optimization/python/test_util.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains helpers used by tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.constrained_optimization.python import constrained_minimization_problem + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import standard_ops + + +class ConstantMinimizationProblem( + constrained_minimization_problem.ConstrainedMinimizationProblem): + """A `ConstrainedMinimizationProblem` with constant constraint violations. + + This minimization problem is intended for use in performing simple tests of + the Lagrange multiplier (or equivalent) update in the optimizers. There is a + one-element "dummy" model parameter, but it should be ignored. + """ + + def __init__(self, constraints): + """Constructs a new `ConstantMinimizationProblem'. + + Args: + constraints: 1d numpy array, the constant constraint violations. + + Returns: + A new `ConstantMinimizationProblem'. + """ + # We make an fake 1-parameter linear objective so that we don't get a "no + # variables to optimize" error. + self._objective = standard_ops.Variable(0.0, dtype=dtypes.float32) + self._constraints = standard_ops.constant(constraints, dtype=dtypes.float32) + + @property + def objective(self): + """Returns the objective function.""" + return self._objective + + @property + def constraints(self): + """Returns the constant constraint violations.""" + return self._constraints diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index a5e065b93a2..74f2ec22ffa 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -152,6 +152,22 @@ class CrfTest(test.TestCase): self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + def testCrfLogNormZeroSeqLength(self): + """ + Test `crf_log_norm` when `sequence_lengths` contains one or more zeros. + """ + with self.test_session() as sess: + inputs = constant_op.constant(np.ones([2, 10, 5], + dtype=np.float32)) + transition_params = constant_op.constant(np.ones([5, 5], + dtype=np.float32)) + sequence_lengths = constant_op.constant(np.zeros([2], + dtype=np.int32)) + expected_log_norm = np.zeros([2], dtype=np.float32) + log_norm = crf.crf_log_norm(inputs, sequence_lengths, transition_params) + tf_log_norm = sess.run(log_norm) + self.assertAllClose(tf_log_norm, expected_log_norm) + def testCrfLogLikelihood(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) @@ -292,10 +308,10 @@ class CrfTest(test.TestCase): dtype=np.float32)) sequence_lengths = constant_op.constant(np.zeros([2], dtype=np.int32)) - values = crf.crf_decode(inputs, transition_params, sequence_lengths) - tags, scores = sess.run(values) - self.assertEqual(len(tags.shape), 2) - self.assertEqual(len(scores.shape), 1) + tags, scores = crf.crf_decode(inputs, transition_params, sequence_lengths) + tf_tags, tf_scores = sess.run([tags, scores]) + self.assertEqual(len(tf_tags.shape), 2) + self.assertEqual(len(tf_scores.shape), 1) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index e37c029cebf..2d2cbdc1990 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -52,6 +52,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops @@ -90,9 +91,13 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] example_inds = array_ops.reshape( math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) - return array_ops.gather_nd( + sequence_scores = array_ops.gather_nd( array_ops.squeeze(inputs, [1]), array_ops.concat([example_inds, tag_indices], axis=1)) + sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(sequence_scores), + sequence_scores) + return sequence_scores def _multi_seq_fn(): # Compute the scores of the given tag sequence. @@ -128,7 +133,12 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over # the "initial state" (the unary potentials). def _single_seq_fn(): - return math_ops.reduce_logsumexp(first_input, [1]) + log_norm = math_ops.reduce_logsumexp(first_input, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(log_norm), + log_norm) + return log_norm def _multi_seq_fn(): """Forward computation of alpha values.""" @@ -137,13 +147,21 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # Compute the alpha values in the forward algorithm in order to get the # partition function. forward_cell = CrfForwardRnnCell(transition_params) + # Sequence length is not allowed to be less than zero. + sequence_lengths_less_one = math_ops.maximum( + constant_op.constant(0, dtype=sequence_lengths.dtype), + sequence_lengths - 1) _, alphas = rnn.dynamic_rnn( cell=forward_cell, inputs=rest_of_input, - sequence_length=sequence_lengths - 1, + sequence_length=sequence_lengths_less_one, initial_state=first_input, dtype=dtypes.float32) log_norm = math_ops.reduce_logsumexp(alphas, [1]) + # Mask `log_norm` of the sequences with length <= zero. + log_norm = array_ops.where(math_ops.less_equal(sequence_lengths, 0), + array_ops.zeros_like(log_norm), + log_norm) return log_norm max_seq_len = array_ops.shape(inputs)[1] @@ -479,7 +497,7 @@ def crf_decode(potentials, transition_params, sequence_length): initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - # sequence length is not allowed to be less than zero + # Sequence length is not allowed to be less than zero. sequence_length_less_one = math_ops.maximum(0, sequence_length - 1) backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] crf_fwd_cell, diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d68015ae156..aeefa3cee62 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -25,7 +25,7 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/eager/python:checkpointable_utils", + "//tensorflow/contrib/checkpoint/python:split_dependency", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 6fb56b08587..8285ea04926 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -54,11 +54,11 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import adagrad from tensorflow.python.training import adam -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training.checkpointable import util as checkpointable_utils CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -717,7 +717,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): inputs = 3. * array_ops.ones([num_applications, num_layers, input_size], dtype=dtypes.float32) cudnn_output, _ = cudnn_layer(inputs) - status.assert_consumed().run_restore_ops() + status.run_restore_ops() second_save_path = cudnn_checkpoint.save(checkpoint_prefix) restore_layer = compatible_cell_fn() restore_layer_checkpoint = checkpointable_utils.Checkpoint( @@ -728,7 +728,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase): restore_layer_output, current_state = restore_layer( inputs=3. * array_ops.ones([1, input_size]), state=current_state) - status.assert_consumed().run_restore_ops() + status.run_restore_ops() self.assertTrue(restore_layer.variables) for variable, expected_value in zip( restore_layer.variables, expected_variable_values): @@ -1072,6 +1072,17 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase): class CudnnRNNTestTraining(test_util.TensorFlowTestCase): + def setUp(self): + super(CudnnRNNTestTraining, self).setUp() + self._reset_rnd_gen_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", + str(False)) + self._rnn_use_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0") + + def tearDown(self): + super(CudnnRNNTestTraining, self).tearDown() + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = self._reset_rnd_gen_state + os.environ["TF_CUDNN_RNN_USE_V2"] = self._rnn_use_v2 + def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): """Compute the numeric gradient of y wrt to x. @@ -1184,11 +1195,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, batch_size, seq_length, dir_count, dropout, dtype, - delta, tolerance): + use_v2, delta, tolerance): # Gradient checking runs two forward ops with almost the same input. Need to # make sure the drop patterns across the two runs are the same. logging.info("Training test with config: %s", locals()) - old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) np.random.seed(1234) @@ -1196,6 +1206,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): has_input_c = (rnn_mode == CUDNN_LSTM) direction = (CUDNN_RNN_UNIDIRECTION if dir_count == 1 else CUDNN_RNN_BIDIRECTION) + if use_v2: + os.environ["TF_CUDNN_RNN_USE_V2"] = "1" + else: + os.environ["TF_CUDNN_RNN_USE_V2"] = "0" model = CudnnTestModel( rnn_mode, num_layers, @@ -1245,22 +1259,22 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase): self._GradientCheck( sess, total_sum, all_inputs, tolerance=tolerance, delta=delta) - os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): dropouts = [0, 0.5, 1.] - for config, dropout in itertools.product(test_configs, dropouts): + v2_options = [str(False), str(True)] + for config, dropout, use_v2 in itertools.product(test_configs, dropouts, + v2_options): dtype = config.get("dtype", dtypes.float32) delta = config.get("delta", 1e-4) tolerance = config.get("tolerance", 1e-6) dir_count = config.get("dir_count", 1) shape = config["shape"] with ops.Graph().as_default(): - self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], - shape["num_units"], shape["input_size"], - shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, - tolerance) + self._TestOneSimpleTraining( + rnn_mode, shape["num_layers"], shape["num_units"], + shape["input_size"], shape["batch_size"], shape["seq_length"], + dir_count, dropout, dtype, use_v2, delta, tolerance) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 00d9544602a..d58198faf35 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -358,7 +358,8 @@ class _CudnnRNN(base_layer.Layer): "CUDA/CuDNN generations.") # Initialize opaque params with a tensor. self.kernel = vs.get_variable( - "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + "opaque_kernel", dtype=self._plain_dtype, + initializer=opaque_params_t, validate_shape=False) # Create saveable in the outer scope of the cudnn subgraph, such that # alternative subgraph with platform-independent rnn cells can load the # checkpoints directly. diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index c28c3a18e40..ed0a26bbd87 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -17,13 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.eager.python import checkpointable_utils +import os +from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras.engine import base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_cudnn_rnn_ops from tensorflow.python.ops import init_ops @@ -32,8 +33,8 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import checkpointable as checkpointable_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable_lib CUDNN_RNN_UNIDIRECTION = "unidirectional" CUDNN_RNN_BIDIRECTION = "bidirectional" @@ -318,7 +319,7 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): dependencies too (typically the cuDNN `Layer`). dtype: The dtype for the canonical parameter Tensors. """ - split_dependencies = checkpointable_utils.split_dependency( + split_dependencies = split_dependency.split_dependency( component_names=self._param_names, component_dtypes=(dtype,) * len(self._param_names), fill_save_buffer_fn=self._checkpointable_save, @@ -901,19 +902,27 @@ def _cudnn_rnn(inputs, check_direction(direction) check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) - outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( - input=inputs, - input_h=input_h, - input_c=input_c, - params=params, - is_training=is_training, - rnn_mode=rnn_mode, - input_mode=input_mode, - direction=direction, - dropout=dropout, - seed=seed, - seed2=seed2, - name=name) + # TODO(jamesqin): switch default value to "1" on May 25th 2018, and get rid + # of V1 ops. + use_cudnn_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0") + args = { + "input": inputs, + "input_h": input_h, + "input_c": input_c, + "params": params, + "is_training": is_training, + "rnn_mode": rnn_mode, + "input_mode": input_mode, + "direction": direction, + "dropout": dropout, + "seed": seed, + "seed2": seed2, + "name": name + } + if use_cudnn_v2 is not "1": + outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args) + else: + outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args) return (outputs, output_h, output_c) @@ -1640,31 +1649,6 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER -@ops.RegisterGradient("CudnnRNN") -def _cudnn_rnn_backward(op, *grad): - if not op.get_attr("is_training"): - raise ValueError( - "CudnnRNN must set is_training to True to be used in gradients") - return gen_cudnn_rnn_ops.cudnn_rnn_backprop( - input=op.inputs[0], - input_h=op.inputs[1], - input_c=op.inputs[2], - params=op.inputs[3], - output=op.outputs[0], - output_h=op.outputs[1], - output_c=op.outputs[2], - output_backprop=grad[0], - output_h_backprop=grad[1], - output_c_backprop=grad[2], - reserve_space=op.outputs[3], - dropout=op.get_attr("dropout"), - seed=op.get_attr("seed"), - seed2=op.get_attr("seed2"), - rnn_mode=op.get_attr("rnn_mode"), - input_mode=op.get_attr("input_mode"), - direction=op.get_attr("direction")) - - ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 637b1dc46cb..a25aa852510 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -23,6 +23,8 @@ removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Counter +@@CheckpointInputPipelineHook +@@CsvDataset @@SqlDataset @@assert_element_shape @@ -41,6 +43,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@prefetch_to_device @@read_batch_features @@rejection_resample +@@sample_from_datasets @@scan @@shuffle_and_repeat @@sliding_window_batch @@ -69,9 +72,12 @@ from tensorflow.contrib.data.python.ops.get_single_element import get_single_ele from tensorflow.contrib.data.python.ops.grouping import bucket_by_sequence_length from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave +from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datasets from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device +from tensorflow.contrib.data.python.ops.readers import CsvDataset from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset from tensorflow.contrib.data.python.ops.readers import make_csv_dataset from tensorflow.contrib.data.python.ops.readers import read_batch_features diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 83ada6fb67d..7b69e10441e 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -18,6 +18,27 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "directed_interleave_dataset_op", + srcs = ["directed_interleave_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "ignore_errors_dataset_op", srcs = ["ignore_errors_dataset_op.cc"], @@ -52,6 +73,8 @@ cc_library( cc_library( name = "dataset_kernels", deps = [ + ":csv_dataset_op", + ":directed_interleave_dataset_op", ":ignore_errors_dataset_op", ":prefetching_kernels", ":threadpool_dataset_op", diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc new file mode 100644 index 00000000000..76e54a284e0 --- /dev/null +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -0,0 +1,508 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/lib/io/random_inputstream.h" + +namespace tensorflow { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "na_value", &na_value)); + + std::vector record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat()(i)); + } + + std::vector select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + bool select_all_cols = select_cols.empty(); + + *output = new Dataset( + ctx, std::move(filenames), header, buffer_size, output_types_, + output_shapes_, std::move(record_defaults), std::move(select_cols), + select_all_cols, use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector filenames, bool header, + int64 buffer_size, const DataTypeVector& output_types, + const std::vector& output_shapes, + std::vector record_defaults, std::vector select_cols, + bool select_all_cols, bool use_quote_delim, char delim, + string na_value) + : GraphDatasetBase(ctx), + filenames_(std::move(filenames)), + header_(header), + buffer_size_(buffer_size), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + select_all_cols_(select_all_cols), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)) {} + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + // TODO(rachelim): Implement this + std::vector input_tensors; + TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output)); + return errors::Unimplemented("CSVDataset: AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a file, so try to read the next record + if (buffered_input_stream_) { + Status s = ReadRecord(ctx, out_tensors); + if (s.ok() || !errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement save + return errors::Unimplemented("CSVDataset: SaveInternal"); + } + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + // TODO(rachelim): Implement restore + return errors::Unimplemented("CSVDataset: RestoreInternal"); + } + + private: + // Reads a record by parsing the input buffer, and converting extracted + // fields to output tensors as we go. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // Extracts fields from line(s) from the buffered input stream. + out_tensors->reserve(dataset()->record_defaults_.size()); + + string input; + TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); + + size_t current_idx = 0; + size_t num_fields_parsed = 0; + size_t selector_idx = 0; // Keep track of index into select_cols + + while (current_idx < input.size()) { + // In each iteration, parse one field + if (input[current_idx] == '\n' || input[current_idx] == '\r') { + // This should never happen, because buffered input reader splits + // input on newlines. + return errors::InvalidArgument("Parsing error."); + } + + bool quoted = false; + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); + + if (dataset()->use_quote_delim_ && input[current_idx] == '"') { + quoted = true; + current_idx++; + } + + // Parse the body of the field + string field; + if (!quoted) { + while (current_idx < input.size() && + input[current_idx] != dataset()->delim_) { + if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || + input[current_idx] == '\n' || input[current_idx] == '\r') { + return errors::InvalidArgument( + "Unquoted fields cannot have quotes/CRLFs inside"); + } + if (include) field += input[current_idx]; + current_idx++; + } // Exit condition: end of input, or current index at delim + + // Go to next field or the end + current_idx++; + } else { + // Quoted field needs to be ended with '"' and delim or end + while (true) { + if (current_idx >= input.size() - 1 || input.empty()) { + if (current_idx == input.size() - 1 && + input[current_idx] == '"') { + // We're at the end of the input, and the quote terminates the + // record. Go to end. + current_idx++; + break; + } + // If there's no terminating quote, it means our buffered record + // line reader split a record up. This can happen if there is a + // newline encased in quotes. The next line is also part of the + // record, so we read it and reset the index. + if (include && current_idx == input.size() - 1) { + // TODO(rachelim): Instead of building up a string, keep track + // of terminal indices (or starting char* and length) + // Also look into using /lib/strings/Scanner + field += input[current_idx]; + } + if (include) { + field += '\n'; + } + current_idx = 0; + Status s = buffered_input_stream_->ReadLine(&input); + if (!s.ok()) { + return errors::InvalidArgument( + "Quoted field has to end with quote followed by delim, " + "CRLF, or EOF"); + } + } else if (input[current_idx] == '"' && + input[current_idx + 1] == dataset()->delim_) { + // End of field, go to next field or end + current_idx += 2; + break; + } else if (input[current_idx] == '"') { + // Current char is a quote. Since we're not at end of field, + // the next character must also be a quote. + if (input[current_idx + 1] != '"') { + return errors::InvalidArgument( + "Quote inside a string has to be escaped by another " + "quote"); + } + if (include) field += '"'; + current_idx += 2; + } else { + if (include) field += input[current_idx]; + current_idx++; + } + } + } + + num_fields_parsed++; + + if (include) { + // Add the tensor to the result + TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), + selector_idx, out_tensors)); + selector_idx++; + // Terminate early if we have all the fields we want + if (selector_idx == dataset()->select_cols_.size()) + return Status::OK(); + } + } // Exit condition: current_idx has reached the end of record + + // Check if the last field is empty, and include it if necessary + bool include = + (dataset()->select_all_cols_ || + dataset()->select_cols_[selector_idx] == num_fields_parsed); + if (include && !input.empty() && + input[input.size() - 1] == dataset()->delim_) { + TF_RETURN_IF_ERROR( + FieldToOutput(ctx, string(), selector_idx, out_tensors)); + } + + // Check that number of fields matches + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have ", + out_tensors->size(), " in record"); + } + return Status::OK(); + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status FieldToOutput(IteratorContext* ctx, string field, + size_t output_idx, + std::vector* out_tensors) { + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + float value; + if (!strings::safe_strtof(field.c_str(), &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + double value; + if (!strings::safe_strtod(field.c_str(), &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar()() = + dataset()->record_defaults_[output_idx].flat()(0); + } else { + component.scalar()() = std::move(field); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + input_stream_.reset( + new io::RandomAccessInputStream(file_.get(), false)); + // TODO(rachelim): Maintain our own buffer so we don't read every record + // twice + buffered_input_stream_.reset(new io::BufferedInputStream( + input_stream_.get(), dataset()->buffer_size_, false)); + if (dataset()->header_) { + // Ignore header line + string str; + Status s = buffered_input_stream_->ReadLine(&str); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument("Can't read header of empty file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + buffered_input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + std::unique_ptr input_stream_ + GUARDED_BY(mu_); + std::unique_ptr buffered_input_stream_ + GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector filenames_; + const bool header_; + const int64 buffer_size_; + const DataTypeVector out_type_; + const std::vector output_shapes_; + const std::vector record_defaults_; + const std::vector select_cols_; + const bool select_all_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc new file mode 100644 index 00000000000..48d37341625 --- /dev/null +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -0,0 +1,274 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + DatasetBase* selector_input; + OP_REQUIRES_OK(ctx, + GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); + + OP_REQUIRES( + ctx, + selector_input->output_dtypes().size() == 1 && + selector_input->output_dtypes()[0] == DT_INT64 && + selector_input->output_shapes().size() == 1 && + selector_input->output_shapes()[0].IsCompatibleWith( + PartialTensorShape({})), + errors::InvalidArgument( + "The selector input must be a dataset of scalar int64 elements.")); + + std::vector data_inputs; + for (size_t i = 1; i < ctx->num_inputs(); ++i) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); + data_inputs.push_back(input); + + OP_REQUIRES( + ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), + errors::InvalidArgument( + "All inputs must have the same output_dtypes. First input " + "has types ", + DataTypeVectorString(data_inputs[0]->output_dtypes()), + ", and input ", i - 1, " has types ", + DataTypeVectorString(input->output_dtypes()))); + } + *output = new Dataset(ctx, selector_input, std::move(data_inputs)); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, + std::vector data_inputs) + : GraphDatasetBase(ctx), + selector_input_(selector_input), + data_inputs_(std::move(data_inputs)) { + selector_input_->Ref(); + + output_shapes_ = data_inputs_[0]->output_shapes(); + data_inputs_[0]->Ref(); + for (size_t i = 1; i < data_inputs_.size(); ++i) { + const DatasetBase* data_input = data_inputs_[i]; + data_input->Ref(); + for (size_t j = 0; j < output_shapes_.size(); ++j) { + output_shapes_[j] = MostSpecificCompatibleShape( + output_shapes_[j], data_input->output_shapes()[j]); + } + } + } + + ~Dataset() override { + selector_input_->Unref(); + for (DatasetBase* data_input : data_inputs_) { + data_input->Unref(); + } + } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr(new Iterator( + {this, strings::StrCat(prefix, "::DirectedInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return data_inputs_[0]->output_dtypes(); + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { + return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* selector_input_node; + TF_RETURN_IF_ERROR( + b->AddParentDataset(ctx, selector_input_, &selector_input_node)); + std::vector data_input_nodes(data_inputs_.size()); + for (size_t i = 0; i < data_inputs_.size(); ++i) { + TF_RETURN_IF_ERROR( + b->AddParentDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + } + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, + {{1, data_input_nodes}}, {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + selector_input_impl_(params.dataset->selector_input_->MakeIterator( + params.prefix + ".selector")), + num_active_inputs_(params.dataset->data_inputs_.size()) { + data_input_impls_.reserve(params.dataset->data_inputs_.size()); + for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { + const DatasetBase* data_input = params.dataset->data_inputs_[i]; + data_input_impls_.push_back(data_input->MakeIterator( + strings::StrCat(params.prefix, "[", i, "]"))); + } + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!selector_input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + + while (true) { + std::vector selector_result; + *end_of_sequence = false; + TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( + ctx, &selector_result, end_of_sequence)); + if (*end_of_sequence) { + selector_input_impl_.reset(); + for (auto& data_input_impl : data_input_impls_) { + data_input_impl.reset(); + } + return Status::OK(); + } + + int64 selected_input = selector_result[0].scalar()(); + if (selected_input < 0 || selected_input > data_input_impls_.size()) { + return errors::InvalidArgument( + "Selector index out of range: ", selected_input, + " >= ", data_input_impls_.size()); + } + + if (data_input_impls_[selected_input]) { + bool end_of_selected_input = false; + TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( + ctx, out_tensors, &end_of_selected_input)); + + if (!end_of_selected_input) { + return Status::OK(); + } + + data_input_impls_[selected_input].reset(); + --num_active_inputs_; + + if (num_active_inputs_ == 0) { + selector_input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); + } + } + + LOG(WARNING) << "DirectedInterleave selected an exhausted input: " + << selected_input; + } + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (selector_input_impl_) { + TF_RETURN_IF_ERROR(SaveParent(writer, selector_input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const auto& data_input_impl = data_input_impls_[i]; + if (data_input_impl) { + TF_RETURN_IF_ERROR(SaveParent(writer, data_input_impl)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, selector_input_impl_)); + } else { + selector_input_impl_.reset(); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains(full_name( + strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR( + RestoreParent(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr selector_input_impl_ GUARDED_BY(mu_); + std::vector> data_input_impls_ + GUARDED_BY(mu_); + int64 num_active_inputs_ GUARDED_BY(mu_); + }; + + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); d++) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } + return output_tensorshape; + } + + const DatasetBase* const selector_input_; + const std::vector data_inputs_; + std::vector output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), + DirectedInterleaveDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index cf0a8bbccb5..f271d269ab1 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -17,6 +17,57 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("DirectedInterleaveDataset") + .Input("selector_input_dataset: variant") + .Input("data_input_datasets: N * variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("N: int >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +A substitute for `InterleaveDataset` on a fixed list of `N` datasets. + +selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines + which of the `N` data inputs should produce the next output element. +data_input_datasets: `N` datasets with the same type that will be interleaved + according to the values of `selector_input_dataset`. +)doc"); + +REGISTER_OP("CSVDataset") + .Input("filenames: string") + .Input("buffer_size: int64") + .Input("header: bool") + .Input("field_delim: string") + .Input("use_quote_delim: bool") + .Input("na_value: string") + .Input("select_cols: int64") + .Input("record_defaults: output_types") + .Output("handle: variant") + .Attr("output_types: list({float,double,int32,int64,string}) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `buffer_size`, `header`, `field_delim`, `use_quote_delim`, + // `na_value` must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + // `select_cols` must be a vector + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &unused)); + // `record_defaults` must be a list of scalars...? + for (size_t i = 7; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &unused)); + } + return shape_inference::ScalarShape(c); + }); + REGISTER_OP("IgnoreErrorsDataset") .Input("input_dataset: variant") .Output("handle: variant") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index b475c9fa6b1..2ebc80fa637 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,14 +4,16 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", size = "medium", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + ], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:batching", @@ -32,7 +34,7 @@ py_test( py_test( name = "bucketing_test", - size = "small", + size = "medium", srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ @@ -117,12 +119,28 @@ py_library( ], ) +py_test( + name = "csv_dataset_op_test", + size = "small", + srcs = ["csv_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:readers", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", srcs = ["filter_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "optonly", + ], deps = [ ":dataset_serialization_test", "//tensorflow/python:array_ops", @@ -211,7 +229,11 @@ py_test( size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "noasan", # times out + "optonly", + ], deps = [ ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:error_ops", @@ -280,6 +302,7 @@ py_test( name = "reader_dataset_ops_test", size = "medium", srcs = ["reader_dataset_ops_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -294,6 +317,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", + "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -306,15 +330,22 @@ py_test( srcs = ["resample_test.py"], shard_count = 2, srcs_version = "PY2AND3", - tags = ["noasan"], + tags = [ + "noasan", + "optonly", + ], deps = [ "//tensorflow/contrib/data/python/ops:resampling", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", ], ) @@ -333,6 +364,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", "//third_party/py/numpy", ], ) @@ -396,6 +428,7 @@ py_test( srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -471,12 +504,11 @@ py_test( ], ) -py_test( +cuda_py_test( name = "prefetching_ops_test", size = "small", srcs = ["prefetching_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", @@ -507,3 +539,23 @@ tf_py_test( "//third_party/py/numpy", ], ) + +tf_py_test( + name = "writer_ops_test", + size = "small", + srcs = ["writer_ops_test.py"], + additional_deps = [ + "//tensorflow/contrib/data/python/ops:writers", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:lib", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + ], +) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 413d8737978..2568b899d7e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -18,15 +18,18 @@ from __future__ import division from __future__ import print_function import math +import time import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -34,6 +37,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test +from tensorflow.python.util import compat class BatchDatasetTest(test.TestCase): @@ -151,6 +155,69 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(op) + def testUnbatchDatasetWithStrings(self): + data = tuple([math_ops.range(10) for _ in range(3)]) + data = dataset_ops.Dataset.from_tensor_slices(data) + data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z)) + expected_types = (dtypes.int32, dtypes.string, dtypes.int32) + data = data.batch(2) + self.assertEqual(expected_types, data.output_types) + data = data.apply(batching.unbatch()) + self.assertEqual(expected_types, data.output_types) + + iterator = data.make_one_shot_iterator() + op = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op)) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(op) + + def testUnbatchDatasetWithSparseTensor(self): + st = sparse_tensor.SparseTensorValue( + indices=[[i, i] for i in range(10)], + values=list(range(10)), + dense_shape=[10, 10]) + data = dataset_ops.Dataset.from_tensors(st) + data = data.apply(batching.unbatch()) + data = data.batch(5) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + st_row = sess.run(next_element) + self.assertEqual([i], st_row.indices) + self.assertEqual([i], st_row.values) + self.assertEqual([10], st_row.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchDatasetWithDenseAndSparseTensor(self): + st = sparse_tensor.SparseTensorValue( + indices=[[i, i] for i in range(10)], + values=list(range(10)), + dense_shape=[10, 10]) + data = dataset_ops.Dataset.from_tensors((list(range(10)), st)) + data = data.apply(batching.unbatch()) + data = data.batch(5) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(10): + dense_elem, st_row = sess.run(next_element) + self.assertEqual(i, dense_elem) + self.assertEqual([i], st_row.indices) + self.assertEqual([i], st_row.values) + self.assertEqual([10], st_row.dense_shape) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testUnbatchSingleElementTupleDataset(self): data = tuple([(math_ops.range(10),) for _ in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) @@ -191,6 +258,53 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(op) + def testUnbatchEmpty(self): + data = dataset_ops.Dataset.from_tensors( + (constant_op.constant([]), constant_op.constant([], shape=[0, 4]), + constant_op.constant([], shape=[0, 4, 0]))) + data = data.apply(batching.unbatch()) + iterator = data.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testUnbatchStaticShapeMismatch(self): + data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8), + np.arange(9))) + with self.assertRaises(ValueError): + data.apply(batching.unbatch()) + + def testUnbatchDynamicShapeMismatch(self): + ph1 = array_ops.placeholder(dtypes.int32, shape=[None]) + ph2 = array_ops.placeholder(dtypes.int32, shape=None) + data = dataset_ops.Dataset.from_tensors((ph1, ph2)) + data = data.apply(batching.unbatch()) + iterator = data.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + # Mismatch in the 0th dimension. + sess.run( + iterator.initializer, + feed_dict={ + ph1: np.arange(7).astype(np.int32), + ph2: np.arange(8).astype(np.int32) + }) + with self.assertRaises(errors.InvalidArgumentError): + print(sess.run(next_element)) + + # No 0th dimension (i.e. scalar value) for one component. + sess.run( + iterator.initializer, + feed_dict={ + ph1: np.arange(7).astype(np.int32), + ph2: 7 + }) + with self.assertRaises(errors.InvalidArgumentError): + print(sess.run(next_element)) + def testBatchAndDropRemainder(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], @@ -313,7 +427,9 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def _testMapAndBatchDatasetHelper(self, num_parallel_batches=1): + def _testMapAndBatchDatasetHelper(self, + num_parallel_calls=None, + num_parallel_batches=None): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size). @@ -332,6 +448,7 @@ class BatchDatasetTest(test.TestCase): batching.map_and_batch( map_func=_map_fn, batch_size=batch_size, + num_parallel_calls=num_parallel_calls, num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer @@ -383,12 +500,18 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def testMapAndBatchDataset(self): + def testMapAndBatch(self): return self._testMapAndBatchDatasetHelper() - def testMapAndBatchDatasetWithParallelBatching(self): + def testMapAndBatchWithParallelBatches(self): return self._testMapAndBatchDatasetHelper(num_parallel_batches=10) + def testMapAndBatchWithSequentialCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=1) + + def testMapAndBatchWithParallelCalls(self): + return self._testMapAndBatchDatasetHelper(num_parallel_calls=2) + def _testMapAndBatchPartialBatchHelper(self, drop_remainder=False): iterator = ( dataset_ops.Dataset.range(10).apply( @@ -516,9 +639,7 @@ class BatchDatasetSerializationTest( lambda x: array_ops.fill([x], x)).apply( batching.dense_to_sparse_batch(4, [12])) - # TODO(b/70988345): Re-enable when sparse tensors are properly supported by - # the DatasetSerializationTestBase. - def _testDenseToSparseBatchDatasetCore(self): + def testDenseToSparseBatchDatasetCore(self): components = np.random.randint(5, size=(40,)).astype(np.int32) diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) @@ -545,6 +666,86 @@ class BatchDatasetSerializationTest( self.run_core_tests(self._build_dataset_nested_sparse, None, 1) +class UnbatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch( + batch_size).apply(batching.unbatch()) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +class MapAndBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testNumParallelBatches(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_batches = 2 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + def testNumParallelCalls(self): + range_size = 11 + num_repeats = 2 + batch_size = 5 + total_outputs = range_size * num_repeats + num_outputs_drop_remainder = total_outputs // batch_size + num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size)) + num_parallel_calls = 7 + + def build_ds(range_start, drop_remainder=False): + + def _map_fn(x): + return math_ops.square(x) + + return dataset_ops.Dataset.range( + range_start, range_start + range_size).repeat(num_repeats).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_calls=num_parallel_calls, + drop_remainder=drop_remainder)) + + self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15), + num_outputs_keep_remainder) + self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True), + num_outputs_drop_remainder) + + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -586,10 +787,12 @@ class RestructuredDatasetTest(test.TestCase): def test_assert_element_shape(self): def create_unknown_shape_dataset(x): - return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), - [x], - [dtypes.float32, dtypes.int32]) + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), @@ -626,10 +829,12 @@ class RestructuredDatasetTest(test.TestCase): def test_assert_wrong_element_shape_on_unknown_shape_dataset(self): def create_unknown_shape_dataset(x): - return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32), - np.zeros((3, 4), dtype=np.int32)), - [x], - [dtypes.float32, dtypes.int32]) + return script_ops.py_func( + lambda _: ( # pylint: disable=g-long-lambda + np.ones(2, dtype=np.float32), + np.zeros((3, 4), dtype=np.int32)), + [x], + [dtypes.float32, dtypes.int32]) dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset) unknown_shapes = (tensor_shape.TensorShape(None), @@ -649,5 +854,77 @@ class RestructuredDatasetTest(test.TestCase): sess.run(get_next) +class UnbatchDatasetBenchmark(test.Benchmark): + + def benchmarkNativeUnbatch(self): + batch_sizes = [1, 2, 5, 10, 20, 50] + elems_per_trial = 10000 + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) + batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset.batch(batch_size_placeholder) + dataset = dataset.apply(batching.unbatch()) + dataset = dataset.skip(elems_per_trial) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for batch_size in batch_sizes: + deltas = [] + for _ in range(5): + sess.run( + iterator.initializer, + feed_dict={batch_size_placeholder: batch_size}) + start = time.time() + sess.run(next_element.op) + end = time.time() + deltas.append((end - start) / elems_per_trial) + + median_wall_time = np.median(deltas) + print("Unbatch (native) batch size: %d Median wall time per element:" + " %f microseconds" % (batch_size, median_wall_time * 1e6)) + self.report_benchmark( + iters=10000, + wall_time=median_wall_time, + name="benchmark_unbatch_dataset_native_batch_size_%d" % + batch_size) + + # Include a benchmark of the previous `unbatch()` implementation that uses + # a composition of more primitive ops. Eventually we'd hope to generate code + # that is as good in both cases. + def benchmarkOldUnbatchImplementation(self): + batch_sizes = [1, 2, 5, 10, 20, 50] + elems_per_trial = 10000 + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors("element").repeat(None) + batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) + dataset = dataset.batch(batch_size_placeholder) + dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices) + dataset = dataset.skip(elems_per_trial) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for batch_size in batch_sizes: + deltas = [] + for _ in range(5): + sess.run( + iterator.initializer, + feed_dict={batch_size_placeholder: batch_size}) + start = time.time() + sess.run(next_element.op) + end = time.time() + deltas.append((end - start) / elems_per_trial) + + median_wall_time = np.median(deltas) + print("Unbatch (unfused) batch size: %d Median wall time per element:" + " %f microseconds" % (batch_size, median_wall_time * 1e6)) + self.report_benchmark( + iters=10000, + wall_time=median_wall_time, + name="benchmark_unbatch_dataset_unfused_batch_size_%d" % + batch_size) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 6002cc73c8b..bd3e034211c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -35,6 +36,179 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test +class GroupByReducerTest(test.TestCase): + + def checkResults(self, dataset, shapes, values): + self.assertEqual(shapes, dataset.output_shapes) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + for expected in values: + got = sess.run(get_next) + self.assertEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSum(self): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer(lambda x: x % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testAverage(self): + + def reduce_fn(x, y): + return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( + x[1] + 1), x[1] + 1 + + reducer = grouping.Reducer( + init_func=lambda _: (0.0, 0.0), + reduce_func=reduce_fn, + finalize_func=lambda x: x[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer( + lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) + + def testConcat(self): + components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) + reducer = grouping.Reducer( + init_func=lambda x: "", + reduce_func=lambda x, y: x + y[0], + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensor_slices(components), + dataset_ops.Dataset.range(2 * i))).apply( + grouping.group_by_reducer(lambda x, y: y % 2, reducer)) + self.checkResults( + dataset, + shapes=tensor_shape.scalar(), + values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) + + def testSparseSum(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1], dtype=np.int64)), + dense_shape=np.array([1, 1])) + + reducer = grouping.Reducer( + init_func=lambda _: _sparse(np.int64(0)), + reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), + finalize_func=lambda x: x.values[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( + grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testChangingStateShape(self): + + def reduce_fn(x, _): + # Statically known rank, but dynamic length. + larger_dim = array_ops.concat([x[0], x[0]], 0) + # Statically unknown rank. + larger_rank = array_ops.expand_dims(x[1], 0) + return larger_dim, larger_rank + + reducer = grouping.Reducer( + init_func=lambda x: ([0], 1), + reduce_func=reduce_fn, + finalize_func=lambda x: x) + + for i in range(1, 11): + dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( + grouping.group_by_reducer(lambda x: x, reducer)) + self.assertEqual([None], dataset.output_shapes[0].as_list()) + self.assertIs(None, dataset.output_shapes[1].ndims) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual([0] * (2**i), x) + self.assertAllEqual(np.array(1, ndmin=i), y) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testTypeMismatch(self): + reducer = grouping.Reducer( + init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), + reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64(0), reducer)) + + # TODO(b/78665031): Remove once non-scalar keys are supported. + def testInvalidKeyShape(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) + + # TODO(b/78665031): Remove once non-int64 keys are supported. + def testInvalidKeyType(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: "wrong", reducer)) + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + class GroupByWindowTest(test.TestCase): def testSimple(self): @@ -61,7 +235,7 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) num_full_batches = len([c for c in counts if c == 4]) - self.assertGreaterEqual(num_full_batches, 23) + self.assertGreaterEqual(num_full_batches, 24) self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) def testImmediateOutput(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py new file mode 100644 index 00000000000..641a389c033 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -0,0 +1,378 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for CsvDatasetOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import time + +import numpy as np + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.client import session +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class CsvDatasetOpTest(test.TestCase): + + def _assert_datasets_equal(self, g, ds1, ds2): + assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' + '%s') % (ds1.output_shapes, + ds2.output_shapes) + assert ds1.output_types == ds2.output_types + assert ds1.output_classes == ds2.output_classes + next1 = ds1.make_one_shot_iterator().get_next() + next2 = ds2.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + # Run through datasets and check that outputs match, or errors match. + while True: + try: + op1 = sess.run(next1) + except (errors.OutOfRangeError, ValueError) as e: + # If op1 throws an exception, check that op2 throws same exception. + with self.assertRaises(type(e)): + sess.run(next2) + break + op2 = sess.run(next2) + self.assertAllEqual(op1, op2) + + def setup_files(self, inputs): + filenames = [] + for i, ip in enumerate(inputs): + fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) + with open(fn, 'w') as f: + f.write('\n'.join(ip)) + filenames.append(fn) + return filenames + + def _make_test_datasets(self, inputs, **kwargs): + # Test by comparing its output to what we could get with map->decode_csv + filenames = self.setup_files(inputs) + dataset_expected = core_readers.TextLineDataset(filenames) + dataset_expected = dataset_expected.map( + lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) + dataset_actual = readers.CsvDataset(filenames, **kwargs) + return (dataset_actual, dataset_expected) + + def _test_by_comparison(self, inputs, **kwargs): + """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" + with ops.Graph().as_default() as g: + dataset_actual, dataset_expected = self._make_test_datasets( + inputs, **kwargs) + self._assert_datasets_equal(g, dataset_actual, dataset_expected) + + def _test_dataset(self, + inputs, + expected_output=None, + expected_err_re=None, + **kwargs): + """Checks that elements produced by CsvDataset match expected output.""" + # Convert str type because py3 tf strings are bytestrings + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, **kwargs) + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + + def testCsvDataset_floatRequired(self): + record_defaults = [[]] * 4 + inputs = [['1,2,3,4']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_int(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_float(self): + record_defaults = [[0.0]] * 4 + inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_string(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_mixedTypes(self): + record_defaults = [ + constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.float32), + constant_op.constant([], dtype=dtypes.string), + constant_op.constant([], dtype=dtypes.float64) + ] + inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withUseQuoteDelimFalse(self): + record_defaults = [['']] * 4 + inputs = [['1,2,"3,4"', '"5,6",7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) + + def testCsvDataset_withFieldDelim(self): + record_defaults = [[0]] * 4 + inputs = [['1:2:3:4', '5:6:7:8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, field_delim=':') + + def testCsvDataset_withEmptyValues(self): + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', ',6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withNaValue(self): + record_defaults = [[0]] * 4 + inputs = [['1,NA,3,4', 'NA,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, na_value='NA') + + def testCsvDataset_withSelectCols(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, select_cols=[1, 2]) + + def testCsvDataset_withSelectColsTooHigh(self): + record_defaults = [[0]] * 2 + inputs = [['1,2,3,4', '5,6,7,8']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 1 in record', + record_defaults=record_defaults, + select_cols=[3, 4]) + + def testCsvDataset_withMultipleFiles(self): + record_defaults = [[0]] * 4 + inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withLeadingAndTrailingSpaces(self): + record_defaults = [[0.0]] * 4 + inputs = [['0, 1, 2, 3']] + expected = [[0.0, 1.0, 2.0, 3.0]] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithMissingDefault(self): + record_defaults = [[]] * 2 + inputs = [['0,']] + self._test_dataset( + inputs, + expected_err_re='Field 1 is required but missing in record!', + record_defaults=record_defaults) + + def testCsvDataset_errorWithFewerDefaultsThanFields(self): + record_defaults = [[0.0]] * 2 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have more in record', + record_defaults=record_defaults) + + def testCsvDataset_errorWithMoreDefaultsThanFields(self): + record_defaults = [[0.0]] * 5 + inputs = [['0,1,2,3']] + self._test_dataset( + inputs, + expected_err_re='Expect 5 fields but have 4 in record', + record_defaults=record_defaults) + + def testCsvDataset_withHeader(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2', '1,2']] + expected = [[1, 2]] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withHeaderAndNoRecords(self): + record_defaults = [[0]] * 2 + inputs = [['col1,col2']] + expected = [] + self._test_dataset( + inputs, + expected, + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_errorWithHeaderEmptyFile(self): + record_defaults = [[0]] * 2 + inputs = [[]] + self._test_dataset( + inputs, + expected_err_re="Can't read header of empty file", + record_defaults=record_defaults, + header=True, + ) + + def testCsvDataset_withEmptyFile(self): + record_defaults = [['']] * 2 + inputs = [['']] # Empty file + self._test_dataset( + inputs, expected_output=[], record_defaults=record_defaults) + + def testCsvDataset_errorWithEmptyRecord(self): + record_defaults = [['']] * 2 + inputs = [['', '1,2']] # First record is empty + self._test_dataset( + inputs, + expected_err_re='Expect 2 fields but have 0 in record', + record_defaults=record_defaults) + + def testCsvDataset_withChainedOps(self): + # Testing that one dataset can create multiple iterators fine. + # `repeat` creates multiple iterators from the same C++ Dataset. + record_defaults = [[0]] * 4 + inputs = [['1,,3,4', '5,6,,8']] + ds_actual, ds_expected = self._make_test_datasets( + inputs, record_defaults=record_defaults) + with ops.Graph().as_default() as g: + self._assert_datasets_equal(g, + ds_actual.repeat(5).prefetch(1), + ds_expected.repeat(5).prefetch(1)) + + def testCsvDataset_withTypeDefaults(self): + # Testing using dtypes as record_defaults for required fields + record_defaults = [dtypes.float32, dtypes.float32] + inputs = [['1.0,2.0', '3.0,4.0']] + self._test_dataset( + inputs, + [[1.0, 2.0], [3.0, 4.0]], + record_defaults=record_defaults, + ) + + +class CsvDatasetBenchmark(test.Benchmark): + """Benchmarks for the various ways of creating a dataset from CSV files. + """ + + def _setUp(self): + # Since this isn't test.TestCase, have to manually create a test dir + gfile.MakeDirs(googletest.GetTempDir()) + self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) + + self._num_cols = [4, 64, 256] + self._batch_size = 500 + self._filenames = [] + for n in self._num_cols: + fn = os.path.join(self._temp_dir, 'file%d.csv' % n) + with open(fn, 'w') as f: + # Just write 10 rows and use `repeat`... + row = ','.join(['1.23456E12' for _ in range(n)]) + f.write('\n'.join([row for _ in range(10)])) + self._filenames.append(fn) + + def _tearDown(self): + gfile.DeleteRecursively(self._temp_dir) + + def _runBenchmark(self, dataset, num_cols, prefix): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: + for _ in range(5): + sess.run(next_element) + deltas = [] + for _ in range(10): + start = time.time() + sess.run(next_element) + end = time.time() + deltas.append(end - start) + median_wall_time = np.median(deltas) / 100 + print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, + median_wall_time)) + self.report_benchmark( + iters=self._batch_size, + wall_time=median_wall_time, + name='%s_with_cols_%d' % (prefix, num_cols)) + + def benchmarkBatchThenMap(self): + self._setUp() + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + dataset = dataset.batch(self._batch_size) + self._runBenchmark(dataset, num_cols, 'csv_map_then_batch') + self._tearDown() + + def benchmarkCsvDataset(self): + self._setUp() + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [[0.0]] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + dataset = dataset.batch(self._batch_size) + self._runBenchmark(dataset, num_cols, 'csv_fused_dataset') + self._tearDown() + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index dbc35097ddd..78ecce8f7da 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -163,7 +163,7 @@ class DatasetSerializationTestBase(test.TestCase): num_outputs, sparse_tensors=False, verify_exhausted=True): - """Verifies that restoring into an already initilized iterator works. + """Verifies that restoring into an already initialized iterator works. Args: ds_fn: See `run_core_tests`. diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 256ad8d94dc..43aa4b1bd02 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -30,6 +30,7 @@ from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -94,6 +95,76 @@ class InterleaveDatasetSerializationTest( self.run_core_tests(_build_dataset, None, 20) +class ParallelInterleaveDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.input_values = np.array([4, 5, 6], dtype=np.int64) + self.num_repeats = 2 + self.num_outputs = np.sum(self.input_values) * 2 + + def _build_ds(self, cycle_length, block_length, sloppy=False): + return (dataset_ops.Dataset.from_tensor_slices( + self.input_values).repeat(self.num_repeats).apply( + interleave_ops.parallel_interleave( + lambda x: dataset_ops.Dataset.range(10 * x, 11 * x), + cycle_length, block_length, sloppy))) + + def testSerializationCore(self): + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + self.run_core_tests( + lambda: self._build_ds(cycle_length, block_length), + lambda: self._build_ds(cycle_length * 2, block_length * 1), + self.num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests(lambda: self._build_ds(cycle_length, block_length), + None, self.num_outputs) + + def testSerializationWithSloppy(self): + break_points = self.gen_break_points(self.num_outputs, 10) + expected_outputs = np.repeat( + np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]), + self.num_repeats).tolist() + + def run_test(cycle_length, block_length): + actual = self.gen_outputs( + lambda: self._build_ds(cycle_length, block_length, True), + break_points, self.num_outputs) + self.assertSequenceEqual(sorted(actual), expected_outputs) + + # cycle_length > 1, block_length > 1 + run_test(2, 3) + # cycle_length = 1 + run_test(1, 3) + # block_length = 1 + run_test(2, 1) + + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).apply( + interleave_ops.parallel_interleave(_interleave_fn, 1)) + + self.run_core_tests(_build_dataset, None, 20) + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -338,7 +409,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): """Tests where all the workers race in producing elements. - Note: this is in contrast with the prevous test which carefully sequences + Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. Args: @@ -424,7 +495,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): """Tests where all the workers race in producing elements. - Note: this is in contrast with the prevous test which carefully sequences + Note: this is in contrast with the previous test which carefully sequences the execution of the map functions. @@ -836,5 +907,114 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 10000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1075302bae9..e0237198b7d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import ops from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -256,6 +257,29 @@ class TFRecordDatasetSerializationTest( lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) +def _interleave(iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + class ReadBatchFeaturesTest(test.TestCase): def setUp(self): @@ -355,8 +379,8 @@ class ReadBatchFeaturesTest(test.TestCase): yield j, i def _next_record_interleaved(file_indices, cycle_length): - return self._interleave([_next_record([i]) for i in file_indices], - cycle_length) + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) file_batch = [] keywords_batch_indices = [] @@ -397,28 +421,6 @@ class ReadBatchFeaturesTest(test.TestCase): [len(file_batch), keywords_batch_max_len], record_batch ] - def _interleave(self, iterators, cycle_length): - pending_iterators = iterators - open_iterators = [] - num_open = 0 - for i in range(cycle_length): - if pending_iterators: - open_iterators.append(pending_iterators.pop(0)) - num_open += 1 - - while num_open: - for i in range(min(cycle_length, len(open_iterators))): - if open_iterators[i] is None: - continue - try: - yield next(open_iterators[i]) - except StopIteration: - if pending_iterators: - open_iterators[i] = pending_iterators.pop(0) - else: - open_iterators[i] = None - num_open -= 1 - def _verify_records(self, sess, batch_size, @@ -620,14 +622,12 @@ class MakeCsvDatasetTest(test.TestCase): f.close() return fn - def _create_file(self, fileno, header=True, comment=True): + def _create_file(self, fileno, header=True): rows = [] if header: rows.append(self.COLUMNS) for recno in range(self._num_records): rows.append(self._csv_values(fileno, recno)) - if comment: - rows.append("# Some comment goes here. Ignore me.") return self._write_file("csv_file%d.csv" % fileno, rows) def _create_files(self): @@ -648,9 +648,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=False, shuffle_seed=None, header=True, - comment="#", na_value="", - default_float_type=dtypes.float32, ): return readers.make_csv_dataset( filenames, @@ -662,9 +660,7 @@ class MakeCsvDatasetTest(test.TestCase): shuffle=shuffle, shuffle_seed=shuffle_seed, header=header, - comment=comment, na_value=na_value, - default_float_type=default_float_type, select_columns=select_cols, ) @@ -786,29 +782,6 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, label_name=None) - def testMakeCSVDataset_withNoComments(self): - """Tests that datasets can be created from CSV files with no header line. - """ - defaults = self.DEFAULTS - file_without_header = self._create_file( - len(self._test_filenames), comment=False) - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - dataset = self._make_csv_dataset( - file_without_header, - defaults, - batch_size=2, - num_epochs=10, - comment=None, - ) - self._verify_records( - sess, - dataset, - [len(self._test_filenames)], - batch_size=2, - num_epochs=10, - ) - def testMakeCSVDataset_withNoHeader(self): """Tests that datasets can be created from CSV files with no header line. """ @@ -876,7 +849,7 @@ class MakeCsvDatasetTest(test.TestCase): In that case, we should infer the types from the first N records. """ - # Test that it works with standard test files (with comments, header, etc) + # Test that it works with standard test files (with header, etc) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -889,7 +862,9 @@ class MakeCsvDatasetTest(test.TestCase): num_epochs=10, defaults=[[], [], [], [], [""]]) - # Test on a deliberately tricky file + def testMakeCSVDataset_withTypeInferenceTricky(self): + # Test on a deliberately tricky file (type changes as we read more rows, and + # there are null values) fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float32, @@ -914,20 +889,29 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float32, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match for i in range(len(expected_dtypes)): + print(features["col%d" % i].dtype, expected_dtypes[i]) assert features["col%d" % i].dtype == expected_dtypes[i] for i in range(len(rows)): assert sess.run(features) == dict(zip(col_names, expected[i])) - # With float64 as default type for floats + def testMakeCSVDataset_withTypeInferenceAllTypes(self): + # Test that we make the correct inference for all types with fallthrough + fn = os.path.join(self.get_temp_dir(), "file.csv") expected_dtypes = [ - dtypes.int32, dtypes.int64, dtypes.float64, dtypes.float64, + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string, dtypes.string ] + col_names = ["col%d" % i for i in range(len(expected_dtypes))] + rows = [[1, 2**31 + 1, 1.0, 4e40, "abc", ""]] + expected = [[ + 1, 2**31 + 1, 1.0, 4e40, "abc".encode("utf-8"), "".encode("utf-8") + ]] + self._write_file("file.csv", [col_names] + rows) + with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = self._make_csv_dataset( @@ -936,7 +920,6 @@ class MakeCsvDatasetTest(test.TestCase): column_names=None, label_name=None, na_value="NAN", - default_float_type=dtypes.float64, ) features = dataset.make_one_shot_iterator().get_next() # Check that types match @@ -1086,5 +1069,189 @@ class MakeCsvDatasetTest(test.TestCase): self.assertFalse(all_equal) +class MakeTFRecordDatasetTest(TFRecordDatasetTestBase): + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length, + drop_final_batch, + use_parser_fn): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return _interleave([_next_record([i]) for i in file_indices], + cycle_length) + + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for f, r in next_records: + record = self._record(f, r) + if use_parser_fn: + record = record[1:] + record_batch.append(record) + batch_index += 1 + if len(record_batch) == batch_size: + yield record_batch + record_batch = [] + batch_index = 0 + if record_batch and not drop_final_batch: + yield record_batch + + def _verify_records(self, + sess, + outputs, + batch_size, + file_index, + num_epochs, + interleave_cycle_length, + drop_final_batch, + use_parser_fn): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length, + drop_final_batch, use_parser_fn): + actual_batch = sess.run(outputs) + self.assertAllEqual(expected_batch, actual_batch) + + def _read_test(self, batch_size, num_epochs, file_index=None, + num_parallel_reads=1, drop_final_batch=False, parser_fn=False): + if file_index is None: + file_pattern = self.test_filenames + else: + file_pattern = self.test_filenames[file_index] + + if parser_fn: + fn = lambda x: string_ops.substr(x, 1, 999) + else: + fn = None + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + outputs = readers.make_tf_record_dataset( + file_pattern=file_pattern, + num_epochs=num_epochs, + batch_size=batch_size, + parser_fn=fn, + num_parallel_reads=num_parallel_reads, + drop_final_batch=drop_final_batch, + shuffle=False).make_one_shot_iterator().get_next() + self._verify_records( + sess, outputs, batch_size, file_index, num_epochs=num_epochs, + interleave_cycle_length=num_parallel_reads, + drop_final_batch=drop_final_batch, use_parser_fn=parser_fn) + with self.assertRaises(errors.OutOfRangeError): + sess.run(outputs) + + def testRead(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + # Basic test: read from file 0. + self._read_test(batch_size, num_epochs, 0) + + # Basic test: read from file 1. + self._read_test(batch_size, num_epochs, 1) + + # Basic test: read from both files. + self._read_test(batch_size, num_epochs) + + # Basic test: read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8) + + def testDropFinalBatch(self): + for batch_size in [1, 2, 10]: + for num_epochs in [1, 3]: + # Read from file 0. + self._read_test(batch_size, num_epochs, 0, drop_final_batch=True) + + # Read from both files. + self._read_test(batch_size, num_epochs, drop_final_batch=True) + + # Read from both files, with parallel reads. + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + drop_final_batch=True) + + def testParserFn(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for drop_final_batch in [False, True]: + self._read_test(batch_size, num_epochs, parser_fn=True, + drop_final_batch=drop_final_batch) + self._read_test(batch_size, num_epochs, num_parallel_reads=8, + parser_fn=True, drop_final_batch=drop_final_batch) + + def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1, + seed=None): + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.make_tf_record_dataset( + file_pattern=self.test_filenames, + num_epochs=num_epochs, + batch_size=batch_size, + num_parallel_reads=num_parallel_reads, + shuffle=True, + shuffle_seed=seed) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + sess.run(iterator.initializer) + first_batches = [] + try: + while True: + first_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + sess.run(iterator.initializer) + second_batches = [] + try: + while True: + second_batches.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + + self.assertEqual(len(first_batches), len(second_batches)) + if seed is not None: + # if you set a seed, should get the same results + for i in range(len(first_batches)): + self.assertAllEqual(first_batches[i], second_batches[i]) + + expected = [] + for f in range(self._num_files): + for r in range(self._num_records): + expected.extend([self._record(f, r)] * num_epochs) + + for batches in (first_batches, second_batches): + actual = [] + for b in batches: + actual.extend(b) + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testShuffle(self): + for batch_size in [1, 2]: + for num_epochs in [1, 3]: + for num_parallel_reads in [1, 2]: + # Test that all expected elements are produced + self._shuffle_test(batch_size, num_epochs, num_parallel_reads) + # Test that elements are produced in a consistent order if + # you specify a seed. + self._shuffle_test(batch_size, num_epochs, num_parallel_reads, + seed=21345) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 5f47dcb3399..bdc003a8a5b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -18,6 +18,9 @@ from __future__ import division from __future__ import print_function import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +import time +from absl.testing import parameterized from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.data.ops import dataset_ops @@ -30,26 +33,86 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class ResampleTest(test.TestCase): +def _time_resampling( + test_obj, data_np, target_dist, init_dist, num_to_sample): + dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat() - def testInitialKnownDistribution(self): - self._testDistribution(initial_known=True) + # Reshape distribution via rejection sampling. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist, + seed=142)) - def testInitialNotKnownDistribution(self): - self._testDistribution(initial_known=False) + get_next = dataset.make_one_shot_iterator().get_next() - def _testDistribution(self, initial_known): + with test_obj.test_session() as sess: + start_time = time.time() + for _ in xrange(num_to_sample): + sess.run(get_next) + end_time = time.time() + + return end_time - start_time + + +class ResampleTest(test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + ("InitialDistributionKnown", True), + ("InitialDistributionUnknown", False)) + def testDistribution(self, initial_known): classes = np.random.randint(5, size=(20000,)) # Uniformly sampled target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] initial_dist = [0.2] * 5 if initial_known else None - iterator = (dataset_ops.Dataset.from_tensor_slices(classes).shuffle( - 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).apply( - resampling.rejection_resample( - target_dist=target_dist, - initial_dist=initial_dist, - class_func=lambda c, _: c, - seed=27)).make_one_shot_iterator()) - get_next = iterator.get_next() + classes = math_ops.to_int64(classes) # needed for Windows build. + dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( + 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() + + get_next = dataset.apply( + resampling.rejection_resample( + target_dist=target_dist, + initial_dist=initial_dist, + class_func=lambda c, _: c, + seed=27)).make_one_shot_iterator().get_next() + + with self.test_session() as sess: + returned = [] + while len(returned) < 4000: + returned.append(sess.run(get_next)) + + returned_classes, returned_classes_and_data = zip(*returned) + _, returned_data = zip(*returned_classes_and_data) + self.assertAllEqual([compat.as_bytes(str(c)) + for c in returned_classes], returned_data) + total_returned = len(returned_classes) + class_counts = np.array([ + len([True for v in returned_classes if v == c]) + for c in range(5)]) + returned_dist = class_counts / total_returned + self.assertAllClose(target_dist, returned_dist, atol=1e-2) + + @parameterized.named_parameters( + ("OnlyInitial", True), + ("NotInitial", False)) + def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): + init_dist = [0.5, 0.5] + target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test that this works. + num_samples = 100 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + dataset = dataset_ops.Dataset.from_tensor_slices(data_np) + + # Reshape distribution. + dataset = dataset.apply( + resampling.rejection_resample( + class_func=lambda x: x, + target_dist=target_dist, + initial_dist=init_dist)) + + get_next = dataset.make_one_shot_iterator().get_next() with self.test_session() as sess: returned = [] @@ -57,25 +120,11 @@ class ResampleTest(test.TestCase): while True: returned.append(sess.run(get_next)) - returned_classes, returned_classes_and_data = zip(*returned) - _, returned_data = zip(*returned_classes_and_data) - self.assertAllEqual([compat.as_bytes(str(c)) - for c in returned_classes], returned_data) - total_returned = len(returned_classes) - # Subsampling rejects a large percentage of the initial data in - # this case. - self.assertGreater(total_returned, 20000 * 0.2) - class_counts = np.array([ - len([True for v in returned_classes if v == c]) - for c in range(5)]) - returned_dist = class_counts / total_returned - self.assertAllClose(target_dist, returned_dist, atol=1e-2) - def testRandomClasses(self): init_dist = [0.25, 0.25, 0.25, 0.25] target_dist = [0.0, 0.0, 0.0, 1.0] num_classes = len(init_dist) - # We don't need many samples to test a dirac-delta target distribution + # We don't need many samples to test a dirac-delta target distribution. num_samples = 100 data_np = np.random.choice(num_classes, num_samples, p=init_dist) @@ -109,5 +158,23 @@ class ResampleTest(test.TestCase): self.assertAllClose(target_dist, bincount, atol=1e-2) + +class ResampleDatasetBenchmark(test.Benchmark): + + def benchmarkResamplePerformance(self): + init_dist = [0.25, 0.25, 0.25, 0.25] + target_dist = [0.0, 0.0, 0.0, 1.0] + num_classes = len(init_dist) + # We don't need many samples to test a dirac-delta target distribution + num_samples = 1000 + data_np = np.random.choice(num_classes, num_samples, p=init_dist) + + resample_time = _time_resampling( + self, data_np, target_dist, init_dist, num_to_sample=1000) + + self.report_benchmark( + iters=1000, wall_time=resample_time, name="benchmark_resample") + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index e0494736b72..eb2ceff8935 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -24,24 +24,31 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test class ScanDatasetTest(test.TestCase): - def _count(self, start, step): - return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( - scan_ops.scan(start, lambda state, _: (state + step, state))) + def _counting_dataset(self, start, scan_fn): + return dataset_ops.Dataset.from_tensors(0).repeat().apply( + scan_ops.scan(start, scan_fn)) def testCount(self): + def make_scan_fn(step): + return lambda state, _: (state + step, state) + start = array_ops.placeholder(dtypes.int32, shape=[]) step = array_ops.placeholder(dtypes.int32, shape=[]) take = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = self._count(start, step).take(take).make_initializable_iterator() + iterator = self._counting_dataset( + start, make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() with self.test_session() as sess: @@ -57,19 +64,55 @@ class ScanDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + @test_util.run_in_graph_and_eager_modes() def testFibonacci(self): iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) ).make_one_shot_iterator() + + if context.executing_eagerly(): + next_element = iterator.get_next + else: + get_next = iterator.get_next() + next_element = lambda: get_next + + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(1, self.evaluate(next_element())) + self.assertEqual(2, self.evaluate(next_element())) + self.assertEqual(3, self.evaluate(next_element())) + self.assertEqual(5, self.evaluate(next_element())) + self.assertEqual(8, self.evaluate(next_element())) + + def testSparseCount(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def make_scan_fn(step): + return lambda state, _: (_sparse(state.values[0] + step), state) + + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._counting_dataset( + _sparse(start), + make_scan_fn(step)).take(take).make_initializable_iterator() next_element = iterator.get_next() with self.test_session() as sess: - self.assertEqual(1, sess.run(next_element)) - self.assertEqual(1, sess.run(next_element)) - self.assertEqual(2, sess.run(next_element)) - self.assertEqual(3, sess.run(next_element)) - self.assertEqual(5, sess.run(next_element)) - self.assertEqual(8, sess.run(next_element)) + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element).values[0]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with @@ -125,7 +168,7 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerialzationTest( +class ScanDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_dataset(self, num_elements): diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index b13ad9ba4e5..d0cb203a3af 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -48,8 +48,8 @@ class SequenceDatasetSerializationTest( self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) def testInvalidSkip(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0) def _build_take_dataset(self, count): @@ -75,8 +75,8 @@ class SequenceDatasetSerializationTest( self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) def testInvalidTake(self): - with self.assertRaisesRegexp( - ValueError, 'Shape must be rank 0 but is rank 1'): + with self.assertRaisesRegexp(ValueError, + 'Shape must be rank 0 but is rank 1'): self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0) def _build_repeat_dataset(self, count, take_count=3): diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index e26cef8ec52..4148addf287 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -22,6 +22,7 @@ import os import sqlite3 +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTest(test.TestCase): +class SqlDatasetTestBase(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): dataset = readers.SqlDataset(self.driver_name, self.data_source_name, @@ -92,6 +93,9 @@ class SqlDatasetTest(test.TestCase): conn.commit() conn.close() + +class SqlDatasetTest(SqlDatasetTestBase): + # Test that SqlDataset can read from a database table. def testReadResultSet(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, @@ -652,5 +656,27 @@ class SqlDatasetTest(test.TestCase): sess.run(get_next) +class SqlDatasetSerializationTest( + SqlDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_repeats): + data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite") + driver_name = array_ops.placeholder_with_default( + array_ops.constant("sqlite", dtypes.string), shape=[]) + query = ("SELECT first_name, last_name, motto FROM students ORDER BY " + "first_name DESC") + output_types = (dtypes.string, dtypes.string, dtypes.string) + return readers.SqlDataset(driver_name, data_source_name, query, + output_types).repeat(num_repeats) + + def testSQLSaveable(self): + num_repeats = 4 + num_outputs = num_repeats * 2 + self.run_core_tests(lambda: self._build_dataset(num_repeats), + lambda: self._build_dataset(num_repeats // 2), + num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index c3a7f291c59..5c74ed6ae72 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -50,17 +50,17 @@ class StatsDatasetTest(test.TestCase): self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) def testBytesProduced(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( - stats_ops.bytes_produced_stats("bytes_produced")) + stats_ops.bytes_produced_stats("bytes_produced")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) expected_sum = 0.0 for i in range(100): self.assertAllEqual( @@ -76,16 +76,16 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) def testLatencyStats(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) - iterator = dataset.make_initializable_iterator() stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -95,16 +95,15 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) def testReinitialize(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) - iterator = dataset.make_initializable_iterator() stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run(stats_aggregator_subscriber) for j in range(5): sess.run(iterator.initializer) for i in range(100): @@ -130,17 +129,17 @@ class StatsDatasetTest(test.TestCase): sess.run(next_element) def testMultipleTags(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")).apply( - stats_ops.latency_stats("record_latency_2")) + stats_ops.latency_stats("record_latency_2")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -154,17 +153,17 @@ class StatsDatasetTest(test.TestCase): sess.run(summary_t), "record_latency_2", 100.0) def testRepeatedTags(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( stats_ops.latency_stats("record_latency")).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) next_element = iterator.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator.initializer, stats_aggregator_subscriber]) + sess.run(iterator.initializer) for i in range(100): self.assertEqual(i, sess.run(next_element)) self._assertSummaryHasCount( @@ -174,19 +173,17 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) def testMultipleIteratorsSameAggregator(self): + stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) iterator_0 = dataset.make_initializable_iterator() iterator_1 = dataset.make_initializable_iterator() - stats_aggregator = stats_ops.StatsAggregator() - stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), - stats_aggregator.subscribe(iterator_1)] next_element = iterator_0.get_next() + iterator_1.get_next() summary_t = stats_aggregator.get_summary() with self.test_session() as sess: - sess.run([iterator_0.initializer, iterator_1.initializer, - stats_aggregator_subscribers]) + sess.run([iterator_0.initializer, iterator_1.initializer]) for i in range(100): self.assertEqual(i * 2, sess.run(next_element)) self._assertSummaryHasCount( @@ -195,20 +192,6 @@ class StatsDatasetTest(test.TestCase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) - def testMultipleStatsAggregatorsSameIteratorFail(self): - dataset = dataset_ops.Dataset.range(100).apply( - stats_ops.latency_stats("record_latency")) - iterator = dataset.make_initializable_iterator() - stats_aggregator_0 = stats_ops.StatsAggregator() - stats_aggregator_1 = stats_ops.StatsAggregator() - - with self.test_session() as sess: - sess.run(stats_aggregator_0.subscribe(iterator)) - # TODO(mrry): Consider making this allowable (and also allowing - # aggregators to unsubscribe). - with self.assertRaises(errors.FailedPreconditionError): - sess.run(stats_aggregator_1.subscribe(iterator)) - class StatsDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -269,5 +252,9 @@ class StatsDatasetSerializationTest( None, num_outputs) +# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the +# transformation `stats_ops.set_stats_aggregator`, since we don't support +# serializing StatsAggregator yet. + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py new file mode 100644 index 00000000000..c603ecc5ab2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -0,0 +1,117 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import writers +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class TFRecordWriterTest(test.TestCase): + + def setUp(self): + super(TFRecordWriterTest, self).setUp() + self._num_records = 7 + self.filename = array_ops.placeholder(dtypes.string, shape=[]) + self.compression_type = array_ops.placeholder_with_default("", shape=[]) + + input_dataset = readers.TFRecordDataset([self.filename], + self.compression_type) + self.writer = writers.TFRecordWriter( + self._outputFilename(), self.compression_type).write(input_dataset) + + def _record(self, i): + return compat.as_bytes("Record %d" % (i)) + + def _createFile(self, options=None): + filename = self._inputFilename() + writer = python_io.TFRecordWriter(filename, options) + for i in range(self._num_records): + writer.write(self._record(i)) + writer.close() + return filename + + def _inputFilename(self): + return os.path.join(self.get_temp_dir(), "tf_record.in.txt") + + def _outputFilename(self): + return os.path.join(self.get_temp_dir(), "tf_record.out.txt") + + def testWrite(self): + with self.test_session() as sess: + sess.run( + self.writer, feed_dict={ + self.filename: self._createFile(), + }) + for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): + self.assertAllEqual(self._record(i), r) + + def testWriteZLIB(self): + options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB) + with self.test_session() as sess: + sess.run( + self.writer, + feed_dict={ + self.filename: self._createFile(options), + self.compression_type: "ZLIB", + }) + for i, r in enumerate( + tf_record.tf_record_iterator(self._outputFilename(), options=options)): + self.assertAllEqual(self._record(i), r) + + def testWriteGZIP(self): + options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP) + with self.test_session() as sess: + sess.run( + self.writer, + feed_dict={ + self.filename: self._createFile(options), + self.compression_type: "GZIP", + }) + for i, r in enumerate( + tf_record.tf_record_iterator(self._outputFilename(), options=options)): + self.assertAllEqual(self._record(i), r) + + def testFailDataset(self): + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write("whoops") + + def testFailDType(self): + input_dataset = dataset_ops.Dataset.from_tensors(10) + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write(input_dataset) + + def testFailShape(self): + input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]]) + with self.assertRaises(TypeError): + writers.TFRecordWriter(self._outputFilename(), + self.compression_type).write(input_dataset) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 0e4590829b1..eceecfd1744 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -45,6 +45,27 @@ py_library( "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", ], ) @@ -172,8 +193,18 @@ py_library( srcs = ["interleave_ops.py"], srcs_version = "PY2AND3", deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + ":random_ops", + "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", ], ) @@ -183,6 +214,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", + ":interleave_ops", ":scan_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -192,6 +224,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) @@ -270,6 +303,18 @@ py_library( ], ) +py_library( + name = "writers", + srcs = [ + "writers.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:dtypes", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + tf_gen_op_wrapper_py( name = "gen_dataset_ops", out = "gen_dataset_ops.py", @@ -332,6 +377,7 @@ py_library( ":stats_ops", ":threadpool", ":unique", + ":writers", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 1eba010b562..b9393de4e90 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -80,28 +80,98 @@ def dense_to_sparse_batch(batch_size, row_shape): return _apply_fn +class UnbatchDataset(dataset_ops.Dataset): + """A dataset that splits the elements of its input into multiple elements.""" + + def __init__(self, input_dataset): + """See `unbatch()` for more details.""" + super(UnbatchDataset, self).__init__() + flat_shapes = nest.flatten(input_dataset.output_shapes) + if any(s.ndims == 0 for s in flat_shapes): + raise ValueError("Cannot unbatch an input with scalar components.") + known_batch_dim = tensor_shape.Dimension(None) + for s in flat_shapes: + try: + known_batch_dim = known_batch_dim.merge_with(s[0]) + except ValueError: + raise ValueError("Cannot unbatch an input whose components have " + "different batch sizes.") + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.unbatch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return nest.map_structure(lambda s: s[1:], + self._input_dataset.output_shapes) + + @property + def output_types(self): + return self._input_dataset.output_types + + def unbatch(): - """A Transformation which splits the elements of a dataset. + """Splits elements of a dataset into multiple elements on the batch dimension. For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, - where `B` may vary from element to element, then for each element in - the dataset, the unbatched dataset will contain `B` consecutive elements + where `B` may vary for each input element, then for each element in the + dataset, the unbatched dataset will contain `B` consecutive elements of shape `[a0, a1, ...]`. + ```python + # NOTE: The following example uses `{ ... }` to represent the contents + # of a dataset. + a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } + + a.apply(tf.contrib.data.unbatch()) == { + 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'} + ``` + Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + if not sparse.any_sparse(dataset.output_classes): + return UnbatchDataset(dataset) - def unbatch_map(arg, *rest): + # NOTE(mrry): We must ensure that any SparseTensors in `dataset` + # are normalized to the rank-1 dense representation, so that the + # sparse-oblivious unbatching logic will slice them + # appropriately. This leads to a somewhat inefficient re-encoding step + # for all SparseTensor components. + # TODO(mrry): Consider optimizing this in future + # if it turns out to be a bottleneck. + def normalize(arg, *rest): if rest: - return dataset_ops.Dataset.from_tensor_slices((arg,) + rest) + return sparse.serialize_many_sparse_tensors((arg,) + rest) else: - return dataset_ops.Dataset.from_tensor_slices(arg) + return sparse.serialize_many_sparse_tensors(arg) - return dataset.flat_map(map_func=unbatch_map) + normalized_dataset = dataset.map(normalize) + + # NOTE(mrry): Our `map()` has lost information about the sparseness + # of any SparseTensor components, so re-apply the structure of the + # original dataset. + restructured_dataset = _RestructuredDataset( + normalized_dataset, + dataset.output_types, + dataset.output_shapes, + dataset.output_classes, + allow_unsafe_cast=True) + return UnbatchDataset(restructured_dataset) return _apply_fn @@ -265,7 +335,8 @@ class _RestructuredDataset(dataset_ops.Dataset): dataset, output_types, output_shapes=None, - output_classes=None): + output_classes=None, + allow_unsafe_cast=False): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: @@ -283,22 +354,27 @@ class _RestructuredDataset(dataset_ops.Dataset): If omitted, the shapes will be inherited from `dataset`. output_classes: (Optional.) A nested structure of class types. If omitted, the class types will be inherited from `dataset`. + allow_unsafe_cast: (Optional.) If `True`, the caller may switch the + reported output types and shapes of the restructured dataset, e.g. to + switch a sparse tensor represented as `tf.variant` to its user-visible + type and shape. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible with the structure of `dataset`. """ super(_RestructuredDataset, self).__init__() - self._dataset = dataset + self._input_dataset = dataset - # Validate that the types are compatible. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - flat_original_types = nest.flatten(dataset.output_types) - flat_new_types = nest.flatten(output_types) - if flat_original_types != flat_new_types: - raise ValueError( - "Dataset with output types %r cannot be restructured to have output " - "types %r" % (dataset.output_types, output_types)) + if not allow_unsafe_cast: + # Validate that the types are compatible. + output_types = nest.map_structure(dtypes.as_dtype, output_types) + flat_original_types = nest.flatten(dataset.output_types) + flat_new_types = nest.flatten(output_types) + if flat_original_types != flat_new_types: + raise ValueError( + "Dataset with output types %r cannot be restructured to have " + "output types %r" % (dataset.output_types, output_types)) self._output_types = output_types @@ -308,18 +384,19 @@ class _RestructuredDataset(dataset_ops.Dataset): nest.flatten( dataset.output_shapes)) else: - # Validate that the shapes are compatible. - nest.assert_same_structure(output_types, output_shapes) - flat_original_shapes = nest.flatten(dataset.output_shapes) - flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) + if not allow_unsafe_cast: + # Validate that the shapes are compatible. + nest.assert_same_structure(output_types, output_shapes) + flat_original_shapes = nest.flatten(dataset.output_shapes) + flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes): - if not original_shape.is_compatible_with(new_shape): - raise ValueError( - "Dataset with output shapes %r cannot be restructured to have " - "incompatible output shapes %r" % (dataset.output_shapes, - output_shapes)) + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes): + if not original_shape.is_compatible_with(new_shape): + raise ValueError( + "Dataset with output shapes %r cannot be restructured to have " + "incompatible output shapes %r" % (dataset.output_shapes, + output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) if output_classes is None: @@ -331,7 +408,7 @@ class _RestructuredDataset(dataset_ops.Dataset): self._output_classes = output_classes def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access + return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access @property def output_classes(self): @@ -370,9 +447,10 @@ def assert_element_shape(expected_shapes): def _check_shape(*elements): flatten_tensors = nest.flatten(elements) flatten_shapes = nest.flatten(expected_shapes) - checked_tensors = [with_shape(shape, tensor) - for shape, tensor in zip(flatten_shapes, - flatten_tensors)] + checked_tensors = [ + with_shape(shape, tensor) + for shape, tensor in zip(flatten_shapes, flatten_tensors) + ] return nest.pack_sequence_as(elements, checked_tensors) def _apply_fn(dataset): @@ -388,14 +466,14 @@ def assert_element_shape(expected_shapes): class _MapAndBatchDataset(dataset_ops.MapDataset): """A `Dataset` that maps a function over a batch of elements.""" - def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches, + def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls, drop_remainder): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) self._batch_size_t = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - self._num_parallel_batches_t = ops.convert_to_tensor( - num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + self._num_parallel_calls_t = ops.convert_to_tensor( + num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") self._drop_remainder_t = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") @@ -405,12 +483,12 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def _as_variant_tensor(self): # pylint: disable=protected-access input_resource = self._input_dataset._as_variant_tensor() - return gen_dataset_ops.map_and_batch_dataset( + return gen_dataset_ops.map_and_batch_dataset_v2( input_resource, self._map_func.captured_inputs, f=self._map_func, batch_size=self._batch_size_t, - num_parallel_batches=self._num_parallel_batches_t, + num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -433,8 +511,9 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def map_and_batch(map_func, batch_size, - num_parallel_batches=1, - drop_remainder=False): + num_parallel_batches=None, + drop_remainder=False, + num_parallel_calls=None): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset @@ -450,21 +529,37 @@ def map_and_batch(map_func, nested structure of tensors. batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements of this dataset to combine in a single batch. - num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the - number of batches to create in parallel. On one hand, higher values can - help mitigate the effect of stragglers. On the other hand, higher values - can increase contention if CPU is scarce. - drop_remainder: A `tf.bool` scalar `tf.Tensor`, representing whether the - last batch should be dropped in case its size is smaller than desired; - the default behavior is not to drop the smaller batch. + num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`, + representing the number of batches to create in parallel. On one hand, + higher values can help mitigate the effect of stragglers. On the other + hand, higher values can increase contention if CPU is scarce. + drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing + whether the last batch should be dropped in case its size is smaller than + desired; the default behavior is not to drop the smaller batch. + num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, + representing the number of elements to process in parallel. If not + specified, `batch_size * num_parallel_batches` elements will be + processed in parallel. Returns: A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. + + Raises: + ValueError: If both `num_parallel_batches` and `num_parallel_calls` are + specified. """ + if num_parallel_batches is None and num_parallel_calls is None: + num_parallel_calls = batch_size + elif num_parallel_batches is not None and num_parallel_calls is None: + num_parallel_calls = batch_size * num_parallel_batches + elif num_parallel_batches is not None and num_parallel_calls is not None: + raise ValueError("The `num_parallel_batches` and `num_parallel_calls` " + "arguments are mutually exclusive.") + def _apply_fn(dataset): return _MapAndBatchDataset(dataset, map_func, batch_size, - num_parallel_batches, drop_remainder) + num_parallel_calls, drop_remainder) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 0531f9cbb9d..ea229b5b27b 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -33,6 +34,35 @@ from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +def group_by_reducer(key_func, reducer): + """A transformation that groups elements and performs a reduction. + + This transformation maps element of a dataset to a key using `key_func` and + groups the elements by key. The `reducer` is used to process each group; its + `init_func` is used to initialize state for each group when it is created, the + `reduce_func` is used to update the state every time an element is mapped to + the matching group, and the `finalize_func` is used to map the final state to + an output value. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reducer: An instance of `Reducer`, which captures the reduction logic using + the `init_func`, `reduce_func`, and `finalize_func` functions. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return GroupByReducerDataset(dataset, key_func, reducer) + + return _apply_fn + + def group_by_window(key_func, reduce_func, window_size=None, @@ -227,6 +257,250 @@ class _VariantDataset(dataset_ops.Dataset): return self._output_types +class GroupByReducerDataset(dataset_ops.Dataset): + """A `Dataset` that groups its input and performs a reduction.""" + + def __init__(self, input_dataset, key_func, reducer): + """See `group_by_reducer()` for details.""" + super(GroupByReducerDataset, self).__init__() + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_init_func(reducer.init_func) + self._make_reduce_func(reducer.reduce_func, input_dataset) + self._make_finalize_func(reducer.finalize_func) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) + def tf_key_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + # pylint: disable=protected-access + if dataset_ops._should_unpack_args(nested_args): + ret = key_func(*nested_args) + # pylint: enable=protected-access + else: + ret = key_func(nested_args) + ret = ops.convert_to_tensor(ret) + if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) + return ret + + self._key_func = tf_key_func + self._key_func.add_to_graph(ops.get_default_graph()) + + def _make_init_func(self, init_func): + """Make wrapping Defun for init_func.""" + + @function.Defun(dtypes.int64) + def tf_init_func(key): + """A wrapper for Defun that facilitates shape inference.""" + key.set_shape([]) + ret = init_func(key) + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._state_classes = sparse.get_classes(ret) + self._state_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._state_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._init_func = tf_init_func + self._init_func.add_to_graph(ops.get_default_graph()) + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + # Iteratively rerun the reduce function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + # Create a list in which `tf_reduce_func` will store the new shapes. + flat_new_state_shapes = [] + + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) + def tf_reduce_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): + arg.set_shape(shape) + + pivot = len(nest.flatten(self._state_shapes)) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) + + ret = reduce_func(nested_state_args, nested_input_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + # Extract shape information from the returned values. + flat_new_state = nest.flatten(ret) + flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, + nest.pack_sequence_as(self._state_types, + [t.dtype for t in flat_new_state]))) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, + [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + # Use the private method that will execute `tf_reduce_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access + + flat_state_shapes = nest.flatten(self._state_shapes) + weakened_state_shapes = [ + old.most_specific_compatible_shape(new) + for old, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for old_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if old_shape.ndims is not None and ( + weakened_shape.ndims is None or + old_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._reduce_func = tf_reduce_func + self._reduce_func.add_to_graph(ops.get_default_graph()) + + def _make_finalize_func(self, finalize_func): + """Make wrapping Defun for finalize_func.""" + + @function.Defun(*(nest.flatten( + sparse.as_dense_types(self._state_types, self._state_classes)))) + def tf_finalize_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes))): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(self._state_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, self._state_types, self._state_shapes, + self._state_classes) + + ret = finalize_func(nested_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._output_classes = sparse.get_classes(ret) + self._output_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._output_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._finalize_func = tf_finalize_func + self._finalize_func.add_to_graph(ops.get_default_graph()) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_reducer_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._init_func.captured_inputs, + self._reduce_func.captured_inputs, + self._finalize_func.captured_inputs, + key_func=self._key_func, + init_func=self._init_func, + reduce_func=self._reduce_func, + finalize_func=self._finalize_func, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + class GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" @@ -336,3 +610,30 @@ class GroupByWindowDataset(dataset_ops.Dataset): sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + +class Reducer(object): + """A reducer is used for reducing a set of elements. + + A reducer is represented as a tuple of the three functions: + 1) initialization function: key => initial state + 2) reduce function: (old state, input) => new state + 3) finalization function: state => result + """ + + def __init__(self, init_func, reduce_func, finalize_func): + self._init_func = init_func + self._reduce_func = reduce_func + self._finalize_func = finalize_func + + @property + def init_func(self): + return self._init_func + + @property + def reduce_func(self): + return self._reduce_func + + @property + def finalize_func(self): + return self._finalize_func diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 91f19da02d4..812a50ecbf1 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,7 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib import stateless +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.contrib.data.python.ops import random_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -140,3 +151,92 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): prefetch_input_elements=None) return _apply_fn + + +class DirectedInterleaveDataset(dataset_ops.Dataset): + """A substitute for `Dataset.interleave()` on a fixed list of datasets.""" + + def __init__(self, selector_input, data_inputs): + self._selector_input = selector_input + self._data_inputs = list(data_inputs) + + for data_input in data_inputs[1:]: + if (data_input.output_types != data_inputs[0].output_types or + data_input.output_classes != data_inputs[0].output_classes): + raise TypeError("All datasets must have the same type.") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + return gen_dataset_ops.directed_interleave_dataset( + self._selector_input._as_variant_tensor(), + [data_input._as_variant_tensor() for data_input in self._data_inputs], + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._data_inputs[0].output_classes + + @property + def output_shapes(self): + ret = self._data_inputs[0].output_shapes + for data_input in self._data_inputs[1:]: + ret = nest.pack_sequence_as(ret, [ + ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip( + nest.flatten(ret), nest.flatten(data_input.output_shapes)) + ]) + return ret + + @property + def output_types(self): + return self._data_inputs[0].output_types + + +def sample_from_datasets(datasets, weights=None, seed=None): + """Samples elements at random from the datasets in `datasets`. + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + weights: (Optional.) A list of `len(datasets)` floating-point values where + `weights[i]` represents the probability with which an element should be + sampled from `datasets[i]`, or a @{tf.data.Dataset} object where each + element is such a list. Defaults to a uniform distribution across + `datasets`. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A dataset that interleaves elements from `datasets` at random, according to + `weights` if provided, otherwise with uniform probability. + + Raises: + TypeError: If the `datasets` or `weights` arguments have the wrong type. + ValueError: If the `weights` argument is specified and does not match the + length of the `datasets` element. + """ + num_datasets = len(datasets) + if weights is None: + weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat() + elif not isinstance(weights, dataset_ops.Dataset): + weights = ops.convert_to_tensor(weights, name="weights") + if weights.dtype not in (dtypes.float32, dtypes.float64): + raise TypeError("`weights` must be convertible to a tensor of " + "`tf.float32` or `tf.float64` elements.") + if not weights.shape.is_compatible_with([num_datasets]): + raise ValueError("`weights` must be a vector of length `len(datasets)`.") + weights = dataset_ops.Dataset.from_tensors(weights).repeat() + + # The `stateless_multinomial()` op expects log-probabilities, as opposed to + # weights. + logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits")) + def select_dataset(logits, seed): + return array_ops.squeeze( + stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) + selector_input = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) + + return DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb03..f1d0e5cddc2 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -16,10 +16,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.training import saver +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import session_run_hook def make_saveable_from_iterator(iterator): @@ -60,14 +62,14 @@ def make_saveable_from_iterator(iterator): return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access -class _Saveable(saver.BaseSaverBuilder.SaveableObject): +class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject): """SaveableObject for saving/restoring iterator state.""" def __init__(self, iterator_resource): serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) specs = [ - saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", - iterator_resource.name + "-state") + saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") ] super(_Saveable, self).__init__(iterator_resource, specs, iterator_resource.name) @@ -75,3 +77,160 @@ class _Saveable(saver.BaseSaverBuilder.SaveableObject): def restore(self, restored_tensors, unused_restored_shapes): with ops.colocate_with(self.op): return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) + + +class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): + """Checkpoints input pipeline state every N steps or seconds. + + This hook saves the state of the iterators in the `Graph` so that when + training is resumed the input pipeline continues from where it left off. + This could potentially avoid overfitting in certain pipelines where the + number of training steps per eval are small compared to the dataset + size or if the training pipeline is pre-empted. + + Differences from `CheckpointSaverHook`: + 1. Saves only the input pipelines in the "iterators" collection and not the + global variables or other saveable objects. + 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. + + Example of checkpointing the training pipeline: + + ```python + est = tf.estimator.Estimator(model_fn) + while True: + est.train( + train_input_fn, + hooks=[tf.contrib.data.CheckpointInputPipelineHook(est)], + steps=train_steps_per_eval) + # Note: We do not pass the hook here. + metrics = est.evaluate(eval_input_fn) + if should_stop_the_training(metrics): + break + ``` + + This hook should be used if the input pipeline state needs to be saved + separate from the model checkpoint. Doing so may be useful for a few reasons: + 1. The input pipeline checkpoint may be large, if there are large shuffle + or prefetch buffers for instance, and may bloat the checkpoint size. + 2. If the input pipeline is shared between training and validation, restoring + the checkpoint during validation may override the validation input + pipeline. + + For saving the input pipeline checkpoint alongside the model weights use + @{tf.contrib.data.make_saveable_from_iterator} directly to create a + `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, + that you will need to be careful not to restore the training iterator during + eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS + collector when building the eval graph. + """ + + def __init__(self, estimator): + """Initializes a `CheckpointInputPipelineHook`. + + Args: + estimator: Estimator. + + Raises: + ValueError: One of `save_steps` or `save_secs` should be set. + ValueError: At most one of saver or scaffold should be set. + """ + # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or + # of the form "input__.ckpt" for distributed pipelines. + # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is + # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix + # to be different to avoid conflicts with the model checkpoint. + + # pylint: disable=protected-access + checkpoint_prefix = "input" + if estimator._config.num_worker_replicas > 1: + # Distributed setting. + suffix = "_{}_{}".format(estimator._config.task_type, + estimator._config.task_id) + checkpoint_prefix += suffix + # pylint: enable=protected-access + + # We use a composition paradigm instead of inheriting from + # `CheckpointSaverHook` because `Estimator` does an `isinstance` check + # to check whether a `CheckpointSaverHook` is already present in the list + # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` + # would thwart this behavior. This hook checkpoints *only the iterators* + # and not the graph variables. + self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( + estimator.model_dir, + save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access + save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access + checkpoint_basename=checkpoint_prefix + ".ckpt") + + # Name for the protocol buffer file that will contain the list of most + # recent checkpoints stored as a `CheckpointState` protocol buffer. + # This file, kept in the same directory as the checkpoint files, is + # automatically managed by the `Saver` to keep track of recent checkpoints. + # The default name used by the `Saver` for this file is "checkpoint". Here + # we use the name "checkpoint_" so that in case the + # `checkpoint_dir` is the same as the model checkpoint directory, there are + # no conflicts during restore. + self._latest_filename = "checkpoint_" + checkpoint_prefix + + def begin(self): + # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` + # collection if no `Saver` or `Scaffold` is provided. + # pylint: disable=protected-access + if (self._checkpoint_saver_hook._saver is None and + self._checkpoint_saver_hook._scaffold is None): + iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) + saveables = [_Saveable(i) for i in iterators] + self._checkpoint_saver_hook._saver = _CustomSaver(saveables, + self._latest_filename) + # pylint: enable=protected-access + self._checkpoint_saver_hook.begin() + + def after_create_session(self, session, coord): + # Check if there is an existing checkpoint. If so, restore from it. + # pylint: disable=protected-access + latest_checkpoint_path = saver_lib.latest_checkpoint( + self._checkpoint_saver_hook._checkpoint_dir, + latest_filename=self._latest_filename) + if latest_checkpoint_path: + self._checkpoint_saver_hook._get_saver().restore(session, + latest_checkpoint_path) + else: + # The checkpoint saved here is the state at step "global_step". + # Note: We do not save the GraphDef or MetaGraphDef here. + global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) + self._checkpoint_saver_hook._save(session, global_step) + self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) + # pylint: enable=protected-access + + def before_run(self, run_context): + return self._checkpoint_saver_hook.before_run(run_context) + + def after_run(self, run_context, run_values): + self._checkpoint_saver_hook.after_run(run_context, run_values) + + def end(self, session): + self._checkpoint_saver_hook.end(session) + + +class _CustomSaver(saver_lib.Saver): + """`Saver` with a different default `latest_filename`. + + This is used in the `CheckpointInputPipelineHook` to avoid conflicts with + the model ckpt saved by the `CheckpointSaverHook`. + """ + + def __init__(self, var_list, latest_filename): + super(_CustomSaver, self).__init__(var_list) + self._latest_filename = latest_filename + + def save(self, + sess, + save_path, + global_step=None, + latest_filename=None, + meta_graph_suffix="meta", + write_meta_graph=True, + write_state=True, + strip_default_attrs=False): + return super(_CustomSaver, self).save( + sess, save_path, global_step, latest_filename or self._latest_filename, + meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) diff --git a/tensorflow/contrib/data/python/ops/iterator_ops_test.py b/tensorflow/contrib/data/python/ops/iterator_ops_test.py new file mode 100644 index 00000000000..30a993b1f70 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops_test.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for experimental iterator_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + + +class CheckpointInputPipelineHookTest(test.TestCase): + + @staticmethod + def _model_fn(features, labels, mode, config): + del labels + del mode + del config + global_step = training_util.get_or_create_global_step() + update_global_step_op = global_step.assign_add(1) + latest_feature = variables.Variable( + 0, name='latest_feature', dtype=dtypes.int64) + store_latest_feature_op = latest_feature.assign(features) + ops.add_to_collection('my_vars', global_step) + ops.add_to_collection('my_vars', latest_feature) + return model_fn.EstimatorSpec( + mode='train', + train_op=control_flow_ops.group( + [update_global_step_op, store_latest_feature_op]), + loss=constant_op.constant(2.0)) + + def _read_vars(self, model_dir): + """Returns (global_step, latest_feature).""" + with ops.Graph().as_default() as g: + ckpt_path = saver_lib.latest_checkpoint(model_dir) + meta_filename = ckpt_path + '.meta' + saver_lib.import_meta_graph(meta_filename) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, ckpt_path) + return sess.run(ops.get_collection('my_vars')) + + def _build_iterator_saver_hook(self, est): + return iterator_ops.CheckpointInputPipelineHook(est) + + def testReturnDatasetFromInputFn(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testBuildIteratorInInputFn(self): + + def _input_fn(): + ds = dataset_ops.Dataset.range(10) + iterator = ds.make_one_shot_iterator() + return iterator.get_next() + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + + def testDoNotRestore(self): + + def _input_fn(): + return dataset_ops.Dataset.range(10) + + est = estimator.Estimator(model_fn=self._model_fn) + + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1)) + est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3)) + # Hook not provided, input pipeline was not restored. + est.train(_input_fn, steps=2) + self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1)) + + def testRaiseErrorIfNoIterator(self): + + def _input_fn(): + return constant_op.constant(1, dtype=dtypes.int64) + + est = estimator.Estimator(model_fn=self._model_fn) + + with self.assertRaises(ValueError): + est.train( + _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 89c04dc89a2..e4c9f8b58a2 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -114,11 +114,13 @@ class _PrefetchToDeviceIterator(object): ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) + iterator_device = gen_dataset_ops.iterator_get_device( + self._input_iterator._iterator_resource) + with ops.device(device): self._buffering_resource = function_buffering_resource( f=_prefetch_fn, - target_device=gen_dataset_ops.iterator_get_device( - self._input_iterator._iterator_resource), + target_device=iterator_device, string_arg=input_iterator_handle, buffer_size=buffer_size, shared_name=shared_name) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 4ec8ae1c79d..75c31a944a0 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,15 +18,16 @@ from __future__ import division from __future__ import print_function import csv -from math import ceil import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -34,9 +35,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import string_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -68,7 +67,7 @@ def _is_valid_float(str_val, float_dtype): return False -def _infer_type(str_val, na_value, prev_type, float_dtype): +def _infer_type(str_val, na_value, prev_type): """Given a string, infers its tensor type. Infers the type of a value by picking the least 'permissive' type possible, @@ -79,29 +78,34 @@ def _infer_type(str_val, na_value, prev_type, float_dtype): na_value: Additional string to recognize as a NA/NaN CSV value. prev_type: Type previously inferred based on values of this column that we've seen up till now. - float_dtype: Either `tf.float32` or `tf.float64`. Denotes what float type - to parse float strings as. Returns: Inferred dtype. """ if str_val in ("", na_value): + # If the field is null, it gives no extra information about its type return prev_type - if _is_valid_int32(str_val) and prev_type in (None, dtypes.int32): - return dtypes.int32 + type_list = [ + dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string + ] # list of types to try, ordered from least permissive to most - if _is_valid_int64(str_val) and prev_type in (None, dtypes.int32, - dtypes.int64): - return dtypes.int64 + type_functions = [ + _is_valid_int32, + _is_valid_int64, + lambda str_val: _is_valid_float(str_val, dtypes.float32), + lambda str_val: _is_valid_float(str_val, dtypes.float64), + lambda str_val: True, + ] # Corresponding list of validation functions - if _is_valid_float(str_val, float_dtype) and prev_type != dtypes.string: - return float_dtype - - return dtypes.string + for i in range(len(type_list)): + validation_fn = type_functions[i] + if validation_fn(str_val) and (prev_type is None or + prev_type in type_list[:i + 1]): + return type_list[i] -def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment): +def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header): + """Generator that yields rows of CSV file(s) in order.""" for fn in filenames: with file_io.FileIO(fn, "r") as f: rdr = csv.reader( @@ -112,9 +116,6 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, next(rdr) # Skip header lines for csv_row in rdr: - if comment is not None and csv_row[0].startswith(comment): - continue # Skip comment lines - if len(csv_row) != num_cols: raise ValueError( "Problem inferring types: CSV row has different number of fields " @@ -123,22 +124,21 @@ def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim, - na_value, header, comment, float_dtype, - num_rows_for_inference, select_columns): + na_value, header, num_rows_for_inference, + select_columns): """Infers column types from the first N valid CSV records of files.""" if select_columns is None: select_columns = range(num_cols) inferred_types = [None] * len(select_columns) for i, csv_row in enumerate( - _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header, - comment)): + _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)): if num_rows_for_inference is not None and i >= num_rows_for_inference: break for j, col_index in enumerate(select_columns): inferred_types[j] = _infer_type(csv_row[col_index], na_value, - inferred_types[j], float_dtype) + inferred_types[j]) # Replace None's with a default type inferred_types = [t or dtypes.string for t in inferred_types] @@ -156,12 +156,21 @@ def _infer_column_names(filenames, field_delim, use_quote_delim): "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE } with file_io.FileIO(filenames[0], "r") as f: - column_names = next(csv.reader(f, **csv_kwargs)) + try: + column_names = next(csv.reader(f, **csv_kwargs)) + except StopIteration: + raise ValueError(("Received StopIteration when reading the header line " + "of %s. Empty file?") % filenames[0]) for name in filenames[1:]: with file_io.FileIO(name, "r") as f: - if next(csv.reader(f, **csv_kwargs)) != column_names: - raise ValueError("Files have different column names in the header row.") + try: + if next(csv.reader(f, **csv_kwargs)) != column_names: + raise ValueError( + "Files have different column names in the header row.") + except StopIteration: + raise ValueError(("Received StopIteration when reading the header line " + "of %s. Empty file?") % filenames[0]) return column_names @@ -189,6 +198,112 @@ def _get_sorted_col_indices(select_columns, column_names): return result +def _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed): + """Optionally shuffle and repeat dataset, as requested.""" + if num_epochs != 1 and shuffle: + # Use shuffle_and_repeat for perf + return dataset.apply( + shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, + shuffle_seed)) + elif shuffle: + return dataset.shuffle(shuffle_buffer_size, shuffle_seed) + elif num_epochs != 1: + return dataset.repeat(num_epochs) + return dataset + + +def make_tf_record_dataset( + file_pattern, + batch_size, + parser_fn=None, + num_epochs=None, + shuffle=True, + shuffle_buffer_size=None, + shuffle_seed=None, + prefetch_buffer_size=None, + num_parallel_reads=None, + num_parallel_parser_calls=None, + drop_final_batch=False): + """Reads and optionally parses TFRecord files into a dataset. + + Provides common functionality such as batching, optional parsing, shuffling, + and performant defaults. + + Args: + file_pattern: List of files or patterns of TFRecord file paths. + See @{tf.gfile.Glob} for pattern rules. + batch_size: An int representing the number of records to combine + in a single batch. + parser_fn: (Optional.) A function accepting string input to parse + and process the record contents. This function must map records + to components of a fixed shape, so they may be batched. By + default, uses the record contents unmodified. + num_epochs: (Optional.) An int specifying the number of times this + dataset is repeated. If None (the default), cycles through the + dataset forever. + shuffle: (Optional.) A bool that indicates whether the input + should be shuffled. Defaults to `True`. + shuffle_buffer_size: (Optional.) Buffer size to use for + shuffling. A large buffer size ensures better shuffling, but + increases memory usage and startup time. + shuffle_seed: (Optional.) Randomization seed to use for shuffling. + prefetch_buffer_size: (Optional.) An int specifying the number of + feature batches to prefetch for performance improvement. + Defaults to auto-tune. Set to 0 to disable prefetching. + num_parallel_reads: (Optional.) Number of threads used to read + records from files. By default or if set to a value >1, the + results will be interleaved. + num_parallel_parser_calls: (Optional.) Number of parallel + records to parse in parallel. Defaults to an automatic selection. + drop_final_batch: (Optional.) Whether the last batch should be + dropped in case its size is smaller than `batch_size`; the + default behavior is not to drop the smaller batch. + + Returns: + A dataset, where each element matches the output of `parser_fn` + except it will have an additional leading `batch-size` dimension, + or a `batch_size`-length 1-D tensor of strings if `parser_fn` is + unspecified. + """ + files = dataset_ops.Dataset.list_files( + file_pattern, shuffle=shuffle, seed=shuffle_seed) + + if num_parallel_reads is None: + # Note: We considered auto-tuning this value, but there is a concern + # that this affects the mixing of records from different files, which + # could affect training convergence/accuracy, so we are defaulting to + # a constant for now. + num_parallel_reads = 24 + dataset = core_readers.TFRecordDataset( + files, num_parallel_reads=num_parallel_reads) + + if shuffle_buffer_size is None: + # TODO(josh11b): Auto-tune this value when not specified + shuffle_buffer_size = 10000 + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + + if parser_fn is None: + if drop_final_batch: + dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) + else: + dataset = dataset.batch(batch_size) + else: + # TODO(josh11b): if num_parallel_parser_calls is None, use some function + # of num cores instead of map_and_batch's default behavior of one batch. + dataset = dataset.apply(batching.map_and_batch( + parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls, + drop_remainder=drop_final_batch)) + + if prefetch_buffer_size is None: + prefetch_buffer_size = -1 # tf.config.data.AUTOTUNE + if prefetch_buffer_size == 0: + return dataset + else: + return dataset.prefetch(buffer_size=prefetch_buffer_size) + + def make_csv_dataset( file_pattern, batch_size, @@ -200,7 +315,6 @@ def make_csv_dataset( use_quote_delim=True, na_value="", header=True, - comment=None, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, @@ -209,7 +323,6 @@ def make_csv_dataset( num_parallel_reads=1, num_parallel_parser_calls=2, sloppy=False, - default_float_type=dtypes.float32, num_rows_for_inference=100, ): """Reads CSV files into a dataset. @@ -222,8 +335,8 @@ def make_csv_dataset( Args: file_pattern: List of files or patterns of file paths containing CSV records. See @{tf.gfile.Glob} for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. column_names: An optional list of strings that corresponds to the CSV columns, in order. One per column of the input record. If this is not provided, infers the column names from the first row of the records. @@ -263,15 +376,11 @@ def make_csv_dataset( header: A bool that indicates whether the first rows of provided CSV files correspond to header lines with column names, and should not be included in the data. - comment: An optional character string that marks lines that should not be - parsed as csv records. If this is provided, all lines that start with - this character will not be parsed. num_epochs: An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever. shuffle: A bool that indicates whether the input should be shuffled. shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size - ensures better shuffling, but would increase memory usage and startup - time. + ensures better shuffling, but increases memory usage and startup time. shuffle_seed: Randomization seed to use for shuffling. prefetch_buffer_size: An int specifying the number of feature batches to prefetch for performance improvement. Recommended value is the number of @@ -285,8 +394,6 @@ def make_csv_dataset( produced is deterministic prior to shuffling (elements are still randomized if `shuffle=True`. Note that if the seed is set, then order of elements after shuffling is deterministic). Defaults to `False`. - default_float_type: Either `tf.float32` or `tf.float64`. If defaults are - not provided, float-like strings are interpreted to be this type. num_rows_for_inference: Number of rows of a file to use for type inference if record_defaults is not provided. If None, reads all the rows of all the files. Defaults to 100. @@ -308,8 +415,6 @@ def make_csv_dataset( dataset = dataset.shuffle(len(filenames), shuffle_seed) # Clean arguments; figure out column names and defaults - if comment is not None and len(comment) != 1: - raise ValueError("`comment` arg must be a single-character string or None") if column_names is None: if not header: @@ -332,8 +437,7 @@ def make_csv_dataset( # construction time column_defaults = _infer_column_defaults( filenames, len(column_names), field_delim, use_quote_delim, na_value, - header, comment, default_float_type, num_rows_for_inference, - select_columns) + header, num_rows_for_inference, select_columns) if select_columns is not None and len(column_defaults) != len(select_columns): raise ValueError( @@ -347,71 +451,189 @@ def make_csv_dataset( if label_name is not None and label_name not in column_names: raise ValueError("`label_name` provided must be one of the columns.") - # Define map and filter functions - def filter_fn(line): - return math_ops.not_equal(string_ops.substr(line, 0, 1), comment) - def filename_to_dataset(filename): - ds = core_readers.TextLineDataset(filename) - if header: - ds = ds.skip(1) - if comment is not None: - ds = ds.filter(filter_fn) - return ds + return CsvDataset( + filename, + record_defaults=column_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + na_value=na_value, + select_cols=select_columns, + header=header) - def decode_csv(line): - """Decodes CSV line into features. + def map_fn(*columns): + """Organizes columns into a features dictionary. Args: - line: String tensor corresponding to one csv record. + *columns: list of `Tensor`s corresponding to one csv record. Returns: A dictionary of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - columns = parsing_ops.decode_csv( - line, - column_defaults, - field_delim=field_delim, - use_quote_delim=use_quote_delim, - na_value=na_value, - select_cols=select_columns, - ) features = dict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label return features - # Read files sequentially or in parallel + # Read files sequentially (if num_parallel_reads=1) or in parallel dataset = dataset.apply( interleave_ops.parallel_interleave( filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy)) - if num_epochs != 1 and shuffle: - # Use shuffle_and_repeat for perf - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) - elif num_epochs != 1: - dataset = dataset.repeat(num_epochs) - - # Use map_and_batch for perf - # TODO(b/76425672): use num_parallel_calls for better performance tuning when - # that is added - dataset = dataset.apply( - batching.map_and_batch( - map_func=decode_csv, - batch_size=batch_size, - num_parallel_batches=int( - ceil(num_parallel_parser_calls / batch_size)))) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + # Apply batch before map for perf, because map has high overhead relative + # to the size of the computation in each map + dataset = dataset.batch(batch_size=batch_size) + dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls) dataset = dataset.prefetch(prefetch_buffer_size) + return dataset +_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB + + +class CsvDataset(dataset_ops.Dataset): + """A Dataset comprising lines from one or more CSV files.""" + + def __init__(self, + filenames, + record_defaults, + buffer_size=None, + header=False, + field_delim=",", + use_quote_delim=True, + na_value="", + select_cols=None): + """Creates a `CsvDataset` by reading and decoding CSV files. + + The elements of this dataset correspond to records from the file(s). + RFC 4180 format is expected for CSV files + (https://tools.ietf.org/html/rfc4180) + Note that we allow leading and trailing spaces with int or float field. + + + For example, suppose we have a file 'my_file0.csv' with four CSV columns of + different data types: + ``` + abcdefg,4.28E10,5.55E6,12 + hijklmn,-5.3E14,,2 + ``` + + We can construct a CsvDataset from it as follows: + ```python + dataset = tf.contrib.data.CsvDataset( + "my_file*.csv", + [tf.float32, # Required field, use dtype or empty tensor + tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 + tf.int32, # Required field, use dtype or empty tensor + ], + select_cols=[1,2,3] # Only parse last three columns + ) + ``` + + The expected output of its iterations is: + ```python + next = dataset.make_one_shot_iterator().get_next() + with tf.Session() as sess: + while True: + try: + print(sess.run(nxt)) + except tf.errors.OutOfRangeError: + break + + >> (4.28e10, 5.55e6, 12) + >> (-5.3e14, 0.0, 2) + ``` + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_defaults: A list of default values for the CSV fields. Each item in + the list is either a valid CSV `DType` (float32, float64, int32, int64, + string), or a `Tensor` object with one of the above types. One per + column of CSV data, with either a scalar `Tensor` default value for the + column if it is optional, or `DType` or empty `Tensor` if required. If + both this and `select_columns` are specified, these must have the same + lengths, and `column_defaults` is assumed to be sorted in order of + increasing column index. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer while reading files. Defaults to 4MB. + header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) + have header line(s) that should be skipped when parsing. Defaults to + `False`. + field_delim: (Optional.) A `tf.string` scalar containing the delimiter + character that separates fields in a record. Defaults to `","`. + use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats + double quotation marks as regular characters inside of string fields + (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. + na_value: (Optional.) A `tf.string` scalar indicating a value that will + be treated as NA/NaN. + select_cols: (Optional.) A sorted list of column indices to select from + the input data. If specified, only this subset of columns will be + parsed. Defaults to parsing all columns. + """ + super(CsvDataset, self).__init__() + self._filenames = ops.convert_to_tensor( + filenames, dtype=dtypes.string, name="filenames") + record_defaults = [ + constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x + for x in record_defaults + ] + self._record_defaults = ops.convert_n_to_tensor( + record_defaults, name="record_defaults") + self._buffer_size = convert.optional_param_to_tensor( + "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) + self._header = ops.convert_to_tensor( + header, dtype=dtypes.bool, name="header") + self._field_delim = ops.convert_to_tensor( + field_delim, dtype=dtypes.string, name="field_delim") + self._use_quote_delim = ops.convert_to_tensor( + use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") + self._na_value = ops.convert_to_tensor( + na_value, dtype=dtypes.string, name="na_value") + self._select_cols = convert.optional_param_to_tensor( + "select_cols", + select_cols, + argument_default=[], + argument_dtype=dtypes.int64, + ) + self._output_shapes = tuple( + tensor_shape.scalar() for _ in range(len(record_defaults))) + self._output_types = tuple(d.dtype for d in self._record_defaults) + self._output_classes = tuple( + ops.Tensor for _ in range(len(record_defaults))) + + def _as_variant_tensor(self): + # Constructs graph node for the dataset op. + return contrib_gen_dataset_ops.csv_dataset( + filenames=self._filenames, + record_defaults=self._record_defaults, + buffer_size=self._buffer_size, + header=self._header, + output_shapes=self._output_shapes, + field_delim=self._field_delim, + use_quote_delim=self._use_quote_delim, + na_value=self._na_value, + select_cols=self._select_cols, + ) + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_classes(self): + return self._output_classes + + def make_batched_features_dataset(file_pattern, batch_size, features, @@ -471,8 +693,8 @@ def make_batched_features_dataset(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be @@ -528,16 +750,8 @@ def make_batched_features_dataset(file_pattern, dataset = dataset.map(lambda _, v: v) # Apply dataset repeat and shuffle transformations. - repeat_dataset = (num_epochs != 1) - if repeat_dataset and shuffle: - # Used fused shuffle_and_repeat operation for better performance - dataset = dataset.apply( - shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs, - shuffle_seed)) - elif repeat_dataset: - dataset = dataset.repeat(num_epochs) - elif shuffle: - dataset = dataset.shuffle(shuffle_buffer_size, shuffle_seed) + dataset = _maybe_shuffle_and_repeat( + dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) @@ -611,8 +825,8 @@ def read_batch_features(file_pattern, Args: file_pattern: List of files or patterns of file paths containing `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. + batch_size: An int representing the number of records to combine + in a single batch. features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index b465397437a..bad6edd5147 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -20,10 +20,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import logging_ops @@ -50,80 +52,182 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): A `Dataset` transformation function, which can be passed to @{tf.data.Dataset.apply}. """ - def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - dist_estimation_batch_size = 32 target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist") class_values_ds = dataset.map(class_func) + + # Get initial distribution. if initial_dist is not None: initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") - acceptance_dist = _calculate_acceptance_probs(initial_dist_t, - target_dist_t) + acceptance_dist, prob_of_original = ( + _calculate_acceptance_probs_with_mixing(initial_dist_t, + target_dist_t)) initial_dist_ds = dataset_ops.Dataset.from_tensors( initial_dist_t).repeat() acceptance_dist_ds = dataset_ops.Dataset.from_tensors( acceptance_dist).repeat() + prob_of_original_ds = dataset_ops.Dataset.from_tensors( + prob_of_original).repeat() else: - num_classes = (target_dist_t.shape[0].value or - array_ops.shape(target_dist_t)[0]) - smoothing_constant = 10 - initial_examples_per_class_seen = array_ops.fill( - [num_classes], np.int64(smoothing_constant)) - - def update_estimate_and_tile(num_examples_per_class_seen, c): - updated_examples_per_class_seen, dist = _estimate_data_distribution( - c, num_examples_per_class_seen) - tiled_dist = array_ops.tile( - array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) - return updated_examples_per_class_seen, tiled_dist - - initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .apply(scan_ops.scan(initial_examples_per_class_seen, - update_estimate_and_tile)) - .apply(batching.unbatch())) - acceptance_dist_ds = initial_dist_ds.map( - lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) - - def maybe_warn_on_large_rejection(accept_dist, initial_dist): - proportion_rejected = math_ops.reduce_sum( - (1 - accept_dist) * initial_dist) - return control_flow_ops.cond( - math_ops.less(proportion_rejected, .5), - lambda: accept_dist, - lambda: logging_ops.Print( # pylint: disable=g-long-lambda - accept_dist, [proportion_rejected, initial_dist, accept_dist], - message="Proportion of examples rejected by sampler is high: ", - summarize=100, - first_n=10)) - - acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, - initial_dist_ds)) - .map(maybe_warn_on_large_rejection)) - - def _gather_and_copy(class_val, acceptance_prob, data): - return (class_val, array_ops.gather(acceptance_prob, class_val), data) - current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( - (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) - filtered_ds = ( - current_probabilities_and_class_and_data_ds - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) - return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + initial_dist_ds = _estimate_initial_dist_ds( + target_dist_t, class_values_ds) + acceptance_and_original_prob_ds = initial_dist_ds.map( + lambda initial: _calculate_acceptance_probs_with_mixing( + initial, target_dist_t)) + acceptance_dist_ds = acceptance_and_original_prob_ds.map( + lambda accept_prob, _: accept_prob) + prob_of_original_ds = acceptance_and_original_prob_ds.map( + lambda _, prob_original: prob_original) + filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, + class_values_ds, seed) + # Prefetch filtered dataset for speed. + filtered_ds = filtered_ds.prefetch(3) + prob_original_static = _get_prob_original_static( + initial_dist_t, target_dist_t) if initial_dist is not None else None + if prob_original_static == 1: + return dataset_ops.Dataset.zip((class_values_ds, dataset)) + elif prob_original_static == 0: + return filtered_ds + else: + return interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds], + weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]), + seed=seed) return _apply_fn -def _calculate_acceptance_probs(initial_probs, target_probs): - """Calculate the per-class acceptance rates. +def _get_prob_original_static(initial_dist_t, target_dist_t): + """Returns the static probability of sampling from the original. + + `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters + an Op that it isn't defined for. We have some custom logic to avoid this. Args: - initial_probs: The class probabilities of the data. - target_probs: The desired class proportion in minibatches. - Returns: - A list of the per-class acceptance probabilities. + initial_dist_t: A tensor of the initial distribution. + target_dist_t: A tensor of the target distribution. - This method is based on solving the following analysis: + Returns: + The probability of sampling from the original distribution as a constant, + if it is a constant, or `None`. + """ + init_static = tensor_util.constant_value(initial_dist_t) + target_static = tensor_util.constant_value(target_dist_t) + + if init_static is None or target_static is None: + return None + else: + return np.min(target_static / init_static) + + +def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds, + seed): + """Filters a dataset based on per-class acceptance probabilities. + + Args: + dataset: The dataset to be filtered. + acceptance_dist_ds: A dataset of acceptance probabilities. + initial_dist_ds: A dataset of the initial probability distribution, given or + estimated. + class_values_ds: A dataset of the corresponding classes. + seed: (Optional.) Python integer seed for the resampler. + + Returns: + A dataset of (class value, data) after filtering. + """ + def maybe_warn_on_large_rejection(accept_dist, initial_dist): + proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist) + return control_flow_ops.cond( + math_ops.less(proportion_rejected, .5), + lambda: accept_dist, + lambda: logging_ops.Print( # pylint: disable=g-long-lambda + accept_dist, [proportion_rejected, initial_dist, accept_dist], + message="Proportion of examples rejected by sampler is high: ", + summarize=100, + first_n=10)) + + acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, + initial_dist_ds)) + .map(maybe_warn_on_large_rejection)) + + def _gather_and_copy(class_val, acceptance_prob, data): + return class_val, array_ops.gather(acceptance_prob, class_val), data + + current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip( + (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy) + filtered_ds = ( + current_probabilities_and_class_and_data_ds + .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) + return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + + +def _estimate_initial_dist_ds( + target_dist_t, class_values_ds, dist_estimation_batch_size=32, + smoothing_constant=10): + num_classes = (target_dist_t.shape[0].value or + array_ops.shape(target_dist_t)[0]) + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist + + initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) + + return initial_dist_ds + + +def _get_target_to_initial_ratio(initial_probs, target_probs): + # Add tiny to initial_probs to avoid divide by zero. + denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) + return target_probs / denom + + +def _estimate_data_distribution(c, num_examples_per_class_seen): + """Estimate data distribution as labels are seen. + + Args: + c: The class labels. Type `int32`, shape `[batch_size]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. + + Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. + dist: The updated distribution. Type `float32`, shape `[num_classes]`. + """ + num_classes = num_examples_per_class_seen.get_shape()[0].value + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( + array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) + init_prob_estimate = math_ops.truediv( + num_examples_per_class_seen, + math_ops.reduce_sum(num_examples_per_class_seen)) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist + + +def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): + """Calculates the acceptance probabilities and mixing ratio. + + In this case, we assume that we can *either* sample from the original data + distribution with probability `m`, or sample from a reshaped distribution + that comes from rejection sampling on the original distribution. This + rejection sampling is done on a per-class basis, with `a_i` representing the + probability of accepting data from class `i`. + + This method is based on solving the following analysis for the reshaped + distribution: Let F be the probability of a rejection (on any example). Let p_i be the proportion of examples in the data in class i (init_probs) @@ -152,39 +256,39 @@ def _calculate_acceptance_probs(initial_probs, target_probs): 0 <= t_i <= 1, sum_i(t_i) = 1 ``` - A solution for a_i in terms of the other variables is the following: ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` - """ - # Add tiny to initial_probs to avoid divide by zero. - denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) - ratio_l = target_probs / denom - # Calculate list of acceptance probabilities. - max_ratio = math_ops.reduce_max(ratio_l) - return ratio_l / max_ratio + If we try to minimize the amount of data rejected, we get the following: + M_max = max_i [ t_i / p_i ] + M_min = min_i [ t_i / p_i ] -def _estimate_data_distribution(c, num_examples_per_class_seen): - """Estimate data distribution as labels are seen. + The desired probability of accepting data if it comes from class `i`: + + a_i = (t_i/p_i - m) / (M_max - m) + + The desired probability of pulling a data element from the original dataset, + rather than the filtered one: + + m = M_min Args: - c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, - containing counts. + initial_probs: A Tensor of the initial probability distribution, given or + estimated. + target_probs: A Tensor of the corresponding classes. Returns: - num_examples_per_lass_seen: Updated counts. Type `int64`, shape - `[num_classes]`. - dist: The updated distribution. Type `float32`, shape `[num_classes]`. + (A 1D Tensor with the per-class acceptance probabilities, the desired + probability of pull from the original distribution.) """ - num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in batch. - num_examples_per_class_seen = math_ops.add( - num_examples_per_class_seen, math_ops.reduce_sum( - array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) - init_prob_estimate = math_ops.truediv( - num_examples_per_class_seen, - math_ops.reduce_sum(num_examples_per_class_seen)) - dist = math_ops.cast(init_prob_estimate, dtypes.float32) - return num_examples_per_class_seen, dist + ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs) + max_ratio = math_ops.reduce_max(ratio_l) + min_ratio = math_ops.reduce_min(ratio_l) + + # Target prob to sample from original distribution. + m = min_ratio + + # TODO(joelshor): Simplify fraction, if possible. + a_i = (ratio_l - m) / (max_ratio - m) + return a_i, m \ No newline at end of file diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 1c88366273f..e911ad0fa05 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_dataset_ops @@ -36,18 +37,22 @@ class _ScanDataset(dataset_ops.Dataset): self._input_dataset = input_dataset with ops.name_scope("initial_state"): + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. self._initial_state = nest.pack_sequence_as(initial_state, [ - ops.convert_to_tensor(t, name="component_%d" % i) + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( + t, name="component_%d" % i) for i, t in enumerate(nest.flatten(initial_state)) ]) - # Compute initial values for the state shapes and types based on - # the initial state. These will be refined by running - # `tf_scan_func` one or more times below. - # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. + # Compute initial values for the state classes, shapes and types based on + # the initial state. The shapes may be refined by running `tf_scan_func` one + # or more times below. + self._state_classes = sparse.get_classes(self._initial_state) self._state_shapes = nest.pack_sequence_as( self._initial_state, - [t.shape for t in nest.flatten(self._initial_state)]) + [t.get_shape() for t in nest.flatten(self._initial_state)]) self._state_types = nest.pack_sequence_as( self._initial_state, [t.dtype for t in nest.flatten(self._initial_state)]) @@ -57,72 +62,107 @@ class _ScanDataset(dataset_ops.Dataset): self._output_shapes = None self._output_types = None - # Iteratively rerun the scan function until reaching a fixed pont on + # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: - flat_state_shapes = nest.flatten(self._state_shapes) - flat_state_types = nest.flatten(self._state_types) - - # Create a list in which `tf_scan_func` will store the s + # Create a list in which `tf_scan_func` will store the new shapes. flat_new_state_shapes = [] - @function.Defun(*(flat_state_types + nest.flatten( - sparse.as_dense_types(input_dataset.output_types, - input_dataset.output_classes)))) + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. - # TODO(b/69424092): Check that neither inputs nor outputs are sparse. - dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, - input_dataset.output_classes) - for arg, shape in zip(args, - flat_state_shapes + nest.flatten(dense_shapes)): + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): arg.set_shape(shape) - pivot = len(flat_state_shapes) - old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) - input_value = nest.pack_sequence_as(input_dataset.output_types, - args[pivot:]) + pivot = len(nest.flatten(self._state_shapes)) + print(self._state_classes) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + print(input_dataset.output_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) - ret = scan_func(old_state, input_value) + ret = scan_func(nested_state_args, nested_input_args) if not isinstance(ret, collections.Sequence) or len(ret) != 2: raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) new_state, output_value = ret - flat_new_state = [ - ops.convert_to_tensor(t) for t in nest.flatten(new_state) - ] - flat_output_value = [ - ops.convert_to_tensor(t) for t in nest.flatten(output_value) - ] + # Extract and validate class information from the returned values. + for t, clazz in zip( + nest.flatten(new_state), nest.flatten(self._state_classes)): + if not isinstance(t, clazz): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_classes, + nest.pack_sequence_as( + self._state_types, + [type(t) for t in nest.flatten(new_state)]))) + self._output_classes = sparse.get_classes(output_value) # Extract shape information from the returned values. - flat_new_state_shapes.extend([t.shape for t in flat_new_state]) + flat_new_state_shapes.extend( + [t.get_shape() for t in nest.flatten(new_state)]) self._output_shapes = nest.pack_sequence_as( - output_value, [t.shape for t in flat_output_value]) + output_value, [t.get_shape() for t in nest.flatten(output_value)]) # Extract and validate type information from the returned values. - for t, dtype in zip(flat_new_state, flat_state_types): + for t, dtype in zip( + nest.flatten(new_state), nest.flatten(self._state_types)): if t.dtype != dtype: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % - (self._state_types, nest.pack_sequence_as( - self._state_types, [t.dtype for t in flat_new_state]))) - self._output_classes = nest.pack_sequence_as( - output_value, [ops.Tensor for _ in flat_output_value]) + (self._state_types, + nest.pack_sequence_as( + self._state_types, + [t.dtype for t in nest.flatten(new_state)]))) self._output_types = nest.pack_sequence_as( - output_value, [t.dtype for t in flat_output_value]) + output_value, [t.dtype for t in nest.flatten(output_value)]) - return flat_new_state + flat_output_value + # Serialize any sparse tensors. + new_state = nest.pack_sequence_as(new_state, [ + t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) + ]) + output_value = nest.pack_sequence_as(output_value, [ + t for t in nest.flatten( + sparse.serialize_sparse_tensors(output_value)) + ]) + return nest.flatten(new_state) + nest.flatten(output_value) # Use the private method that will execute `tf_scan_func` but delay # adding it to the graph in case we need to rerun the function. tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + flat_state_shapes = nest.flatten(self._state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) @@ -144,12 +184,13 @@ class _ScanDataset(dataset_ops.Dataset): weakened_state_shapes) self._scan_func = tf_scan_func + self._scan_func.add_to_graph(ops.get_default_graph()) def _as_variant_tensor(self): input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access return gen_dataset_ops.scan_dataset( input_t, - nest.flatten(self._initial_state), + nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, output_types=nest.flatten( diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index b5cf0fcfe91..3cbaab5affd 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -85,32 +84,60 @@ class StatsAggregator(object): """ return gen_dataset_ops.stats_aggregator_summary(self._resource) - def subscribe(self, iterator): - """Returns a @{tf.Operation} to associate this aggregator with `iterator`. - Note: Each @{tf.data.Iterator} can be associated with at most one - `StatsAggregator`. After running the operation that this function - returns, all statistics recorded in the iteration of `iterator` - will be stored in `stats_aggregator`. +class _SetStatsAggregatorDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" - Args: - iterator: A @{tf.data.Iterator} object. + def __init__(self, input_dataset, stats_aggregator): + super(_SetStatsAggregatorDataset, self).__init__() + self._input_dataset = input_dataset + self._stats_aggregator = stats_aggregator - Returns: - A @{tf.Operation} that, when run, associates this aggregator with - `iterator`. - """ - if not isinstance(iterator, iterator_ops.Iterator): - raise TypeError("`iterator` must be a `tf.data.Iterator` object.") - return gen_dataset_ops.iterator_set_stats_aggregator( - iterator._iterator_resource, self._resource) # pylint: disable=protected-access + def _as_variant_tensor(self): + return gen_dataset_ops.set_stats_aggregator_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._stats_aggregator._resource, # pylint: disable=protected-access + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`. +def set_stats_aggregator(stats_aggregator): + """Set the given stats_aggregator for aggregating the input dataset stats. + + Args: + stats_aggregator: A `StatsAggregator` object. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _SetStatsAggregatorDataset(dataset, stats_aggregator) + + return _apply_fn def bytes_produced_stats(tag): """Records the number of bytes produced by each element of the input dataset. - To consume the statistics, associate a `StatsAggregator` with an iterator - over the output dataset. + To consume the statistics, associate a `StatsAggregator` with the output + dataset. Args: tag: String. All statistics recorded by the returned transformation will @@ -131,8 +158,8 @@ def bytes_produced_stats(tag): def latency_stats(tag): """Records the latency of producing each element of the input dataset. - To consume the statistics, associate a `StatsAggregator` with an iterator - over the output dataset. + To consume the statistics, associate a `StatsAggregator` with the output + dataset. Args: tag: String. All statistics recorded by the returned transformation will diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py new file mode 100644 index 00000000000..f53bd3f7383 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/writers.py @@ -0,0 +1,58 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for tf.data writers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class TFRecordWriter(object): + """Writes data to a TFRecord file.""" + + def __init__(self, filename, compression_type=None): + self._filename = ops.convert_to_tensor( + filename, dtypes.string, name="filename") + self._compression_type = convert.optional_param_to_tensor( + "compression_type", + compression_type, + argument_default="", + argument_dtype=dtypes.string) + + def write(self, dataset): + """Returns a @{tf.Operation} to write a dataset to a file. + + Args: + dataset: a @{tf.data.Dataset} whose elements are to be written to a file + + Returns: + A @{tf.Operation} that, when run, writes contents of `dataset` to a file. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + if (dataset.output_types != dtypes.string or + dataset.output_shapes != tensor_shape.scalar()): + raise TypeError( + "`dataset` must produce scalar `DT_STRING` tensors whereas it " + "produces shape {0} and types {1}".format(dataset.output_shapes, + dataset.output_types)) + return gen_dataset_ops.dataset_to_tf_record( + dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 14de1e8f491..44a4481021c 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -116,7 +116,8 @@ in the input function gives a solid boost in performance. When using ## Caveats This feature is in early stages and there are a lot of improvements forthcoming: -* Metrics are not yet supported during distributed training. +* Metrics are not yet supported during distributed training. They are still +supported during the evaluation. * Summaries are only computed in the first tower in `MirroredStrategy`. * Evaluation is not yet distributed. * Eager support is in the works; performance can be more challenging with eager diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 5aad21cccd3..64a77bbed1d 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -21,11 +21,11 @@ py_library( srcs = ["values.py"], visibility = ["//tensorflow:internal"], deps = [ + ":input_ops", ":prefetching_ops_v2", "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/contrib/eager/python:datasets", "//tensorflow/python:array_ops", - "//tensorflow/python:checkpointable", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", "//tensorflow/python:distribute", @@ -33,6 +33,7 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/eager:context", + "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", ], ) @@ -42,6 +43,7 @@ cuda_py_test( srcs = ["values_test.py"], additional_deps = [ ":mirrored_strategy", + ":multi_worker_test_base", ":values", "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", @@ -57,6 +59,9 @@ cuda_py_test( "//tensorflow/python/eager:test", "//tensorflow/python/estimator:model_fn", ], + tags = [ + "no_pip", + ], ) py_library( @@ -81,6 +86,19 @@ py_library( ], ) +py_library( + name = "multi_worker_strategy", + srcs = ["multi_worker_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":mirrored_strategy", + ":values", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:training", + "//tensorflow/python:util", + ], +) + py_library( name = "one_device_strategy", srcs = ["one_device_strategy.py"], @@ -131,7 +149,9 @@ py_library( deps = [ ":mirrored_strategy", ":one_device_strategy", + ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", + "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", @@ -215,6 +235,24 @@ cuda_py_test( ], ) +py_library( + name = "multi_worker_test_base", + testonly = 1, + srcs = ["multi_worker_test_base.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + ], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python/eager:test", + ], +) + py_library( name = "step_fn", srcs = ["step_fn.py"], @@ -225,21 +263,50 @@ py_library( ], ) -cuda_py_test( - name = "minimize_loss_test", - srcs = ["minimize_loss_test.py"], - additional_deps = [ - ":combinations", - ":single_loss_example", - "@absl_py//absl/testing:parameterized", - "//third_party/py/numpy", +py_library( + name = "tpu_strategy", + srcs = ["tpu_strategy.py"], + visibility = ["//tensorflow:internal"], + deps = [ + ":one_device_strategy", + ":values", + "//tensorflow/contrib/tpu", + "//tensorflow/contrib/tpu:tpu_py", + "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "minimize_loss_test_lib", + testonly = 1, + srcs = ["minimize_loss_test.py"], + deps = [ + ":combinations", + ":mirrored_strategy", + ":single_loss_example", + "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +cuda_py_test( + name = "minimize_loss_test", + srcs = ["minimize_loss_test.py"], + additional_deps = [ + ":minimize_loss_test_lib", ], tags = [ "multi_and_single_gpu", @@ -297,6 +364,7 @@ py_library( srcs = ["single_loss_example.py"], deps = [ ":step_fn", + "//tensorflow/contrib/data/python/ops:batching", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:layers", @@ -402,24 +470,24 @@ py_library( ], ) -py_test( +cuda_py_test( name = "cross_tower_ops_test", srcs = ["cross_tower_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "no_pip", - ], - deps = [ + additional_deps = [ ":combinations", ":cross_tower_ops", ":values", + "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "@absl_py//absl/testing:parameterized", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", ], ) @@ -448,3 +516,34 @@ cuda_py_test( "//tensorflow/python/data/ops:iterator_ops", ], ) + +py_library( + name = "input_ops", + srcs = ["input_ops.py"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + ], +) + +cuda_py_test( + name = "input_ops_test", + srcs = ["input_ops_test.py"], + additional_deps = [ + ":input_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:errors", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:io_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python:util", + ], + tags = [ + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 02b1e7ef9fc..514fd271477 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -45,16 +45,19 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import tpu_strategy from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.training import adam +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import gradient_descent from tensorflow.python.util import tf_inspect GPU_TEST = "test_gpu" in sys.argv[0] +TPU_TEST = "test_tpu" in sys.argv[0] def generate(combinations): @@ -67,26 +70,31 @@ def generate(combinations): -- there should always be a "mode" argument. Accepted values are "eager" and "graph". -- arguments of the test method must match by name to get the corresponding - value of the combination. Tests must accept all arguments (except "mode", - which is optional). - -- distribution argument is special. It is meant for passing instances of - DistributionStrategy. Each instance is to be passed as `(, - )` tuple, where is the number of required - GPUs. If the required number of GPUs for the DistributionStrategy isn't - available then the test case is going to be skipped. + value of the combination. Tests must accept all arguments except the + "mode", "required_tpu" and "required_gpus". + -- "distribution" argument is special and optional. It is meant for passing + instances of DistributionStrategy. Each instance is to be passed as via + `NamedDistribution`. If using "distribution", "required_gpus" and + "required_tpu" should be specified via the NamedDistribution instance, + rather than as separate arguments. + -- "required_tpu" argument is special and optional. If not `None`, then the + test will be skipped if TPUs aren't available. + -- "required_gpus" argument is special and optional. If not `None`, then the + test will be skipped if the specified number of GPUs aren't available. Returns: a decorator that will cause the test method to be run under the specified conditions. Raises: - ValueError - if "mode" argument wasn't either "eager" or "graph. + ValueError - if "mode" argument wasn't either "eager" or "graph". """ def decorator(test_function): """The decorator to be returned.""" # Generate good test names that can be used with --test_filter. + named_combinations = [] for combination in combinations: # We use OrderedDicts in `combine()` and `times()` to ensure stable # order of keys in each dictionary. @@ -97,25 +105,46 @@ def generate(combinations): "".join(filter(str.isalnum, str(value)))) for key, value in combination.items() ]) - combination.update({"testcase_name": "_test{}".format(name)}) + named_combinations.append( + OrderedDict( + list(combination.items()) + [("testcase_name", + "_test{}".format(name))])) - @parameterized.named_parameters(*combinations) + @parameterized.named_parameters(*named_combinations) def decorated(self, **kwargs): """A wrapped test method that sets up `test_function`.""" assert "mode" in kwargs mode = kwargs["mode"] - if "distribution" in kwargs: - distribution = kwargs["distribution"] - kwargs["distribution"] = distribution.strategy - if not distribution.required_gpus: - if GPU_TEST: - self.skipTest("Test that doesn't require GPUs.") - elif context.num_gpus() < distribution.required_gpus: - self.skipTest( - "{} GPUs are not available for this test. {} GPUs are available". - format(distribution.required_gpus, context.num_gpus())) + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. requested_arguments = tf_inspect.getfullargspec(test_function).args missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( set(requested_arguments + ["mode"])) @@ -232,10 +261,12 @@ class NamedObject(object): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus): + def __init__(self, name, distribution, required_gpus=None, + required_tpu=False): self._distribution = distribution self._name = name self._required_gpus = required_gpus + self._required_tpu = required_tpu def __repr__(self): return self._name @@ -248,20 +279,36 @@ class NamedDistribution(object): def required_gpus(self): return self._required_gpus + @property + def required_tpu(self): + return self._required_tpu + +default_strategy = NamedDistribution( + "Default", + distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + required_gpus=None) one_device_strategy = NamedDistribution( "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), - None) + required_gpus=None) +tpu_strategy_single_iteration = NamedDistribution( + "TPUSingleIteration", + tpu_strategy.TPUStrategy(iterations_per_step=1), + required_tpu=True) +tpu_strategy = NamedDistribution( + "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +# Note that we disable prefetching for testing since prefetching makes +# the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/cpu:0"]), 1) -mirrored_strategy_without_prefetch = NamedDistribution( - "MirroredCPUAndGPUNoPrefetch", mirrored_strategy.MirroredStrategy( - ["/gpu:0", "/cpu:0"], prefetch_on_device=False), 1) + ["/gpu:0", "/cpu:0"], prefetch_on_device=False), + required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy(["/gpu:0", "/gpu:1"]), 2) + mirrored_strategy.MirroredStrategy( + ["/gpu:0", "/gpu:1"], prefetch_on_device=False), + required_gpus=2) adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index cff717db80f..c6a1bf6a9f6 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -53,15 +53,14 @@ def _validate_value_destination_pairs(value_destination_pairs): return True +# TODO(yuefengz): consider calling this function in the caller of CrossTowerOps. def _get_devices_from(destinations): if isinstance(destinations, value_lib.DistributedValues): return list(destinations.devices) elif isinstance(destinations, six.string_types): - return [device_util.canonicalize(destinations)] + return [device_util.resolve(destinations)] else: - return [ - device_util.canonicalize(destination) for destination in destinations - ] + return [device_util.resolve(destination) for destination in destinations] def _devices_match(left, right): diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py index c5a520ab5ae..34410a64701 100644 --- a/tensorflow/contrib/distribute/python/estimator_integration_test.py +++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py @@ -61,7 +61,8 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase, mode=['graph'], distribution=[ combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus ])) def test_complete_flow_with_mode(self, distribution): label_dimension = 2 diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py index b87224251ca..2b05884b9b9 100644 --- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py +++ b/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""An example tf.keras model that is trained using MirroredStrategy.""" +"""An example of training tf.keras Model using MirroredStrategy.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from sys import argv + +import sys + import numpy as np import tensorflow as tf @@ -33,30 +35,37 @@ def input_fn(): def main(args): if len(args) < 2: - print('You must specify model_dir for checkpoints such as' - ' /tmp/tfkeras_example./') + print('You must specify model_dir for checkpoints such as' + ' /tmp/tfkeras_example/.') return - print('Using %s to store checkpoints.' % args[1]) - - strategy = tf.contrib.distribute.MirroredStrategy( - ['/device:GPU:0', '/device:GPU:1']) - config = tf.estimator.RunConfig(train_distribute=strategy) - optimizer = tf.train.GradientDescentOptimizer(0.2) + model_dir = args[1] + print('Using %s to store checkpoints.' % model_dir) + # Define tf.keras Model. model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,))) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) + # Compile tf.keras Model. + optimizer = tf.train.GradientDescentOptimizer(0.2) model.compile(loss='binary_crossentropy', optimizer=optimizer) model.summary() tf.keras.backend.set_learning_phase(True) - keras_estimator = tf.keras.estimator.model_to_estimator( - keras_model=model, config=config, model_dir=args[1]) + # Define a DistributionStrategy and convert the tf.keras Model to a + # tf.Estimator that utilizes the DistributionStrategy. + strategy = tf.contrib.distribute.MirroredStrategy( + ['/device:GPU:0', '/device:GPU:1']) + config = tf.estimator.RunConfig(train_distribute=strategy) + keras_estimator = tf.keras.estimator.model_to_estimator( + keras_model=model, config=config, model_dir=model_dir) + + # Train and evaluate the tf.Estimator. keras_estimator.train(input_fn=input_fn, steps=10) eval_result = keras_estimator.evaluate(input_fn=input_fn) print('Eval result: {}'.format(eval_result)) + if __name__ == '__main__': - tf.app.run(argv=argv) + tf.app.run(argv=sys.argv) diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py new file mode 100644 index 00000000000..1f24f629479 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_ops.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Input-pipeline utilities for Distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging + +# TODO(priyag): Any other reader datasets to consider here? +_READER_DATASET_OPS = [ + "TextLineDataset", + "TFRecordDataset", + "FixedLengthRecordDataset" +] + + +# pylint: disable=protected-access +def auto_shard_dataset(dataset, num_shards, index): + """Shard the input pipeline by sharding the underlying list of files. + + Args: + dataset: A `tf.data.Dataset` instance, typically the result of a bunch of + dataset transformations. + num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of + shards operating in parallel. Same usage as in `Dataset.shard`. + index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. + Same usage as in `Dataset.shard`. + + Returns: + A modified `Dataset` obtained by updating the pipeline sharded by the + files. + + Raises: + NotImplementedError: If we cannot automatically determine a good way to + shard the input dataset. + """ + + # TODO(priyag): Clone datasets instead of updating in place, similar to the + # clone method for TFRecordDataset. + def _auto_shard_impl(dataset, found_reader_op): + """Recursive implementation of auto sharding.""" + + if not found_reader_op: + # TODO(priyag): Make this check more robust by enforcing some common + # property on reader datasets. + if (isinstance(dataset, readers.TextLineDataset) or + isinstance(dataset, readers.FixedLengthRecordDataset)): + filenames_tensor = dataset._filenames + num_files = array_ops.size(filenames_tensor) + sharded_filenames_tensor = array_ops.gather( + filenames_tensor, math_ops.range(index, num_files, num_shards)) + dataset._filenames = sharded_filenames_tensor + return dataset + elif isinstance(dataset, readers.TFRecordDataset): + # `TFRecordDataset` needs to be handled separately than other readers + # because it converts filenames to a dataset first. Also, we clone it + # instead of updating in place because it has special logic in the + # constructor. Eventually we will change all cases to clone datasets + # instead of updating in-place. + return dataset._clone( + filenames=dataset._filenames.shard(num_shards, index)) + elif hasattr(dataset, "_map_func"): + # TODO(priyag): Make this check more robust by enforcing some common + # property on all map/flatmap/interleave datasets. + map_func_def = dataset._map_func.definition + for node in map_func_def.node_def: + if node.op in _READER_DATASET_OPS: + found_reader_op = True + break + elif node.op == "FlatMapDataset": + # TODO(priyag): Should this check for other map datasets? Should it + # be recursive? It is too specific to implementation of + # TFRecordDataset right now. + nested_func_name = node.attr["f"].func.name + nested_func = ops.get_default_graph()._functions[nested_func_name] + for nested_node in nested_func.definition.node_def: + if nested_node.op in _READER_DATASET_OPS: + found_reader_op = True + break + if found_reader_op: + break + if found_reader_op: + dataset._input_dataset = _auto_shard_impl( + dataset._input_dataset, found_reader_op) + return dataset + + # TODO(priyag): Make _input_dataset(s) a common property of all datasets to + # make this check more robust. + if hasattr(dataset, "_input_dataset"): + dataset._input_dataset = _auto_shard_impl( + dataset._input_dataset, found_reader_op) + if hasattr(dataset, "_dataset_to_concatenate"): + # Special case for `ConcatentateDataset`. We want to shard all input + # datasets. + dataset._dataset_to_concatenate = _auto_shard_impl( + dataset._dataset_to_concatenate, found_reader_op) + return dataset + + if hasattr(dataset, "_datasets"): + # Special case for `ZipDataset`. + dataset._datasets = nest.pack_sequence_as(dataset._datasets, [ + _auto_shard_impl(ds, found_reader_op) + for ds in nest.flatten(dataset._datasets) + ]) + return dataset + + if not found_reader_op: + tf_logging.warn( + "Could not find a standard reader in the input pipeline" + "(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)." + "Falling back to sharding the dataset anyway. Please verify" + "correctness of auto-sharding for your input.") + + # TODO(priyag): What do we want to do if the number of filenames is + # uneven in the number of shards? By default, this will just return as + # many items it can before throwing OutOfRangeError. + # TODO(priyag): This will shard the filenames before any shuffling of the + # filename dataset. It might be desirable to shard after shuffling + # filenames? If so, how do we achieve that? + return dataset.shard(num_shards, index) + + return _auto_shard_impl(dataset=dataset, found_reader_op=False) diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py new file mode 100644 index 00000000000..16179c3a490 --- /dev/null +++ b/tensorflow/contrib/distribute/python/input_ops_test.py @@ -0,0 +1,265 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for input pipeline modifications for distribution strategies.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.contrib.distribute.python import input_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import errors +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class AutoShardDatasetTest(test.TestCase): + + def setUp(self): + super(AutoShardDatasetTest, self).setUp() + self._num_files = 10 + self._num_records = 4 + self._num_shards = 2 + self._shard_index = 0 + self._record_bytes = 10 + + def _record(self, r, f): + return compat.as_bytes("Record %d of file %d" % (r, f)) + + def _text_line(self, r, f): + return compat.as_bytes("Text line %d of file %d" % (r, f)) + + def _fixed_length_record(self, r, f): + return compat.as_bytes(str((r * f) % 10) * self._record_bytes) + + def _createTFRecordFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + record = self._record(j, i) + writer.write(record) + writer.close() + return filenames + + def _createTextFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) + filenames.append(fn) + contents = [] + for j in range(self._num_records): + contents.append(self._text_line(j, i)) + if j + 1 != self._num_records or i == 0: + contents.append(b"\r\n") + contents = b"".join(contents) + + with open(fn, "wb") as f: + f.write(contents) + return filenames + + def _createFixedLengthRecordFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) + filenames.append(fn) + with open(fn, "wb") as f: + for j in range(self._num_records): + f.write(self._fixed_length_record(j, i)) + return filenames + + def _verifySimpleShardingOutput(self, dataset, record_fn): + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(record_fn(r, f), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testTFRecordDataset(self): + dataset = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._record) + + def testFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._record) + + def testInterleave(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.interleave( + readers.TFRecordDataset, cycle_length=4, block_length=self._num_records) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Since block_length == num records in each file, the output will still + # contain records in order of files. + self._verifySimpleShardingOutput(dataset, self._record) + + def testParallelInterleave(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.apply(interleave_ops.parallel_interleave( + readers.TFRecordDataset, + cycle_length=4, + block_length=self._num_records)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Since block_length == num records in each file, the output will still + # contain records in order of files. + self._verifySimpleShardingOutput(dataset, self._record) + + def testListfiles(self): + filenames = self._createTFRecordFiles() + file_pattern = filenames[0].rsplit("/", 1)[0] + "/tf_record.*.txt" + dataset = dataset_ops.Dataset.list_files(file_pattern, shuffle=False) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + actual, expected = [], [] + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + actual.append(sess.run(next_element)) + expected.append(self._record(r, f)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self.assertAllEqual(expected, actual) + + def testComplexPipeline(self): + # Setup a complex input pipeline. + batch_size = 2 + num_epochs = 5 + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createTFRecordFiles()) + dataset = dataset.shuffle(buffer_size=self._num_files) + dataset = dataset.flat_map(readers.TFRecordDataset) + dataset = dataset.prefetch(buffer_size=batch_size) + dataset = dataset.shuffle(2 * self._num_files * self._num_records) + dataset = dataset.repeat(num_epochs) + dataset = dataset.apply(batching.map_and_batch( + lambda x: x, batch_size=batch_size)) + dataset = dataset.prefetch(buffer_size=None) + + # Auto shard. + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + # Verify output. + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + actual = [] + num_iterations = (self._num_files * self._num_records * num_epochs) // ( + self._num_shards * batch_size) + for _ in range(num_iterations): + actual.extend(sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + expected = [] + for f in range(0, self._num_files, self._num_shards): + for r in range(self._num_records): + expected.append(self._record(r, f)) + expected *= num_epochs + + self.assertAllEqual(sorted(expected), sorted(actual)) + + def testZip(self): + dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset2 = readers.TextLineDataset(self._createTextFiles()) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + record_fn = lambda r, f: (self._record(r, f), self._text_line(r, f)) + self._verifySimpleShardingOutput(dataset, record_fn) + + def testConcat(self): + dataset1 = readers.TFRecordDataset(self._createTFRecordFiles()) + dataset2 = readers.TextLineDataset(self._createTextFiles()) + dataset = dataset1.concatenate(dataset2) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + with self.test_session() as sess: + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(self._record(r, f), sess.run(next_element)) + for f in range(self._shard_index, self._num_files, self._num_shards): + for r in range(self._num_records): + self.assertAllEqual(self._text_line(r, f), sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testTextLineReader(self): + dataset = readers.TextLineDataset(self._createTextFiles()) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._text_line) + + def testTextLineReaderWithFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices(self._createTextFiles()) + dataset = dataset.flat_map(readers.TextLineDataset) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._text_line) + + def testFixedLengthReader(self): + dataset = readers.FixedLengthRecordDataset( + self._createFixedLengthRecordFiles(), self._record_bytes) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + + def testFixedLengthReaderWithFlatMap(self): + dataset = dataset_ops.Dataset.from_tensor_slices( + self._createFixedLengthRecordFiles()) + dataset = dataset.flat_map( + lambda f: readers.FixedLengthRecordDataset(f, self._record_bytes)) + dataset = input_ops.auto_shard_dataset( + dataset, self._num_shards, self._shard_index) + + self._verifySimpleShardingOutput(dataset, self._fixed_length_record) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 0fa90df79bb..5c056a7c73d 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import mirrored_strategy from tensorflow.contrib.distribute.python.single_loss_example import batchnorm_example from tensorflow.contrib.distribute.python.single_loss_example import minimize_loss_example +from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -42,16 +43,30 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.times( combinations.distributions_and_v1_optimizers(), combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]))) - def testTrainNetwork(self, distribution, optimizer_fn, - use_callable_loss=True): + + combinations.combine(mode=["eager"], use_callable_loss=[True]), + combinations.combine(is_tpu=[False])) + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + # TODO(isaprykin): Make Adam v2 work with while_loops + # and TPUs. + ], + mode=["graph"], + use_callable_loss=[False], + is_tpu=[True])) + def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss, + is_tpu): with distribution.scope(): - model_fn, dataset, layer = minimize_loss_example( - optimizer_fn, - use_bias=True, - use_callable_loss=use_callable_loss) + model_fn, dataset_fn, layer = minimize_loss_example( + optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset(dataset) + # TODO(isaprykin): Eliminate `is_tpu`. Probably add a + # `DistributionStrategy.create_monitor` so that each DistributionStrategy + # could influence its training loop. That method would return an instance + # of Monitor. TPUMonitor would execute tpu.initialize_system() and + # tpu.shutdown_system(). + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( @@ -60,6 +75,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -70,6 +87,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): weights.append(self.evaluate(distribution.fetch(layer.kernel))) biases.append(self.evaluate(distribution.fetch(layer.bias))) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + error = abs(numpy.add(numpy.squeeze(weights), numpy.squeeze(biases)) - 1) is_not_increasing = all(y <= x for x, y in zip(error, error[1:])) self.assertTrue(is_not_increasing) @@ -78,8 +99,18 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): combinations.times( combinations.distributions_and_v1_optimizers() + combinations.distributions_and_v2_optimizers(), - combinations.combine(mode=["graph", "eager"]))) - def testOptimizerInsideModelFn(self, distribution, optimizer_fn): + combinations.combine(mode=["graph", "eager"], is_tpu=[False])) + + combinations.combine( + distribution=[combinations.tpu_strategy], + optimizer_fn=[ + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn, + ], + mode=["graph"], + is_tpu=[True])) + + def testOptimizerInsideModelFn(self, distribution, optimizer_fn, is_tpu): created_variables = [] trainable_variables = [] @@ -94,13 +125,14 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # `distribution.scope`. with variable_scope.variable_creator_scope( appending_creator), distribution.scope(): - model_fn, dataset, layer = minimize_loss_example( + model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=True, create_optimizer_inside_model_fn=True) - iterator = distribution.distribute_dataset(dataset) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( @@ -109,11 +141,17 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) run_step() + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + def get_expected_variables(optimizer_fn, num_parameter_devices): variables_map = { "GradientDescent": ["dense/kernel", "dense/bias"], @@ -137,40 +175,59 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): set(created_variables)) @combinations.generate( - combinations.times(combinations.distributions_and_v1_optimizers(), - combinations.combine( - mode=["graph", "eager"], - momentum=[0.8, 0.9, 0.99], - renorm=[False, True]))) + combinations.times( + combinations.combine(momentum=[0.8, 0.9, 0.99], renorm=[False, True]), + combinations.times( + combinations.distributions_and_v1_optimizers(), + combinations.combine( + mode=["graph", "eager"], + is_tpu=[False], + # TODO(isaprykin): Allow False here. Currently subsequent + # towers will re-execute UPDATE_OPS of previous towers. + update_ops_in_cross_tower_mode=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + mode=["graph"], + is_tpu=[True], + update_ops_in_cross_tower_mode=[False]))) def testTrainNetworkWithBatchNorm(self, distribution, optimizer_fn, momentum, - renorm): + renorm, is_tpu, + update_ops_in_cross_tower_mode): """Verifies that moving mean updates are reduced across towers.""" with distribution.scope(): num_towers = len(distribution.worker_devices) - model_fn, dataset, batchnorm = batchnorm_example( + model_fn, dataset_fn, batchnorm = batchnorm_example( optimizer_fn, batch_per_epoch=num_towers, momentum=momentum, - renorm=renorm) + renorm=renorm, + update_ops_in_tower_mode=not update_ops_in_cross_tower_mode) - # Disable prefetching since that makes the specific input on each device - # to be non deterministic, and this test relies on specific input being - # on each device. + # Make sure prefetching is disabled since that makes the + # specific input on each device to be non deterministic, and + # this test relies on specific input being on each device. if isinstance(distribution, mirrored_strategy.MirroredStrategy): - distribution._prefetch_on_device = False - iterator = distribution.distribute_dataset(dataset) + self.assertFalse(distribution._prefetch_on_device) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): - return control_flow_ops.group( - distribution.unwrap( - distribution.call_for_each_tower( - model_fn, - iterator.get_next(), - run_concurrently=batchnorm.built)) + - ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + fetches = distribution.unwrap( + distribution.call_for_each_tower( + model_fn, iterator.get_next(), + run_concurrently=batchnorm.built)) + if update_ops_in_cross_tower_mode: + fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) + return control_flow_ops.group(fetches) if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) @@ -194,22 +251,40 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): expected_moving_mean - averaged_batch_mean(i)) * (1.0 - momentum)) self.assertNear(expected_moving_means[i], moving_means[i], 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + @combinations.generate( combinations.times( combinations.combine( - distribution=[combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus], - optimizer_fn=[combinations.gradient_descent_optimizer_v1_fn, - combinations.gradient_descent_optimizer_v2_fn], - loss_reduction=[losses_impl.Reduction.SUM, - losses_impl.Reduction.MEAN, - losses_impl.Reduction.SUM_OVER_BATCH_SIZE, - losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS]), - combinations.combine(mode=["graph"], use_callable_loss=[True, False]) - + combinations.combine(mode=["eager"], use_callable_loss=[True]))) + optimizer_fn=[ + combinations.gradient_descent_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v2_fn + ], + loss_reduction=[ + losses_impl.Reduction.SUM, losses_impl.Reduction.MEAN, + losses_impl.Reduction.SUM_OVER_BATCH_SIZE, + losses_impl.Reduction.SUM_OVER_NONZERO_WEIGHTS + ]), + combinations.times( + combinations.combine( + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + is_tpu=[False]), + combinations.combine( + mode=["graph"], use_callable_loss=[True, False]) + + combinations.combine(mode=["eager"], use_callable_loss=[True])) + + combinations.combine( + distribution=[combinations.tpu_strategy_single_iteration], + is_tpu=[True], + mode=["graph"], + use_callable_loss=[True, False]))) def testMeanVsSum(self, distribution, optimizer_fn, loss_reduction, - use_callable_loss): + use_callable_loss, is_tpu): with distribution.scope(): all_vars = [] @@ -230,10 +305,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): else: return optimizer.minimize(loss_fn()) - features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) - labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) - dataset = dataset_ops.Dataset.zip((features, labels)).repeat() - iterator = distribution.distribute_dataset(dataset) + def dataset_fn(): + features = dataset_ops.Dataset.from_tensors([[2.], [7.]]) + labels = dataset_ops.Dataset.from_tensors([[6.], [21.]]) + return dataset_ops.Dataset.zip((features, labels)).repeat() + + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( @@ -242,12 +320,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.test_session() as sess: + if is_tpu: + sess.run(tpu.initialize_system()) run_step = sess.make_callable(run_step()) self.evaluate(variables_lib.global_variables_initializer()) run_step() - self.assertEqual(distribution.num_towers, len(all_vars)) v = all_vars[0] self.assertTrue(all([v is vi for vi in all_vars[1:]])) weight = numpy.squeeze(self.evaluate(distribution.fetch(v))) @@ -274,6 +353,10 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): # One of the mean loss reductions. self.assertNear(weight, 2 + 10.6, 0.0001) + if is_tpu: + with self.test_session() as sess: + sess.run(tpu.shutdown_system()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index eb0edb3a11d..89f2c431fec 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -80,6 +80,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + # TODO(yuefengz): consider setting the default device. def _create_variable(self, next_creator, *args, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" @@ -110,10 +111,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): kwargs["name"] = "%s/replica_%d" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): - initial_value = index[devices[0]].value() + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) else: - initial_value = index[devices[0]].initial_value - kwargs["initial_value"] = array_ops.identity(initial_value) + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) @@ -140,10 +144,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): g.add_to_collections(collections, result) return result - def distribute_dataset(self, dataset): - per_device_dataset = values.PerDeviceDataset( - dataset, self._devices, self._prefetch_on_device) - return per_device_dataset.make_one_shot_iterator() + def distribute_dataset(self, dataset_fn): + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. @@ -321,7 +325,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" - assert isinstance(destination, six.string_types) if isinstance(val, values.TowerLocalVariable): val = self.reduce(val.reduce_method, val, destinations=destination) with ops.device(destination): diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 9e9f06da8e2..3f9a02b249d 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -28,9 +28,12 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import distribute as distribute_lib @@ -116,7 +119,6 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase): self.assertEqual(expected, self.evaluate(unwrapped[0])) -@test_util.with_c_api class MirroredStrategyVariableCreationTest(test.TestCase): config = config_pb2.ConfigProto() @@ -247,8 +249,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) - features = dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) - features = dist.distribute_dataset(features).get_next() + features = dist.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + ).make_one_shot_iterator().get_next() with dist.scope(): result = dist.call_for_each_tower( @@ -369,22 +372,27 @@ class MirroredStrategyVariableCreationTest(test.TestCase): expected_sum = 0.0 expected_mean = 0.0 for i, d in enumerate(dist.worker_devices): - # Test access within a device scope, should see different values. - with ops.device(d): - v_sum_value = self.evaluate(ret_v_sum.read_value()) - v_mean_value = self.evaluate(ret_v_mean.read_value()) - expected = i + 3.0 - self.assertEqual(expected, v_sum_value) - expected_sum += expected - expected = i * 6.0 - self.assertEqual(expected, v_mean_value) - expected_mean += expected - - # fetch() should return the value you get by applying the - # reduction across all towers. - self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) + # Should see different values on different devices. + v_sum_value = self.evaluate(ret_v_sum.get(d).read_value()) + v_mean_value = self.evaluate(ret_v_mean.get(d).read_value()) + expected = i + 3.0 + self.assertEqual(expected, v_sum_value) + expected_sum += expected + expected = i * 6.0 + self.assertEqual(expected, v_mean_value) + expected_mean += expected expected_mean /= len(dist.worker_devices) + + # Without get(device), should return the value you get by + # applying the reduction across all towers (whether you use + # fetch(), get(), or nothing). + self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum))) self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean))) + self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get())) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get())) + if not context.executing_eagerly(): + self.assertEqual(expected_sum, self.evaluate(ret_v_sum)) + self.assertEqual(expected_mean, self.evaluate(ret_v_mean)) # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not # testing this in eager mode. @@ -430,6 +438,30 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + def testDynamicRnnVariables(self): + def model_fn(): + inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) + cell_fw = rnn_cell_impl.LSTMCell(300) + cell_bw = rnn_cell_impl.LSTMCell(300) + (outputs, _) = rnn.bidirectional_dynamic_rnn( + cell_fw, + cell_bw, + inputs, + dtype=dtypes.float32) + return outputs + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + # Two variables are created by the RNN layer. + self.assertEquals(2, len(result)) + for v in result: + self.assertIsInstance(v, values.DistributedValues) + _, v1 = dist.unwrap(v) + self.assertStartsWith(v1.name, "tower_1/") + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py index a1ef0ecc77a..61cbe6df813 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.training import distribute as distribute_lib -@test_util.with_c_api class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): @@ -53,7 +52,6 @@ class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase): self._test_call_and_merge_exceptions(self._get_distribution_strategy()) -@test_util.with_c_api class VariableCreatorStackTest(test.TestCase): def testCreatorStacksAreThreadLocal(self): diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py new file mode 100644 index 00000000000..a552b370ebf --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -0,0 +1,141 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes implementing a mirrored DistributionStrategy for multiple workers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import partial + +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.python.training import device_util +from tensorflow.python.training import server_lib +from tensorflow.python.util import nest + + +# TODO(yuefengz): support between-graph replication. +# TODO(yuefengz): merge this class into its base class. +# TODO(yuefengz): in some cases, we probably want to use configure method to +# configure this class. +# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the +# class is introduced. +class MultiWorkerMirroredStrategy(MirroredStrategy): + """Mirrored strategy that works on multiple workers with in-graph replication. + + There are several important concepts for distributed TensorFlow, e.g. + `client`, `job`, 'task', `cluster`, `in-graph replication` and + 'synchronous training' and they have already been defined in the + [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed). + The distribution strategy inherits these concepts as well and in addition to + that we also clarify several more concepts: + * **In-graph replication**: the `client` creates a single `tf.Graph` that + specifies tasks for devices on all workers. The `client` then creates a + client session which will talk to the `master` service of a `worker`. Then + the `master` will parition the graph and distribute the work to all + participating workers. + * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one + physical machine. We will have multiple `worker`s with different `task` + index. They all do similar things except for one worker checkpointing model + variables, writing summaries, etc. in addition to its ordinary work. + + This class maps one tower to one device on a worker. It mirrors all model + variables on all towers. For example, if you have two `worker`s and each + `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8 + GPUs. Then like in MirroredStrategy, each tower performs their computation + with their own copy of variables unless in cross-tower model where variable or + tensor reduction happens. + """ + + def __init__(self, + num_gpus_per_worker=1, + worker_job_name=None, + num_workers=None, + cluster=None, + cross_tower_ops=None, + prefetch_on_device=None): + """Initialize the strategy object. + + Args: + num_gpus_per_worker: number of GPUs per work. If it is zero, the local + CPU will be used. + worker_job_name: the job name for `worker`, typically just 'worker'. + num_workers: the number of workers. If it is 0, it regenerates to + single-worker MirroredStrategy. + cluster: a `tf.train.ClusterSpec` object or a dict that can be used to + construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef` + proto buffer. It is an alternative way to initialize this object. + cross_tower_ops: the cross tower ops to use. If None, a default one will + be used. If configure method is called, a best one for the configuration + will be chosen. + prefetch_on_device: a boolean to specify whether to prefetech input to + each worker's devices. + + Raises: + ValueError: if got an unexpected `cluster`. + """ + if cluster is None: + self._workers = [ + '/job:%s/task:%d' % (worker_job_name, task_index) + for task_index in range(num_workers) + ] + else: + if isinstance(cluster, (dict, cluster_pb2.ClusterDef)): + cluster_spec = server_lib.ClusterSpec(cluster) + elif isinstance(cluster, server_lib.ClusterSpec): + cluster_spec = cluster + else: + raise ValueError( + "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a " + '`tf.train.ClusterDef` object') + + self._workers = [] + for job in sorted(cluster_spec.jobs): + for task in range(cluster_spec.num_tasks(job)): + self._workers.append('/job:%s/task:%d' % (job, task)) + + self._num_gpus_per_worker = num_gpus_per_worker + if num_gpus_per_worker > 0: + self._worker_device_map = { + worker: [ + device_util.canonicalize(worker + '/device:GPU:%d' % gpu) + for gpu in range(num_gpus_per_worker) + ] for worker in self._workers + } + else: + self._worker_device_map = { + worker: [device_util.canonicalize(worker, '/device:CPU:0')] + for worker in self._workers + } + self._devices = nest.flatten(self._worker_device_map.values()) + + super(MultiWorkerMirroredStrategy, self).__init__( + devices=self._devices, prefetch_on_device=prefetch_on_device) + + # Setting `_default_device` will add a device scope in the + # distribution.scope. We set the default device to the first worker. When + # users specify device under distribution.scope by + # with tf.device("/cpu:0"): + # ... + # their ops will end up on the cpu device of its first worker, e.g. + # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode. + self._default_device = self._workers[0] + + def distribute_dataset(self, dataset_fn): + return values.MultiWorkerDataset( + partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, + self._prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py new file mode 100644 index 00000000000..09c859b32a3 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py @@ -0,0 +1,62 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MultiWorkerMirroredStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base +from tensorflow.contrib.distribute.python import strategy_test_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.training import server_lib + + +class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase, + strategy_test_lib.DistributionTestBase): + + def _get_distribution_strategy(self): + return multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster=server_lib.ClusterSpec({ + 'worker': ['/job:worker/task:0', '/job:worker/task:1'] + }), + num_gpus_per_worker=context.num_gpus()) + + def testMinimizeLossGraph(self): + self._test_minimize_loss_graph(self._get_distribution_strategy()) + + +class DeviceScopeTest(test.TestCase): + """Test the device scope of MultiWorkerMirroredStrategy.""" + + def testDeviceScope(self): + with context.graph_mode(): + strategy = multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']}, + num_gpus_per_worker=context.num_gpus()) + with strategy.scope(): + a = constant_op.constant(1.) + with ops.device('/cpu:0'): + b = constant_op.constant(1.) + self.assertEqual(a.device, '/job:worker/task:0') + self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py new file mode 100644 index 00000000000..f659be5f425 --- /dev/null +++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py @@ -0,0 +1,90 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base testing class for strategies that require multiple nodes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import copy + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util + + +class MultiWorkerTestBase(test.TestCase): + """Base class for testing multi node strategy and dataset.""" + + @classmethod + def setUpClass(cls): + """Create a local cluster with 2 workers.""" + num_workers = 2 + # Leave some memory for cuda runtime. + gpu_mem_frac = 0.7 / num_workers + default_config = config_pb2.ConfigProto() + default_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac + + # The local cluster takes some portion of the local GPUs and there is no way + # for the cluster to terminate unless using multiple processes. Therefore, + # we have to only create only one cluster throughout a test process. + workers, _ = test_util.create_local_cluster( + num_workers, num_ps=0, worker_config=default_config) + cls._master_target = workers[0].target + + @contextlib.contextmanager + def test_session(self, graph=None, config=None): + """Create a test session with master target set to the testing cluster. + + This overrides the base class' method, removes arguments that are not needed + by the multi-node case and creates a test session that connects to the local + testing cluster. + + Args: + graph: Optional graph to use during the returned session. + config: An optional config_pb2.ConfigProto to use to configure the + session. + + Yields: + A Session object that should be used as a context manager to surround + the graph building and execution code in a test case. + """ + if self.id().endswith('.test_session'): + self.skipTest('Not a test.') + + if config is None: + config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + config = copy.deepcopy(config) + # Don't perform optimizations for tests so we don't inadvertently run + # gpu ops on cpu + config.graph_options.optimizer_options.opt_level = -1 + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + + if graph is None: + if self._cached_session is None: # pylint: disable=access-member-before-definition + self._cached_session = session.Session( + graph=None, config=config, target=self._master_target) + sess = self._cached_session + with sess.graph.as_default(), sess.as_default(): + yield sess + else: + with session.Session( + graph=graph, config=config, target=self._master_target) as sess: + yield sess diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 39c49442b9c..09b6d4a515a 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -21,8 +21,6 @@ from __future__ import print_function import six from tensorflow.contrib.distribute.python import values -from tensorflow.contrib.eager.python import datasets -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -38,9 +36,11 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): # doing something that won't work with other DistributionStrategy # implementations? - def __init__(self, device): + def __init__(self, device, prefetch_on_device=None): super(OneDeviceStrategy, self).__init__() self._device = device + self._prefetch_on_device = prefetch_on_device + self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): # No need to distinguish tower-local variables when not mirroring, @@ -62,11 +62,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.colocate_with(colocate_with): return next_creator(*args, **kwargs) - def distribute_dataset(self, dataset): - if context.executing_eagerly(): - return datasets.Iterator(dataset) - else: - return dataset.make_one_shot_iterator() + def distribute_dataset(self, dataset_fn): + return values.PerDeviceDataset( + self._call_dataset_fn(dataset_fn), [self._device], + self._prefetch_on_device) def _broadcast(self, tensor, destinations): return tensor diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py index 7101ed0756f..7aad8a953cb 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py @@ -24,7 +24,6 @@ from tensorflow.python.eager import test from tensorflow.python.framework import test_util -@test_util.with_c_api class OneDeviceStrategyTest(strategy_test_lib.DistributionTestBase): def _get_distribution_strategy(self): diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index a0912b625f4..abd3a65ac4e 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -39,10 +39,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): def testTrainNetwork(self, distribution, optimizer_fn, use_callable_loss=True): with distribution.scope(): - model_fn, dataset, layer = minimize_loss_example( + model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset(dataset) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index dfcbb8568f9..7b3670b45ab 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -26,6 +26,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest as data_nest from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops @@ -34,26 +35,55 @@ from tensorflow.python.util import nest # pylint: disable=protected-access class _PrefetchToDeviceIterator(object): - """A replacement for @{tf.data.Iterator} that prefetches to another device.""" + """A replacement for @{tf.data.Iterator} that prefetches to another device. - def __init__(self, input_dataset, devices, buffer_size): + Args: + input_dataset: The input dataset. + one_shot: If true, we make a one shot iterator that's already initialized. + devices: Devices on which to prefetch. + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be + shared under the given name across multiple sessions that share the + same devices (e.g. when using a remote server). Only used if one_shot + is False. + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + devices, + buffer_size, + shared_name=None): self._input_dataset = input_dataset self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" self._devices = devices - input_iterator = input_dataset.make_one_shot_iterator() - input_iterator_handle = input_iterator.string_handle() + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() @function.Defun(dtypes.string) def _prefetch_fn(handle): """Prefetches one element from `input_iterator`.""" remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, input_iterator.output_types, input_iterator.output_shapes, - input_iterator.output_classes) + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) target_device = gen_dataset_ops.iterator_get_device( - input_iterator._iterator_resource) + self._input_iterator._iterator_resource) self._buffering_resources = [] for device in nest.flatten(self._devices): with ops.device(device): @@ -61,9 +91,19 @@ class _PrefetchToDeviceIterator(object): f=_prefetch_fn, target_device=target_device, string_arg=input_iterator_handle, - buffer_size=buffer_size) + buffer_size=buffer_size, + shared_name=shared_name) self._buffering_resources.append(buffer_resource_handle) + if not self._one_shot: + reset_ops = [] + for buffer_resource in self._buffering_resources: + reset_ops.append( + prefetching_ops.function_buffering_resource_reset(buffer_resource)) + with ops.control_dependencies(reset_ops): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + def get_next(self, name=None): """See @{tf.data.Iterator.get_next}.""" self._get_next_call_count += 1 @@ -92,6 +132,12 @@ class _PrefetchToDeviceIterator(object): return nest.pack_sequence_as(self._devices, flat_result) + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + @property def output_classes(self): return self._input_dataset.output_classes @@ -115,13 +161,24 @@ class _PrefetchToDeviceDataset(dataset_ops.Dataset): self._buffer_size = buffer_size if buffer_size is not None else 1 def make_one_shot_iterator(self): - return _PrefetchToDeviceIterator(self._input_dataset, self._devices, - self._buffer_size) + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=True, + devices=self._devices, + buffer_size=self._buffer_size) def make_initializable_iterator(self, shared_name=None): - raise NotImplementedError("`prefetch_to_devices()` is not currently " - "compatible with initializable iterators. Use " - "`make_one_shot_iterator()` instead.") + if context.executing_eagerly(): + raise RuntimeError( + "make_initializable_iterator is not supported when eager " + "execution is enabled.") + + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + devices=self._devices, + buffer_size=self._buffer_size, + shared_name=shared_name) def _as_variant_tensor(self): # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py index 8ed16f46078..a68dbce6c7d 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -64,5 +64,27 @@ class PrefetchingOpsV2Test(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testPrefetchToTwoDevicesWithReinit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(5): + sess.run(next_element) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + sess.run(iterator.initializer) + for _ in range(5): + sess.run(next_element) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py index 713494d603b..a0b452fc2d4 100644 --- a/tensorflow/contrib/distribute/python/shared_variable_creator_test.py +++ b/tensorflow/contrib/distribute/python/shared_variable_creator_test.py @@ -44,7 +44,6 @@ class CanonicalizeVariableNameTest(test.TestCase): self.assertEquals("foo_a", self._canonicalize("foo_a")) -@test_util.with_c_api class SharedVariableCreatorTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() diff --git a/tensorflow/contrib/distribute/python/single_loss_example.py b/tensorflow/contrib/distribute/python/single_loss_example.py index cef5fd2f894..d1fdb3279cf 100644 --- a/tensorflow/contrib/distribute/python/single_loss_example.py +++ b/tensorflow/contrib/distribute/python/single_loss_example.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import step_fn from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.layers import normalization from tensorflow.python.ops import array_ops @@ -29,7 +31,10 @@ from tensorflow.python.ops import math_ops def single_loss_example(optimizer_fn, distribution, use_bias=False): """Build a very simple network to use in tests and examples.""" - dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + + def dataset_fn(): + return dataset_ops.Dataset.from_tensors([[1.]]).repeat() + optimizer = optimizer_fn() layer = core.Dense(1, use_bias=use_bias) @@ -37,8 +42,8 @@ def single_loss_example(optimizer_fn, distribution, use_bias=False): y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y - single_loss_step = step_fn.StandardSingleLossStep(dataset, loss_fn, optimizer, - distribution) + single_loss_step = step_fn.StandardSingleLossStep(dataset_fn, loss_fn, + optimizer, distribution) # Layer is returned for inspecting the kernels in tests. return single_loss_step, layer @@ -49,7 +54,14 @@ def minimize_loss_example(optimizer_fn, use_callable_loss=True, create_optimizer_inside_model_fn=False): """Example of non-distribution-aware legacy code.""" - dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + + def dataset_fn(): + dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat() + # TODO(isaprykin): map_and_batch with drop_remainder causes shapes to be + # fully defined for TPU. Remove this when XLA supports dynamic shapes. + return dataset.apply( + batching.map_and_batch(lambda x: x, batch_size=1, drop_remainder=True)) + # An Optimizer instance is created either outside or inside model_fn. outer_optimizer = None if not create_optimizer_inside_model_fn: @@ -71,32 +83,43 @@ def minimize_loss_example(optimizer_fn, else: return optimizer.minimize(loss_fn()) - return model_fn, dataset, layer + return model_fn, dataset_fn, layer def batchnorm_example(optimizer_fn, batch_per_epoch=1, momentum=0.9, - renorm=False): + renorm=False, + update_ops_in_tower_mode=False): """Example of non-distribution-aware legacy code with batch normalization.""" - # input shape is [16, 8], input values are increasing in both dimensions. - dataset = dataset_ops.Dataset.from_tensor_slices( - [[[float(x * 8 + y + z * 100) - for y in range(8)] - for x in range(16)] - for z in range(batch_per_epoch)]).repeat() + + def dataset_fn(): + # input shape is [16, 8], input values are increasing in both dimensions. + return dataset_ops.Dataset.from_tensor_slices( + [[[float(x * 8 + y + z * 100) + for y in range(8)] + for x in range(16)] + for z in range(batch_per_epoch)]).repeat() + optimizer = optimizer_fn() batchnorm = normalization.BatchNormalization( renorm=renorm, momentum=momentum, fused=False) + layer = core.Dense(1, use_bias=False) def model_fn(x): + """A model that uses batchnorm.""" def loss_fn(): - y = math_ops.reduce_sum(batchnorm(x, training=True), axis=1) - loss = math_ops.reduce_mean(y - constant_op.constant(1.)) + y = batchnorm(x, training=True) + with ops.control_dependencies( + ops.get_collection(ops.GraphKeys.UPDATE_OPS) + if update_ops_in_tower_mode else []): + loss = math_ops.reduce_mean( + math_ops.reduce_sum(layer(y)) - constant_op.constant(1.)) + # `x` and `y` will be fetched by the gradient computation, but not `loss`. return loss # Callable loss. return optimizer.minimize(loss_fn) - return model_fn, dataset, batchnorm + return model_fn, dataset_fn, batchnorm diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 82514c64be4..d1910622b38 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -49,12 +49,14 @@ class StandardInputStep(Step): """Step with a standard implementation of input handling. Args: - input_dataset: a tf.data Dataset that provides input. + dataset_fn: a function that returns a tf.data Dataset that produces the + input for the model. """ - def __init__(self, input_dataset, distribution): + def __init__(self, dataset_fn, distribution): Step.__init__(self, distribution) - self._distributed_input = distribution.distribute_dataset(input_dataset) + self._distributed_input = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def inputs(self): return self._distributed_input.get_next() @@ -76,14 +78,15 @@ class StandardSingleLossStep(StandardInputStep): ``` Args: - input_dataset: a tf.data Dataset that provides input. + dataset_fn: a function that returns a tf.data Dataset that produces the + input for the model. loss_fn: a function that returns loss. optimizer: an optimizer that implements an update rule. distribution: a `DistributionStrategy` object. """ - def __init__(self, input_dataset, loss_fn, optimizer, distribution): - StandardInputStep.__init__(self, input_dataset, distribution) + def __init__(self, dataset_fn, loss_fn, optimizer, distribution): + StandardInputStep.__init__(self, dataset_fn, distribution) self._loss_fn = loss_fn self._optimizer = optimizer self._is_run_concurrently = False diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py new file mode 100644 index 00000000000..75441786a61 --- /dev/null +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -0,0 +1,128 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU Distribution Strategy. + +This is experimental. It's not ready for general use. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +from tensorflow.contrib import tpu +from tensorflow.contrib.distribute.python import one_device_strategy +from tensorflow.contrib.distribute.python import values +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import nest + + +class TPUStrategy(one_device_strategy.OneDeviceStrategy): + """Experimental TPU distribution strategy implementation.""" + + def __init__(self, + num_cores_per_host=2, + iterations_per_step=2): + # TODO(isaprykin): Generalize the defaults. They are currently tailored for + # the unit test. + super(TPUStrategy, self).__init__('/cpu:0') + # TODO(isaprykin): Auto-detect number of cores and hosts. + self._num_cores_per_host = num_cores_per_host + # TODO(isaprykin): This might have to be per-call. + self._iterations_per_step = iterations_per_step + + def distribute_dataset(self, dataset_fn): + return values.PerIterationDataset( + self._call_dataset_fn(dataset_fn), self._iterations_per_step, + self._num_cores_per_host) + + def _call_for_each_tower(self, fn, *args, **kwargs): + kwargs.pop('run_concurrently', None) + + inputs = {'args': args, 'kwargs': kwargs} + flat_inputs = nest.flatten(inputs) + + feed_mask = [isinstance(f, values.PerIteration) for f in flat_inputs] + + feeds = lambda: itertools.compress(flat_inputs, feed_mask) + shapes = [f.get_shape() for f in feeds()] + if any([not s.is_fully_defined() for s in shapes]): + raise ValueError( + 'TPU currently requires fully defined shapes. Either use ' + 'set_shape() on the input tensors or use ' + 'dataset.apply(map_and_batch(..., drop_remainder=True)).') + types = [f.get_dtype() for f in feeds()] + + def infeed_input(i): + """Get input, split it and then enqueue.""" + iteration_inputs = [f.get(i) for f in feeds()] + infeed_inputs = [[inputs_per_core[core_id] + for inputs_per_core in iteration_inputs] + for core_id in range(self._num_cores_per_host)] + + infeed_ops = [] + for core_id, infeed_input in enumerate(infeed_inputs): + infeed_ops.append( + tpu_ops.infeed_enqueue_tuple( + inputs=infeed_input, shapes=shapes, device_ordinal=core_id)) + + with ops.control_dependencies(infeed_ops): + return i + 1 + + with ops.device('/task:0/device:CPU:0'): + enqueue_ops = control_flow_ops.while_loop( + lambda i: i < self._iterations_per_step, + infeed_input, [constant_op.constant(0)], + parallel_iterations=1) + + def dequeueing_fn(*args, **kwargs): + """Dequeue input arguments and supply them to `fn`.""" + del args, kwargs + dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes) + dequeued = iter(dequeued) + + fn_inputs = [] + for inp, is_feed in zip(flat_inputs, feed_mask): + if is_feed: + fn_inputs.append(next(dequeued)) + else: + fn_inputs.append(inp) + + fn_inputs = nest.pack_sequence_as(inputs, fn_inputs) + return fn(*fn_inputs['args'], **fn_inputs['kwargs']) + + def iterate_on_tpu(): + return tpu.repeat(self._iterations_per_step, dequeueing_fn, []) + + with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access + tpu_result = tpu.batch_parallel( + iterate_on_tpu, [], num_shards=self._num_cores_per_host) + + return control_flow_ops.group(tpu_result, enqueue_ops) + + def _reduce(self, method_string, value, destinations): + del destinations # TPU is graph mode only. Rely on implicit Send/Recv. + if method_string == 'mean': + # TODO(jhseu): Revisit once we support model-parallelism. + value *= (1. / self._num_cores_per_host) + return tpu_ops.cross_replica_sum(value) + + @property + def num_towers(self): + return self._num_cores_per_host diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 87bf0590384..49b4e24daa4 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -27,16 +27,18 @@ import weakref import six from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 -from tensorflow.contrib.eager.python import datasets from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import checkpointable +from tensorflow.python.ops import math_ops from tensorflow.python.training import device_util from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest @@ -59,7 +61,7 @@ class DistributedValues(object): else: device = distribute_lib.get_update_device() if device is None: - device = device_util.current() + return self._get_cross_tower() device = device_util.canonicalize(device) try: return self._index[device] @@ -313,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored, def assign(self, *args, **kwargs): return self.get(device=_get_update_device()).assign(*args, **kwargs) + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return array_ops.identity(self._index[device]) + return array_ops.identity(self._primary_var) + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -357,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access +def _assert_tower_context(): + if not distribute_lib.get_tower_context(): + raise RuntimeError( + "Tower-local variables may only be assigned in a tower context.") + + class TowerLocalVariable(DistributedVariable, PerDevice, checkpointable.CheckpointableBase): """Holds a map from device to variables whose values are reduced on save.""" @@ -367,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice, super(TowerLocalVariable, self).__init__(index) def assign_sub(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_sub(*args, **kwargs) def assign_add(self, *args, **kwargs): + _assert_tower_context() return self.get().assign_add(*args, **kwargs) def assign(self, *args, **kwargs): + _assert_tower_context() return self.get().assign(*args, **kwargs) @property def reduce_method(self): return self._reduce_method + def _get_cross_tower(self): + all_components = tuple(self._index.values()) + # TODO(josh11b): Use a strategy-specific method. + total = math_ops.add_n(all_components) + if self._reduce_method == "mean": + return total * (1./ len(all_components)) + return total + + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._get_cross_tower() + return self.get()._as_graph_element() + def _gather_saveables_for_checkpoint(self): """Overrides CheckpointableBase method. @@ -510,6 +547,10 @@ class PerDeviceDataIterator(object): self._devices = devices self._prefetch_on_device = prefetch_on_device + @property + def initializer(self): + return self._iterator.initializer + def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: @@ -545,7 +586,8 @@ class PerDeviceDataset(object): "Prefetching is only supported in graph mode currently") if self._prefetch_on_device: - self._dataset = dataset + self._dataset = dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. @@ -555,18 +597,204 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" - if self._prefetch_on_device: - on_device_dataset = self._dataset.apply( - prefetching_ops_v2.prefetch_to_devices(self._devices)) - dataset_iterator = on_device_dataset.make_one_shot_iterator() - elif context.executing_eagerly(): - dataset_iterator = datasets.Iterator(self._dataset) - else: - dataset_iterator = self._dataset.make_one_shot_iterator() - + dataset_iterator = self._dataset.make_one_shot_iterator() return PerDeviceDataIterator( dataset_iterator, self._devices, self._prefetch_on_device) + def make_initializable_iterator(self): + """Get an initializable iterator for the distributed PerDeviceDataset.""" + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator( + dataset_iterator, self._devices, self._prefetch_on_device) + + +class MultiWorkerDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" + + def __init__(self, iterators, worker_device_map): + """Initialize the MultiWorkerDataIterator object. + + Args: + iterators: a dict mapping from each worker to an iterator for + that worker. + worker_device_map: a dict mapping from each worker's devices to a list of + devices that belong to this worker. + + Raises: + ValueError: if iterators and worker_device_map are not compatible. + """ + self._iterators = iterators + self._worker_device_map = worker_device_map + if set(self._iterators) != set(self._worker_device_map): + raise ValueError("iterators and worker_device_map are not compatible.") + + @property + def initializer(self): + return control_flow_ops.group( + [iterator.initializer for iterator in self._iterators.values()]) + + def get_next(self, name=None): + """Scatter the input across hosts and devices.""" + index = {} + for worker, iterator in six.iteritems(self._iterators): + if name is not None: + d = tf_device.DeviceSpec.from_string(worker) + new_name = "%s_%s_%d" % (name, d.job, d.task) + else: + new_name = None + with ops.device(worker): + data_per_worker = iterator.get_next(name=new_name) + + worker_devices = self._worker_device_map[worker] + # Ungroup these per-device value so as to get a flat map from devices to + # values. + for d in worker_devices: + v = select_device(d, data_per_worker) + if d in index: + raise ValueError("Duplicated devices in worker_device_map: %r" % v) + index[d] = v + + return regroup(index) + + +class MultiWorkerDataset(object): + """Like a `tf.data.Dataset` that distributes data to different workers. + + Each worker gets one shard of the input dataset. It is currently not working + in + eager mode. + """ + + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + """Initialize the MultiWorkerDataset object. + + Args: + dataset_fn: a function that returns a `tf.data.Dataset`. + worker_device_map: a dict mapping from each worker to a list of devices + that belong to this worker. + prefetch_on_device: whether to prefetch to devices. + """ + self._worker_device_map = worker_device_map + self._datasets = {} + # TODO(yuefengz, priyag): support different set of jobs for input + # processing. + for i, (worker, worker_devices) in enumerate( + six.iteritems(worker_device_map)): + with ops.device(worker): + worker_input = dataset_fn() + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) + self._datasets[worker] = PerDeviceDataset( + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + + def make_one_shot_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_one_shot_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + def make_initializable_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_initializable_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + +class _PerKey(object): + """Holds data associated by keys.""" + + def __init__(self, *index): + # pylint: disable=protected-access + self._index = list(index) + + def get(self, iteration): + return array_ops.gather(self._index, iteration) + + def get_shape(self): + return self._index[-1][-1].get_shape() + + def get_dtype(self): + return self._index[-1][-1].dtype + + def __str__(self): + return "%s:%s" % (self.__class__.__name__, self._index) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._index) + + +class PerIteration(_PerKey): + """Holds input for multiple iterations at once.""" + + def __init__(self, *index): + # pylint: disable=protected-access + super(PerIteration, self).__init__(*[batch._index for batch in index]) + + +class Batches(_PerKey): + pass + + +class MultiIterator(object): + """Iterator that returns results of multiple get_next()s.""" + + def __init__(self, dataset_iterator, iterations, batches_per_iteration): + self._dataset_iterator = dataset_iterator + self._iterations = iterations + self._batches_per_iteration = batches_per_iteration + + def get_next(self, name=None): + """Return PerIteration with `iterations x batches_per_iteration` inputs.""" + data = [] + for _ in range(self._batches_per_iteration): + batch = [] + for _ in range(self._iterations): + batch.append(self._dataset_iterator.get_next(name=name)) + data.append(batch) + + # Here is an example. Suppose each get_next returns a tuple of two tensors. + # For 3 `iterations` and 2 `batches_per_iteration`, the `data` is: + # [[(a,z), (b,y), (c,x)], [(A,Z), (B,Y), (C,X)]] + # + # After the first `map_structure` it gets transformed to: + # [(Batches(a, A), Batches(z, Z)), + # (Batches(b, B), Batches(y, Y)), + # (Batches(c, C), Batches(x, X))] + # + # After the second `map_structure` it gets transformed to a tuple of: + # (PerIteration([Batches(a, A), Batches(b, B), Batches(c, C)]), + # PerIteration([Batches(z, Z), Batches(y, Y), Batches(x, X)])) + + data = nest.map_structure(Batches, *data) + data = nest.map_structure(PerIteration, *data) + + return data + + @property + def initializer(self): + return self._dataset_iterator.initializer + + +class PerIterationDataset(object): + """A dataset that returns MultiIterators.""" + + def __init__(self, dataset, iterations, batches_per_iteration): + self._dataset = dataset + self._iterations = iterations + self._batches_per_iteration = batches_per_iteration + + def make_one_shot_iterator(self): + iterator = self._dataset.make_one_shot_iterator() + return MultiIterator(iterator, self._iterations, + self._batches_per_iteration) + + def make_initializable_iterator(self): + iterator = self._dataset.make_initializable_iterator() + return MultiIterator(iterator, self._iterations, + self._batches_per_iteration) + class MapOutput(object): """Map can result in multiple outputs per device.""" diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 5c0d4b7d6c7..1c95758d96a 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -18,9 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops @@ -32,12 +34,14 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest -@test_util.with_c_api class DistributedValuesTest(test.TestCase): def testGetEager(self): @@ -76,7 +80,6 @@ class DistributedValuesTest(test.TestCase): v = values.DistributedValues({"/device:cpu:0": 42}) -@test_util.with_c_api class DistributedDelegateTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() @@ -159,7 +162,6 @@ def _make_mirrored(): return v, devices, mirrored -@test_util.with_c_api class RegroupAndSelectDeviceTest(test.TestCase): def _is_per_device(self, result, expected, klass=values.PerDevice): @@ -312,7 +314,6 @@ class RegroupAndSelectDeviceTest(test.TestCase): merged_estimator_spec)) -@test_util.with_c_api class PerDeviceDatasetTest(test.TestCase): config = config_pb2.ConfigProto() @@ -408,8 +409,157 @@ class PerDeviceDatasetTest(test.TestCase): expected_values = [[i, i+1] for i in range(0, 10, 2)] self._test_iterator(devices, dataset, expected_values) + def testInitializableIterator(self): + with context.graph_mode(): + devices = ["/device:CPU:0"] + # Using random input since that is only allowed with initializable + # iterator. + dataset = dataset_ops.Dataset.from_tensor_slices( + random_ops.random_uniform((10,))) + + per_device_dataset = values.PerDeviceDataset( + dataset, devices, prefetch_on_device=False) + iterator = per_device_dataset.make_initializable_iterator() + + self.evaluate(iterator.initializer) + next_element = iterator.get_next() + for _ in range(10): + self.evaluate(next_element) + + # Should fail after the input is finished. + with self.assertRaises(errors.OutOfRangeError): + self.evaluate(next_element) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(iterator.initializer) + for _ in range(10): + self.evaluate(next_element) + + +class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase): + + def _test_iterator(self, iterator, devices, expected_values): + next_element = iterator.get_next() + for device in devices: + v = values.select_device(device, next_element) + # The `v` here can be a tuple. + for element in nest.flatten(v): + self.assertTrue(element.device in device) + + for expected_value in expected_values: + actual = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, actual) + + with self.assertRaises(errors.OutOfRangeError): + self.evaluate([values.select_device(d, next_element) for d in devices]) + + def _test_dataset(self, dataset_fn, worker_device_map, devices, + expected_values): + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_one_shot_iterator() + self._test_iterator(multi_worker_iterator, devices, expected_values) + + def _cpu_devices(self): + worker_device_map = collections.OrderedDict( + [("/job:worker/replica:0/task:0", + ["/job:worker/replica:0/task:0/device:CPU:0"]), + ("/job:worker/replica:0/task:1", + ["/job:worker/replica:0/task:1/device:CPU:0"])]) + devices = [ + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_device_map, devices + + def _cpu_and_one_gpu_devices(self): + # The worker_device_map doesn't have to be a OrderDict object, this is just + # to simplify the testing so that we can pass expected values as a list + # instead of a dict. + worker_device_map = collections.OrderedDict( + [("/job:worker/replica:0/task:0", [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0" + ]), ("/job:worker/replica:0/task:1", [ + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ])]) + devices = [ + "/job:worker/replica:0/task:0/device:GPU:0", + "/job:worker/replica:0/task:0/device:CPU:0", + "/job:worker/replica:0/task:1/device:GPU:0", + "/job:worker/replica:0/task:1/device:CPU:0" + ] + return worker_device_map, devices + + def testDataDistributionOneDevicePerWorker(self): + worker_device_map, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset(dataset_fn, worker_device_map, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + def testDataDistributionTwoDevicePerWorker(self): + if context.num_gpus() < 1: + self.skipTest("A GPU is not available for this test.") + worker_device_map, devices = self._cpu_and_one_gpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + self._test_dataset(dataset_fn, worker_device_map, devices, + [[0, 2, 1, 3], [4, 6, 5, 7]]) + + def testTupleDataset(self): + worker_device_map, devices = self._cpu_devices() + + with context.graph_mode(): + + def dataset_fn(): + dataset1 = dataset_ops.Dataset.range(8) + dataset2 = dataset_ops.Dataset.range(8).map(lambda x: x**2) + return dataset_ops.Dataset.zip((dataset1, dataset2)) + + expected_values = [ + [(i, i**2), (i + 1, (i + 1)**2)] for i in range(0, 8, 2) + ] + self._test_dataset(dataset_fn, worker_device_map, devices, + expected_values) + + def testInitializableIterator(self): + worker_device_map, devices = self._cpu_devices() + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + + self.evaluate(multi_worker_iterator.initializer) + self._test_iterator(multi_worker_iterator, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + # After re-initializing the iterator, should be able to iterate again. + self.evaluate(multi_worker_iterator.initializer) + self._test_iterator(multi_worker_iterator, devices, + [[0, 1], [2, 3], [4, 5], [6, 7]]) + + def testValueErrorForIterator(self): + # Incompatiable arguments. + with self.assertRaises(ValueError): + values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"}) + + # Test duplicated devices under same worker. + worker_device_map, _ = self._cpu_devices() + worker_device_map["/job:worker/replica:0/task:0"].append( + "/job:worker/replica:0/task:0/device:CPU:0") + with context.graph_mode(): + dataset_fn = lambda: dataset_ops.Dataset.range(8) + multi_worker_dataset = values.MultiWorkerDataset( + dataset_fn, worker_device_map, prefetch_on_device=False) + multi_worker_iterator = multi_worker_dataset.make_initializable_iterator() + with self.assertRaises(ValueError): + multi_worker_iterator.get_next() + -@test_util.with_c_api class MirroredVariableTest(test.TestCase): config = config_pb2.ConfigProto() @@ -555,6 +705,21 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) + @test_util.run_in_graph_and_eager_modes(config=config) + def testFetchAMirroredVariable(self): + if context.num_gpus() < 1 or context.executing_eagerly(): + self.skipTest("A GPU is not available for this test or it's eager mode.") + + with self.test_session( + graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( + ["/device:GPU:0"]).scope(): + with ops.device("/device:GPU:0"): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + sess.run(variables_lib.global_variables_initializer()) + sess.run({"complicated": mirrored}) + _devices = ["/device:GPU:0", "/device:CPU:0"] @@ -571,7 +736,6 @@ def _make_tower_local(method): return v, tower_local -@test_util.with_c_api class TowerLocalVariableTest(test.TestCase): config = config_pb2.ConfigProto() diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 20e432b88dc..6192f04c8b6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -94,7 +94,7 @@ cuda_py_test( cuda_py_test( name = "distribution_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/distribution_test.py"], additional_deps = [ ":distributions_py", @@ -337,7 +337,7 @@ cuda_py_test( cuda_py_test( name = "mvn_tril_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/mvn_tril_test.py"], additional_deps = [ ":distributions_py", @@ -372,6 +372,7 @@ cuda_py_test( "//tensorflow/python:random_ops", "//tensorflow/python:variables", ], + shard_count = 4, ) cuda_py_test( @@ -459,7 +460,7 @@ cuda_py_test( cuda_py_test( name = "batch_reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/batch_reshape_test.py"], additional_deps = [ ":distributions_py", @@ -578,7 +579,7 @@ cuda_py_test( cuda_py_test( name = "wishart_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/wishart_test.py"], additional_deps = [ ":distributions_py", @@ -709,6 +710,8 @@ cuda_py_test( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", ], + shard_count = 4, + tags = ["noasan"], # times out, http://b/78588814 ) cuda_py_test( @@ -865,7 +868,7 @@ cuda_py_test( cuda_py_test( name = "batch_normalization_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/batch_normalization_test.py"], additional_deps = [ ":bijectors_py", @@ -877,6 +880,7 @@ cuda_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], + tags = ["optonly"], ) cuda_py_test( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py index 59d549b7b80..f2bb2d3325a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/batch_reshape_test.py @@ -448,8 +448,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r"`batch_shape` size must match " - r"`distributions.batch_shape` size"): + with self.assertRaisesOpError(r"Shape sizes do not match."): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -457,8 +456,13 @@ class _BatchReshapeTest(object): def test_non_positive_shape(self): dims = 2 - new_batch_shape = [-1, -2] # -1*-2=2 so will pass size check. - old_batch_shape = [2] + old_batch_shape = [4] + if self.is_static_shape: + # Unknown first dimension does not trigger size check. Note that + # any dimension < 0 is treated statically as unknown. + new_batch_shape = [-1, 0] + else: + new_batch_shape = [-2, -2] # -2 * -2 = 4, same size as the old shape. new_batch_shape_ph = ( constant_op.constant(np.int32(new_batch_shape)) if self.is_static_shape @@ -471,7 +475,7 @@ class _BatchReshapeTest(object): mvn = mvn_lib.MultivariateNormalDiag(scale_diag=scale_ph) if self.is_static_shape: - with self.assertRaisesRegexp(ValueError, r".*must be positive.*"): + with self.assertRaisesRegexp(ValueError, r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, @@ -479,7 +483,7 @@ class _BatchReshapeTest(object): else: with self.test_session(): - with self.assertRaisesOpError(r".*must be positive.*"): + with self.assertRaisesOpError(r".*must be >=-1.*"): batch_reshape_lib.BatchReshape( distribution=mvn, batch_shape=new_batch_shape_ph, diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py index ca20442c394..dc45114b1c2 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency from tensorflow.python.platform import test @@ -188,6 +189,15 @@ class ChainBijectorTest(test.TestCase): -np.log(6, dtype=np.float32) - np.sum(x), self.evaluate(chain.inverse_log_det_jacobian(y, event_ndims=1))) + def testChainIldjWithPlaceholder(self): + chain = Chain((Exp(), Exp())) + samples = array_ops.placeholder( + dtype=np.float32, shape=[None, 10], name="samples") + ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0) + self.assertTrue(ildj is not None) + with self.test_session(): + ildj.eval({samples: np.zeros([2, 10], np.float32)}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py new file mode 100644 index 00000000000..a5f5219588f --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.ordered import Ordered +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + + +class OrderedBijectorTest(test.TestCase): + """Tests correctness of the ordered transformation.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijectorVector(self): + with self.test_session(): + ordered = Ordered() + self.assertEqual("ordered", ordered.name) + x = np.asarray([[2., 3, 4], [4., 8, 13]]) + y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] + self.assertAllClose(y, self.evaluate(ordered.forward(x))) + self.assertAllClose(x, self.evaluate(ordered.inverse(y))) + self.assertAllClose( + np.sum(np.asarray(y)[..., 1:], axis=-1), + self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)), + atol=0., + rtol=1e-7) + self.assertAllClose( + self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)), + self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)), + atol=0., + rtol=1e-7) + + def testBijectorUnknownShape(self): + with self.test_session(): + ordered = Ordered() + self.assertEqual("ordered", ordered.name) + x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_x = np.asarray([[2., 3, 4], [4., 8, 13]]) + y = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32) + real_y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]] + self.assertAllClose(real_y, ordered.forward(x).eval( + feed_dict={x: real_x})) + self.assertAllClose(real_x, ordered.inverse(y).eval( + feed_dict={y: real_y})) + self.assertAllClose( + np.sum(np.asarray(real_y)[..., 1:], axis=-1), + ordered.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), + atol=0., + rtol=1e-7) + self.assertAllClose( + -ordered.inverse_log_det_jacobian(y, event_ndims=1).eval( + feed_dict={y: real_y}), + ordered.forward_log_det_jacobian(x, event_ndims=1).eval( + feed_dict={x: real_x}), + atol=0., + rtol=1e-7) + + @test_util.run_in_graph_and_eager_modes() + def testShapeGetters(self): + with self.test_session(): + x = tensor_shape.TensorShape([4]) + y = tensor_shape.TensorShape([4]) + bijector = Ordered(validate_args=True) + self.assertAllEqual(y, bijector.forward_event_shape(x)) + self.assertAllEqual(y.as_list(), + self.evaluate(bijector.forward_event_shape_tensor( + x.as_list()))) + self.assertAllEqual(x, bijector.inverse_event_shape(y)) + self.assertAllEqual(x.as_list(), + self.evaluate(bijector.inverse_event_shape_tensor( + y.as_list()))) + + def testBijectiveAndFinite(self): + with self.test_session(): + ordered = Ordered() + x = np.sort(self._rng.randn(3, 10), axis=-1).astype(np.float32) + y = (self._rng.randn(3, 10)).astype(np.float32) + assert_bijective_and_finite(ordered, x, y, event_ndims=1) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 46f2c63f9b0..d44e49b4874 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,15 +22,12 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test -@test_util.with_c_api class _ReshapeBijectorTest(object): """Base class for testing the reshape transformation. @@ -265,7 +262,6 @@ class _ReshapeBijectorTest(object): raise NotImplementedError("Subclass failed to implement `build_shapes`.") -@test_util.with_c_api class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -305,21 +301,13 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): bijector, x, y, event_ndims=2, rtol=1e-6, atol=0) def testInvalidDimensionsOpError(self): - if ops._USE_C_API: - error_message = "Invalid value in tensor used for shape: -2" - else: - error_message = "elements must be either positive integers or `-1`." - self._testInvalidDimensionsOpError(error_message) + self._testInvalidDimensionsOpError( + "Invalid value in tensor used for shape: -2") def testInputOutputMismatchOpError(self): - if ops._USE_C_API: - error_message = "Cannot reshape a tensor with" - else: - error_message = "Input to reshape is a tensor with" - self._testInputOutputMismatchOpError(error_message) + self._testInputOutputMismatchOpError("Cannot reshape a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): @@ -341,7 +329,6 @@ class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): self._testInputOutputMismatchOpError("Input to reshape is a tensor with") -@test_util.with_c_api class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): def build_shapes(self, shape_in, shape_out): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py index 68e0d9cb827..f42feae25d8 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_test.py @@ -190,11 +190,30 @@ class DistributionTest(test.TestCase): y = dist._set_sample_static_shape(x, sample_shape) self.assertTrue(y.get_shape().ndims is None) + def testNameScopeWorksCorrectly(self): + x = tfd.Normal(loc=0., scale=1., name="x") + x_duplicate = tfd.Normal(loc=0., scale=1., name="x") + with ops.name_scope("y") as name: + y = tfd.Bernoulli(logits=0., name=name) + x_sample = x.sample(name="custom_sample") + x_sample_duplicate = x.sample(name="custom_sample") + x_log_prob = x.log_prob(0., name="custom_log_prob") + x_duplicate_sample = x_duplicate.sample(name="custom_sample") + + self.assertEqual(x.name, "x/") + self.assertEqual(x_duplicate.name, "x_1/") + self.assertEqual(y.name, "y/") + self.assertTrue(x_sample.name.startswith("x/custom_sample")) + self.assertTrue(x_sample_duplicate.name.startswith("x/custom_sample_1")) + self.assertTrue(x_log_prob.name.startswith("x/custom_log_prob")) + self.assertTrue(x_duplicate_sample.name.startswith( + "x_1/custom_sample")) + def testStrWorksCorrectlyScalar(self): normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) self.assertEqual( ("tf.distributions.Normal(" - "\"Normal\", " + "\"Normal/\", " "batch_shape=(), " "event_shape=(), " "dtype=float16)"), # Got the dtype right. @@ -203,7 +222,7 @@ class DistributionTest(test.TestCase): chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") self.assertEqual( ("tf.distributions.Chi2(" - "\"silly\", " # What a silly name that is! + "\"silly/\", " # What a silly name that is! "batch_shape=(2,), " "event_shape=(), " "dtype=float32)"), @@ -211,7 +230,7 @@ class DistributionTest(test.TestCase): exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) self.assertEqual( - ("tf.distributions.Exponential(\"Exponential\", " + ("tf.distributions.Exponential(\"Exponential/\", " # No batch shape. "event_shape=(), " "dtype=float32)"), @@ -222,7 +241,7 @@ class DistributionTest(test.TestCase): loc=np.zeros([2, 2]), name="MVN") self.assertEqual( ("tf.distributions.MultivariateNormalDiag(" - "\"MVN\", " + "\"MVN/\", " "batch_shape=(2,), " "event_shape=(2,), " "dtype=float64)"), @@ -233,7 +252,7 @@ class DistributionTest(test.TestCase): name="MVN2") self.assertEqual( ("tf.distributions.MultivariateNormalDiag(" - "\"MVN2\", " + "\"MVN2/\", " "batch_shape=(?,), " # Partially known. "event_shape=(3,), " "dtype=float32)"), @@ -243,7 +262,7 @@ class DistributionTest(test.TestCase): normal = tfd.Normal(loc=np.float16(0), scale=np.float16(1)) self.assertEqual( (""), # Got the dtype right. @@ -252,7 +271,7 @@ class DistributionTest(test.TestCase): chi2 = tfd.Chi2(df=np.float32([1., 2.]), name="silly") self.assertEqual( (""), @@ -261,7 +280,7 @@ class DistributionTest(test.TestCase): exp = tfd.Exponential(rate=array_ops.placeholder(dtype=dtypes.float32)) self.assertEqual( ("" " event_shape=()" " dtype=float32>"), @@ -272,7 +291,7 @@ class DistributionTest(test.TestCase): loc=np.zeros([2, 2]), name="MVN") self.assertEqual( (""), @@ -283,7 +302,7 @@ class DistributionTest(test.TestCase): name="MVN2") self.assertEqual( (""), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py index 1a02fbefb8e..b0035263927 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py @@ -52,7 +52,7 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): mu = [1., 2.] sigma = [[1., 0.], [0., 1.]] mvn = ds.MultivariateNormalFullCovariance(mu, sigma, name="Billy") - self.assertEqual(mvn.name, "Billy") + self.assertEqual(mvn.name, "Billy/") def testDoesNotRaiseIfInitializedWithSymmetricMatrix(self): with self.test_session(): @@ -131,8 +131,8 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): return mu, sigma def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -156,6 +156,33 @@ class MultivariateNormalFullCovarianceTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalFullCovariance( + loc=mu_a, + covariance_matrix=sigma_a, + validate_args=True) + mvn_b = ds.MultivariateNormalFullCovariance( + loc=mu_b, + covariance_matrix=sigma_b, + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def _compute_non_batch_kl(mu_a, sigma_a, mu_b, sigma_b): """Non-batch KL for N(mu_a, sigma_a), N(mu_b, sigma_b).""" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py index 685f32883da..b556d061238 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py @@ -235,8 +235,8 @@ class MultivariateNormalTriLTest(test.TestCase): return mu, sigma def testKLNonBatch(self): - batch_shape = () - event_shape = (2,) + batch_shape = [] + event_shape = [2] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -257,8 +257,8 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl, kl_v) def testKLBatch(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mu_b, sigma_b = self._random_mu_and_sigma(batch_shape, event_shape) @@ -282,9 +282,36 @@ class MultivariateNormalTriLTest(test.TestCase): self.assertAllClose(expected_kl_0, kl_v[0]) self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLBatchBroadcast(self): + batch_shape = [2] + event_shape = [3] + with self.test_session(): + mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) + # No batch shape. + mu_b, sigma_b = self._random_mu_and_sigma([], event_shape) + mvn_a = ds.MultivariateNormalTriL( + loc=mu_a, + scale_tril=np.linalg.cholesky(sigma_a), + validate_args=True) + mvn_b = ds.MultivariateNormalTriL( + loc=mu_b, + scale_tril=np.linalg.cholesky(sigma_b), + validate_args=True) + + kl = ds.kl_divergence(mvn_a, mvn_b) + self.assertEqual(batch_shape, kl.get_shape()) + + kl_v = kl.eval() + expected_kl_0 = _compute_non_batch_kl(mu_a[0, :], sigma_a[0, :, :], + mu_b, sigma_b) + expected_kl_1 = _compute_non_batch_kl(mu_a[1, :], sigma_a[1, :, :], + mu_b, sigma_b) + self.assertAllClose(expected_kl_0, kl_v[0]) + self.assertAllClose(expected_kl_1, kl_v[1]) + def testKLTwoIdenticalDistributionsIsZero(self): - batch_shape = (2,) - event_shape = (3,) + batch_shape = [2] + event_shape = [3] with self.test_session(): mu_a, sigma_a = self._random_mu_and_sigma(batch_shape, event_shape) mvn_a = ds.MultivariateNormalTriL( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py index 96805733178..b91a610acf1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/seed_stream_test.py @@ -65,6 +65,16 @@ class SeedStreamTest(test.TestCase): self.assertAllUnique( outputs + [strm2() for _ in range(50)] + [strm3() for _ in range(50)]) + def testInitFromOtherSeedStream(self): + strm1 = seed_stream.SeedStream(seed=4, salt="salt") + strm2 = seed_stream.SeedStream(strm1, salt="salt") + strm3 = seed_stream.SeedStream(strm1, salt="another salt") + out1 = [strm1() for _ in range(50)] + out2 = [strm2() for _ in range(50)] + out3 = [strm3() for _ in range(50)] + self.assertAllEqual(out1, out2) + self.assertAllUnique(out1 + out3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py index c8d795c3f6a..243b5a03485 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py @@ -584,7 +584,6 @@ class DistributionShapeTest(test.TestCase): def testDistributionShapeGetDimsStatic(self): with self.test_session(): - shaper = _DistributionShape(batch_ndims=0, event_ndims=0) shaper = _DistributionShape(batch_ndims=0, event_ndims=0) x = 1 self.assertAllEqual((_empty_shape, _empty_shape, _empty_shape), diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index 69f3d57ff00..d813831bef8 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -144,8 +144,8 @@ class Autoregressive(distribution_lib.Distribution): `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = locals() - with ops.name_scope(name): + parameters = distribution_util.parent_frame_arguments() + with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 self._distribution0 = (distribution_fn() if sample0 is None diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index bf5590cd552..c709318f765 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -41,9 +42,6 @@ class BatchReshape(distribution_lib.Distribution): This "meta-distribution" reshapes the batch dimensions of another distribution. - Note: Unlike `tf.reshape`, the `BatchReshape` distribution does not support - `-1` for flattening. - #### Examples ```python @@ -51,7 +49,7 @@ class BatchReshape(distribution_lib.Distribution): dtype = np.float32 dims = 2 - new_batch_shape = [1, 2, 3] + new_batch_shape = [1, 2, -1] old_batch_shape = [6] scale = np.ones(old_batch_shape + [dims], dtype) @@ -85,8 +83,9 @@ class BatchReshape(distribution_lib.Distribution): Args: distribution: The base distribution instance to reshape. Typically an instance of `Distribution`. - batch_shape: Positive `int`-like vector-shaped `Tensor` representing the - new shape of the batch dimensions. + batch_shape: Positive `int`-like vector-shaped `Tensor` representing + the new shape of the batch dimensions. Up to one dimension may contain + `-1`, meaning the remainder of the batch size. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -104,31 +103,28 @@ class BatchReshape(distribution_lib.Distribution): ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ - parameters = locals() + parameters = distribution_util.parent_frame_arguments() name = name or "BatchReshape" + distribution.name - self._distribution = distribution with ops.name_scope(name, values=[batch_shape]) as name: - self._batch_shape_ = ops.convert_to_tensor( - batch_shape, - dtype=dtypes.int32, - name="batch_shape") - self._batch_shape_static = tensor_util.constant_value(self._batch_shape_) - if self._batch_shape_static is not None: - self._batch_shape_static = np.int32(self._batch_shape_static) - self._runtime_assertions = validate_init_args( - self._distribution, - self._batch_shape_, - validate_args, - self._batch_shape_static) + # The unexpanded batch shape may contain up to one dimension of -1. + self._batch_shape_unexpanded = ops.convert_to_tensor( + batch_shape, dtype=dtypes.int32, name="batch_shape") + validate_init_args_statically(distribution, self._batch_shape_unexpanded) + batch_shape, batch_shape_static, runtime_assertions = calculate_reshape( + distribution.batch_shape_tensor(), self._batch_shape_unexpanded, + validate_args) + self._distribution = distribution + self._batch_shape_ = batch_shape + self._batch_shape_static = batch_shape_static + self._runtime_assertions = runtime_assertions super(BatchReshape, self).__init__( - dtype=self._distribution.dtype, - reparameterization_type=self._distribution.reparameterization_type, + dtype=distribution.dtype, + reparameterization_type=distribution.reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [self._batch_shape_] + - self._distribution._graph_parents), # pylint: disable=protected-access + [self._batch_shape_unexpanded] + distribution._graph_parents), # pylint: disable=protected-access name=name) @property @@ -140,7 +136,7 @@ class BatchReshape(distribution_lib.Distribution): return array_ops.identity(self._batch_shape_) def _batch_shape(self): - return tensor_shape.TensorShape(self._batch_shape_static) + return self._batch_shape_static def _event_shape_tensor(self): with ops.control_dependencies(self._runtime_assertions): @@ -152,11 +148,13 @@ class BatchReshape(distribution_lib.Distribution): def _sample_n(self, n, seed=None): with ops.control_dependencies(self._runtime_assertions): x = self.distribution.sample(sample_shape=n, seed=seed) - new_shape = array_ops.concat([ - [n], - self.batch_shape_tensor(), - self.event_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + [n], + self._batch_shape_unexpanded, + self.event_shape_tensor(), + ], + axis=0) return array_ops.reshape(x, new_shape) def _log_prob(self, x): @@ -213,9 +211,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) sample_ndims = x_ndims - batch_ndims - event_ndims if isinstance(sample_ndims, int): static_sample_shape = x.shape[:sample_ndims] @@ -238,10 +236,11 @@ class BatchReshape(distribution_lib.Distribution): self.event_shape_tensor(), ], axis=0) result = fn(array_ops.reshape(x, old_shape)) - new_shape = array_ops.concat([ - sample_shape, - self.batch_shape_tensor(), - ], axis=0) + new_shape = array_ops.concat( + [ + sample_shape, + self._batch_shape_unexpanded, + ], axis=0) result = array_ops.reshape(result, new_shape) if (static_sample_shape.ndims is not None and self.batch_shape.ndims is not None): @@ -261,8 +260,7 @@ class BatchReshape(distribution_lib.Distribution): if static_event_shape_list is None: static_event_shape_list = [self.event_shape] new_shape = array_ops.concat( - [self.batch_shape_tensor()] + event_shape_list, - axis=0) + [self._batch_shape_unexpanded] + event_shape_list, axis=0) result = array_ops.reshape(fn(), new_shape) if (self.batch_shape.ndims is not None and self.event_shape.ndims is not None): @@ -281,9 +279,9 @@ class BatchReshape(distribution_lib.Distribution): event_ndims = (array_ops.size(self.event_shape_tensor()) if self.event_shape.ndims is None else self.event_shape.ndims) - batch_ndims = (array_ops.size(self.batch_shape_tensor()) - if self.batch_shape.ndims is None - else self.batch_shape.ndims) + batch_ndims = ( + array_ops.size(self._batch_shape_unexpanded) + if self.batch_shape.ndims is None else self.batch_shape.ndims) expected_batch_event_ndims = batch_ndims + event_ndims if (isinstance(x_ndims, int) and @@ -355,62 +353,56 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions -def validate_init_args( - distribution, - batch_shape, - validate_args, - batch_shape_static): +def calculate_reshape(original_shape, new_shape, validate=False, name=None): + """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" + batch_shape_static = tensor_util.constant_value_as_shape(new_shape) + if batch_shape_static.is_fully_defined(): + return np.int32(batch_shape_static.as_list()), batch_shape_static, [] + with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]): + original_size = math_ops.reduce_prod(original_shape) + implicit_dim = math_ops.equal(new_shape, -1) + size_implicit_dim = ( + original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape))) + new_ndims = array_ops.shape(new_shape) + expanded_new_shape = array_ops.where( # Assumes exactly one `-1`. + implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape) + validations = [] if not validate else [ + check_ops.assert_rank( + original_shape, 1, message="Original shape must be a vector."), + check_ops.assert_rank( + new_shape, 1, message="New shape must be a vector."), + check_ops.assert_less_equal( + math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32), + 1, + message="At most one dimension can be unknown."), + check_ops.assert_positive( + expanded_new_shape, message="Shape elements must be >=-1."), + check_ops.assert_equal( + math_ops.reduce_prod(expanded_new_shape), + original_size, + message="Shape sizes do not match."), + ] + return expanded_new_shape, batch_shape_static, validations + + +def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" - with ops.name_scope(name="validate_init_args", - values=[batch_shape] + distribution._graph_parents): # pylint: disable=protected-access - runtime_assertions = [] + if batch_shape.shape.ndims is not None: + if batch_shape.shape.ndims != 1: + raise ValueError("`batch_shape` must be a vector " + "(saw rank: {}).".format(batch_shape.shape.ndims)) - if batch_shape.shape.ndims is not None: - if batch_shape.shape.ndims != 1: - raise ValueError("`batch_shape` must be a vector " - "(saw rank: {}).".format( - batch_shape.shape.ndims)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_rank( - batch_shape, - 1, - message="`batch_shape` must be a vector.", - name="assert_batch_shape_is_vector"), - ] + batch_shape_static = tensor_util.constant_value_as_shape(batch_shape) + batch_size_static = batch_shape_static.num_elements() + dist_batch_size_static = distribution.batch_shape.num_elements() - batch_size_static = np.prod(batch_shape_static) - dist_batch_size_static = ( - None if not distribution.batch_shape.is_fully_defined() - else np.prod(distribution.batch_shape).value) + if batch_size_static is not None and dist_batch_size_static is not None: + if batch_size_static != dist_batch_size_static: + raise ValueError("`batch_shape` size ({}) must match " + "`distribution.batch_shape` size ({}).".format( + batch_size_static, dist_batch_size_static)) - if batch_size_static is not None and dist_batch_size_static is not None: - if batch_size_static != dist_batch_size_static: - raise ValueError("`batch_shape` size ({}) must match " - "`distribution.batch_shape` size ({}).".format( - batch_size_static, - dist_batch_size_static)) - elif validate_args: - runtime_assertions += [ - check_ops.assert_equal( - math_ops.reduce_prod(batch_shape), - math_ops.reduce_prod(distribution.batch_shape_tensor()), - message=("`batch_shape` size must match " - "`distributions.batch_shape` size."), - name="assert_batch_size"), - ] - - if batch_shape_static is not None: - if np.any(batch_shape_static < 1): - raise ValueError("`batch_shape` elements must be positive " - "(i.e., larger than zero).") - elif validate_args: - runtime_assertions += [ - check_ops.assert_positive( - batch_shape, - message=("`batch_shape` elements must be positive " - "(i.e., larger than zero)."), - name="assert_batch_shape_positive") - ] - - return runtime_assertions + if batch_shape_static.dims is not None: + if any( + dim.value is not None and dim.value < 1 for dim in batch_shape_static): + raise ValueError("`batch_shape` elements must be >=-1.") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index babce80396c..51478dbeffa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -30,6 +30,7 @@ @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@Ordered @@Permute @@PowerTransform @@RealNVP @@ -67,6 +68,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 85ad23e4133..b158a51bb02 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -20,10 +20,9 @@ from __future__ import print_function import itertools -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector @@ -36,15 +35,6 @@ def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) -def _maybe_get_event_ndims_statically(event_ndims): - static_event_ndims = (event_ndims if isinstance(event_ndims, int) - else tensor_util.constant_value(event_ndims)) - if static_event_ndims is not None: - return static_event_ndims - - return event_ndims - - def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -238,13 +228,13 @@ class Chain(bijector.Bijector): return y def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant( - 0., dtype=y.dtype.base_dtype, name="inverse_log_det_jacobian") + y = ops.convert_to_tensor(y, name="y") + ildj = math_ops.cast(0., dtype=y.dtype.base_dtype) if not self.bijectors: return ildj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -258,11 +248,12 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically( + event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -274,13 +265,12 @@ class Chain(bijector.Bijector): def _forward_log_det_jacobian(self, x, **kwargs): x = ops.convert_to_tensor(x, name="x") - fldj = constant_op.constant( - 0., dtype=x.dtype, name="inverse_log_det_jacobian") + fldj = math_ops.cast(0., dtype=x.dtype.base_dtype) if not self.bijectors: return fldj - event_ndims = _maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_event_ndims_statically( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -293,13 +283,21 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = _maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = _maybe_get_event_ndims_statically( - array_ops.rank(event_shape)) + event_ndims = self._maybe_get_event_ndims_statically( + array_ops.size(event_shape)) x = b.forward(x, **kwargs.get(b.name, {})) return fldj + def _maybe_get_event_ndims_statically(self, event_ndims): + event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( + event_ndims) + if event_ndims_ is None: + return event_ndims + return event_ndims_ + + diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index caae2adcfac..268c8d03426 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -53,7 +53,7 @@ class CholeskyOuterProduct(bijector.Bijector): its spectrum), and that the product of two positive-diagonal lower-triangular matrices is another positive-diagonal lower-triangular matrix. - A simple inductive argument (proceding one column of L_3 at a time) shows + A simple inductive argument (proceeding one column of L_3 at a time) shows that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. @@ -170,7 +170,7 @@ class CholeskyOuterProduct(bijector.Bijector): sum_weighted_log_diag = array_ops.squeeze( math_ops.matmul(math_ops.log(diag), exponents[..., array_ops.newaxis]), - squeeze_dims=-1) + axis=-1) fldj = p_float * np.log(2.) + sum_weighted_log_diag return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 1904239a0e7..84a3289ba21 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,14 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops.distributions import bijector as bijector_lib +from tensorflow.python.ops.distributions import bijector __all__ = [ "Invert", ] -class Invert(bijector_lib.Bijector): +class Invert(bijector.Bijector): """Bijector which inverts another Bijector. Example Use: [ExpGammaDistribution (see Background & Context)]( diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index ef56cf6ddda..83667b0e80c 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -32,7 +32,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib +from tensorflow.python.ops.distributions import bijector __all__ = [ @@ -42,7 +42,7 @@ __all__ = [ ] -class MaskedAutoregressiveFlow(bijector_lib.Bijector): +class MaskedAutoregressiveFlow(bijector.Bijector): """Affine MaskedAutoregressiveFlow bijector for vector-valued events. The affine autoregressive flow [(Papamakarios et al., 2016)][3] provides a diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py new file mode 100644 index 00000000000..3f03592f314 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py @@ -0,0 +1,125 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ordered bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + + +__all__ = [ + "Ordered", +] + + +class Ordered(bijector.Bijector): + """Bijector which maps a tensor x_k that has increasing elements in the last + dimension to an unconstrained tensor y_k. + + Both the domain and the codomain of the mapping is [-inf, inf], however, + the input of the forward mapping must be strictly increasing. + The inverse of the bijector applied to a normal random vector `y ~ N(0, 1)` + gives back a sorted random vector with the same distribution `x ~ N(0, 1)` + where `x = sort(y)` + + On the last dimension of the tensor, Ordered bijector performs: + `y[0] = x[0]` + `y[1:] = math_ops.log(x[1:] - x[:-1])` + + #### Example Use: + + ```python + bijector.Ordered().forward([2, 3, 4]) + # Result: [2., 0., 0.] + + bijector.Ordered().inverse([0.06428002, -1.07774478, -0.71530371]) + # Result: [0.06428002, 0.40464228, 0.8936858] + ``` + """ + + def __init__(self, validate_args=False, name="ordered"): + super(Ordered, self).__init__( + forward_min_event_ndims=1, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None or input_shape[-1] is None: + return input_shape + return tensor_shape.TensorShape([input_shape[-1]]) + + def _forward_event_shape_tensor(self, input_shape): + return (input_shape[-1])[..., array_ops.newaxis] + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None or output_shape[-1] is None: + return output_shape + if output_shape[-1] <= 1: + raise ValueError("output_shape[-1] = %d <= 1" % output_shape[-1]) + return tensor_shape.TensorShape([output_shape[-1]]) + + def _inverse_event_shape_tensor(self, output_shape): + if self.validate_args: + is_greater_one = check_ops.assert_greater( + output_shape[-1], 1, message="Need last dimension greater than 1.") + output_shape = control_flow_ops.with_dependencies( + [is_greater_one], output_shape) + return (output_shape[-1])[..., array_ops.newaxis] + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + y0 = x[..., 0, array_ops.newaxis] + yk = math_ops.log(x[..., 1:] - x[..., :-1]) + y = array_ops.concat([y0, yk], axis=-1) + return y + + def _inverse(self, y): + x0 = y[..., 0, array_ops.newaxis] + xk = math_ops.exp(y[..., 1:]) + x = array_ops.concat([x0, xk], axis=-1) + return math_ops.cumsum(x, axis=-1) + + def _inverse_log_det_jacobian(self, y): + # The Jacobian of the inverse mapping is lower + # triangular, with the diagonal elements being: + # J[i,i] = 1 if i=1, and + # exp(y_i) if 1GPU->CPU, which forces - # a sync. This is a roundabout way, yes. + # then this will force a copy from CPU->NON_CPU_DEVICE->CPU, + # which forces a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def _benchmark_eager_apply(self, label, defun=False, execution_mode=None): + def _benchmark_eager_apply(self, label, device_and_format, defun=False, + execution_mode=None, compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_data_format() + device, data_format = device_and_format model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) batch_size = 64 num_burn = 5 num_iters = 30 with tf.device(device): - images, _ = random_batch(batch_size) + images, _ = random_batch(batch_size, data_format) for _ in xrange(num_burn): model(images, training=False).cpu() if execution_mode: @@ -220,30 +224,35 @@ class ResNet50Benchmarks(tf.test.Benchmark): tfe.async_wait() self._report(label, start, num_iters, device, batch_size, data_format) - def benchmark_eager_apply(self): - self._benchmark_eager_apply('eager_apply', defun=False) + def benchmark_eager_apply_sync(self): + self._benchmark_eager_apply('eager_apply', device_and_data_format(), + defun=False) def benchmark_eager_apply_async(self): self._benchmark_eager_apply( - 'eager_apply_async', defun=False, execution_mode=tfe.ASYNC) + 'eager_apply_async', device_and_data_format(), defun=False, + execution_mode=tfe.ASYNC) def benchmark_eager_apply_with_defun(self): - self._benchmark_eager_apply('eager_apply_with_defun', defun=True) + self._benchmark_eager_apply('eager_apply_with_defun', + device_and_data_format(), defun=True) def _benchmark_eager_train(self, label, make_iterator, + device_and_format, defun=False, - execution_mode=None): + execution_mode=None, + compiled=False): with tfe.execution_mode(execution_mode): - device, data_format = device_and_data_format() + device, data_format = device_and_format for batch_size in self._train_batch_sizes(): - (images, labels) = random_batch(batch_size) + (images, labels) = random_batch(batch_size, data_format) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) if defun: - model.call = tfe.defun(model.call) + model.call = tfe.defun(model.call, compiled=compiled) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -253,7 +262,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): train_one_step(model, images, labels, optimizer) if execution_mode: tfe.async_wait() - self._force_gpu_sync() + self._force_device_sync() gc.collect() start = time.time() @@ -262,22 +271,25 @@ class ResNet50Benchmarks(tf.test.Benchmark): train_one_step(model, images, labels, optimizer) if execution_mode: tfe.async_wait() - self._force_gpu_sync() + self._force_device_sync() self._report(label, start, num_iters, device, batch_size, data_format) - def benchmark_eager_train(self): - self._benchmark_eager_train('eager_train', MockIterator, defun=False) + def benchmark_eager_train_sync(self): + self._benchmark_eager_train('eager_train', MockIterator, + device_and_data_format(), defun=False) def benchmark_eager_train_async(self): self._benchmark_eager_train( 'eager_train_async', MockIterator, + device_and_data_format(), defun=False, execution_mode=tfe.ASYNC) def benchmark_eager_train_with_defun(self): self._benchmark_eager_train( - 'eager_train_with_defun', MockIterator, defun=True) + 'eager_train_with_defun', MockIterator, + device_and_data_format(), defun=True) def benchmark_eager_train_datasets(self): @@ -287,7 +299,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset', make_iterator, defun=False) + 'eager_train_dataset', make_iterator, + device_and_data_format(), defun=False) def benchmark_eager_train_datasets_with_defun(self): @@ -297,7 +310,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): return tfe.Iterator(ds) self._benchmark_eager_train( - 'eager_train_dataset_with_defun', make_iterator, defun=True) + 'eager_train_dataset_with_defun', make_iterator, + device_and_data_format(), defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD new file mode 100644 index 00000000000..638c57d1c92 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +cuda_py_test( + name = "scan_test", + size = "small", + srcs = ["scan_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "scan_graph_test", + size = "small", + srcs = ["scan_graph_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py new file mode 100644 index 00000000000..4661dafbed1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under graph mode execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + with tf.Session() as sess: + sess.run(sum_op) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan32000(self): + self.runScan(32000) + + def benchmarkScan1M(self): + self.runScan(1000000) + + def benchmarkScan2M(self): + self.runScan(2000000) + + def benchmarkScan4M(self): + self.runScan(4000000) + + def benchmarkScan8M(self): + self.runScan(8000000) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py new file mode 100644 index 00000000000..b8c7cf1fe5b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/scan/scan_test.py @@ -0,0 +1,56 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit test for tf.scan under eager execution.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy as np +import tensorflow as tf + + +class ScanBenchmark(tf.test.Benchmark): + + def runScan(self, n): + elems = np.arange(n) + start_time = time.time() + _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1) + wall_time = time.time() - start_time + + self.report_benchmark( + name='scan', + iters=n, + wall_time=wall_time) + + def benchmarkScan2000(self): + self.runScan(2000) + + def benchmarkScan4000(self): + self.runScan(4000) + + def benchmarkScan8000(self): + self.runScan(8000) + + def benchmarkScan16000(self): + self.runScan(16000) + + def benchmarkScan32000(self): + self.runScan(32000) + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index f825a2a7363..8ac553e0ae7 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -34,10 +34,10 @@ import tensorflow.contrib.eager as tfe from tensorflow.contrib.eager.python.examples.spinn import data from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util -from tensorflow.core.protobuf import checkpointable_object_graph_pb2 from tensorflow.python.eager import test from tensorflow.python.framework import test_util -from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=g-bad-import-order @@ -421,10 +421,8 @@ class SpinnTest(test_util.TensorFlowTestCase): # 5. Verify that checkpoints exist and contains all the expected variables. self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) - object_graph_string = checkpoint_utils.load_variable( - config.logdir, name="_CHECKPOINTABLE_OBJECT_GRAPH") - object_graph = checkpointable_object_graph_pb2.CheckpointableObjectGraph() - object_graph.ParseFromString(object_graph_string) + object_graph = checkpointable_utils.object_metadata( + saver.latest_checkpoint(config.logdir)) ckpt_variable_names = set() for node in object_graph.nodes: for attribute in node.attributes: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 907f9204c2d..1ae6415d5ec 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -30,7 +30,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import checkpointable +from tensorflow.python.training.checkpointable import base as checkpointable _to_replace = re.compile("[^A-Za-z0-9.]") diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index f0fe4ce8c53..98a98a8d358 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -30,8 +30,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops -from tensorflow.python.training import checkpointable_utils from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils class MetricsTest(test.TestCase): @@ -146,8 +146,6 @@ class MetricsTest(test.TestCase): self.assertAllEqual(2.0, m2.result()) def testNamesWithSpaces(self): - # Verify two metrics with the same class and name don't - # accidentally share state. m1 = metrics.Mean("has space") m1(0) self.assertEqual(m1.name, "has space") @@ -186,8 +184,8 @@ class MetricsTest(test.TestCase): self.assertEqual(self.evaluate(value), 2.5) def testTwoMeansGraph(self): - # Verify two metrics with the same class and name don't - # accidentally share state. + # Verify two metrics with the same name in the same graph raises a + # ValueError. with context.graph_mode(): m1 = metrics.Mean() m1(0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 2f8721324f5..f801d9a47b2 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -23,14 +23,16 @@ import os import weakref from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops -from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer +from tensorflow.python.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util +from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils # pylint: disable=protected-access # Explanation for protected-access disable: Network has lots of same-class and @@ -52,9 +54,40 @@ def _network_name_scope_naming(current_variable_scope): return current_variable_scope.name + "/" +_NETWORK_DEPRECATION_MESSAGE = ( + "Please inherit from `tf.keras.Model`, and see its documentation for " + "details. `tf.keras.Model` should be a drop-in replacement for " + "`tfe.Network` in most cases, but note that `track_layer` is no longer " + "necessary or supported. Instead, `Layer` instances are tracked on " + "attribute assignment (see the section of `tf.keras.Model`'s documentation " + "on subclassing). Since the output of `track_layer` is often assigned to " + "an attribute anyway, most code can be ported by simply removing the " + "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow " + "`Layer` instances, including those from `tf.layers`, but switching to " + "the `tf.keras.layers` versions along with the migration to " + "`tf.keras.Model` is recommended, since it will preserve variable names. " + "Feel free to import it with an alias to avoid excess typing :)." +) + + class Network(base.Layer): """Represents the composition of a set of Layers. + *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation + for details. `tf.keras.Model` should be a drop-in replacement for + `tfe.Network` in most cases, but note that `track_layer` is no longer + necessary or supported. Instead, `Layer` instances are tracked on attribute + assignment (see the section of `tf.keras.Model`'s documentation on + subclassing). Since the output of `track_layer` is often assigned to an + attribute anyway, most code can be ported by simply removing the `track_layer` + calls. + + `tf.keras.Model` works with all TensorFlow `Layer` instances, including those + from `tf.layers`, but switching to the `tf.keras.layers` versions along with + the migration to `tf.keras.Model` is recommended, since it will preserve + variable names. Feel free to import it with an alias to avoid excess typing + :). + `Network` implements the `Layer` interface and adds convenience methods for managing sub-`Layer`s, such as listing variables. @@ -112,6 +145,7 @@ class Network(base.Layer): # - Detect layers used in __call__ that weren't registered with track_layer. # - Convert inputs to __call__ to tensors. + @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE) def __init__(self, name=None): """Configure the `Network`. @@ -130,6 +164,10 @@ class Network(base.Layer): ValueError: If `name` is not valid. Note that some naming errors will instead be raised when the `Network` is called. """ + if context.executing_eagerly(): + logging.warning( + ("** tfe.Network is deprecated and will be removed in a future " + "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE) if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -152,6 +190,11 @@ class Network(base.Layer): self._variable_scope_counts_on_init = ( variable_scope.get_variable_scope_store().variable_scopes_count) + def _gather_saveables_for_checkpoint(self): + raise NotImplementedError( + "tfe.Network does not support object-based checkpointing.\n\n%s" + % _NETWORK_DEPRECATION_MESSAGE) + def _name_scope_name(self, current_variable_scope): """Overrides Layer op naming to match variable naming.""" return _network_name_scope_naming( @@ -502,10 +545,10 @@ class Sequential(Network): def add(self, layer_func): if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) + args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): - args = estimator_util.fn_args(layer_func) + args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " @@ -706,6 +749,9 @@ def _make_prefix_stripping_map_fn(scope_name): return _strip_variable_prefix +@deprecation.deprecated(date=None, instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.save_weights.")) def save_network_checkpoint( network, save_path, global_step=None, map_func=None): """Save variables from the Network to a checkpoint. @@ -905,6 +951,9 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func, _add_deferred_restoration(network, deferred_restoration) +@deprecation.deprecated(date=None, instructions=( + "Please inherit from tf.keras.Model instead of tfe.Network, and use " + "tf.keras.Model.load_weights.")) def restore_network_checkpoint(network, save_path, map_func=None): """Restore the Network from a checkpoint. diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index f43376d5d77..c92bd15b253 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import training_util +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: disable=not-callable @@ -62,6 +63,12 @@ class RegularizedNetwork(network.Network): class NetworkTest(test.TestCase): + def test_checkpointing_not_implemented(self): + checkpoint_directory = self.get_temp_dir() + checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork()) + with self.assertRaises(NotImplementedError): + checkpoint.save(checkpoint_directory) + def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() checkpoint_path = network.save_network_checkpoint( diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 1a7f7b85e68..4032e755f6e 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -102,7 +102,6 @@ class SaverTest(test.TestCase): # Can still restore it. saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) - self.assertEqual(v1.read_value().numpy(), 1.0) # However, cannot restore it with default name. with self.assertRaisesOpError('not found in checkpoint'): saver = _saver.Saver([v1, v2]).restore(ckpt_prefix) diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 79dd117854e..5826700c73e 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -120,9 +120,9 @@ from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Vari from tensorflow.python.ops.variable_scope import EagerVariableStore from tensorflow.python.ops import script_ops from tensorflow.python.ops import template -from tensorflow.python.training.checkpointable import Checkpointable -from tensorflow.python.training.checkpointable_utils import CheckpointableSaver -from tensorflow.python.training.checkpointable_utils import Checkpoint +from tensorflow.python.training.checkpointable.base import Checkpointable +from tensorflow.python.training.checkpointable.util import CheckpointableSaver +from tensorflow.python.training.checkpointable.util import Checkpoint from tensorflow.python.util.all_util import remove_undocumented py_func = script_ops.eager_py_func diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index e80ccbb74d8..db50b33af2e 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -57,7 +57,7 @@ class TFETest(test_util.TensorFlowTestCase): return math_ops.multiply(x, x) grad = tfe.gradients_function(square) - self.assertEquals([6], [x.numpy() for x in grad(3)]) + self.assertEquals([6], [x.numpy() for x in grad(3.)]) def testGradOfGrad(self): @@ -66,7 +66,7 @@ class TFETest(test_util.TensorFlowTestCase): grad = tfe.gradients_function(square) gradgrad = tfe.gradients_function(lambda x: grad(x)[0]) - self.assertEquals([2], [x.numpy() for x in gradgrad(3)]) + self.assertEquals([2], [x.numpy() for x in gradgrad(3.)]) def testCustomGrad(self): @@ -80,7 +80,7 @@ class TFETest(test_util.TensorFlowTestCase): return y, grad_fn grad = tfe.gradients_function(f) - self.assertEquals([12], [x.numpy() for x in grad(3)]) + self.assertEquals([12], [x.numpy() for x in grad(3.)]) def testGPU(self): if tfe.num_gpus() <= 0: diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 9f4cd44afbe..d5d2abf8c4c 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -14,11 +14,14 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ + ":baseline", ":boosted_trees", ":dnn", ":dnn_linear_combined", + ":export", ":extenders", ":head", + ":hooks", ":linear", ":logit_fns", ":multi_head", @@ -28,6 +31,49 @@ py_library( ], ) +py_library( + name = "baseline", + srcs = ["python/estimator/baseline.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:baseline", + ], +) + +py_test( + name = "baseline_test", + size = "small", + srcs = ["python/estimator/baseline_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":baseline", + ":head", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:metric_keys", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "boosted_trees", srcs = ["python/estimator/boosted_trees.py"], @@ -77,6 +123,7 @@ py_test( tags = [ "no_pip", "notsan", + "optonly", # times out http://b/79220679 ], deps = [ ":dnn", @@ -179,6 +226,43 @@ py_test( ], ) +py_library( + name = "export", + srcs = [ + "python/estimator/export.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "export_test", + size = "medium", + srcs = ["python/estimator/export_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 + deps = [ + ":export", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", + ], +) + py_library( name = "head", srcs = [ @@ -210,7 +294,7 @@ py_library( py_test( name = "head_test", - size = "small", + size = "medium", srcs = ["python/estimator/head_test.py"], srcs_version = "PY2AND3", deps = [ @@ -238,6 +322,36 @@ py_test( ], ) +py_library( + name = "hooks", + srcs = [ + "python/estimator/hooks.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "hooks_test", + size = "medium", + srcs = ["python/estimator/hooks_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":hooks", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "linear", srcs = ["python/estimator/linear.py"], @@ -250,7 +364,7 @@ py_library( py_test( name = "linear_test", - size = "small", + size = "medium", srcs = ["python/estimator/linear_test.py"], srcs_version = "PY2AND3", tags = [ @@ -283,9 +397,9 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:util", ], ) @@ -367,6 +481,7 @@ py_library( "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", @@ -447,21 +562,27 @@ py_test( srcs_version = "PY2AND3", tags = [ "no_pip", + "noasan", # times out "notsan", + "optonly", # times out http://b/79220679 ], deps = [ + ":head", ":rnn", + "//tensorflow/contrib/data", "//tensorflow/core:protos_all_py", "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:lib", "//tensorflow/python:math_ops", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variables", "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:parsing_utils", "//tensorflow/python/feature_column", "//third_party/py/numpy", "@six_archive//:six", diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index be20d1b7770..788ac5ca704 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -19,11 +19,14 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.estimator.python.estimator.baseline import * from tensorflow.contrib.estimator.python.estimator.boosted_trees import * from tensorflow.contrib.estimator.python.estimator.dnn import * from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * +from tensorflow.contrib.estimator.python.estimator.export import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.hooks import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * @@ -38,11 +41,14 @@ _allowed_symbols = [ 'binary_classification_head', 'clip_gradients_by_norm', 'forward_features', + 'InMemoryEvaluatorHook', + 'logistic_regression_head', 'multi_class_head', 'multi_head', 'multi_label_head', 'poisson_regression_head', 'regression_head', + 'BaselineEstimator', 'DNNEstimator', 'DNNLinearCombinedEstimator', 'LinearEstimator', @@ -54,6 +60,9 @@ _allowed_symbols = [ 'replicate_model_fn', 'TowerOptimizer', 'RNNClassifier', + 'RNNEstimator', + 'export_saved_model_for_mode', + 'export_all_saved_models', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline.py b/tensorflow/contrib/estimator/python/estimator/baseline.py new file mode 100644 index 00000000000..beffbee7306 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline.py @@ -0,0 +1,98 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Baseline estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import baseline + + +class BaselineEstimator(estimator.Estimator): + """An estimator that can establish a simple baseline. + + The estimator uses a user-specified head. + + This estimator ignores feature values and will learn to predict the average + value of each label. E.g. for single-label classification problems, this will + predict the probability distribution of the classes as seen in the labels. + For multi-label classification problems, it will predict the ratio of examples + that contain each class. + + Example: + + ```python + + # Build baseline multi-label classifier. + estimator = BaselineEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3)) + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + estimator.train(input_fn=input_fn_train) + + # Evaluates cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # For each class, predicts the ratio of training examples that contain the + # class. + predictions = classifier.predict(new_samples) + + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` passed to the `head` constructor is not `None`, a feature + with `key=weight_column` whose value is a `Tensor`. + """ + + def __init__(self, + head, + model_dir=None, + optimizer='Ftrl', + config=None): + """Initializes a BaselineEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: String, `tf.Optimizer` object, or callable that creates the + optimizer to use for training. If not specified, will use + `FtrlOptimizer` with a default learning rate of 0.3. + config: `RunConfig` object to configure the runtime settings. + """ + def _model_fn(features, labels, mode, config): + return baseline._baseline_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + optimizer=optimizer, + config=config) + super(BaselineEstimator, self).__init__( + model_fn=_model_fn, + model_dir=model_dir, + config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py new file mode 100644 index 00000000000..d0e3e670f73 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py @@ -0,0 +1,430 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for baseline.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import baseline +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.client import session as tf_session +from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import distribute as distribute_lib +from tensorflow.python.training import optimizer +from tensorflow.python.training import saver + +# Names of variables created by model. +BIAS_NAME = 'baseline/bias' + + +def assert_close(expected, actual, rtol=1e-04, name='assert_close'): + with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope: + expected = ops.convert_to_tensor(expected, name='expected') + actual = ops.convert_to_tensor(actual, name='actual') + rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected) + rtol = ops.convert_to_tensor(rtol, name='rtol') + return check_ops.assert_less( + rdiff, + rtol, + data=('Condition expected =~ actual did not hold element-wise:' + 'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff, + 'rtol = ', rtol,), + name=scope) + + +def save_variables_to_ckpt(model_dir): + init_all_op = [variables.global_variables_initializer()] + with tf_session.Session() as sess: + sess.run(init_all_op) + saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt')) + + +def _baseline_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a BaselineEstimator that uses regression_head.""" + return baseline.BaselineEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension, + # Tests in core (from which this test inherits) test the sum loss. + loss_reduction=losses.Reduction.SUM), + *args, **kwargs) + + +class BaselineEstimatorEvaluationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_evaluation_batch(self): + """Tests evaluation for batch_size==2.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate( + input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the sum over batch = 9 + 9 = 18 + # Average loss is the average over batch = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 18., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_weights(self): + """Tests evaluation with weights.""" + with ops.Graph().as_default(): + variables.Variable([13.0], name=BIAS_NAME) + variables.Variable( + 100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + def _input_fn(): + features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))} + labels = ((10.,), (10.,)) + return features, labels + + baseline_estimator = _baseline_estimator_fn( + weight_column='weights', + model_dir=self._model_dir) + eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1) + + # Logit is bias = 13, while label is 10. + # Loss per example is 3**2 = 9. + # Training loss is the weighted sum over batch = 9 + 2*9 = 27 + # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9 + self.assertDictEqual({ + metric_keys.MetricKeys.LOSS: 27., + metric_keys.MetricKeys.LOSS_MEAN: 9., + ops.GraphKeys.GLOBAL_STEP: 100 + }, eval_metrics) + + def test_evaluation_for_multi_dimensions(self): + label_dim = 2 + with ops.Graph().as_default(): + variables.Variable([46.0, 58.0], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dim, + model_dir=self._model_dir) + input_fn = numpy_io.numpy_input_fn( + x={ + 'age': np.array([[2., 4., 5.]]), + }, + y=np.array([[46., 58.]]), + batch_size=1, + num_epochs=None, + shuffle=False) + eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1) + + self.assertItemsEqual( + (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN, + ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys()) + + # Logit is bias which is [46, 58] + self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS]) + + +class BaselineEstimatorPredictTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def test_1d(self): + """Tests predict when all variables are one-dimensional.""" + with ops.Graph().as_default(): + variables.Variable([.2], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': np.array([[2.]])}, + y=None, + batch_size=1, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # x * weight + bias = 2. * 10. + .2 = 20.2 + self.assertAllClose([[.2]], predicted_scores) + + def testMultiDim(self): + """Tests predict when all variables are multi-dimenstional.""" + batch_size = 2 + label_dimension = 3 + with ops.Graph().as_default(): + variables.Variable( # shape=[label_dimension] + [.2, .4, .6], name=BIAS_NAME) + variables.Variable(100, name='global_step', dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + baseline_estimator = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + predict_input_fn = numpy_io.numpy_input_fn( + # x shape=[batch_size, x_dim] + x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predictions = baseline_estimator.predict(input_fn=predict_input_fn) + predicted_scores = list([x['predictions'] for x in predictions]) + # score = bias, shape=[batch_size, label_dimension] + self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], + predicted_scores) + + +class BaselineEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn, + input_dimension, label_dimension, prediction_length): + feature_columns = [ + feature_column_lib.numeric_column('x', shape=(input_dimension,)) + ] + est = _baseline_estimator_fn( + label_dimension=label_dimension, + model_dir=self._model_dir) + + # TRAIN + # learn y = x + est.train(train_input_fn, steps=200) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores)) + + # PREDICT + predictions = np.array( + [x['predictions'] for x in est.predict(predict_input_fn)]) + self.assertAllEqual((prediction_length, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column_lib.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + input_dimension = label_dimension + batch_size = 10 + prediction_length = batch_size + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=None, + batch_size=batch_size, + num_epochs=1, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=input_dimension, + label_dimension=label_dimension, + prediction_length=prediction_length) + + +class BaselineEstimatorTrainingTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _mock_optimizer(self, expected_loss=None): + expected_var_names = [ + '%s:0' % BIAS_NAME + ] + + def _minimize(loss, global_step=None, var_list=None): + trainable_vars = var_list or ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + # Verify loss. We can't check the value directly, so we add an assert op. + self.assertEquals(0, loss.shape.ndims) + if expected_loss is None: + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + assert_loss = assert_close( + math_ops.to_float(expected_loss, name='expected'), + loss, + name='assert_loss') + with ops.control_dependencies((assert_loss,)): + if global_step is not None: + return distribute_lib.increment_var(global_step) + return control_flow_ops.no_op() + + mock_optimizer = test.mock.NonCallableMock( + spec=optimizer.Optimizer, + wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer')) + mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize) + + # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks. + # So, return mock_optimizer itself for deepcopy. + mock_optimizer.__deepcopy__ = lambda _: mock_optimizer + return mock_optimizer + + def _assert_checkpoint(self, + label_dimension, + expected_global_step, + expected_bias=None): + shapes = { + name: shape + for (name, shape) in checkpoint_utils.list_variables(self._model_dir) + } + + self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP]) + self.assertEqual(expected_global_step, + checkpoint_utils.load_variable(self._model_dir, + ops.GraphKeys.GLOBAL_STEP)) + + self.assertEqual([label_dimension], shapes[BIAS_NAME]) + if expected_bias is not None: + self.assertEqual(expected_bias, + checkpoint_utils.load_variable(self._model_dir, + BIAS_NAME)) + + def testFromScratch(self): + # Create BaselineRegressor. + label = 5. + age = 17 + # loss = (logits - label)^2 = (0 - 5.)^2 = 25. + mock_optimizer = self._mock_optimizer(expected_loss=25.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=num_steps, + expected_bias=[0.]) + + def testFromCheckpoint(self): + # Create initial checkpoint. + bias = 7.0 + initial_global_step = 100 + with ops.Graph().as_default(): + variables.Variable([bias], name=BIAS_NAME) + variables.Variable( + initial_global_step, + name=ops.GraphKeys.GLOBAL_STEP, + dtype=dtypes.int64) + save_variables_to_ckpt(self._model_dir) + + # logits = bias = 6. + # loss = (logits - label)^2 = (7 - 5)^2 = 4 + mock_optimizer = self._mock_optimizer(expected_loss=4.) + baseline_estimator = _baseline_estimator_fn( + model_dir=self._model_dir, + optimizer=mock_optimizer) + self.assertEqual(0, mock_optimizer.minimize.call_count) + + # Train for a few steps, and validate optimizer and final checkpoint. + num_steps = 10 + baseline_estimator.train( + input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps) + self.assertEqual(1, mock_optimizer.minimize.call_count) + self._assert_checkpoint( + label_dimension=1, + expected_global_step=initial_global_step + num_steps, + expected_bias=[bias]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 314c54ed003..bd641014e9e 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -17,10 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +def _validate_input_fn_and_repeat_dataset(train_input_fn): + """Validates whether the input_fn is valid, and repeat() if tf.Dataset.""" + def _input_fn(): + result_input_fn = train_input_fn() + if isinstance(result_input_fn, dataset_ops.Dataset): + return result_input_fn.repeat() + return result_input_fn + + return _input_fn + + class _BoostedTreesEstimator(estimator.Estimator): """An Estimator for Tensorflow Boosted Trees models.""" @@ -36,6 +48,7 @@ class _BoostedTreesEstimator(estimator.Estimator): l1_regularization=0., l2_regularization=0., tree_complexity=0., + min_node_weight=0., config=None): """Initializes a `BoostedTreesEstimator` instance. @@ -65,13 +78,16 @@ class _BoostedTreesEstimator(estimator.Estimator): l2_regularization: regularization multiplier applied to the square weights of the tree leafs. tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. """ # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity) + tree_complexity, min_node_weight) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -96,6 +112,7 @@ def boosted_trees_classifier_train_in_memory( l1_regularization=0., l2_regularization=0., tree_complexity=0., + min_node_weight=0., config=None, train_hooks=None): """Trains a boosted tree classifier with in memory dataset. @@ -108,10 +125,13 @@ def boosted_trees_classifier_train_in_memory( bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) - def input_fn_train(): + def train_input_fn(): dataset = create-dataset-from-training-data - # Don't use repeat or cache, since it is assumed to be one epoch - # This is either tf.data.Dataset, or a tuple of feature dict and label. + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. return dataset classifier = boosted_trees_classifier_train_in_memory( @@ -162,6 +182,9 @@ def boosted_trees_classifier_train_in_memory( l2_regularization: regularization multiplier applied to the square weights of the tree leafs. tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. train_hooks: a list of Hook instances to be passed to estimator.train(). @@ -184,7 +207,7 @@ def boosted_trees_classifier_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity) + tree_complexity, min_node_weight) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -202,7 +225,9 @@ def boosted_trees_classifier_train_in_memory( in_memory_classifier = estimator.Estimator( model_fn=_model_fn, model_dir=model_dir, config=config) - in_memory_classifier.train(input_fn=train_input_fn, hooks=train_hooks) + in_memory_classifier.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) return in_memory_classifier # pylint: enable=protected-access @@ -220,6 +245,7 @@ def boosted_trees_regressor_train_in_memory( l1_regularization=0., l2_regularization=0., tree_complexity=0., + min_node_weight=0., config=None, train_hooks=None): """Trains a boosted tree regressor with in memory dataset. @@ -232,10 +258,13 @@ def boosted_trees_regressor_train_in_memory( bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) - def input_fn_train(): + def train_input_fn(): dataset = create-dataset-from-training-data - # Don't use repeat or cache, since it is assumed to be one epoch - # This is either tf.data.Dataset, or a tuple of feature dict and label. + # This is tf.data.Dataset of a tuple of feature dict and label. + # e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}), + # Dataset.from_tensors(label_array))) + # The returned Dataset shouldn't be batched. + # If Dataset repeats, only the first repetition would be used for training. return dataset regressor = boosted_trees_regressor_train_in_memory( @@ -279,6 +308,9 @@ def boosted_trees_regressor_train_in_memory( l2_regularization: regularization multiplier applied to the square weights of the tree leafs. tree_complexity: regularization factor to penalize trees with more leaves. + min_node_weight: minimum hessian a node must have for a split to be + considered. The value will be compared with sum(leaf_hessian)/ + (batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. train_hooks: a list of Hook instances to be passed to estimator.train(). @@ -300,7 +332,7 @@ def boosted_trees_regressor_train_in_memory( # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity) + tree_complexity, min_node_weight) def _model_fn(features, labels, mode, config): return canned_boosted_trees._bt_model_fn( @@ -317,7 +349,9 @@ def boosted_trees_regressor_train_in_memory( in_memory_regressor = estimator.Estimator( model_fn=_model_fn, model_dir=model_dir, config=config) - in_memory_regressor.train(input_fn=train_input_fn, hooks=train_hooks) + in_memory_regressor.train( + input_fn=_validate_input_fn_and_repeat_dataset(train_input_fn), + hooks=train_hooks) return in_memory_regressor # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index eee59106876..76cbefe5e94 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -21,6 +21,7 @@ import numpy as np from tensorflow.contrib.estimator.python.estimator import boosted_trees from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.feature_column import feature_column @@ -49,12 +50,24 @@ def _make_train_input_fn(is_classification): """Makes train input_fn for classification/regression.""" def _input_fn(): - features = dict(FEATURES_DICT) - if is_classification: - labels = CLASSIFICATION_LABELS - else: - labels = REGRESSION_LABELS - return features, labels + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + return features_dict, labels + + return _input_fn + + +def _make_train_input_fn_dataset(is_classification): + """Makes input_fn using Dataset.""" + + def _input_fn(): + features_dict = dict(FEATURES_DICT) + labels = CLASSIFICATION_LABELS if is_classification else REGRESSION_LABELS + ds = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(features_dict), + dataset_ops.Dataset.from_tensors(labels) + )) + return ds return _input_fn @@ -132,15 +145,13 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_classifier_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5) + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - # Check eval. + # Check evaluate and predict. eval_res = est.evaluate(input_fn=train_input_fn, steps=1) self.assertAllClose(eval_res['accuracy'], 1.0) # Validate predictions. @@ -148,24 +159,59 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertAllClose([[0], [1], [1], [0], [0]], [pred['class_ids'] for pred in predictions]) + def testBinaryClassifierTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=True) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_classifier_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['accuracy'], 1.0) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose([[0], [1], [1], [0], [0]], + [pred['class_ids'] for pred in predictions]) + def testRegressorTrainInMemoryAndEvalAndInfer(self): train_input_fn = _make_train_input_fn(is_classification=False) predict_input_fn = numpy_io.numpy_input_fn( x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) est = boosted_trees.boosted_trees_regressor_train_in_memory( - train_input_fn=train_input_fn, - feature_columns=self._feature_columns, - n_trees=1, - max_depth=5) + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) # It will stop after 5 steps because of the max depth and num trees. self._assert_checkpoint( est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) - # Check eval. + # Check evaluate and predict. + eval_res = est.evaluate(input_fn=train_input_fn, steps=1) + self.assertAllClose(eval_res['average_loss'], 2.478283) + predictions = list(est.predict(input_fn=predict_input_fn)) + self.assertAllClose( + [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], + [pred['predictions'] for pred in predictions]) + + def testRegressorTrainInMemoryWithDataset(self): + train_input_fn = _make_train_input_fn_dataset(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees.boosted_trees_regressor_train_in_memory( + train_input_fn=train_input_fn, feature_columns=self._feature_columns, + n_trees=1, max_depth=5) + # It will stop after 5 steps because of the max depth and num trees. + self._assert_checkpoint( + est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5) + # Check evaluate and predict. eval_res = est.evaluate(input_fn=train_input_fn, steps=1) self.assertAllClose(eval_res['average_loss'], 2.478283) - # Validate predictions. predictions = list(est.predict(input_fn=predict_input_fn)) self.assertAllClose( [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]], diff --git a/tensorflow/contrib/estimator/python/estimator/dnn.py b/tensorflow/contrib/estimator/python/estimator/dnn.py index cf6e3329d2e..7ff25b95c07 100644 --- a/tensorflow/contrib/estimator/python/estimator/dnn.py +++ b/tensorflow/contrib/estimator/python/estimator/dnn.py @@ -93,7 +93,7 @@ class DNNEstimator(estimator.Estimator): dropout=None, input_layer_partitioner=None, config=None): - """Initializes a `DNNClassifier` instance. + """Initializes a `DNNEstimator` instance. Args: head: A `_Head` instance constructed with a method such as diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py new file mode 100644 index 00000000000..03cf6f107c1 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export.py @@ -0,0 +1,223 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for methods to export train/eval graphs from Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import model_fn as model_fn_lib + + +def export_saved_model_for_mode( + estimator, export_dir_base, input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False, + mode=model_fn_lib.ModeKeys.PREDICT): + # pylint: disable=line-too-long + """Exports a single train/eval/predict graph as a SavedModel. + + For a detailed guide, see + @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn, steps=1000) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + export_dir = tf.contrib.estimator.export_saved_model_for_mode( + classifier, + export_dir_base='my_model/', + input_receiver_fn=train_rcvr_fn, + mode=model_fn_lib.ModeKeys.TRAIN) + + # export_dir is a timestamped directory with the SavedModel, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name(''linear/linear_model/age/weights') + ... + ``` + + This method is a wrapper for _export_all_saved_models, and wraps a raw + input_receiver_fn in a dictionary to pass in to that function. + See _export_all_saved_models for full docs. + + See tf.contrib.estimator.export_saved_model_for_mode for the currently + exposed version of this function. + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn: a function that takes no argument and + returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + mode: tf.estimator.ModeKeys value indicating with mode will be exported. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_saved_model_for_mode( + export_dir_base, input_receiver_fn, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs, + mode=mode) + # pylint: enable=protected-access + + +def export_all_saved_models( + estimator, export_dir_base, input_receiver_fn_map, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + # pylint: disable=line-too-long + """Exports requested train/eval/predict graphs as separate SavedModels. + + See tf.contrib.estimator.export_all_saved_models for the currently + exposed version of this function. + + For each mode passed in via the input_receiver_fn_map, + this method builds a new graph by calling the input_receiver_fn to obtain + feature and label `Tensor`s. Next, this method calls the `Estimator`'s + model_fn in the passed mode to generate the model graph based on + those features and labels, and restores the given checkpoint + (or, lacking that, the most recent checkpoint) into the graph. + Only one of the modes is used for saving variables to the SavedModel + (order of preference: TRAIN, EVAL, then PREDICT), such that up to three + MetaGraphDefs are saved with a single set of variables in a single + SavedModel directory. + + For prediction, the exported `MetaGraphDef` will provide one `SignatureDef` + for each element of the export_outputs dict returned from the model_fn, + named using the same keys. One of these keys is always + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which + signature will be served when a serving request does not specify one. + For each signature, the outputs are provided by the corresponding + `ExportOutput`s, and the inputs are always the input receivers provided by + the serving_input_receiver_fn. + + For training and evaluation, the train_op is stored in an extra collection, + and loss, metrics, and predictions are included in a SignatureDef for the + mode in question. + + Extra assets may be written into the SavedModel via the assets_extra + argument. This should be a dict, where each key gives a destination path + (including the filename) relative to the assets.extra directory. The + corresponding value gives the full path of the source file to be copied. + For example, the simple case of copying a single file without renaming it + is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + + Sample usage: + ```python + classifier = tf.estimator.LinearClassifier( + feature_columns=[age, language]) + classifier.train(input_fn=input_fn) + + feature_spec = { + 'age': tf.placeholder(dtype=tf.int64), + 'language': array_ops.placeholder(dtype=tf.string) + } + label_spec = tf.placeholder(dtype=dtypes.int64) + + train_rcvr_fn = tf.contrib.estimator.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + serve_rcvr_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn( + feature_spec) + + rcvr_fn_map = { + model_fn_lib.ModeKeys.TRAIN: train_rcvr_fn, + model_fn_lib.ModeKeys.PREDICT: serve_rcvr_fn, + } + + export_dir = tf.contrib.estimator.export_all_saved_models( + classifier, + export_dir_base='my_model/', + input_receiver_fn_map=rcvr_fn_map) + + # export_dirs is a dict of directories with SavedModels, which + # can be used for serving, analysis with TFMA, or directly loaded in. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + weights = graph.get_tensor_by_name('linear/linear_model/age/weights') + ... + ``` + + Args: + estimator: an instance of tf.estimator.Estimator + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn + mappings, where the input_receiver_fn is a function that takes no + argument and returns the appropriate subclass of `InputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). + + Returns: + A dict of tf.estimator.ModeKeys value to string path for each exported + directory. + + Raises: + ValueError: if any input_receiver_fn is None, no export_outputs + are provided, or no checkpoint can be found. + """ + # pylint: enable=line-too-long + + # pylint: disable=protected-access + return estimator._export_all_saved_models( + export_dir_base, input_receiver_fn_map, + assets_extra=assets_extra, + as_text=as_text, + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) + # pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/export_test.py b/tensorflow/contrib/estimator/python/estimator/export_test.py new file mode 100644 index 00000000000..050821ee672 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/export_test.py @@ -0,0 +1,373 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib wrapping of export_saved_model_for_mode functionality. + +These are direct copies of the tests included in core, with import locations +changed. These should be removed when the functionality in core is part of the +public API. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from tensorflow.contrib.estimator.python.estimator import export as contrib_export +from tensorflow.python.client import session +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import training +from tensorflow.python.util import compat + + +def _model_fn_for_export_tests(features, labels, mode): + _, _ = features, labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + update_global_step = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([update_global_step]): + train_op = constant_op.constant(2.) + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=train_op, + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + + +def _x_y_input_fn(): + return ({'x': constant_op.constant([[1], [1]]), + 'y': constant_op.constant([[2], [2]])}, + constant_op.constant([[1], [1]])) + + +def _model_fn_with_x_y(features, labels, mode): + _ = labels + variables.Variable(1., name='weight') + scores = constant_op.constant([3.]) + classes = constant_op.constant(['wumpus']) + if mode == model_fn_lib.ModeKeys.PREDICT: + variables.Variable(36., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + export_outputs={ + 'test': export_output.ClassificationOutput(scores, classes)}) + else: + prefix = 'eval_' if mode == model_fn_lib.ModeKeys.EVAL else '' + + multiplied = math_ops.multiply( + features['x'], features['y'], name='{}multiplied'.format(prefix)) + metrics = {'mean': metrics_lib.mean(features['x'] - features['y'], + name='{}mean'.format(prefix))} + variables.Variable(1., name='later_var') + variables.Variable(3., name='name_collision') + return model_fn_lib.EstimatorSpec( + mode, + predictions=multiplied, + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + +def _get_serving_input_receiver_fn(): + feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64), + 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)} + return export.build_parsing_serving_input_receiver_fn(feature_spec) + + +def _get_supervised_input_receiver_fn(): + feature_spec = { + 'x': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_x'), + 'y': array_ops.placeholder( + dtype=dtypes.int64, shape=(2, 1), name='feature_y') + } + label_spec = array_ops.placeholder( + dtype=dtypes.float32, shape=[1], name='truth') + + return export.build_raw_supervised_input_receiver_fn( + feature_spec, label_spec) + + +class EstimatorExportTest(test.TestCase): + + def test_export_saved_model_train(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.TRAIN) + + def test_export_saved_model_eval(self): + self._test_export_saved_model_for_mode( + _get_supervised_input_receiver_fn(), model_fn_lib.ModeKeys.EVAL) + + def test_export_saved_model_predict(self): + self._test_export_saved_model_for_mode( + _get_serving_input_receiver_fn(), model_fn_lib.ModeKeys.PREDICT) + + def _test_export_saved_model_for_mode(self, input_receiver_fn, mode): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_for_export_tests) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_saved_model_for_mode( + est, export_dir_base, input_receiver_fn, mode=mode) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + self._validate_exported_files(export_dir) + + # Restore, to validate that the export was well-formed. + tag_set = model_fn_lib.EXPORT_TAG_MAP[mode] + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('name_collision_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_receiver_map(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertFalse('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_train_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertTrue('mean/update_op' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_eval_only(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertTrue('eval_mean/value' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_no_serving(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('multiplied' in graph_ops) + self.assertFalse('eval_multiplied' in graph_ops) + self.assertTrue('feature_x' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.EVAL], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('eval_multiplied' in graph_ops) + self.assertFalse('multiplied' in graph_ops) + # TODO(karmel): is this the desired behavior when names are shared? + self.assertTrue('feature_x_1' in graph_ops) + self.assertTrue('feature_y_1' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_three_defs(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + for tag_set in model_fn_lib.EXPORT_TAG_MAP.values(): + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, tag_set, export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('global_step/Assign' in graph_ops) + self.assertTrue('global_step/Initializer/zeros' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_proto_roundtrip_all_vars(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertFalse('later_var' in graph_ops) + self.assertTrue('weight' in graph_ops) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def test_export_all_saved_models_name_collision(self): + input_receiver_fn_map = { + model_fn_lib.ModeKeys.TRAIN: _get_supervised_input_receiver_fn(), + model_fn_lib.ModeKeys.PREDICT: _get_serving_input_receiver_fn() + } + export_dir, tmpdir = self._test_export_all_saved_models( + input_receiver_fn_map) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.TRAINING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertEqual(3, collection_vars[-1].eval()) + + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('name_collision' in graph_ops) + self.assertFalse('name_collision_1' in graph_ops) + collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # This is a non-obvious detail: when we load the estimator spec + # for predict, name_collision gets set to 36. However, we then restore + # from checkpoint, which should overwrite that var and make it the 3 + # from training. In practice, this would not be a good way to write + # a model_fn, but leaving this check in for now to ensure consistency + # with what would happen given our current order of spec, then + # checkpoint. + self.assertEqual(3, collection_vars[-1].eval()) + + # Clean up. + gfile.DeleteRecursively(tmpdir) + + def _test_export_all_saved_models(self, input_receiver_fn_map): + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn_with_x_y) + est.train(input_fn=_x_y_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('export')) + export_dir = contrib_export.export_all_saved_models( + est, export_dir_base, input_receiver_fn_map) + + # Check that all the files are in the right places. + self.assertTrue(gfile.Exists(export_dir_base)) + + self._validate_exported_files(export_dir) + + return export_dir, tmpdir + + def _validate_exported_files(self, export_dir): + self.assertTrue(gfile.Exists(export_dir)) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 201699ed775..bf08be09e7b 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -22,12 +22,12 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import function_utils _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): - args = set(estimator_util.fn_args(metric_fn)) + args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % @@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn): def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" - metric_fn_args = estimator_util.fn_args(metric_fn) + metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index ae2fd8b4902..8b97f86db19 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys @@ -72,6 +74,33 @@ def multi_class_head(n_classes, shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_class_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -139,6 +168,33 @@ def binary_classification_head( shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.binary_classification_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.binary_classification_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -205,11 +261,39 @@ def regression_head(weight_column=None, shape `[D0, D1, ... DN, label_dimension]`. Also supports custom `inverse_link_fn`, also known as 'mean function'. - `inverse_link_fn` takes `logits` as argument and returns predicted values. - This function is the inverse of the link function defined in + `inverse_link_fn` is only used in `PREDICT` mode. It takes `logits` as + argument and returns predicted values. This function is the inverse of the + link function defined in https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -234,7 +318,7 @@ def regression_head(weight_column=None, Raises: ValueError: If `label_dimension` or `loss_reduction` is invalid. """ - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, @@ -269,6 +353,33 @@ def poisson_regression_head( This is implemented as a generalized linear model, see https://en.wikipedia.org/wiki/Generalized_linear_model. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.poisson_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.poisson_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -296,7 +407,7 @@ def poisson_regression_head( def _poisson_loss(labels, logits): return nn.log_poisson_loss( targets=labels, log_input=logits, compute_full_loss=compute_full_loss) - return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + return head_lib._regression_head( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, loss_reduction=loss_reduction, @@ -305,12 +416,103 @@ def poisson_regression_head( name=name) +def logistic_regression_head( + weight_column=None, + loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, + name=None): + """Creates a `_Head` for logistic regression. + + Uses `sigmoid_cross_entropy_with_logits` loss, which is the same as + `binary_classification_head`. The differences compared to + `binary_classification_head` are: + + * Does not support `label_vocabulary`. Instead, labels must be float in the + range [0, 1]. + * Does not calculate some metrics that do not make sense, such as AUC. + * In `PREDICT` mode, only returns logits and predictions + (`=tf.sigmoid(logits)`), whereas `binary_classification_head` also returns + probabilities, classes, and class_ids. + * Export output defaults to `RegressionOutput`, whereas + `binary_classification_head` defaults to `PredictOutput`. + + The head expects `logits` with shape `[D0, D1, ... DN, 1]`. + In many applications, the shape is `[batch_size, 1]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`. + + This is implemented as a generalized linear model, see + https://en.wikipedia.org/wiki/Generalized_linear_model. + + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.logistic_regression_head() + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.logistic_regression_head() + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch and label dimension. Defaults to + `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by + `batch size * label_dimension`. See `tf.losses.Reduction`. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. + + Returns: + An instance of `_Head` for logistic regression. + + Raises: + ValueError: If `loss_reduction` is invalid. + """ + def _logistic_loss(labels, logits): + labels = head_lib._assert_range( # pylint:disable=protected-access + labels, n_classes=2, message='Labels must be in range [0, 1]') + return nn.sigmoid_cross_entropy_with_logits( + labels=labels, logits=logits) + return head_lib._regression_head( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=1, + loss_reduction=loss_reduction, + loss_fn=_logistic_loss, + inverse_link_fn=math_ops.sigmoid, + name=name) + + def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): """Creates a `_Head` for multi-label classification. @@ -342,6 +544,33 @@ def multi_label_head(n_classes, shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies `label_vocabulary` to the input labels before passing them to `loss_fn`. + The head can be used with a canned estimator. Example: + + ```python + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + my_estimator = tf.contrib.estimator.DNNEstimator( + head=my_head, + hidden_units=..., + feature_columns=...) + ``` + + It can also be used with a custom `model_fn`. Example: + + ```python + def _my_model_fn(features, labels, mode): + my_head = tf.contrib.estimator.multi_label_head(n_classes=3) + logits = tf.keras.Model(...)(features) + + return my_head.create_estimator_spec( + features=features, + mode=mode, + labels=labels, + optimizer=tf.AdagradOptimizer(learning_rate=0.1), + logits=logits) + + my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn) + ``` + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -363,6 +592,10 @@ def multi_label_head(n_classes, reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See `tf.losses.Reduction`. loss_fn: Optional loss function. + classes_for_class_based_metrics: List of integer class IDs or string class + names for which per-class metrics are evaluated. If integers, all must be + in the range `[0, n_classes - 1]`. If strings, all must be in + `label_vocabulary`. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -370,8 +603,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is - invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or + `metric_class_ids` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -396,10 +629,31 @@ def multi_label_head(n_classes, if (loss_reduction not in losses.Reduction.all() or loss_reduction == losses.Reduction.NONE): raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) + classes_for_class_based_metrics = tuple( + [] if classes_for_class_based_metrics is None + else classes_for_class_based_metrics) + if classes_for_class_based_metrics: + if isinstance(classes_for_class_based_metrics[0], six.string_types): + if not label_vocabulary: + raise ValueError( + 'label_vocabulary must be provided when ' + 'classes_for_class_based_metrics are sting.') + class_ids = [] + for class_string in classes_for_class_based_metrics: + class_ids.append(label_vocabulary.index(class_string)) + classes_for_class_based_metrics = tuple(class_ids) + else: + for class_id in classes_for_class_based_metrics: + if (class_id < 0) or (class_id >= n_classes): + raise ValueError( + 'All classes_for_class_based_metrics must be in range [0, {}]. ' + 'Given: {}'.format(n_classes - 1, class_id)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, - loss_fn=loss_fn, name=name) + loss_fn=loss_fn, + classes_for_class_based_metrics=classes_for_class_based_metrics, + name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -412,6 +666,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access label_vocabulary=None, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None, + classes_for_class_based_metrics=None, name=None): self._n_classes = n_classes self._weight_column = weight_column @@ -419,6 +674,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access self._label_vocabulary = label_vocabulary self._loss_reduction = loss_reduction self._loss_fn = loss_fn + self._classes_for_class_based_metrics = classes_for_class_based_metrics self._name = name @property @@ -485,7 +741,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access reduction=losses.Reduction.NONE) # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( - unweighted_loss, axis=-1, keep_dims=True) + unweighted_loss, axis=-1, keepdims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) training_loss = losses.compute_weighted_loss( @@ -496,10 +752,10 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, processed_labels=processed_labels) - def create_estimator_spec( + def _create_tpu_estimator_spec( self, features, mode, logits, labels=None, optimizer=None, train_op_fn=None, regularization_losses=None): - """Returns an `EstimatorSpec`. + """Returns an `model_fn._TPUEstimatorSpec`. Args: features: Input `dict` of `Tensor` or `SparseTensor` objects. @@ -522,7 +778,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to avoid scaling errors. Returns: - `EstimatorSpec`. + `model_fn._TPUEstimatorSpec`. Raises: ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN mode, or if both are set. @@ -542,7 +798,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access classifier_output = head_lib._classification_output( # pylint:disable=protected-access scores=probabilities, n_classes=self._n_classes, label_vocabulary=self._label_vocabulary) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ @@ -565,16 +821,18 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Eval. if mode == model_fn.ModeKeys.EVAL: - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.EVAL, predictions=predictions, loss=regularized_training_loss, - eval_metric_ops=self._eval_metric_ops( - labels=processed_labels, - probabilities=probabilities, - weights=weights, - unreduced_loss=unreduced_loss, - regularization_loss=regularization_loss)) + eval_metrics=head_lib._create_eval_metrics_tuple( # pylint:disable=protected-access + self._eval_metric_ops, { + 'labels': processed_labels, + 'probabilities': probabilities, + 'weights': weights, + 'unreduced_loss': unreduced_loss, + 'regularization_loss': regularization_loss, + })) # Train. if optimizer is not None: @@ -608,7 +866,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access summary.scalar( head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access regularization_loss) - return model_fn.EstimatorSpec( + return model_fn._TPUEstimatorSpec( # pylint:disable=protected-access mode=model_fn.ModeKeys.TRAIN, predictions=predictions, loss=regularized_training_loss, @@ -671,4 +929,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, threshold=threshold, name=recall_key)) + for class_id in self._classes_for_class_based_metrics: + batch_rank = array_ops.rank(probabilities) - 1 + begin = array_ops.concat( + [array_ops.zeros([batch_rank], dtype=dtypes.int32), [class_id]], + axis=0) + size = array_ops.concat( + [-1 * array_ops.ones([batch_rank], dtype=dtypes.int32), [1]], + axis=0) + class_probabilities = array_ops.slice( + probabilities, begin=begin, size=size) + class_labels = array_ops.slice(labels, begin=begin, size=size) + prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, prob_key)] = ( # pylint:disable=protected-access + head_lib._predictions_mean( # pylint:disable=protected-access + predictions=class_probabilities, + weights=weights, + name=prob_key)) + auc_key = keys.AUC_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + name=auc_key)) + auc_pr_key = keys.AUC_PR_AT_CLASS % class_id + metric_ops[head_lib._summary_key(self._name, auc_pr_key)] = ( # pylint:disable=protected-access + head_lib._auc( # pylint:disable=protected-access + labels=class_labels, + predictions=class_probabilities, + weights=weights, + curve='PR', + name=auc_pr_key)) return metric_ops diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 98962ca4277..d6c158608b5 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -175,6 +175,21 @@ class MultiLabelHead(test.TestCase): r'loss_fn has unexpected args: \[\'name\'\]'): head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_classes_for_class_based_metrics_invalid(self): + with self.assertRaisesRegexp( + ValueError, + r'All classes_for_class_based_metrics must be in range \[0, 2\]\. ' + r'Given: -1'): + head_lib.multi_label_head( + n_classes=3, classes_for_class_based_metrics=[2, -1]) + + def test_classes_for_class_based_metrics_string_invalid(self): + with self.assertRaisesRegexp( + ValueError, r'\'z\' is not in list'): + head_lib.multi_label_head( + n_classes=3, label_vocabulary=['a', 'b', 'c'], + classes_for_class_based_metrics=['c', 'z']) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -591,6 +606,81 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_classes_for_class_based_metrics(self): + head = head_lib.multi_label_head( + n_classes=2, classes_for_class_based_metrics=[0, 1]) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + + def test_eval_with_classes_for_class_based_metrics_string(self): + head = head_lib.multi_label_head( + n_classes=2, label_vocabulary=['a', 'b'], + classes_for_class_based_metrics=['a', 'b']) + + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = sparse_tensor.SparseTensor( + values=['a', 'a', 'b'], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # Sum over examples, divide by batch_size. + expected_loss = 0.5 * np.sum( + _sigmoid_cross_entropy(labels=labels_onehot, logits=logits)) + + keys = metric_keys.MetricKeys + expected_metrics = { + # Average loss over examples. + keys.LOSS_MEAN: expected_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + keys.PROBABILITY_MEAN_AT_CLASS % 0: np.sum(_sigmoid(logits[:, 0])) / 2., + keys.AUC_AT_CLASS % 0: 0., + keys.AUC_PR_AT_CLASS % 0: 1., + keys.PROBABILITY_MEAN_AT_CLASS % 1: np.sum(_sigmoid(logits[:, 1])) / 2., + keys.AUC_AT_CLASS % 1: 1., + keys.AUC_PR_AT_CLASS % 1: 1., + } + + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + def test_eval_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') @@ -1211,5 +1301,124 @@ class PoissonRegressionHead(test.TestCase): self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) +class LogisticRegressionHead(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def test_train(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [.6], [.8]], dtype=np.float32) + # Following the documentation in + # tf.nn.sigmoid_cross_entropy_with_logits: + # With x = logits, z = labels. + # loss = max(x, 0) - x * z + log(1 + exp(-abs(x))) + # loss = [0 - 0 * 0.4 + ln(1 + exp(-0)), + # 0 + 1 * 0.6 + ln(1 + exp(-1)), + # 1 - 1 * 0.8 + ln(1 + exp(-1))] + # = [0.6931, 0.9133, 0.5133] + # training_loss = (0.6931 + 0.9133 + 0.5133) / 3 + expected_loss = 0.7066 + atol = 0.001 + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + with ops.control_dependencies((check_ops.assert_near( + math_ops.to_float(expected_loss), math_ops.to_float(loss), + atol=atol, name='assert_loss'),)): + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + loss, train_result = sess.run([spec.loss, spec.train_op]) + self.assertAlmostEqual(expected_loss, loss, delta=atol) + self.assertEqual(expected_train_result, train_result) + + def test_train_labels_too_large(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [1.2], [.8]], dtype=np.float32) + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[1.2\]\[0.8\]\]'): + _ = sess.run(spec.loss) + + def test_train_labels_negative(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + labels = np.array([[.4], [-0.2], [.8]], dtype=np.float32) + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return constant_op.constant(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[Labels must be in range \[0, 1\]\] .* \[\[0.4\]\[-0.2\]\[0.8\]\]' + ): + _ = sess.run(spec.loss) + + def test_predict(self): + head = head_lib.logistic_regression_head() + + # Create estimator spec. + logits = np.array([[0], [-1], [1]], dtype=np.float32) + expected_predictions = 1. / (1. + np.exp(-logits)) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py new file mode 100644 index 00000000000..4808b9ee30e --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Some useful session run hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.training import training + + +# pylint: disable=protected-access +class InMemoryEvaluatorHook(training.SessionRunHook): + """Hook to run evaluation in training without a checkpoint. + + Example: + + ```python + def train_input_fn(): + ... + return train_dataset + + def eval_input_fn(): + ... + return eval_dataset + + estimator = tf.estimator.DNNClassifier(...) + + evaluator = tf.contrib.estimator.InMemoryEvaluatorHook( + estimator, eval_input_fn) + estimator.train(train_input_fn, hooks=[evaluator]) + ``` + + Current limitations of this approach are: + * It doesn't support multi-node distributed mode. + * It doesn't support saveable objects other than variables (such as boosted + tree support) + * It doesn't support custom saver logic (such as ExponentialMovingAverage + support) + + """ + + def __init__(self, + estimator, + input_fn, + steps=None, + hooks=None, + name=None, + every_n_iter=100): + """Initializes a `InMemoryEvaluatorHook`. + + Args: + estimator: A `tf.estimator.Estimator` instance to call evaluate. + input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A + function that constructs the input data for evaluation. + See @{$get_started/premade_estimators#create_input_functions} for more + information. The function should construct and return one of + the following: + + * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a + tuple (features, labels) with same constraints as below. + * A tuple (features, labels): Where `features` is a `Tensor` or a + dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + + steps: Equivalent to the `steps` arg to `estimator.evaluate`. Number of + steps for which to evaluate model. If `None`, evaluates until `input_fn` + raises an end-of-input exception. + hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of + `SessionRunHook` subclass instances. Used for callbacks inside the + evaluation call. + name: Equivalent to the `name` arg to `estimator.evaluate`. Name of the + evaluation if user needs to run multiple evaluations on different data + sets, such as on training data vs test data. Metrics for different + evaluations are saved in separate folders, and appear separately in + tensorboard. + every_n_iter: `int`, runs the evaluator once every N training iteration. + + Raises: + ValueError: if `every_n_iter` is non-positive or it's not a single machine + training + """ + if every_n_iter is None or every_n_iter <= 0: + raise ValueError('invalid every_n_iter=%s.' % every_n_iter) + if (estimator.config.num_ps_replicas > 0 or + estimator.config.num_worker_replicas > 1): + raise ValueError( + 'InMemoryEvaluator supports only single machine (aka Local) setting.') + self._estimator = estimator + self._input_fn = input_fn + self._steps = steps + self._name = name + self._every_n_iter = every_n_iter + self._eval_dir = os.path.join(self._estimator.model_dir, 'eval' + if not name else 'eval_' + name) + + self._graph = None + self._hooks = estimator_lib._check_hooks_type(hooks) + self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps)) + self._timer = training.SecondOrStepTimer(every_steps=every_n_iter) + + def begin(self): + """Build eval graph and restoring op.""" + self._timer.reset() + self._iter_count = 0 + self._graph = ops.Graph() + with self._graph.as_default(): + (self._scaffold, self._update_op, self._eval_dict, + self._all_hooks) = self._estimator._evaluate_build_graph( + self._input_fn, self._hooks, checkpoint_path=None) + + if self._scaffold.saver is not None: + raise ValueError('InMemoryEvaluator does not support custom saver') + if self._scaffold.init_fn is not None: + raise ValueError('InMemoryEvaluator does not support custom init_fn') + + self._var_name_to_eval_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + self._var_name_to_placeholder = { + v.name: array_ops.placeholder(v.dtype) + for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + + def after_create_session(self, session, coord): # pylint: disable=unused-argument + """Does first run which shows the eval metrics before training.""" + if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS): + raise ValueError( + 'InMemoryEvaluator does not support saveables other than global ' + 'variables.') + self._var_name_to_train_var = { + v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + } + var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set( + self._var_name_to_train_var.keys()) + # Filter training var names that are not exist in evaluation + self._var_name_to_train_var = { + v_name: self._var_name_to_train_var[v_name] + for v_name in var_names_to_transfer + } + # Filter eval var names that are not exist in training + self._var_name_to_eval_var = { + v_name: self._var_name_to_eval_var[v_name] + for v_name in var_names_to_transfer + } + + with self._graph.as_default(): + self._var_feed_op = control_flow_ops.group([ + state_ops.assign(self._var_name_to_eval_var[v_name], + self._var_name_to_placeholder[v_name]) + for v_name in var_names_to_transfer + ]) + + self._evaluate(session) + + def _evaluate(self, train_session): + var_name_to_value = train_session.run(self._var_name_to_train_var) + placeholder_to_value = { + self._var_name_to_placeholder[v_name]: var_name_to_value[v_name] + for v_name in var_name_to_value + } + + def feed_variables(scaffold, session): + del scaffold + session.run(self._var_feed_op, feed_dict=placeholder_to_value) + + scaffold = training.Scaffold( + init_fn=feed_variables, copy_from_scaffold=self._scaffold) + + with self._graph.as_default(): + return self._estimator._evaluate_run( + checkpoint_path=None, + scaffold=scaffold, + update_op=self._update_op, + eval_dict=self._eval_dict, + all_hooks=self._all_hooks, + output_dir=self._eval_dir) + + self._timer.update_last_triggered_step(self._iter_count) + + def after_run(self, run_context, run_values): # pylint: disable=unused-argument + """Runs evaluator.""" + self._iter_count += 1 + if self._timer.should_trigger_for_step(self._iter_count): + self._evaluate(run_context.session) + + def end(self, session): # pylint: disable=unused-argument + """Runs evaluator for final model.""" + self._evaluate(session) + + +# pylint: enable=protected-access diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py new file mode 100644 index 00000000000..95ae971852e --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py @@ -0,0 +1,318 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import json +import os + +from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.feature_column import feature_column as feature_column_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.summary import summary_iterator +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import training + + +def summary_step_keyword_to_value_mapping(dir_): + writer_cache.FileWriterCache.clear() + + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + step_keyword_to_value = {} + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step not in step_keyword_to_value: + step_keyword_to_value[last_event.step] = {} + if last_event.summary is not None: + for value in last_event.summary.value: + step_keyword_to_value[last_event.step][value.tag] = value.simple_value + + return step_keyword_to_value + + +def get_summary_value(dir_, step, keyword): + """Get summary value for given step and keyword.""" + + writer_cache.FileWriterCache.clear() + # Get last Event written. + event_paths = glob.glob(os.path.join(dir_, 'events*')) + print('XXX', event_paths) + for last_event in summary_iterator.summary_iterator(event_paths[-1]): + if last_event.step == step and last_event.summary is not None: + for value in last_event.summary.value: + if keyword in value.tag: + return value.simple_value + return None + + +class InMemoryEvaluatorHookTest(test.TestCase): + + def test_runs_eval_metrics(self): + + def model_fn(features, labels, mode): + _ = labels + if estimator_lib.ModeKeys.TRAIN == mode: + with ops.control_dependencies([features]): + train_op = state_ops.assign_add(training.get_global_step(), 1) + return estimator_lib.EstimatorSpec( + mode, loss=constant_op.constant(3.), train_op=train_op) + if estimator_lib.ModeKeys.EVAL == mode: + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(5.), + eval_metric_ops={'mean_of_features': metrics_lib.mean(features)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # 4.5 = sum(range(10))/10 + # before training + self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features']) + # intervals (every_n_iter=4) + self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features']) + self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features']) + # end + self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features']) + + def test_uses_latest_variable_value(self): + + def model_fn(features, labels, mode): + _ = labels + step = training.get_global_step() + w = variable_scope.get_variable( + 'w', + shape=[], + initializer=init_ops.zeros_initializer(), + dtype=dtypes.int64) + if estimator_lib.ModeKeys.TRAIN == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + step_inc = state_ops.assign_add(training.get_global_step(), 1) + with ops.control_dependencies([step_inc]): + assign_w_to_step_plus_2 = w.assign(step + 2) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + train_op=assign_w_to_step_plus_2) + if estimator_lib.ModeKeys.EVAL == mode: + # to consume features, we have control dependency + with ops.control_dependencies([features]): + loss = constant_op.constant(5.) + return estimator_lib.EstimatorSpec( + mode, + loss=loss, + # w is constant in each step, so the mean. + # w = 0 if step==0 else step+2 + eval_metric_ops={'mean_of_const': metrics_lib.mean(w)}) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + estimator, input_fn, every_n_iter=4) + estimator.train(input_fn, hooks=[evaluator]) + + self.assertTrue(os.path.isdir(estimator.eval_dir())) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + estimator.eval_dir()) + # w = 0 if step==0 else step+2 + self.assertEqual(0, step_keyword_to_value[0]['mean_of_const']) + self.assertEqual(6, step_keyword_to_value[4]['mean_of_const']) + self.assertEqual(12, step_keyword_to_value[10]['mean_of_const']) + + def test_dnn_classifier(self): + embedding = feature_column_lib.embedding_column( + feature_column_lib.categorical_column_with_vocabulary_list( + 'wire_cast', ['kima', 'omar', 'stringer']), 8) + dnn = estimator_lib.DNNClassifier( + feature_columns=[embedding], hidden_units=[3, 1]) + + def train_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['omar'], ['kima']] + }, [[0], [1]])).repeat(3) + + def eval_input_fn(): + return dataset_ops.Dataset.from_tensors(({ + 'wire_cast': [['stringer'], ['kima']] + }, [[0], [1]])).repeat(2) + + evaluator = hooks_lib.InMemoryEvaluatorHook( + dnn, eval_input_fn, name='in-memory') + dnn.train(train_input_fn, hooks=[evaluator]) + self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory'))) + step_keyword_to_value = summary_step_keyword_to_value_mapping( + dnn.eval_dir('in-memory')) + + final_metrics = dnn.evaluate(eval_input_fn) + step = final_metrics[ops.GraphKeys.GLOBAL_STEP] + for summary_tag in final_metrics: + if summary_tag == ops.GraphKeys.GLOBAL_STEP: + continue + self.assertEqual(final_metrics[summary_tag], + step_keyword_to_value[step][summary_tag]) + + def test_raise_error_with_multi_worker(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'] + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_ps(self): + tf_config = { + 'cluster': { + run_config_lib.TaskType.CHIEF: ['host0:0'], + run_config_lib.TaskType.PS: ['host1:1'], + }, + 'task': { + 'type': run_config_lib.TaskType.CHIEF, + 'index': 0 + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + dnn = estimator_lib.DNNClassifier( + feature_columns=[feature_column_lib.numeric_column('x')], + hidden_units=[3, 1]) + + def eval_input_fn(): + pass + + with self.assertRaisesRegexp(ValueError, 'supports only single machine'): + hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn) + + def test_raise_error_with_custom_saver_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(saver=training.Saver()), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom saver'): + evaluator.begin() + + def test_raise_error_with_custom_init_fn_in_eval(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + + def init_fn(scaffold, session): + _, _ = scaffold, session + + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_fn=init_fn), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'): + evaluator.begin() + + def test_raise_error_with_saveables_other_than_global_variables(self): + + def model_fn(features, labels, mode): + _, _ = features, labels + w = variables.Variable( + initial_value=[0.], + trainable=False, + collections=[ops.GraphKeys.SAVEABLE_OBJECTS]) + init_op = control_flow_ops.group( + [w.initializer, training.get_global_step().initializer]) + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant(3.), + scaffold=training.Scaffold(init_op=init_op), + train_op=constant_op.constant(5.), + eval_metric_ops={ + 'mean_of_features': metrics_lib.mean(constant_op.constant(2.)) + }) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + + def input_fn(): + return dataset_ops.Dataset.range(10) + + evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn) + with self.assertRaisesRegexp(ValueError, 'does not support saveables'): + estimator.train(input_fn, hooks=[evaluator]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 09c2862ccd3..c8b0dd62970 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -41,10 +41,10 @@ from __future__ import print_function import six -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core from tensorflow.python.framework import ops +from tensorflow.python.util import function_utils # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder @@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ - logit_fn_args = util.fn_args(logit_fn) + logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index a8774d6dab9..cda23aa437f 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -32,7 +32,6 @@ import six from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import device_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib @@ -47,8 +46,13 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils +@deprecation.deprecated( + '2018-05-31', + 'Please use `tf.contrib.distribute.MirroredStrategy` instead.') def replicate_model_fn(model_fn, loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, devices=None): @@ -255,6 +259,9 @@ class TowerOptimizer(optimizer_lib.Optimizer): COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states' + @deprecation.deprecated( + '2018-05-31', + 'Please use `tf.contrib.distribute.MirroredStrategy` instead.') def __init__(self, optimizer_or_optimizer_fn): """Wrap an existing optimizer for gathering gradients across towers. @@ -514,7 +521,7 @@ def _get_loss_towers(model_fn, """Replicate the loss computation across devices.""" tower_specs = [] - model_fn_args = util.fn_args(model_fn) + model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index 144b45982c8..dd8a3a95f1b 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -540,59 +540,6 @@ class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( self.assertEqual(7.0, session.run(c)) -class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): - - def model_fn(self, mode, features, labels, params): - c = variable_scope.get_variable( - 'c', - initializer=constant_op.constant(10, dtype=dtypes.float64), - dtype=dtypes.float64) - - features = features['features'] - predictions = math_ops.multiply(features, c) - - loss = losses.absolute_difference( - labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) - loss = math_ops.reduce_sum(loss) - - metrics = { - 'accuracy': metrics_lib.accuracy(labels, predictions), - 'auc': metrics_lib.auc(labels, predictions) - } - - optimizer = replicate_model_fn.TowerOptimizer( - gradient_descent.GradientDescentOptimizer(params['learning_rate'])) - - return model_fn_lib.EstimatorSpec( - mode=mode, - loss=loss, - eval_metric_ops=metrics, - predictions={'probabilities': predictions}, - train_op=optimizer.minimize(loss)) - - @property - def params(self): - params = {} - params['learning_rate'] = 1.0 - return params - - def test_train_single_tower(self): - features = np.array([[1.0], [2.0]]) - labels = np.array([[1.0], [2.0]]) - - train_input_fn = numpy_io.numpy_input_fn( - x={'features': features}, y=labels, batch_size=2, shuffle=False) - - with self.test_session(): - estimator = estimator_lib.Estimator( - model_fn=self.model_fn, - model_dir=tempfile.mkdtemp(), - params=self.params) - estimator.train(train_input_fn, steps=1) - - self.assertEqual(7.0, estimator.get_variable_value('c')) - - class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase): def model_fn(self, mode, features, labels, params): diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index b475c12f5af..7c49cd00d16 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -229,6 +229,7 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, rnn_outputs, _ = rnn.dynamic_rnn( cell=cell, inputs=sequence_input, + sequence_length=sequence_length, dtype=dtypes.float32, time_major=False) last_activations = _select_last_activations(rnn_outputs, sequence_length) @@ -328,6 +329,19 @@ def _rnn_model_fn(features, logits=logits) +def _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type): + """Assert arguments are valid and return rnn_cell_fn.""" + if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): + raise ValueError( + 'num_units and cell_type must not be specified when using rnn_cell_fn' + ) + if not rnn_cell_fn: + if cell_type == USE_DEFAULT: + cell_type = 'basic_rnn' + rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + return rnn_cell_fn + + class RNNClassifier(estimator.Estimator): """A classifier for TensorFlow RNN models. @@ -341,8 +355,8 @@ class RNNClassifier(estimator.Estimator): token_emb = embedding_column(categorical_column=token_sequence, ...) estimator = RNNClassifier( - num_units=[32, 16], cell_type='lstm', - sequence_feature_columns=[token_emb]) + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') # Input builders def input_fn_train: # returns x, y @@ -438,8 +452,8 @@ class RNNClassifier(estimator.Estimator): encoded as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there will be errors if vocabulary is not provided and labels are string. - optimizer: An instance of `tf.Optimizer` used to train the model. Defaults - to Adagrad optimizer. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. input_layer_partitioner: Optional. Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. @@ -448,14 +462,7 @@ class RNNClassifier(estimator.Estimator): ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not compatible. """ - if rnn_cell_fn and (num_units or cell_type != USE_DEFAULT): - raise ValueError( - 'num_units and cell_type must not be specified when using rnn_cell_fn' - ) - if not rnn_cell_fn: - if cell_type == USE_DEFAULT: - cell_type = 'basic_rnn' - rnn_cell_fn = _make_rnn_cell_fn(num_units, cell_type) + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) if n_classes == 2: head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access @@ -479,3 +486,137 @@ class RNNClassifier(estimator.Estimator): config=config) super(RNNClassifier, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) + + +class RNNEstimator(estimator.Estimator): + """An Estimator for TensorFlow RNN models with user-specified head. + + Example: + + ```python + token_sequence = sequence_categorical_column_with_hash_bucket(...) + token_emb = embedding_column(categorical_column=token_sequence, ...) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + num_units=[32, 16], cell_type='lstm') + + # Or with custom RNN cell: + def rnn_cell_fn(mode): + cells = [ tf.contrib.rnn.LSTMCell(size) for size in [32, 16] ] + if mode == tf.estimator.ModeKeys.TRAIN: + cells = [ tf.contrib.rnn.DropoutWrapper(cell, input_keep_prob=0.5) + for cell in cells ] + return tf.contrib.rnn.MultiRNNCell(cells) + + estimator = RNNEstimator( + head=tf.contrib.estimator.regression_head(), + sequence_feature_columns=[token_emb], + rnn_cell_fn=rnn_cell_fn) + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if the head's `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `sequence_feature_columns`: + - a feature with `key=column.name` whose `value` is a `SparseTensor`. + * for each `column` in `context_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + sequence_feature_columns, + context_feature_columns=None, + num_units=None, + cell_type=USE_DEFAULT, + rnn_cell_fn=None, + model_dir=None, + optimizer='Adagrad', + input_layer_partitioner=None, + config=None): + """Initializes a `RNNClassifier` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. This specifies the model's + output and loss function to be optimized. + sequence_feature_columns: An iterable containing the `FeatureColumn`s + that represent sequential input. All items in the set should either be + sequence columns (e.g. `sequence_numeric_column`) or constructed from + one (e.g. `embedding_column` with `sequence_categorical_column_*` as + input). + context_feature_columns: An iterable containing the `FeatureColumn`s + for contextual input. The data represented by these columns will be + replicated and given to the RNN at each timestep. These columns must be + instances of classes derived from `_DenseColumn` such as + `numeric_column`, not the sequential variants. + num_units: Iterable of integer number of hidden units per RNN layer. If + set, `cell_type` must also be specified and `rnn_cell_fn` must be + `None`. + cell_type: A subclass of `tf.nn.rnn_cell.RNNCell` or a string specifying + the cell type. Supported strings are: `'basic_rnn'`, `'lstm'`, and + `'gru'`. If set, `num_units` must also be specified and `rnn_cell_fn` + must be `None`. + rnn_cell_fn: A function with one argument, a `tf.estimator.ModeKeys`, and + returns an object of type `tf.nn.rnn_cell.RNNCell` that will be used to + construct the RNN. If set, `num_units` and `cell_type` cannot be set. + This is for advanced users who need additional customization beyond + `num_units` and `cell_type`. Note that `tf.nn.rnn_cell.MultiRNNCell` is + needed for stacked RNNs. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` or string specifying optimizer + type. Defaults to Adagrad optimizer. + input_layer_partitioner: Optional. Partitioner for input layer. Defaults + to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If `num_units`, `cell_type`, and `rnn_cell_fn` are not + compatible. + """ + rnn_cell_fn = _assert_rnn_cell_fn(rnn_cell_fn, num_units, cell_type) + + def _model_fn(features, labels, mode, config): + return _rnn_model_fn( + features=features, + labels=labels, + mode=mode, + head=head, + rnn_cell_fn=rnn_cell_fn, + sequence_feature_columns=tuple(sequence_feature_columns or []), + context_feature_columns=tuple(context_feature_columns or []), + optimizer=optimizer, + input_layer_partitioner=input_layer_partitioner, + config=config) + super(RNNEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py index 393f94f5c7d..959b40371aa 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py @@ -25,12 +25,15 @@ import tempfile import numpy as np import six +from tensorflow.contrib.data.python.ops import readers +from tensorflow.contrib.estimator.python.estimator import head as head_lib from tensorflow.contrib.estimator.python.estimator import rnn from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import metric_keys +from tensorflow.python.estimator.canned import parsing_utils from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export from tensorflow.python.estimator.inputs import numpy_io @@ -38,9 +41,9 @@ from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import python_io from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import state_ops @@ -50,7 +53,6 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_utils -from tensorflow.python.training import input as input_lib from tensorflow.python.training import monitored_session from tensorflow.python.training import optimizer from tensorflow.python.training import training_util @@ -984,7 +986,10 @@ class RNNClassifierPredictionTest(test.TestCase): predictions[prediction_keys.PredictionKeys.CLASSES]) -class RNNClassifierIntegrationTest(test.TestCase): +class BaseRNNClassificationIntegrationTest(object): + + def __init__(self, _create_estimator_fn): + self._create_estimator_fn = _create_estimator_fn def setUp(self): self._model_dir = tempfile.mkdtemp() @@ -994,20 +999,11 @@ class RNNClassifierIntegrationTest(test.TestCase): writer_cache.FileWriterCache.clear() shutil.rmtree(self._model_dir) - def _test_complete_flow( - self, train_input_fn, eval_input_fn, predict_input_fn, n_classes, - batch_size): - col = seq_fc.sequence_categorical_column_with_hash_bucket( - 'tokens', hash_bucket_size=10) - embed = fc.embedding_column(col, dimension=2) - feature_columns = [embed] - + def _test_complete_flow(self, feature_columns, train_input_fn, eval_input_fn, + predict_input_fn, n_classes, batch_size): cell_units = [4, 2] - est = rnn.RNNClassifier( - num_units=cell_units, - sequence_feature_columns=feature_columns, - n_classes=n_classes, - model_dir=self._model_dir) + est = self._create_estimator_fn(feature_columns, n_classes, cell_units, + self._model_dir) # TRAIN num_steps = 10 @@ -1026,10 +1022,10 @@ class RNNClassifierIntegrationTest(test.TestCase): self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) # EXPORT - feature_spec = { - 'tokens': parsing_ops.VarLenFeature(dtypes.string), - 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), - } + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( feature_spec) export_dir = est.export_savedmodel(tempfile.mkdtemp(), @@ -1069,7 +1065,13 @@ class RNNClassifierIntegrationTest(test.TestCase): batch_size=batch_size, shuffle=False) + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + self._test_complete_flow( + feature_columns=feature_columns, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, predict_input_fn=predict_input_fn, @@ -1082,7 +1084,8 @@ class RNNClassifierIntegrationTest(test.TestCase): batch_size = 10 words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept'] - serialized_examples = [] + _, examples_file = tempfile.mkstemp() + writer = python_io.TFRecordWriter(examples_file) for _ in range(batch_size): sequence_length = random.randint(1, len(words)) sentence = random.sample(words, sequence_length) @@ -1096,30 +1099,36 @@ class RNNClassifierIntegrationTest(test.TestCase): feature_pb2.Feature(int64_list=feature_pb2.Int64List( value=[label])), })) - serialized_examples.append(example.SerializeToString()) + writer.write(example.SerializeToString()) + writer.close() + + col = seq_fc.sequence_categorical_column_with_hash_bucket( + 'tokens', hash_bucket_size=10) + embed = fc.embedding_column(col, dimension=2) + feature_columns = [embed] + feature_spec = parsing_utils.classifier_parse_example_spec( + feature_columns, + label_key='label', + label_dtype=dtypes.int64) - feature_spec = { - 'tokens': parsing_ops.VarLenFeature(dtypes.string), - 'label': parsing_ops.FixedLenFeature([1], dtypes.int64), - } def _train_input_fn(): - features = parsing_ops.parse_example(serialized_examples, feature_spec) - labels = features.pop('label') - return features, labels + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec) + return dataset.map(lambda features: (features, features.pop('label'))) def _eval_input_fn(): - features = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - labels = features.pop('label') - return features, labels + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + return dataset.map(lambda features: (features, features.pop('label'))) def _predict_input_fn(): - features = parsing_ops.parse_example( - input_lib.limit_epochs(serialized_examples, num_epochs=1), - feature_spec) - features.pop('label') - return features, None + dataset = readers.make_batched_features_dataset( + examples_file, batch_size, feature_spec, num_epochs=1) + def features_fn(features): + features.pop('label') + return features + return dataset.map(features_fn) self._test_complete_flow( + feature_columns=feature_columns, train_input_fn=_train_input_fn, eval_input_fn=_eval_input_fn, predict_input_fn=_predict_input_fn, @@ -1127,5 +1136,37 @@ class RNNClassifierIntegrationTest(test.TestCase): batch_size=batch_size) +def _rnn_classifier_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNClassifier( + num_units=cell_units, + sequence_feature_columns=feature_columns, + n_classes=n_classes, + model_dir=model_dir) + + +class RNNClassifierIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_classifier_fn) + + +def _rnn_estimator_fn(feature_columns, n_classes, cell_units, model_dir): + return rnn.RNNEstimator( + head=head_lib.multi_class_head(n_classes=n_classes), + num_units=cell_units, + sequence_feature_columns=feature_columns, + model_dir=model_dir) + + +class RNNEstimatorIntegrationTest(BaseRNNClassificationIntegrationTest, + test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + BaseRNNClassificationIntegrationTest.__init__(self, _rnn_estimator_fn) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 0a648d5d40e..effec42f028 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -215,6 +215,7 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:sparse_tensor", ], + shard_count = 4, ) # Estimators tests diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index 2a6c97e8b95..025534d540b 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -32,6 +32,7 @@ #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 811fa89bc38..5cef4068ed1 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -107,7 +107,7 @@ class WALSModel(object): # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init - # To be run once per integration sweep before the row(column) update + # To be run once per iteration sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. @@ -436,7 +436,7 @@ class WALSModel(object): gramian: Variable storing the gramian calculated from the factors. Returns: - A op that updates the gramian with the calculated value from the factors. + An op that updates the gramian with the calculated value from the factors. """ partial_gramians = [] for f in factors: diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index 5d77bc77e12..e076631bc16 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -54,10 +54,10 @@ def _covariance(x, diag): diagonal matrix just the diagonal is returned. """ num_points = math_ops.to_float(array_ops.shape(x)[0]) - x -= math_ops.reduce_mean(x, 0, keep_dims=True) + x -= math_ops.reduce_mean(x, 0, keepdims=True) if diag: cov = math_ops.reduce_sum( - math_ops.square(x), 0, keep_dims=True) / (num_points - 1) + math_ops.square(x), 0, keepdims=True) / (num_points - 1) else: cov = math_ops.matmul(x, x, transpose_a=True) / (num_points - 1) return cov @@ -313,7 +313,7 @@ class GmmAlgorithm(object): # TODO(xavigonzalvo): look into alternatives to log for # reparametrization of variance parameters. det_expanded = math_ops.reduce_sum( - math_ops.log(self._covs + 1e-3), 1, keep_dims=True) + math_ops.log(self._covs + 1e-3), 1, keepdims=True) diff = shard - self._means x2 = math_ops.square(diff) cov_expanded = array_ops.expand_dims(1.0 / (self._covs + 1e-3), 2) @@ -351,7 +351,7 @@ class GmmAlgorithm(object): shard_id: id of current shard_id. """ self._prior_probs[shard_id] = math_ops.reduce_logsumexp( - self._probs[shard_id], axis=1, keep_dims=True) + self._probs[shard_id], axis=1, keepdims=True) def _define_expectation_operation(self, shard_id): # Shape broadcasting. @@ -375,7 +375,7 @@ class GmmAlgorithm(object): """ # Soft assignment of each data point to each of the two clusters. self._points_in_k[shard_id] = math_ops.reduce_sum( - self._w[shard_id], 0, keep_dims=True) + self._w[shard_id], 0, keepdims=True) # Partial means. w_mul_x = array_ops.expand_dims( math_ops.matmul( @@ -397,7 +397,7 @@ class GmmAlgorithm(object): # Compute the effective number of data points assigned to component k. with ops.control_dependencies(self._w): points_in_k = array_ops.squeeze( - math_ops.add_n(self._points_in_k), squeeze_dims=[0]) + math_ops.add_n(self._points_in_k), axis=[0]) # Update alpha. if 'w' in self._params: final_points_in_k = points_in_k / num_batches @@ -454,7 +454,7 @@ class GmmAlgorithm(object): for shard_id, prior_probs in enumerate(self._prior_probs): op.append(prior_probs + math_ops.log(self._w[shard_id])) self._scores = array_ops.squeeze( - math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0) + math_ops.reduce_logsumexp(op, axis=2, keepdims=True), axis=0) def gmm(inp, diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index bfe338c9f9a..9ffdd3ba5e8 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -374,11 +374,11 @@ class KMeansClustering(estimator.Estimator): than `num_clusters`, a TensorFlow runtime error occurs. distance_metric: The distance metric used for clustering. One of: * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance - between vectors `u` and `v` is defined as `\\(||u - v||_2\\)` + between vectors `u` and `v` is defined as \\(||u - v||_2\\) which is the square root of the sum of the absolute squares of the elements' difference. * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors - `u` and `v` is defined as `\\(1 - (u . v) / (||u||_2 ||v||_2)\\)`. + `u` and `v` is defined as \\(1 - (u . v) / (||u||_2 ||v||_2)\\). random_seed: Python integer. Seed for PRNG used to initialize centers. use_mini_batch: A boolean specifying whether to use the mini-batch k-means algorithm. See explanation above. diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 35341406a08..cca1a054193 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -28,7 +28,7 @@ #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" using tensorflow::strings::StrCat; diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index a8d5a0dd83f..bf2aa755458 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -53,7 +53,7 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, int32 samples_per_second, int32 channel_count, const std::vector& samples, string* output_data); -// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8 +// Reads an video file using ffmpeg and converts it into a RGB24 in uint8 // [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg. Status ReadVideoFile(const string& filename, std::vector* output_data, uint32* width, uint32* height, uint32* frames); diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index b1c8ad49eaf..249debbdf6d 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -93,6 +93,7 @@ tf_kernel_library( ], deps = [ "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", ], alwayslink = 1, @@ -177,6 +178,8 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index a52907f163f..10d1ecc738d 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -72,7 +72,9 @@ See the @{$python/contrib.framework} guide. @@variable @@VariableDeviceChooser @@convolutional_delta_orthogonal +@@convolutional_orthogonal_1d @@convolutional_orthogonal_2d +@@convolutional_orthogonal_3d @@zero_initializer @@load_checkpoint @@ -106,6 +108,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.framework.python.framework import * +from tensorflow.contrib.framework.python.framework import nest from tensorflow.contrib.framework.python.ops import * # pylint: enable=unused-import,wildcard-import @@ -116,10 +119,28 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec +from tensorflow.python.ops.array_ops import broadcast_to from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal +from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d +from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest'] +_allowed_symbols = ['nest', 'broadcast_to'] +_nest_allowed_symbols = [ + 'assert_same_structure', + 'is_sequence', + 'flatten', + 'flatten_dict_items', + 'pack_sequence_as', + 'map_structure', + 'assert_shallow_structure', + 'flatten_up_to', + 'map_structure_up_to', + 'get_traverse_shallow_structure', + 'yield_flat_paths', + 'flatten_with_joined_string_paths', +] +remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols) remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc index 5bf6b675295..6ab3f460b36 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_var.h" namespace tensorflow { @@ -85,4 +86,74 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_KERNELS +template +class ZeroVarInitializer : public OpKernel { + public: + explicit ZeroVarInitializer(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_)); + } + + void Compute(OpKernelContext* ctx) override { + Var* variable = nullptr; + OP_REQUIRES_OK(ctx, LookupOrCreateResource( + ctx, HandleFromInput(ctx, 0), &variable, + [this, ctx](Var** var_ptr) { + *var_ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* var_tensor = nullptr; + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + dtype_, shape_, &unused, &var_tensor, attr)); + + functor::TensorSetZero()( + ctx->eigen_device(), + var_tensor->flat()); + + *(*var_ptr)->tensor() = *var_tensor; + + return Status::OK(); + })); + + core::ScopedUnref scoped(variable); + mutex_lock ml(*variable->mu()); + + OP_REQUIRES(ctx, !variable->is_initialized, + errors::InvalidArgument("input is already initialized")); + + variable->is_initialized = true; + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = HandleFromInput(ctx, 0); + } + + private: + DataType dtype_; + TensorShape shape_; +}; + +#define REGISTER_CPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ZeroVarInitializer); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); +#undef REGISTER_CPU_KERNELS + +#if GOOGLE_CUDA +#define REGISTER_GPU_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("dtype") \ + .HostMemory("var"), \ + ZeroVarInitializer); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +#undef REGISTER_GPU_KERNELS +#endif // GOOGLE_CUDA + } // namespace tensorflow diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc index 706134ba9a5..f6ee6cdb571 100644 --- a/tensorflow/contrib/framework/ops/variable_ops.cc +++ b/tensorflow/contrib/framework/ops/variable_ops.cc @@ -39,4 +39,33 @@ ref: Should be from a `Variable` node. output_ref:= Same as "ref". )doc"); +REGISTER_OP("ZeroVarInitializer") + .Input("var: resource") + .Output("output_var: resource") + .Attr("dtype: type") + .Attr("shape: shape") + .SetAllowsUninitializedInput() + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Scalar()); + DataType t; + TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); + PartialTensorShape p; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &p)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); + c->set_output_handle_shapes_and_types( + 0, std::vector{{s, t}}); + + return Status::OK(); + }) + .Doc(R"doc( +Initialize 'var' with all zeros. This op requires that the resource var is not +initialized. The var will first be allocated memory, then be filled with all +zeros. This op is intended to save memory during initialization, +if you use this op, you should not run initializer of the var. + +var: Should be a ResourceVariable. +output_var:= Same as "var". +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index a2834b64893..af1b404cb51 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -48,7 +48,7 @@ class LocalVariabletest(test.TestCase): variables = variables_lib.local_variables() self.assertEquals(2, len(variables)) self.assertRaises(errors_impl.OpError, sess.run, variables) - variables_lib.initialize_variables(variables).run() + variables_lib.variables_initializer(variables).run() self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) @@ -78,7 +78,6 @@ class AssertScalarIntTest(test.TestCase): [3, 4], dtype=dtypes.int32)) -@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -216,25 +215,18 @@ class WithShapeTest(test.TestCase): tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: - if ops._USE_C_API: - error_message = "Shapes must be equal rank, but are 2 and 1" - else: - error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, "Shapes must be equal rank, but are 2 and 1", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: - if ops._USE_C_API: - error_message = (r"Dimension 1 in both shapes must be equal, but are " - r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") - else: - error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, error_message, + ValueError, + r"Dimension 1 in both shapes must be equal, but are 2 and 1. " + r"Shapes are \[\?,2\] and \[2,1\].", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index bd764ed57a6..72835c3ad86 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -202,7 +202,7 @@ class CriticalSection(object): or lazy way that may cause a deadlock. ValueError: If `exclusive_resource_access` is not provided (is `True`) and another `CriticalSection` has an execution requesting the same - resources as in `*args`, `**kwargs`, and any additionaly captured + resources as in `*args`, `**kwargs`, and any additionally captured inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, if another execution in another `CriticalSection` was created without `exclusive_resource_access=True`, a `ValueError` will be raised. diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index ba660295cb3..df7d7e9dae8 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import critical_section_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -330,6 +332,25 @@ class CriticalSectionTest(test.TestCase): self.evaluate(v.initializer) self.assertEqual(10, self.evaluate(out)) + @test_util.run_in_graph_and_eager_modes() + def testInsideFunction(self): + cs = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(1) + def fn(): + return v.read_value() + + # map() creates a TensorFlow function. + ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn)) + + def get_first(): + if context.executing_eagerly(): + return self.evaluate(ds.make_one_shot_iterator().get_next()) + itr = ds.make_initializable_iterator() + self.evaluate([v.initializer, itr.initializer]) + return self.evaluate(itr.get_next()) + + self.assertEqual(1, get_first()) + # TODO(ebrevdo): Re-enable once CriticalSection is in core. # # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 0754c3e0e30..40ae01bfcce 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import resource_loader from tensorflow.python.platform import tf_logging as logging @@ -82,7 +83,12 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"): """ loader.load_op_library( resource_loader.get_path_to_datafile("_variable_ops.so")) - return gen_variable_ops.zero_initializer(ref, name=name) + if resource_variable_ops.is_resource_variable(ref): + return gen_variable_ops.zero_var_initializer( + ref.handle, shape=ref.shape, dtype=ref.dtype, name=name) + else: + return gen_variable_ops.zero_initializer(ref, name=name) + @deprecated(None, "Please switch to tf.train.assert_global_step") def assert_global_step(global_step_tensor): diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 2f06df93acb..37ea6eb12ab 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -1284,6 +1284,32 @@ class ZeroInitializerOpTest(test.TestCase): [10, 20], dtype=dtype), use_init) +class ZeroVarInitializerOpTest(test.TestCase): + + def _testZeroVarInitializer(self, shape, initializer, use_init): + var = resource_variable_ops.ResourceVariable(initializer) + var_zero = variables_lib2.zero_initializer(var) + + with self.test_session() as sess: + with self.assertRaisesOpError('Error while reading resource variable'): + var.eval() + if use_init: + sess.run(var.initializer) + with self.assertRaisesOpError('input is already initialized'): + var_zero.eval() + self.assertAllClose(np.ones(shape), var.eval()) + else: + var_zero.eval() + self.assertAllClose(np.zeros(shape), var.eval()) + + def testZeroVarInitializer(self): + for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64): + for use_init in (False, True): + self._testZeroVarInitializer([10, 20], + array_ops.ones([10, 20], dtype=dtype), + use_init) + + class FilterVariablesTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 0e06575d96f..2458f7554af 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -247,7 +247,7 @@ class FusedConv2DBiasActivationOp : public OpKernel { }; #if GOOGLE_CUDA -namespace dnn = ::perftools::gputools::dnn; +namespace dnn = se::dnn; // A dummy type to group forward convolution autotune results together. struct ConvBiasActivationAutoTuneGroup { @@ -543,7 +543,8 @@ void LaunchFusedConv2DBiasActivationOp:: fused_conv_parameters, &algorithm_config)) { std::vector algorithms; CHECK(stream->parent()->GetConvolveAlgorithms( - fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), + fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo( + stream->parent()), &algorithms)); dnn::ProfileResult best_result; dnn::ProfileResult best_result_no_scratch; diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e3fc6bf0f03..4092b320042 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator): generator_optimizer=None, discriminator_optimizer=None, get_hooks_fn=None, + get_eval_metric_ops_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator): list of hooks. These hooks are run on the generator and discriminator train ops, and can be used to implement the GAN training scheme. Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn) + use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 387a62bd741..955482599b3 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) + def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( @@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase): discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) # TRAIN @@ -213,6 +221,9 @@ class GANEstimatorIntegrationTest(test.TestCase): scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 4750f94d9a6..513e451e7bf 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,17 +25,21 @@ from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib __all__ = [ 'GANHead', 'gan_head', ] +def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - name=None): + get_eval_metric_ops_fn=None, name=None): """Creates a `GANHead`. Args: @@ -47,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + If `None`, uses defaults. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. @@ -62,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn, name=name) @@ -72,6 +80,7 @@ class GANHead(head._Head): # pylint: disable=protected-access generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, + get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. @@ -85,8 +94,11 @@ class GANHead(head._Head): # pylint: disable=protected-access discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. Defaults to `train.get_sequential_train_hooks()` + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. Defaults to `train.get_sequential_train_hooks()` + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ @@ -115,6 +127,8 @@ class GANHead(head._Head): # pylint: disable=protected-access self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn + self._get_eval_metric_ops_fn = get_eval_metric_ops_fn + self._name = name @property def name(self): @@ -184,13 +198,26 @@ class GANHead(head._Head): # pylint: disable=protected-access gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + eval_metric_ops = { + _summary_key(self._name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(self._name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if self._get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, - # TODO(joelshor): Add metrics. If head name provided, append it to - # metric keys. - eval_metric_ops={}) + eval_metric_ops=eval_metric_ops) elif mode == model_fn_lib.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None.') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8168f005cd1..6587f1fc600 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase): generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + get_eval_metric_ops_fn=self.get_metrics) self.assertTrue(isinstance(self.gan_head, head.GANHead)) + def get_metrics(self, gan_model): + self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) + return {} + def _test_modes_helper(self, mode): self.gan_head.create_estimator_spec( features=None, diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 47e51415fd9..d914f549457 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -488,25 +488,25 @@ def frechet_classifier_distance(real_images, The Frechet Inception distance. A floating-point scalar of the same type as the output of `classifier_fn`. """ - real_images_list = array_ops.split( real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split( generated_images, num_or_size_splits=num_batches) - imgs = array_ops.stack(real_images_list + generated_images_list) + real_imgs = array_ops.stack(real_images_list) + generated_imgs = array_ops.stack(generated_images_list) # Compute the activations using the memory-efficient `map_fn`. - activations = functional_ops.map_fn( - fn=classifier_fn, - elems=imgs, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') + def compute_activations(elems): + return functional_ops.map_fn(fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') - # Split the activations by the real and generated images. - real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) + real_a = compute_activations(real_imgs) + gen_a = compute_activations(generated_imgs) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) @@ -697,18 +697,20 @@ def frechet_classifier_distance_from_activations(real_activations, # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) - num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) + num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0]) + num_examples_generated = math_ops.to_double( + array_ops.shape(generated_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( real_centered, real_centered, transpose_a=True) / ( - num_examples - 1) + num_examples_real - 1) gen_centered = generated_activations - m_w sigma_w = math_ops.matmul( gen_centered, gen_centered, transpose_a=True) / ( - num_examples - 1) + num_examples_generated - 1) # Find the Tr(sqrt(sigma sigma_w)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 4b10bc0f8e6..4b1105f6bd4 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): proj = random_ops.random_normal( [array_ops.shape(a)[1], random_projection_dim]) proj *= math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True)) # Project both distributions and sort them. proj_a = math_ops.matmul(a, proj) proj_b = math_ops.matmul(b, proj) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py index df71187fbd9..a9b8faa7126 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples.""" +"""Miscellaneous utilities for TFGAN code and examples.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f8b372546b6..650eab97a39 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -64,11 +64,11 @@ def _statistics(x, axes): y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True)) + shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True) + shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True) + mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) mean = array_ops.squeeze(mean, axes) mean_squared = array_ops.squeeze(mean_squared, axes) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 2889e937436..9f5fee45422 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -570,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul' + self._expected_op_name = 'mutual_information_loss/mul_1' self._batch_size = 2 diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 73acd05b60a..6fa43059f31 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -710,7 +710,10 @@ def gan_train_ops( be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): - saved_params = locals() + # Get and store all arguments other than model and loss from locals. + # Contents of locals should not be modified, may not affect values. So make + # a copy. https://docs.python.org/2/library/functions.html#locals. + saved_params = dict(locals()) saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index 28f68cec8cc..94f522c04e5 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -155,7 +155,7 @@ class GdrRemoteRendezvous : public BaseRemoteRendezvous { } Device* dst_device; - Status s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device); + Status s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device); if (!s.ok()) { sess->worker_cache->ReleaseWorker(src_worker, rwi); done(s, Args(), recv_args, Tensor{}, false); diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index a320a3f232f..592d37b432e 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -677,7 +677,7 @@ def copy_with_input_replacements(sgv, replacement_ts, def _add_control_flow_ops(ops, control_ios): - """Complete `ops` so that the tranformed graph is valid. + """Complete `ops` so that the transformed graph is valid. Partially copying a graph can lead to a malformed graph. For instance, copying half of a while construct is likely to result in an invalid graph. diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 60281951dda..66939fbb0f0 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -115,7 +115,7 @@ static void CheckOpsSupport(const GraphDef& graph_def, HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; LOG(INFO) << "dump_all_nodes = " << dump_all_nodes - << ", dump_shape_and_tpye = " << dump_shape_and_type; + << ", dump_shape_and_type = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index e982030bc89..f230d93da4a 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -17,7 +17,7 @@ ### API This module provides functions for image manipulation; currently, chrominance -transformas (including changing saturation and hue) in YIQ space and +transforms (including changing saturation and hue) in YIQ space and projective transforms (including rotation) are supported. ## Image Transformation `Ops` @@ -25,6 +25,8 @@ projective transforms (including rotation) are supported. @@angles_to_projective_transforms @@compose_transforms @@adjust_yiq_hsv +@@flat_transforms_to_matrices +@@matrices_to_flat_transforms @@random_yiq_hsv @@rotate @@transform @@ -58,6 +60,8 @@ from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_ from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms from tensorflow.contrib.image.python.ops.image_ops import compose_transforms from tensorflow.contrib.image.python.ops.image_ops import connected_components +from tensorflow.contrib.image.python.ops.image_ops import flat_transforms_to_matrices +from tensorflow.contrib.image.python.ops.image_ops import matrices_to_flat_transforms from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc index 645abbf0b0e..bbb3a3b18fd 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -59,7 +59,7 @@ void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), tranformation_matrix.flat().size()); // Call cuBlas C = A * B directly. - auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto no_transpose = se::blas::Transpose::kNoTranspose; auto a_ptr = AsDeviceMemory(input->flat().data(), input->flat().size()); auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), diff --git a/tensorflow/contrib/keras/api/keras/activations/__init__.py b/tensorflow/contrib/keras/api/keras/activations/__init__.py index d04838c218d..3f0184276f6 100644 --- a/tensorflow/contrib/keras/api/keras/activations/__init__.py +++ b/tensorflow/contrib/keras/api/keras/activations/__init__.py @@ -19,22 +19,22 @@ from __future__ import division from __future__ import print_function # Activation functions. -from tensorflow.python.keras._impl.keras.activations import elu -from tensorflow.python.keras._impl.keras.activations import hard_sigmoid -from tensorflow.python.keras._impl.keras.activations import linear -from tensorflow.python.keras._impl.keras.activations import relu -from tensorflow.python.keras._impl.keras.activations import selu -from tensorflow.python.keras._impl.keras.activations import sigmoid -from tensorflow.python.keras._impl.keras.activations import softmax -from tensorflow.python.keras._impl.keras.activations import softplus -from tensorflow.python.keras._impl.keras.activations import softsign -from tensorflow.python.keras._impl.keras.activations import tanh +from tensorflow.python.keras.activations import elu +from tensorflow.python.keras.activations import hard_sigmoid +from tensorflow.python.keras.activations import linear +from tensorflow.python.keras.activations import relu +from tensorflow.python.keras.activations import selu +from tensorflow.python.keras.activations import sigmoid +from tensorflow.python.keras.activations import softmax +from tensorflow.python.keras.activations import softplus +from tensorflow.python.keras.activations import softsign +from tensorflow.python.keras.activations import tanh # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.activations import deserialize -from tensorflow.python.keras._impl.keras.activations import serialize -from tensorflow.python.keras._impl.keras.activations import get +from tensorflow.python.keras.activations import deserialize +from tensorflow.python.keras.activations import serialize +from tensorflow.python.keras.activations import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py index abf8393ae45..6dfb5cab17c 100644 --- a/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/inception_v3/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.inception_v3 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.inception_v3 import InceptionV3 -from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input +from tensorflow.python.keras.applications.inception_v3 import decode_predictions +from tensorflow.python.keras.applications.inception_v3 import InceptionV3 +from tensorflow.python.keras.applications.inception_v3 import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py index b809e91193b..67306cc51e1 100644 --- a/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/mobilenet/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.mobilenet import decode_predictions -from tensorflow.python.keras._impl.keras.applications.mobilenet import MobileNet -from tensorflow.python.keras._impl.keras.applications.mobilenet import preprocess_input +from tensorflow.python.keras.applications.mobilenet import decode_predictions +from tensorflow.python.keras.applications.mobilenet import MobileNet +from tensorflow.python.keras.applications.mobilenet import preprocess_input del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py index 530805d150b..a25ff48b593 100644 --- a/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/resnet50/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.resnet50 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.resnet50 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.resnet50 import ResNet50 +from tensorflow.python.keras.applications.resnet50 import decode_predictions +from tensorflow.python.keras.applications.resnet50 import preprocess_input +from tensorflow.python.keras.applications.resnet50 import ResNet50 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py index 118361604bb..4964b1b7deb 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg16/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg16 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg16 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg16 import VGG16 +from tensorflow.python.keras.applications.vgg16 import decode_predictions +from tensorflow.python.keras.applications.vgg16 import preprocess_input +from tensorflow.python.keras.applications.vgg16 import VGG16 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py index cda52628f3c..afb3abebdd6 100644 --- a/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/vgg19/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.vgg19 import decode_predictions -from tensorflow.python.keras._impl.keras.applications.vgg19 import preprocess_input -from tensorflow.python.keras._impl.keras.applications.vgg19 import VGG19 +from tensorflow.python.keras.applications.vgg19 import decode_predictions +from tensorflow.python.keras.applications.vgg19 import preprocess_input +from tensorflow.python.keras.applications.vgg19 import VGG19 del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py index ae9cd9cd18c..2e3335d02af 100644 --- a/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py +++ b/tensorflow/contrib/keras/api/keras/applications/xception/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.applications.xception import decode_predictions -from tensorflow.python.keras._impl.keras.applications.xception import preprocess_input -from tensorflow.python.keras._impl.keras.applications.xception import Xception +from tensorflow.python.keras.applications.xception import decode_predictions +from tensorflow.python.keras.applications.xception import preprocess_input +from tensorflow.python.keras.applications.xception import Xception del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/backend/__init__.py b/tensorflow/contrib/keras/api/keras/backend/__init__.py index 10ef5a75852..a7553640142 100644 --- a/tensorflow/contrib/keras/api/keras/backend/__init__.py +++ b/tensorflow/contrib/keras/api/keras/backend/__init__.py @@ -19,144 +19,144 @@ from __future__ import division from __future__ import print_function # pylint: disable=redefined-builtin -from tensorflow.python.keras._impl.keras.backend import abs -from tensorflow.python.keras._impl.keras.backend import all -from tensorflow.python.keras._impl.keras.backend import any -from tensorflow.python.keras._impl.keras.backend import arange -from tensorflow.python.keras._impl.keras.backend import argmax -from tensorflow.python.keras._impl.keras.backend import argmin -from tensorflow.python.keras._impl.keras.backend import backend -from tensorflow.python.keras._impl.keras.backend import batch_dot -from tensorflow.python.keras._impl.keras.backend import batch_flatten -from tensorflow.python.keras._impl.keras.backend import batch_get_value -from tensorflow.python.keras._impl.keras.backend import batch_normalization -from tensorflow.python.keras._impl.keras.backend import batch_set_value -from tensorflow.python.keras._impl.keras.backend import bias_add -from tensorflow.python.keras._impl.keras.backend import binary_crossentropy -from tensorflow.python.keras._impl.keras.backend import cast -from tensorflow.python.keras._impl.keras.backend import cast_to_floatx -from tensorflow.python.keras._impl.keras.backend import categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import clear_session -from tensorflow.python.keras._impl.keras.backend import clip -from tensorflow.python.keras._impl.keras.backend import concatenate -from tensorflow.python.keras._impl.keras.backend import constant -from tensorflow.python.keras._impl.keras.backend import conv1d -from tensorflow.python.keras._impl.keras.backend import conv2d -from tensorflow.python.keras._impl.keras.backend import conv2d_transpose -from tensorflow.python.keras._impl.keras.backend import conv3d -from tensorflow.python.keras._impl.keras.backend import cos -from tensorflow.python.keras._impl.keras.backend import count_params -from tensorflow.python.keras._impl.keras.backend import ctc_batch_cost -from tensorflow.python.keras._impl.keras.backend import ctc_decode -from tensorflow.python.keras._impl.keras.backend import ctc_label_dense_to_sparse -from tensorflow.python.keras._impl.keras.backend import dot -from tensorflow.python.keras._impl.keras.backend import dropout -from tensorflow.python.keras._impl.keras.backend import dtype -from tensorflow.python.keras._impl.keras.backend import elu -from tensorflow.python.keras._impl.keras.backend import epsilon -from tensorflow.python.keras._impl.keras.backend import equal -from tensorflow.python.keras._impl.keras.backend import eval -from tensorflow.python.keras._impl.keras.backend import exp -from tensorflow.python.keras._impl.keras.backend import expand_dims -from tensorflow.python.keras._impl.keras.backend import eye -from tensorflow.python.keras._impl.keras.backend import flatten -from tensorflow.python.keras._impl.keras.backend import floatx -from tensorflow.python.keras._impl.keras.backend import foldl -from tensorflow.python.keras._impl.keras.backend import foldr -from tensorflow.python.keras._impl.keras.backend import function -from tensorflow.python.keras._impl.keras.backend import gather -from tensorflow.python.keras._impl.keras.backend import get_session -from tensorflow.python.keras._impl.keras.backend import get_uid -from tensorflow.python.keras._impl.keras.backend import get_value -from tensorflow.python.keras._impl.keras.backend import gradients -from tensorflow.python.keras._impl.keras.backend import greater -from tensorflow.python.keras._impl.keras.backend import greater_equal -from tensorflow.python.keras._impl.keras.backend import hard_sigmoid -from tensorflow.python.keras._impl.keras.backend import image_data_format -from tensorflow.python.keras._impl.keras.backend import in_test_phase -from tensorflow.python.keras._impl.keras.backend import in_top_k -from tensorflow.python.keras._impl.keras.backend import in_train_phase -from tensorflow.python.keras._impl.keras.backend import int_shape -from tensorflow.python.keras._impl.keras.backend import is_sparse -from tensorflow.python.keras._impl.keras.backend import l2_normalize -from tensorflow.python.keras._impl.keras.backend import learning_phase -from tensorflow.python.keras._impl.keras.backend import less -from tensorflow.python.keras._impl.keras.backend import less_equal -from tensorflow.python.keras._impl.keras.backend import log -from tensorflow.python.keras._impl.keras.backend import manual_variable_initialization -from tensorflow.python.keras._impl.keras.backend import map_fn -from tensorflow.python.keras._impl.keras.backend import max -from tensorflow.python.keras._impl.keras.backend import maximum -from tensorflow.python.keras._impl.keras.backend import mean -from tensorflow.python.keras._impl.keras.backend import min -from tensorflow.python.keras._impl.keras.backend import minimum -from tensorflow.python.keras._impl.keras.backend import moving_average_update -from tensorflow.python.keras._impl.keras.backend import name_scope -from tensorflow.python.keras._impl.keras.backend import ndim -from tensorflow.python.keras._impl.keras.backend import normalize_batch_in_training -from tensorflow.python.keras._impl.keras.backend import not_equal -from tensorflow.python.keras._impl.keras.backend import one_hot -from tensorflow.python.keras._impl.keras.backend import ones -from tensorflow.python.keras._impl.keras.backend import ones_like -from tensorflow.python.keras._impl.keras.backend import permute_dimensions -from tensorflow.python.keras._impl.keras.backend import placeholder -from tensorflow.python.keras._impl.keras.backend import pool2d -from tensorflow.python.keras._impl.keras.backend import pool3d -from tensorflow.python.keras._impl.keras.backend import pow -from tensorflow.python.keras._impl.keras.backend import print_tensor -from tensorflow.python.keras._impl.keras.backend import prod -from tensorflow.python.keras._impl.keras.backend import random_binomial -from tensorflow.python.keras._impl.keras.backend import random_normal -from tensorflow.python.keras._impl.keras.backend import random_normal_variable -from tensorflow.python.keras._impl.keras.backend import random_uniform -from tensorflow.python.keras._impl.keras.backend import random_uniform_variable -from tensorflow.python.keras._impl.keras.backend import relu -from tensorflow.python.keras._impl.keras.backend import repeat -from tensorflow.python.keras._impl.keras.backend import repeat_elements -from tensorflow.python.keras._impl.keras.backend import reset_uids -from tensorflow.python.keras._impl.keras.backend import reshape -from tensorflow.python.keras._impl.keras.backend import resize_images -from tensorflow.python.keras._impl.keras.backend import resize_volumes -from tensorflow.python.keras._impl.keras.backend import reverse -from tensorflow.python.keras._impl.keras.backend import rnn -from tensorflow.python.keras._impl.keras.backend import round -from tensorflow.python.keras._impl.keras.backend import separable_conv2d -from tensorflow.python.keras._impl.keras.backend import set_epsilon -from tensorflow.python.keras._impl.keras.backend import set_floatx -from tensorflow.python.keras._impl.keras.backend import set_image_data_format -from tensorflow.python.keras._impl.keras.backend import set_learning_phase -from tensorflow.python.keras._impl.keras.backend import set_session -from tensorflow.python.keras._impl.keras.backend import set_value -from tensorflow.python.keras._impl.keras.backend import shape -from tensorflow.python.keras._impl.keras.backend import sigmoid -from tensorflow.python.keras._impl.keras.backend import sign -from tensorflow.python.keras._impl.keras.backend import sin -from tensorflow.python.keras._impl.keras.backend import softmax -from tensorflow.python.keras._impl.keras.backend import softplus -from tensorflow.python.keras._impl.keras.backend import softsign -from tensorflow.python.keras._impl.keras.backend import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.backend import spatial_2d_padding -from tensorflow.python.keras._impl.keras.backend import spatial_3d_padding -from tensorflow.python.keras._impl.keras.backend import sqrt -from tensorflow.python.keras._impl.keras.backend import square -from tensorflow.python.keras._impl.keras.backend import squeeze -from tensorflow.python.keras._impl.keras.backend import stack -from tensorflow.python.keras._impl.keras.backend import std -from tensorflow.python.keras._impl.keras.backend import stop_gradient -from tensorflow.python.keras._impl.keras.backend import sum -from tensorflow.python.keras._impl.keras.backend import switch -from tensorflow.python.keras._impl.keras.backend import tanh -from tensorflow.python.keras._impl.keras.backend import temporal_padding -from tensorflow.python.keras._impl.keras.backend import to_dense -from tensorflow.python.keras._impl.keras.backend import transpose -from tensorflow.python.keras._impl.keras.backend import truncated_normal -from tensorflow.python.keras._impl.keras.backend import update -from tensorflow.python.keras._impl.keras.backend import update_add -from tensorflow.python.keras._impl.keras.backend import update_sub -from tensorflow.python.keras._impl.keras.backend import var -from tensorflow.python.keras._impl.keras.backend import variable -from tensorflow.python.keras._impl.keras.backend import zeros -from tensorflow.python.keras._impl.keras.backend import zeros_like +from tensorflow.python.keras.backend import abs +from tensorflow.python.keras.backend import all +from tensorflow.python.keras.backend import any +from tensorflow.python.keras.backend import arange +from tensorflow.python.keras.backend import argmax +from tensorflow.python.keras.backend import argmin +from tensorflow.python.keras.backend import backend +from tensorflow.python.keras.backend import batch_dot +from tensorflow.python.keras.backend import batch_flatten +from tensorflow.python.keras.backend import batch_get_value +from tensorflow.python.keras.backend import batch_normalization +from tensorflow.python.keras.backend import batch_set_value +from tensorflow.python.keras.backend import bias_add +from tensorflow.python.keras.backend import binary_crossentropy +from tensorflow.python.keras.backend import cast +from tensorflow.python.keras.backend import cast_to_floatx +from tensorflow.python.keras.backend import categorical_crossentropy +from tensorflow.python.keras.backend import clear_session +from tensorflow.python.keras.backend import clip +from tensorflow.python.keras.backend import concatenate +from tensorflow.python.keras.backend import constant +from tensorflow.python.keras.backend import conv1d +from tensorflow.python.keras.backend import conv2d +from tensorflow.python.keras.backend import conv2d_transpose +from tensorflow.python.keras.backend import conv3d +from tensorflow.python.keras.backend import cos +from tensorflow.python.keras.backend import count_params +from tensorflow.python.keras.backend import ctc_batch_cost +from tensorflow.python.keras.backend import ctc_decode +from tensorflow.python.keras.backend import ctc_label_dense_to_sparse +from tensorflow.python.keras.backend import dot +from tensorflow.python.keras.backend import dropout +from tensorflow.python.keras.backend import dtype +from tensorflow.python.keras.backend import elu +from tensorflow.python.keras.backend import epsilon +from tensorflow.python.keras.backend import equal +from tensorflow.python.keras.backend import eval +from tensorflow.python.keras.backend import exp +from tensorflow.python.keras.backend import expand_dims +from tensorflow.python.keras.backend import eye +from tensorflow.python.keras.backend import flatten +from tensorflow.python.keras.backend import floatx +from tensorflow.python.keras.backend import foldl +from tensorflow.python.keras.backend import foldr +from tensorflow.python.keras.backend import function +from tensorflow.python.keras.backend import gather +from tensorflow.python.keras.backend import get_session +from tensorflow.python.keras.backend import get_uid +from tensorflow.python.keras.backend import get_value +from tensorflow.python.keras.backend import gradients +from tensorflow.python.keras.backend import greater +from tensorflow.python.keras.backend import greater_equal +from tensorflow.python.keras.backend import hard_sigmoid +from tensorflow.python.keras.backend import image_data_format +from tensorflow.python.keras.backend import in_test_phase +from tensorflow.python.keras.backend import in_top_k +from tensorflow.python.keras.backend import in_train_phase +from tensorflow.python.keras.backend import int_shape +from tensorflow.python.keras.backend import is_sparse +from tensorflow.python.keras.backend import l2_normalize +from tensorflow.python.keras.backend import learning_phase +from tensorflow.python.keras.backend import less +from tensorflow.python.keras.backend import less_equal +from tensorflow.python.keras.backend import log +from tensorflow.python.keras.backend import manual_variable_initialization +from tensorflow.python.keras.backend import map_fn +from tensorflow.python.keras.backend import max +from tensorflow.python.keras.backend import maximum +from tensorflow.python.keras.backend import mean +from tensorflow.python.keras.backend import min +from tensorflow.python.keras.backend import minimum +from tensorflow.python.keras.backend import moving_average_update +from tensorflow.python.keras.backend import name_scope +from tensorflow.python.keras.backend import ndim +from tensorflow.python.keras.backend import normalize_batch_in_training +from tensorflow.python.keras.backend import not_equal +from tensorflow.python.keras.backend import one_hot +from tensorflow.python.keras.backend import ones +from tensorflow.python.keras.backend import ones_like +from tensorflow.python.keras.backend import permute_dimensions +from tensorflow.python.keras.backend import placeholder +from tensorflow.python.keras.backend import pool2d +from tensorflow.python.keras.backend import pool3d +from tensorflow.python.keras.backend import pow +from tensorflow.python.keras.backend import print_tensor +from tensorflow.python.keras.backend import prod +from tensorflow.python.keras.backend import random_binomial +from tensorflow.python.keras.backend import random_normal +from tensorflow.python.keras.backend import random_normal_variable +from tensorflow.python.keras.backend import random_uniform +from tensorflow.python.keras.backend import random_uniform_variable +from tensorflow.python.keras.backend import relu +from tensorflow.python.keras.backend import repeat +from tensorflow.python.keras.backend import repeat_elements +from tensorflow.python.keras.backend import reset_uids +from tensorflow.python.keras.backend import reshape +from tensorflow.python.keras.backend import resize_images +from tensorflow.python.keras.backend import resize_volumes +from tensorflow.python.keras.backend import reverse +from tensorflow.python.keras.backend import rnn +from tensorflow.python.keras.backend import round +from tensorflow.python.keras.backend import separable_conv2d +from tensorflow.python.keras.backend import set_epsilon +from tensorflow.python.keras.backend import set_floatx +from tensorflow.python.keras.backend import set_image_data_format +from tensorflow.python.keras.backend import set_learning_phase +from tensorflow.python.keras.backend import set_session +from tensorflow.python.keras.backend import set_value +from tensorflow.python.keras.backend import shape +from tensorflow.python.keras.backend import sigmoid +from tensorflow.python.keras.backend import sign +from tensorflow.python.keras.backend import sin +from tensorflow.python.keras.backend import softmax +from tensorflow.python.keras.backend import softplus +from tensorflow.python.keras.backend import softsign +from tensorflow.python.keras.backend import sparse_categorical_crossentropy +from tensorflow.python.keras.backend import spatial_2d_padding +from tensorflow.python.keras.backend import spatial_3d_padding +from tensorflow.python.keras.backend import sqrt +from tensorflow.python.keras.backend import square +from tensorflow.python.keras.backend import squeeze +from tensorflow.python.keras.backend import stack +from tensorflow.python.keras.backend import std +from tensorflow.python.keras.backend import stop_gradient +from tensorflow.python.keras.backend import sum +from tensorflow.python.keras.backend import switch +from tensorflow.python.keras.backend import tanh +from tensorflow.python.keras.backend import temporal_padding +from tensorflow.python.keras.backend import to_dense +from tensorflow.python.keras.backend import transpose +from tensorflow.python.keras.backend import truncated_normal +from tensorflow.python.keras.backend import update +from tensorflow.python.keras.backend import update_add +from tensorflow.python.keras.backend import update_sub +from tensorflow.python.keras.backend import var +from tensorflow.python.keras.backend import variable +from tensorflow.python.keras.backend import zeros +from tensorflow.python.keras.backend import zeros_like del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py index 2d884790ddb..10e05f2969b 100644 --- a/tensorflow/contrib/keras/api/keras/callbacks/__init__.py +++ b/tensorflow/contrib/keras/api/keras/callbacks/__init__.py @@ -18,19 +18,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.callbacks import BaseLogger -from tensorflow.python.keras._impl.keras.callbacks import Callback -from tensorflow.python.keras._impl.keras.callbacks import CSVLogger -from tensorflow.python.keras._impl.keras.callbacks import EarlyStopping -from tensorflow.python.keras._impl.keras.callbacks import History -from tensorflow.python.keras._impl.keras.callbacks import LambdaCallback -from tensorflow.python.keras._impl.keras.callbacks import LearningRateScheduler -from tensorflow.python.keras._impl.keras.callbacks import ModelCheckpoint -from tensorflow.python.keras._impl.keras.callbacks import ProgbarLogger -from tensorflow.python.keras._impl.keras.callbacks import ReduceLROnPlateau -from tensorflow.python.keras._impl.keras.callbacks import RemoteMonitor -from tensorflow.python.keras._impl.keras.callbacks import TensorBoard -from tensorflow.python.keras._impl.keras.callbacks import TerminateOnNaN +from tensorflow.python.keras.callbacks import BaseLogger +from tensorflow.python.keras.callbacks import Callback +from tensorflow.python.keras.callbacks import CSVLogger +from tensorflow.python.keras.callbacks import EarlyStopping +from tensorflow.python.keras.callbacks import History +from tensorflow.python.keras.callbacks import LambdaCallback +from tensorflow.python.keras.callbacks import LearningRateScheduler +from tensorflow.python.keras.callbacks import ModelCheckpoint +from tensorflow.python.keras.callbacks import ProgbarLogger +from tensorflow.python.keras.callbacks import ReduceLROnPlateau +from tensorflow.python.keras.callbacks import RemoteMonitor +from tensorflow.python.keras.callbacks import TensorBoard +from tensorflow.python.keras.callbacks import TerminateOnNaN del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/constraints/__init__.py b/tensorflow/contrib/keras/api/keras/constraints/__init__.py index 152606d8ebb..08debf974ec 100644 --- a/tensorflow/contrib/keras/api/keras/constraints/__init__.py +++ b/tensorflow/contrib/keras/api/keras/constraints/__init__.py @@ -19,21 +19,21 @@ from __future__ import division from __future__ import print_function # Constraints functions / callable classes. -from tensorflow.python.keras._impl.keras.constraints import Constraint -from tensorflow.python.keras._impl.keras.constraints import max_norm -from tensorflow.python.keras._impl.keras.constraints import MaxNorm -from tensorflow.python.keras._impl.keras.constraints import min_max_norm -from tensorflow.python.keras._impl.keras.constraints import MinMaxNorm -from tensorflow.python.keras._impl.keras.constraints import non_neg -from tensorflow.python.keras._impl.keras.constraints import NonNeg -from tensorflow.python.keras._impl.keras.constraints import unit_norm -from tensorflow.python.keras._impl.keras.constraints import UnitNorm +from tensorflow.python.keras.constraints import Constraint +from tensorflow.python.keras.constraints import max_norm +from tensorflow.python.keras.constraints import MaxNorm +from tensorflow.python.keras.constraints import min_max_norm +from tensorflow.python.keras.constraints import MinMaxNorm +from tensorflow.python.keras.constraints import non_neg +from tensorflow.python.keras.constraints import NonNeg +from tensorflow.python.keras.constraints import unit_norm +from tensorflow.python.keras.constraints import UnitNorm # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.constraints import deserialize -from tensorflow.python.keras._impl.keras.constraints import serialize -from tensorflow.python.keras._impl.keras.constraints import get +from tensorflow.python.keras.constraints import deserialize +from tensorflow.python.keras.constraints import serialize +from tensorflow.python.keras.constraints import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py index b5371a03fd5..a5a6fdab445 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/boston_housing/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.boston_housing import load_data +from tensorflow.python.keras.datasets.boston_housing import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py index 68d3eb789ea..e74e5f347df 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar10/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar10 import load_data +from tensorflow.python.keras.datasets.cifar10 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py index ca937426733..8f5753a6360 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/cifar100/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.cifar100 import load_data +from tensorflow.python.keras.datasets.cifar100 import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py index 1c6396d2d32..bd6ec4b8dfb 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/imdb/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.imdb import get_word_index -from tensorflow.python.keras._impl.keras.datasets.imdb import load_data +from tensorflow.python.keras.datasets.imdb import get_word_index +from tensorflow.python.keras.datasets.imdb import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py index 364255f3387..f61145655bd 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/mnist/__init__.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.mnist import load_data +from tensorflow.python.keras.datasets.mnist import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py index bb6791a344a..ade31f4ea9c 100644 --- a/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py +++ b/tensorflow/contrib/keras/api/keras/datasets/reuters/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.datasets.reuters import get_word_index -from tensorflow.python.keras._impl.keras.datasets.reuters import load_data +from tensorflow.python.keras.datasets.reuters import get_word_index +from tensorflow.python.keras.datasets.reuters import load_data del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/initializers/__init__.py b/tensorflow/contrib/keras/api/keras/initializers/__init__.py index 6b1fcfd2d95..c6bdc4f0dac 100644 --- a/tensorflow/contrib/keras/api/keras/initializers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/initializers/__init__.py @@ -19,30 +19,30 @@ from __future__ import division from __future__ import print_function # Initializer functions / callable classes. -from tensorflow.python.keras._impl.keras.initializers import Constant -from tensorflow.python.keras._impl.keras.initializers import Identity -from tensorflow.python.keras._impl.keras.initializers import Initializer -from tensorflow.python.keras._impl.keras.initializers import Ones -from tensorflow.python.keras._impl.keras.initializers import Orthogonal -from tensorflow.python.keras._impl.keras.initializers import RandomNormal -from tensorflow.python.keras._impl.keras.initializers import RandomUniform -from tensorflow.python.keras._impl.keras.initializers import TruncatedNormal -from tensorflow.python.keras._impl.keras.initializers import VarianceScaling -from tensorflow.python.keras._impl.keras.initializers import Zeros +from tensorflow.python.keras.initializers import Constant +from tensorflow.python.keras.initializers import Identity +from tensorflow.python.keras.initializers import Initializer +from tensorflow.python.keras.initializers import Ones +from tensorflow.python.keras.initializers import Orthogonal +from tensorflow.python.keras.initializers import RandomNormal +from tensorflow.python.keras.initializers import RandomUniform +from tensorflow.python.keras.initializers import TruncatedNormal +from tensorflow.python.keras.initializers import VarianceScaling +from tensorflow.python.keras.initializers import Zeros # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.initializers import glorot_normal -from tensorflow.python.keras._impl.keras.initializers import glorot_uniform -from tensorflow.python.keras._impl.keras.initializers import he_normal -from tensorflow.python.keras._impl.keras.initializers import he_uniform -from tensorflow.python.keras._impl.keras.initializers import lecun_normal -from tensorflow.python.keras._impl.keras.initializers import lecun_uniform +from tensorflow.python.keras.initializers import glorot_normal +from tensorflow.python.keras.initializers import glorot_uniform +from tensorflow.python.keras.initializers import he_normal +from tensorflow.python.keras.initializers import he_uniform +from tensorflow.python.keras.initializers import lecun_normal +from tensorflow.python.keras.initializers import lecun_uniform # Auxiliary utils. -from tensorflow.python.keras._impl.keras.initializers import deserialize -from tensorflow.python.keras._impl.keras.initializers import serialize -from tensorflow.python.keras._impl.keras.initializers import get +from tensorflow.python.keras.initializers import deserialize +from tensorflow.python.keras.initializers import serialize +from tensorflow.python.keras.initializers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/layers/__init__.py b/tensorflow/contrib/keras/api/keras/layers/__init__.py index acf0a5e1799..938c881fcbe 100644 --- a/tensorflow/contrib/keras/api/keras/layers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/layers/__init__.py @@ -20,128 +20,128 @@ from __future__ import print_function # Generic layers. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.engine import Input -from tensorflow.python.keras._impl.keras.engine import InputLayer -from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine import Layer +from tensorflow.python.keras.engine import Input +from tensorflow.python.keras.engine import InputLayer +from tensorflow.python.keras.engine import InputSpec +from tensorflow.python.keras.engine import Layer # Advanced activations. -from tensorflow.python.keras._impl.keras.layers.advanced_activations import LeakyReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import PReLU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ELU -from tensorflow.python.keras._impl.keras.layers.advanced_activations import ThresholdedReLU +from tensorflow.python.keras.layers.advanced_activations import LeakyReLU +from tensorflow.python.keras.layers.advanced_activations import PReLU +from tensorflow.python.keras.layers.advanced_activations import ELU +from tensorflow.python.keras.layers.advanced_activations import ThresholdedReLU # Convolution layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Conv3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConv2D +from tensorflow.python.keras.layers.convolutional import Conv1D +from tensorflow.python.keras.layers.convolutional import Conv2D +from tensorflow.python.keras.layers.convolutional import Conv3D +from tensorflow.python.keras.layers.convolutional import Conv2DTranspose +from tensorflow.python.keras.layers.convolutional import Conv3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConv2D # Convolution layer aliases. -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution2DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import Convolution3DTranspose -from tensorflow.python.keras._impl.keras.layers.convolutional import SeparableConvolution2D +from tensorflow.python.keras.layers.convolutional import Convolution1D +from tensorflow.python.keras.layers.convolutional import Convolution2D +from tensorflow.python.keras.layers.convolutional import Convolution3D +from tensorflow.python.keras.layers.convolutional import Convolution2DTranspose +from tensorflow.python.keras.layers.convolutional import Convolution3DTranspose +from tensorflow.python.keras.layers.convolutional import SeparableConvolution2D # Image processing layers. -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling1D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling2D -from tensorflow.python.keras._impl.keras.layers.convolutional import UpSampling3D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding1D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding2D -from tensorflow.python.keras._impl.keras.layers.convolutional import ZeroPadding3D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping1D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping2D -from tensorflow.python.keras._impl.keras.layers.convolutional import Cropping3D +from tensorflow.python.keras.layers.convolutional import UpSampling1D +from tensorflow.python.keras.layers.convolutional import UpSampling2D +from tensorflow.python.keras.layers.convolutional import UpSampling3D +from tensorflow.python.keras.layers.convolutional import ZeroPadding1D +from tensorflow.python.keras.layers.convolutional import ZeroPadding2D +from tensorflow.python.keras.layers.convolutional import ZeroPadding3D +from tensorflow.python.keras.layers.convolutional import Cropping1D +from tensorflow.python.keras.layers.convolutional import Cropping2D +from tensorflow.python.keras.layers.convolutional import Cropping3D # Convolutional-recurrent layers. -from tensorflow.python.keras._impl.keras.layers.convolutional_recurrent import ConvLSTM2D +from tensorflow.python.keras.layers.convolutional_recurrent import ConvLSTM2D # Core layers. -from tensorflow.python.keras._impl.keras.layers.core import Masking -from tensorflow.python.keras._impl.keras.layers.core import Dropout -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout1D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout2D -from tensorflow.python.keras._impl.keras.layers.core import SpatialDropout3D -from tensorflow.python.keras._impl.keras.layers.core import Activation -from tensorflow.python.keras._impl.keras.layers.core import Reshape -from tensorflow.python.keras._impl.keras.layers.core import Permute -from tensorflow.python.keras._impl.keras.layers.core import Flatten -from tensorflow.python.keras._impl.keras.layers.core import RepeatVector -from tensorflow.python.keras._impl.keras.layers.core import Lambda -from tensorflow.python.keras._impl.keras.layers.core import Dense -from tensorflow.python.keras._impl.keras.layers.core import ActivityRegularization +from tensorflow.python.keras.layers.core import Masking +from tensorflow.python.keras.layers.core import Dropout +from tensorflow.python.keras.layers.core import SpatialDropout1D +from tensorflow.python.keras.layers.core import SpatialDropout2D +from tensorflow.python.keras.layers.core import SpatialDropout3D +from tensorflow.python.keras.layers.core import Activation +from tensorflow.python.keras.layers.core import Reshape +from tensorflow.python.keras.layers.core import Permute +from tensorflow.python.keras.layers.core import Flatten +from tensorflow.python.keras.layers.core import RepeatVector +from tensorflow.python.keras.layers.core import Lambda +from tensorflow.python.keras.layers.core import Dense +from tensorflow.python.keras.layers.core import ActivityRegularization # Embedding layers. -from tensorflow.python.keras._impl.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers.embeddings import Embedding # Locally-connected layers. -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected1D -from tensorflow.python.keras._impl.keras.layers.local import LocallyConnected2D +from tensorflow.python.keras.layers.local import LocallyConnected1D +from tensorflow.python.keras.layers.local import LocallyConnected2D # Merge layers. -from tensorflow.python.keras._impl.keras.layers.merge import Add -from tensorflow.python.keras._impl.keras.layers.merge import Multiply -from tensorflow.python.keras._impl.keras.layers.merge import Average -from tensorflow.python.keras._impl.keras.layers.merge import Maximum -from tensorflow.python.keras._impl.keras.layers.merge import Concatenate -from tensorflow.python.keras._impl.keras.layers.merge import Dot -from tensorflow.python.keras._impl.keras.layers.merge import add -from tensorflow.python.keras._impl.keras.layers.merge import multiply -from tensorflow.python.keras._impl.keras.layers.merge import average -from tensorflow.python.keras._impl.keras.layers.merge import maximum -from tensorflow.python.keras._impl.keras.layers.merge import concatenate -from tensorflow.python.keras._impl.keras.layers.merge import dot +from tensorflow.python.keras.layers.merge import Add +from tensorflow.python.keras.layers.merge import Multiply +from tensorflow.python.keras.layers.merge import Average +from tensorflow.python.keras.layers.merge import Maximum +from tensorflow.python.keras.layers.merge import Concatenate +from tensorflow.python.keras.layers.merge import Dot +from tensorflow.python.keras.layers.merge import add +from tensorflow.python.keras.layers.merge import multiply +from tensorflow.python.keras.layers.merge import average +from tensorflow.python.keras.layers.merge import maximum +from tensorflow.python.keras.layers.merge import concatenate +from tensorflow.python.keras.layers.merge import dot # Noise layers. -from tensorflow.python.keras._impl.keras.layers.noise import AlphaDropout -from tensorflow.python.keras._impl.keras.layers.noise import GaussianNoise -from tensorflow.python.keras._impl.keras.layers.noise import GaussianDropout +from tensorflow.python.keras.layers.noise import AlphaDropout +from tensorflow.python.keras.layers.noise import GaussianNoise +from tensorflow.python.keras.layers.noise import GaussianDropout # Normalization layers. -from tensorflow.python.keras._impl.keras.layers.normalization import BatchNormalization +from tensorflow.python.keras.layers.normalization import BatchNormalization # Pooling layers. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAveragePooling3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPooling3D +from tensorflow.python.keras.layers.pooling import MaxPooling1D +from tensorflow.python.keras.layers.pooling import MaxPooling2D +from tensorflow.python.keras.layers.pooling import MaxPooling3D +from tensorflow.python.keras.layers.pooling import AveragePooling1D +from tensorflow.python.keras.layers.pooling import AveragePooling2D +from tensorflow.python.keras.layers.pooling import AveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling1D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling2D +from tensorflow.python.keras.layers.pooling import GlobalAveragePooling3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPooling3D # Pooling layer aliases. -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import MaxPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import AvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalAvgPool3D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool1D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D -from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D +from tensorflow.python.keras.layers.pooling import MaxPool1D +from tensorflow.python.keras.layers.pooling import MaxPool2D +from tensorflow.python.keras.layers.pooling import MaxPool3D +from tensorflow.python.keras.layers.pooling import AvgPool1D +from tensorflow.python.keras.layers.pooling import AvgPool2D +from tensorflow.python.keras.layers.pooling import AvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool1D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool2D +from tensorflow.python.keras.layers.pooling import GlobalAvgPool3D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool1D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool2D +from tensorflow.python.keras.layers.pooling import GlobalMaxPool3D # Recurrent layers. -from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN -from tensorflow.python.keras._impl.keras.layers.recurrent import GRU -from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM +from tensorflow.python.keras.layers.recurrent import SimpleRNN +from tensorflow.python.keras.layers.recurrent import GRU +from tensorflow.python.keras.layers.recurrent import LSTM # Wrapper functions -from tensorflow.python.keras._impl.keras.layers.wrappers import Wrapper -from tensorflow.python.keras._impl.keras.layers.wrappers import Bidirectional -from tensorflow.python.keras._impl.keras.layers.wrappers import TimeDistributed +from tensorflow.python.keras.layers.wrappers import Wrapper +from tensorflow.python.keras.layers.wrappers import Bidirectional +from tensorflow.python.keras.layers.wrappers import TimeDistributed del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/losses/__init__.py b/tensorflow/contrib/keras/api/keras/losses/__init__.py index 66721b694f5..c4476a7bbd5 100644 --- a/tensorflow/contrib/keras/api/keras/losses/__init__.py +++ b/tensorflow/contrib/keras/api/keras/losses/__init__.py @@ -19,26 +19,26 @@ from __future__ import division from __future__ import print_function # Loss functions. -from tensorflow.python.keras._impl.keras.losses import binary_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import categorical_hinge -from tensorflow.python.keras._impl.keras.losses import cosine_proximity -from tensorflow.python.keras._impl.keras.losses import hinge -from tensorflow.python.keras._impl.keras.losses import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.losses import logcosh -from tensorflow.python.keras._impl.keras.losses import mean_absolute_error -from tensorflow.python.keras._impl.keras.losses import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_error -from tensorflow.python.keras._impl.keras.losses import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.losses import poisson -from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.losses import squared_hinge +from tensorflow.python.keras.losses import binary_crossentropy +from tensorflow.python.keras.losses import categorical_crossentropy +from tensorflow.python.keras.losses import categorical_hinge +from tensorflow.python.keras.losses import cosine_proximity +from tensorflow.python.keras.losses import hinge +from tensorflow.python.keras.losses import kullback_leibler_divergence +from tensorflow.python.keras.losses import logcosh +from tensorflow.python.keras.losses import mean_absolute_error +from tensorflow.python.keras.losses import mean_absolute_percentage_error +from tensorflow.python.keras.losses import mean_squared_error +from tensorflow.python.keras.losses import mean_squared_logarithmic_error +from tensorflow.python.keras.losses import poisson +from tensorflow.python.keras.losses import sparse_categorical_crossentropy +from tensorflow.python.keras.losses import squared_hinge # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.losses import deserialize -from tensorflow.python.keras._impl.keras.losses import serialize -from tensorflow.python.keras._impl.keras.losses import get +from tensorflow.python.keras.losses import deserialize +from tensorflow.python.keras.losses import serialize +from tensorflow.python.keras.losses import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/metrics/__init__.py b/tensorflow/contrib/keras/api/keras/metrics/__init__.py index 59faf037bce..7317fdb52c5 100644 --- a/tensorflow/contrib/keras/api/keras/metrics/__init__.py +++ b/tensorflow/contrib/keras/api/keras/metrics/__init__.py @@ -19,28 +19,28 @@ from __future__ import division from __future__ import print_function # Metrics functions. -from tensorflow.python.keras._impl.keras.metrics import binary_accuracy -from tensorflow.python.keras._impl.keras.metrics import binary_crossentropy -from tensorflow.python.keras._impl.keras.metrics import categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import cosine_proximity -from tensorflow.python.keras._impl.keras.metrics import hinge -from tensorflow.python.keras._impl.keras.metrics import kullback_leibler_divergence -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_error -from tensorflow.python.keras._impl.keras.metrics import mean_absolute_percentage_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_error -from tensorflow.python.keras._impl.keras.metrics import mean_squared_logarithmic_error -from tensorflow.python.keras._impl.keras.metrics import poisson -from tensorflow.python.keras._impl.keras.metrics import sparse_categorical_crossentropy -from tensorflow.python.keras._impl.keras.metrics import sparse_top_k_categorical_accuracy -from tensorflow.python.keras._impl.keras.metrics import squared_hinge -from tensorflow.python.keras._impl.keras.metrics import top_k_categorical_accuracy +from tensorflow.python.keras.metrics import binary_accuracy +from tensorflow.python.keras.metrics import binary_crossentropy +from tensorflow.python.keras.metrics import categorical_accuracy +from tensorflow.python.keras.metrics import categorical_crossentropy +from tensorflow.python.keras.metrics import cosine_proximity +from tensorflow.python.keras.metrics import hinge +from tensorflow.python.keras.metrics import kullback_leibler_divergence +from tensorflow.python.keras.metrics import mean_absolute_error +from tensorflow.python.keras.metrics import mean_absolute_percentage_error +from tensorflow.python.keras.metrics import mean_squared_error +from tensorflow.python.keras.metrics import mean_squared_logarithmic_error +from tensorflow.python.keras.metrics import poisson +from tensorflow.python.keras.metrics import sparse_categorical_crossentropy +from tensorflow.python.keras.metrics import sparse_top_k_categorical_accuracy +from tensorflow.python.keras.metrics import squared_hinge +from tensorflow.python.keras.metrics import top_k_categorical_accuracy # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.metrics import deserialize -from tensorflow.python.keras._impl.keras.metrics import serialize -from tensorflow.python.keras._impl.keras.metrics import get +from tensorflow.python.keras.metrics import deserialize +from tensorflow.python.keras.metrics import serialize +from tensorflow.python.keras.metrics import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/models/__init__.py b/tensorflow/contrib/keras/api/keras/models/__init__.py index 2fb4ac0960d..3a196984cd8 100644 --- a/tensorflow/contrib/keras/api/keras/models/__init__.py +++ b/tensorflow/contrib/keras/api/keras/models/__init__.py @@ -18,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.models import load_model -from tensorflow.python.keras._impl.keras.models import Model -from tensorflow.python.keras._impl.keras.models import model_from_config -from tensorflow.python.keras._impl.keras.models import model_from_json -from tensorflow.python.keras._impl.keras.models import model_from_yaml -from tensorflow.python.keras._impl.keras.models import save_model -from tensorflow.python.keras._impl.keras.models import Sequential +from tensorflow.python.keras.models import load_model +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.models import model_from_config +from tensorflow.python.keras.models import model_from_json +from tensorflow.python.keras.models import model_from_yaml +from tensorflow.python.keras.models import save_model +from tensorflow.python.keras.models import Sequential del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py index 44f47bc47f4..4849a067479 100644 --- a/tensorflow/contrib/keras/api/keras/optimizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/optimizers/__init__.py @@ -19,20 +19,20 @@ from __future__ import division from __future__ import print_function # Optimizer classes. -from tensorflow.python.keras._impl.keras.optimizers import Adadelta -from tensorflow.python.keras._impl.keras.optimizers import Adagrad -from tensorflow.python.keras._impl.keras.optimizers import Adam -from tensorflow.python.keras._impl.keras.optimizers import Adamax -from tensorflow.python.keras._impl.keras.optimizers import Nadam -from tensorflow.python.keras._impl.keras.optimizers import Optimizer -from tensorflow.python.keras._impl.keras.optimizers import RMSprop -from tensorflow.python.keras._impl.keras.optimizers import SGD +from tensorflow.python.keras.optimizers import Adadelta +from tensorflow.python.keras.optimizers import Adagrad +from tensorflow.python.keras.optimizers import Adam +from tensorflow.python.keras.optimizers import Adamax +from tensorflow.python.keras.optimizers import Nadam +from tensorflow.python.keras.optimizers import Optimizer +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.optimizers import SGD # Auxiliary utils. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.optimizers import deserialize -from tensorflow.python.keras._impl.keras.optimizers import serialize -from tensorflow.python.keras._impl.keras.optimizers import get +from tensorflow.python.keras.optimizers import deserialize +from tensorflow.python.keras.optimizers import serialize +from tensorflow.python.keras.optimizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py index b96e7675527..1f9e82b41bf 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/image/__init__.py @@ -18,20 +18,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.image import apply_transform -from tensorflow.python.keras._impl.keras.preprocessing.image import array_to_img -from tensorflow.python.keras._impl.keras.preprocessing.image import DirectoryIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import flip_axis -from tensorflow.python.keras._impl.keras.preprocessing.image import ImageDataGenerator -from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array -from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator -from tensorflow.python.keras._impl.keras.preprocessing.image import load_img -from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator -from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear -from tensorflow.python.keras._impl.keras.preprocessing.image import random_shift -from tensorflow.python.keras._impl.keras.preprocessing.image import random_zoom +from tensorflow.python.keras.preprocessing.image import apply_transform +from tensorflow.python.keras.preprocessing.image import array_to_img +from tensorflow.python.keras.preprocessing.image import DirectoryIterator +from tensorflow.python.keras.preprocessing.image import flip_axis +from tensorflow.python.keras.preprocessing.image import ImageDataGenerator +from tensorflow.python.keras.preprocessing.image import img_to_array +from tensorflow.python.keras.preprocessing.image import Iterator +from tensorflow.python.keras.preprocessing.image import load_img +from tensorflow.python.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras.preprocessing.image import random_channel_shift +from tensorflow.python.keras.preprocessing.image import random_rotation +from tensorflow.python.keras.preprocessing.image import random_shear +from tensorflow.python.keras.preprocessing.image import random_shift +from tensorflow.python.keras.preprocessing.image import random_zoom del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py index 112f6af5e58..9a93b6fb57f 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/sequence/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table -from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences -from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras.preprocessing.sequence import make_sampling_table +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from tensorflow.python.keras.preprocessing.sequence import skipgrams del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py index 5bf1a2fb21d..86386a9b676 100644 --- a/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py +++ b/tensorflow/contrib/keras/api/keras/preprocessing/text/__init__.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot -from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence -from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer +from tensorflow.python.keras.preprocessing.text import one_hot +from tensorflow.python.keras.preprocessing.text import text_to_word_sequence +from tensorflow.python.keras.preprocessing.text import Tokenizer del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py index 3e707ccab57..d668e39c09c 100644 --- a/tensorflow/contrib/keras/api/keras/regularizers/__init__.py +++ b/tensorflow/contrib/keras/api/keras/regularizers/__init__.py @@ -19,19 +19,19 @@ from __future__ import division from __future__ import print_function # Regularizer functions / callable classes. -from tensorflow.python.keras._impl.keras.regularizers import L1L2 -from tensorflow.python.keras._impl.keras.regularizers import Regularizer +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.regularizers import Regularizer # Functional interface. # pylint: disable=g-bad-import-order -from tensorflow.python.keras._impl.keras.regularizers import l1 -from tensorflow.python.keras._impl.keras.regularizers import l2 -from tensorflow.python.keras._impl.keras.regularizers import l1_l2 +from tensorflow.python.keras.regularizers import l1 +from tensorflow.python.keras.regularizers import l2 +from tensorflow.python.keras.regularizers import l1_l2 # Auxiliary utils. -from tensorflow.python.keras._impl.keras.regularizers import deserialize -from tensorflow.python.keras._impl.keras.regularizers import serialize -from tensorflow.python.keras._impl.keras.regularizers import get +from tensorflow.python.keras.regularizers import deserialize +from tensorflow.python.keras.regularizers import serialize +from tensorflow.python.keras.regularizers import get del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/utils/__init__.py b/tensorflow/contrib/keras/api/keras/utils/__init__.py index a7c2179fe7a..47cd01b924f 100644 --- a/tensorflow/contrib/keras/api/keras/utils/__init__.py +++ b/tensorflow/contrib/keras/api/keras/utils/__init__.py @@ -18,21 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer -from tensorflow.python.keras._impl.keras.utils.data_utils import get_file -from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence -from tensorflow.python.keras._impl.keras.utils.data_utils import SequenceEnqueuer -from tensorflow.python.keras._impl.keras.utils.generic_utils import custom_object_scope -from tensorflow.python.keras._impl.keras.utils.generic_utils import CustomObjectScope -from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object -from tensorflow.python.keras._impl.keras.utils.generic_utils import get_custom_objects -from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar -from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object -from tensorflow.python.keras._impl.keras.utils.io_utils import HDF5Matrix -from tensorflow.python.keras._impl.keras.utils.layer_utils import convert_all_kernels_in_model -from tensorflow.python.keras._impl.keras.utils.np_utils import normalize -from tensorflow.python.keras._impl.keras.utils.np_utils import to_categorical -from tensorflow.python.keras._impl.keras.utils.vis_utils import plot_model +from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.utils.data_utils import Sequence +from tensorflow.python.keras.utils.data_utils import SequenceEnqueuer +from tensorflow.python.keras.utils.generic_utils import custom_object_scope +from tensorflow.python.keras.utils.generic_utils import CustomObjectScope +from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras.utils.generic_utils import get_custom_objects +from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.keras.utils.generic_utils import serialize_keras_object +from tensorflow.python.keras.utils.io_utils import HDF5Matrix +from tensorflow.python.keras.utils.layer_utils import convert_all_kernels_in_model +from tensorflow.python.keras.utils.np_utils import normalize +from tensorflow.python.keras.utils.np_utils import to_categorical +from tensorflow.python.keras.utils.vis_utils import plot_model del absolute_import del division diff --git a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py index a46f859273e..c4b7aa765c2 100644 --- a/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py +++ b/tensorflow/contrib/keras/api/keras/wrappers/scikit_learn/__init__.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasClassifier -from tensorflow.python.keras._impl.keras.wrappers.scikit_learn import KerasRegressor +from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier +from tensorflow.python.keras.wrappers.scikit_learn import KerasRegressor del absolute_import del division diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 91929184a2e..2ff4d41d75f 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -31,7 +31,7 @@ from tensorflow.python.platform import googletest def _inner_product(x, y): - """Inner product between tensors x and y. + r"""Inner product between tensors x and y. The input tensors are assumed to be in ROW representation, that is, the method returns \\(x * y^T\\). @@ -131,10 +131,6 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): mapped_dim = 5000 stddev = 5.0 - # TODO(sibyl-vie3Poto): Reduce test's running time before moving to third_party. One - # possible way to speed the test up is to compute both the approximate and - # the exact kernel matrix directly using matrix operations instead of - # computing the values for each pair of points separately. points_shape = [1, input_dim] points = [ random_ops.random_uniform(shape=points_shape, maxval=1.0) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index e8e3353091d..d6b1a61b716 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -223,26 +223,26 @@ def minimize_loss_single_machine(loss, (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - with tf.device(device): - train_op = optimizer.minimize(loss, global_step=g_step) - def make_update_op(update_thunks): - update_op = [thunk() for thunk in update_thunks] - return tf.group(*update_op) + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([train_op, cov_update_op]): + with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( - tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + with tf.device(device): + train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, inverse_op]) + [g_step, loss, accuracy, train_op]) - if (global_step_ + 1) % _INVERT_EVERY == 0: + if global_step_ % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) @@ -325,7 +325,7 @@ def distributed_grads_only_and_ops_chief_worker( All workers perform gradient computation. Chief worker applies gradient after averaging the gradients obtained from all the workers. All workers block - execution untill the update is applied. Chief worker runs covariance and + execution until the update is applied. Chief worker runs covariance and inverse update ops. Covariance and inverse matrices are placed on parameter servers in a round robin manner. For further details on synchronous distributed optimization check `tf.train.SyncReplicasOptimizer`. @@ -357,24 +357,25 @@ def distributed_grads_only_and_ops_chief_worker( task_id, num_worker_tasks, num_ps_tasks, layer_collection) (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - train_op = sync_optimizer.minimize(loss, global_step=global_step) tf.logging.info("Starting training.") hooks = [sync_optimizer.make_session_run_hook(is_chief)] def make_update_op(update_thunks): - update_op = [thunk() for thunk in update_thunks] - return tf.group(*update_op) + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) if is_chief: cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([train_op, cov_update_op]): - update_op = tf.cond( - tf.equal(tf.mod(global_step + 1, invert_every), 0), + with tf.control_dependencies([cov_update_op]): + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, invert_every), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = sync_optimizer.minimize(loss, global_step=global_step) else: - update_op = train_op + train_op = sync_optimizer.minimize(loss, global_step=global_step) with tf.train.MonitoredTrainingSession( master=master, @@ -384,7 +385,7 @@ def distributed_grads_only_and_ops_chief_worker( stop_grace_period_secs=0) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( - [global_step, loss, accuracy, update_op]) + [global_step, loss, accuracy, train_op]) tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) return accuracy_ @@ -577,25 +578,25 @@ def train_mnist_multitower(data_dir, num_epochs, num_towers, (cov_update_thunks, inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() - train_op = optimizer.minimize(loss, global_step=g_step) - def make_update_op(update_thunks): - update_op = [thunk() for thunk in update_thunks] - return tf.group(*update_op) + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) cov_update_op = make_update_op(cov_update_thunks) - with tf.control_dependencies([train_op, cov_update_op]): + with tf.control_dependencies([cov_update_op]): inverse_op = tf.cond( - tf.equal(tf.mod(g_step + 1, _INVERT_EVERY), 0), + tf.equal(tf.mod(g_step, _INVERT_EVERY), 0), lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=g_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): global_step_, loss_, accuracy_, _ = sess.run( - [g_step, loss, accuracy, inverse_op]) + [g_step, loss, accuracy, train_op]) - if (global_step_ + 1) % _INVERT_EVERY == 0: + if global_step_ % _INVERT_EVERY == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, loss_, accuracy_) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 87eed03888c..ea2b252a057 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -105,18 +105,21 @@ def build_model(examples, labels, num_labels, layer_collection): return loss, accuracy -def minimize(loss, accuracy, layer_collection, session_config=None): +def minimize(loss, accuracy, layer_collection, num_towers, session_config=None): """Minimize 'loss' with KfacOptimizer. Args: loss: 0-D Tensor. Loss to be minimized. accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. layer_collection: LayerCollection instance. Describes layers in model. + num_towers: int. Number of CPUs to split minibatch across. session_config: tf.ConfigProto. Configuration for tf.Session(). Returns: accuracy of classifier on final minibatch. """ + devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers)) + # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 # every 10k iterations. tf.logging.info("Building KFAC Optimizer.") @@ -125,27 +128,38 @@ def minimize(loss, accuracy, layer_collection, session_config=None): learning_rate=tf.train.exponential_decay( 0.00002, global_step, 10000, 0.5, staircase=True), cov_ema_decay=0.95, - damping=0.0001, + damping=0.0005, layer_collection=layer_collection, - momentum=0.99) - train_op = optimizer.minimize(loss, global_step=global_step) + momentum=0.99, + placement_strategy="round_robin", + cov_devices=devices, + inv_devices=devices) + + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt + # once that gets moved over? Could still leave more advanced examples as they + # are (e.g. train_mnist_estimator in this file) + + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # We update the inverses only every 20 iterations. + inverse_op = tf.cond( + tf.equal(tf.mod(global_step, 100), 0), + lambda: make_update_op(inv_update_thunks), tf.no_op) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) tf.logging.info("Starting training.") with tf.train.MonitoredTrainingSession(config=session_config) as sess: while not sess.should_stop(): - # K-FAC has 3 primary ops, - # - train_op: Update the weights with the minibatch's gradient. - # - cov_update_op: Update statistics used for building K-FAC's - # preconditioner matrix. - # - inv_update_op: Update preconditioner matrix using statistics. - # - # The first 2 of these are cheap and should be done with each step. The - # latter is more expensive, and should be updated ~100 iterations. - global_step_, loss_, accuracy_, _, _ = sess.run( - [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) - - if global_step_ % 100 == 0: - sess.run(optimizer.inv_update_op) + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, train_op]) if global_step_ % 100 == 0: tf.logging.info("global_step: %d | loss: %f | accuracy: %f", @@ -180,7 +194,7 @@ def train_mnist(data_dir, num_epochs, use_fake_data=False): loss, accuracy = build_model(examples, labels, 10, layer_collection) # Fit model. - minimize(loss, accuracy, layer_collection) + minimize(loss, accuracy, layer_collection, 1) def train_mnist_multitower(data_dir, @@ -238,7 +252,8 @@ def train_mnist_multitower(data_dir, "CPU": num_towers }) return minimize( - loss, accuracy, layer_collection, session_config=session_config) + loss, accuracy, layer_collection, num_towers, + session_config=session_config) def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): @@ -298,13 +313,26 @@ def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): layer_collection=layer_collection, momentum=0.99) + (cov_update_thunks, + inv_update_thunks) = optimizer.make_vars_and_create_op_thunks() + + def make_update_op(update_thunks): + update_ops = [thunk() for thunk in update_thunks] + return tf.group(*update_ops) + + def make_batch_executed_op(update_thunks, batch_size=1): + return tf.group(*tf.contrib.kfac.utils.batch_execute( + global_step, update_thunks, batch_size=batch_size)) + # Run cov_update_op every step. Run 1 inv_update_ops per step. - cov_update_op = optimizer.cov_update_op - inv_update_op = tf.group( - tf.contrib.kfac.utils.batch_execute( - global_step, optimizer.inv_update_thunks, batch_size=1)) - with tf.control_dependencies([cov_update_op, inv_update_op]): - train_op = optimizer.minimize(loss, global_step=global_step) + cov_update_op = make_update_op(cov_update_thunks) + with tf.control_dependencies([cov_update_op]): + # But make sure to execute all the inverse ops on the first step + inverse_op = tf.cond(tf.equal(global_step, 0), + lambda: make_update_op(inv_update_thunks), + lambda: make_batch_executed_op(inv_update_thunks)) + with tf.control_dependencies([inverse_op]): + train_op = optimizer.minimize(loss, global_step=global_step) # Print metrics every 5 sec. hooks = [ diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 6de775cc799..adecda71666 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -157,7 +157,7 @@ class ConvNetTest(tf.test.TestCase): num_ps_tasks=0, master="", data_dir=None, - num_epochs=1, + num_epochs=2, op_strategy="chief_worker", use_fake_data=True) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 2477d2bfc12..6e4a8d71baa 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -58,6 +58,7 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:linear_operator", "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -96,6 +97,7 @@ py_test( srcs = ["optimizer_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index f22dbcf2156..0e65d419a31 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -81,7 +81,7 @@ class EstimatorTest(test.TestCase): damping=0.2, layer_collection=self.layer_collection ) - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() # Check that we throw an error if we don't include registered variables, # i.e. self.weights @@ -91,7 +91,7 @@ class EstimatorTest(test.TestCase): cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): @@ -101,7 +101,7 @@ class EstimatorTest(test.TestCase): cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testInvalidEstimationMode(self): with self.assertRaises(ValueError): @@ -111,7 +111,7 @@ class EstimatorTest(test.TestCase): damping=0.2, layer_collection=self.layer_collection, estimation_mode="not_a_real_mode") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testGradientsModeBuild(self): with self._graph.as_default(): @@ -121,7 +121,7 @@ class EstimatorTest(test.TestCase): damping=0.2, layer_collection=self.layer_collection, estimation_mode="gradients") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testEmpiricalModeBuild(self): with self._graph.as_default(): @@ -131,7 +131,7 @@ class EstimatorTest(test.TestCase): damping=0.2, layer_collection=self.layer_collection, estimation_mode="empirical") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testCurvaturePropModeBuild(self): with self._graph.as_default(): @@ -141,7 +141,7 @@ class EstimatorTest(test.TestCase): damping=0.2, layer_collection=self.layer_collection, estimation_mode="curvature_prop") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def testExactModeBuild(self): with self._graph.as_default(): @@ -151,7 +151,7 @@ class EstimatorTest(test.TestCase): damping=0.2, layer_collection=self.layer_collection, estimation_mode="exact") - est.make_ops_and_vars() + est.make_vars_and_create_op_thunks() def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" @@ -215,8 +215,11 @@ class EstimatorTest(test.TestCase): inv_devices=["/cpu:{}".format(i) for i in range(2)]) # Construct an op that executes one covariance update per step. - (cov_update_ops, _, inv_update_ops, _, _, - _) = fisher_estimator.make_ops_and_vars(scope="test") + (cov_update_thunks, + inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks( + scope="test") + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 6eda6c31e34..86ec7a095af 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -34,6 +36,19 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + +# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our +# inverse is something other than the identity" are actually broken. They never +# run the covariance update ops and so the inverse actually is the identity +# (possible plus the damping term, which would still make it a multiple of the +# identity). + + def _make_psd(dim): """Constructs a PSD matrix of the given dimension.""" mat = np.ones((dim, dim), dtype=np.float32) @@ -46,8 +61,9 @@ class UtilsTest(test.TestCase): def testComputePiTracenorm(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) + diag = ops.convert_to_tensor([1., 2., 0., 1.]) + left_factor = lo.LinearOperatorDiag(diag) + right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2])) # pi is the sqrt of the left trace norm divided by the right trace norm pi = fb.compute_pi_tracenorm(left_factor, right_factor) @@ -245,7 +261,6 @@ class NaiveDiagonalFBTest(test.TestCase): full = sess.run(block.full_fisher_block()) explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) - self.assertAllClose(output_flat, explicit) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 2a3592c53fd..fad47cd02f3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -35,6 +35,13 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + def make_damping_func(damping): return fb._package_func(lambda: damping, damping) @@ -70,18 +77,6 @@ class FisherFactorTestingDummy(ff.FisherFactor): def get_cov(self): return NotImplementedError - def left_multiply(self, x, damping): - return NotImplementedError - - def right_multiply(self, x, damping): - return NotImplementedError - - def left_multiply_matpower(self, x, exp, damping): - return NotImplementedError - - def right_multiply_matpower(self, x, exp, damping): - return NotImplementedError - def instantiate_inv_variables(self): return NotImplementedError @@ -91,14 +86,35 @@ class FisherFactorTestingDummy(ff.FisherFactor): def _get_data_device(self): raise NotImplementedError + def register_matpower(self, exp, damping_func): + raise NotImplementedError -class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): - """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. + def register_cholesky(self, damping_func): + raise NotImplementedError + + def register_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_matpower(self, exp, damping_func): + raise NotImplementedError + + def get_cholesky(self, damping_func): + raise NotImplementedError + + def get_cholesky_inverse(self, damping_func): + raise NotImplementedError + + def get_cov_as_linear_operator(self): + raise NotImplementedError + + +class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor): + """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor. """ def __init__(self, shape): self._shape = shape - super(InverseProvidingFactorTestingDummy, self).__init__() + super(DenseSquareMatrixFactorTestingDummy, self).__init__() @property def _var_scope(self): @@ -230,13 +246,13 @@ class FisherFactorTest(test.TestCase): self.assertEqual(0, len(factor.make_inverse_update_ops())) -class InverseProvidingFactorTest(test.TestCase): +class DenseSquareMatrixFactorTest(test.TestCase): def testRegisterDampedInverse(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [2, 2] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' damping_funcs = [make_damping_func(0.1), @@ -248,22 +264,25 @@ class InverseProvidingFactorTest(test.TestCase): factor.instantiate_inv_variables() - inv = factor.get_inverse(damping_funcs[0]) - self.assertEqual(inv, factor.get_inverse(damping_funcs[1])) - self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2])) - self.assertEqual(factor.get_inverse(damping_funcs[2]), - factor.get_inverse(damping_funcs[3])) + inv = factor.get_inverse(damping_funcs[0]).to_dense() + self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense()) + self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense()) + self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(), + factor.get_inverse(damping_funcs[3]).to_dense()) factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - self.assertEqual(set([inv, factor.get_inverse(damping_funcs[2])]), - set(factor_vars)) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + self.assertEqual(set([inv, + factor.get_inverse(damping_funcs[2]).to_dense()]), + set(factor_tensors)) self.assertEqual(shape, inv.get_shape()) def testRegisterMatpower(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) shape = [3, 3] - factor = InverseProvidingFactorTestingDummy(shape) + factor = DenseSquareMatrixFactorTestingDummy(shape) factor_var_scope = 'dummy/a_b_c' # TODO(b/74201126): Change to using the same func for both once @@ -278,10 +297,13 @@ class InverseProvidingFactorTest(test.TestCase): factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, factor_var_scope) - matpower1 = factor.get_matpower(-0.5, damping_func_1) - matpower2 = factor.get_matpower(2, damping_func_2) - self.assertEqual(set([matpower1, matpower2]), set(factor_vars)) + factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars) + + matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense() + matpower2 = factor.get_matpower(2, damping_func_2).to_dense() + + self.assertEqual(set([matpower1, matpower2]), set(factor_tensors)) self.assertEqual(shape, matpower1.get_shape()) self.assertEqual(shape, matpower2.get_shape()) @@ -297,7 +319,7 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[1., 2.], [3., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) damping_funcs = [] @@ -316,7 +338,8 @@ class InverseProvidingFactorTest(test.TestCase): sess.run(ops) for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD): # The inverse op will assign the damped inverse of cov to the inv var. - new_invs.append(sess.run(factor.get_inverse(damping_funcs[i]))) + new_invs.append( + sess.run(factor.get_inverse(damping_funcs[i]).to_dense())) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -328,7 +351,7 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[6., 2.], [2., 4.]]) - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power damping = 0.5 @@ -341,7 +364,7 @@ class InverseProvidingFactorTest(test.TestCase): sess.run(tf_variables.global_variables_initializer()) sess.run(ops[0]) - matpower = sess.run(factor.get_matpower(exp, damping_func)) + matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense()) matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) self.assertAllClose(matpower, matpower_np) @@ -349,7 +372,7 @@ class InverseProvidingFactorTest(test.TestCase): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric - factor = InverseProvidingFactorTestingDummy(cov.shape) + factor = DenseSquareMatrixFactorTestingDummy(cov.shape) factor._cov = array_ops.constant(cov, dtype=dtypes.float32) damping_func = make_damping_func(0) @@ -361,12 +384,12 @@ class InverseProvidingFactorTest(test.TestCase): sess.run(tf_variables.global_variables_initializer()) # The inverse op will assign the damped inverse of cov to the inv var. - old_inv = sess.run(factor.get_inverse(damping_func)) + old_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose( sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) sess.run(ops) - new_inv = sess.run(factor.get_inverse(damping_func)) + new_inv = sess.run(factor.get_inverse(damping_func).to_dense()) self.assertAllClose(new_inv, np.linalg.inv(cov)) @@ -411,7 +434,7 @@ class NaiveDiagonalFactorTest(test.TestCase): tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) factor.instantiate_cov_variables() - self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -420,7 +443,7 @@ class NaiveDiagonalFactorTest(test.TestCase): tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) factor.instantiate_cov_variables() - cov = factor.get_cov_var() + cov = factor.get_cov() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -444,7 +467,7 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase): vocab_size = 5 factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) factor.instantiate_cov_variables() - cov = factor.get_cov_var() + cov = factor.get_cov() self.assertEqual(cov.shape.as_list(), [vocab_size]) def testCovarianceUpdateOp(self): @@ -502,7 +525,7 @@ class ConvDiagonalFactorTest(test.TestCase): self.kernel_height * self.kernel_width * self.in_channels, self.out_channels ], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(): @@ -564,7 +587,7 @@ class ConvDiagonalFactorTest(test.TestCase): self.kernel_height * self.kernel_width * self.in_channels + 1, self.out_channels ], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure update op doesn't crash. cov_update_op = factor.make_covariance_update_op(0.0) @@ -654,13 +677,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): # Ensure shape of covariance matches input size of filter. input_size = in_channels * (width**3) self.assertEqual([input_size, input_size], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank-8, as the filter will be applied at each corner of # the 4-D cube. @@ -685,13 +708,13 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): # Ensure shape of covariance matches input size of filter. self.assertEqual([in_channels, in_channels], - factor.get_cov_var().shape.as_list()) + factor.get_cov().shape.as_list()) # Ensure cov_update_op doesn't crash. with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank-9, as the filter will be applied at each location. self.assertMatrixRank(9, cov) @@ -716,7 +739,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be the sum of 3 * 2 = 6 outer products. self.assertMatrixRank(6, cov) @@ -742,7 +765,7 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): with self.test_session() as sess: sess.run(tf_variables.global_variables_initializer()) sess.run(factor.make_covariance_update_op(0.0)) - cov = sess.run(factor.get_cov_var()) + cov = sess.run(factor.get_cov()) # Cov should be rank = in_channels, as only the center of the filter # receives non-zero input for each input channel. @@ -814,6 +837,21 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase): new_cov = sess.run(factor.make_covariance_update_op(0.)) self.assertAllClose([[(1. + 4.) / 2.]], new_cov) + def testSubSample(self): + with tf_ops.Graph().as_default(): + patches_1 = array_ops.constant(1, shape=(10, 2)) + patches_2 = array_ops.constant(1, shape=(10, 8)) + patches_3 = array_ops.constant(1, shape=(3, 3)) + patches_1_sub = ff._subsample_for_cov_computation(patches_1) + patches_2_sub = ff._subsample_for_cov_computation(patches_2) + patches_3_sub = ff._subsample_for_cov_computation(patches_3) + patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0] + patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0] + patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0] + self.assertEqual(2, patches_1_sub_batch_size) + self.assertEqual(8, patches_2_sub_batch_size) + self.assertEqual(3, patches_3_sub_batch_size) + class ConvOutputKroneckerFactorTest(ConvFactorTestCase): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py index 9325aa1b732..560a9b0b426 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import optimizer from tensorflow.python.framework import ops @@ -32,6 +33,13 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +# We need to set these constants since the numerical values used in the tests +# were chosen when these used to be the defaults. +ff.set_global_constants(init_covariances_at_zero=False, + zero_debias=False, + init_inverses_at_zero=False) + + def dummy_layer_collection(): lcoll = lc.LayerCollection() dummy = array_ops.constant([1., 2.]) @@ -186,6 +194,11 @@ class OptimizerTest(test.TestCase): layer_collection, momentum=0.5, momentum_type='regular') + (cov_update_thunks, + inv_update_thunks) = opt.make_vars_and_create_op_thunks() + cov_update_ops = tuple(thunk() for thunk in cov_update_thunks) + inv_update_ops = tuple(thunk() for thunk in inv_update_thunks) + grads_and_vars = opt.compute_gradients(output, [weights, bias]) all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] @@ -193,6 +206,8 @@ class OptimizerTest(test.TestCase): sess.run(tf_variables.global_variables_initializer()) old_vars = sess.run(all_vars) + sess.run(cov_update_ops) + sess.run(inv_update_ops) sess.run(op) new_vars = sess.run(all_vars) diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index b897fd68a08..3c01eb65e7a 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -35,12 +35,16 @@ py_library( srcs = ["fisher_factors.py"], srcs_version = "PY2AND3", deps = [ + ":linear_operator", ":utils", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:special_math_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -60,6 +64,19 @@ py_library( ], ) +py_library( + name = "linear_operator", + srcs = ["linear_operator.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", + ], +) + py_library( name = "loss_functions", srcs = ["loss_functions.py"], diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index d11c9c82881..854f885c26f 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -57,8 +57,8 @@ def make_fisher_estimator(placement_strategy=None, **kwargs): if placement_strategy in [None, "round_robin"]: return FisherEstimatorRoundRobin(**kwargs) else: - raise ValueError("Unimplemented vars and ops placement strategy : %s", - placement_strategy) + raise ValueError("Unimplemented vars and ops " + "placement strategy : {}".format(placement_strategy)) # pylint: enable=abstract-class-instantiated @@ -81,7 +81,9 @@ class FisherEstimator(object): exps=(-1,), estimation_mode="gradients", colocate_gradients_with_ops=True, - name="FisherEstimator"): + name="FisherEstimator", + compute_cholesky=False, + compute_cholesky_inverse=False): """Create a FisherEstimator object. Args: @@ -124,6 +126,12 @@ class FisherEstimator(object): name: A string. A name given to this estimator, which is added to the variable scope when constructing variables and ops. (Default: "FisherEstimator") + compute_cholesky: Bool. Whether or not the FisherEstimator will be + able to multiply vectors by the Cholesky factor. + (Default: False) + compute_cholesky_inverse: Bool. Whether or not the FisherEstimator + will be able to multiply vectors by the Cholesky factor inverse. + (Default: False) Raises: ValueError: If no losses have been registered with layer_collection. """ @@ -142,6 +150,8 @@ class FisherEstimator(object): self._made_vars = False self._exps = exps + self._compute_cholesky = compute_cholesky + self._compute_cholesky_inverse = compute_cholesky_inverse self._name = name @@ -170,44 +180,6 @@ class FisherEstimator(object): def name(self): return self._name - @abc.abstractmethod - def make_ops_and_vars(self, scope=None): - """Make ops and vars with a specific placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. For example in case of - round robin placement a new device is chosen for each factor by cycling - through list of devices in the cov_devices argument. If cov_devices is None - then no explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the inv_devices argument. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - pass - @abc.abstractmethod def make_vars_and_create_op_thunks(self, scope=None): """Make vars and create op thunks with a specific placement strategy. @@ -300,9 +272,54 @@ class FisherEstimator(object): A list of (transformed vector, var) pairs in the same order as vecs_and_vars. """ + assert exp in self._exps + fcn = lambda fb, vec: fb.multiply_matpower(vec, exp) return self._apply_transformation(vecs_and_vars, fcn) + def multiply_cholesky(self, vecs_and_vars, transpose=False): + """Multiplies the vecs by the corresponding Cholesky factors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factors are transposed before + multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky + + fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + + def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False): + """Mults the vecs by the inverses of the corresponding Cholesky factors. + + Note: if you are using Cholesky inverse multiplication to sample from + a matrix-variate Gaussian you will want to multiply by the transpose. + Let L be the Cholesky factor of F and observe that + + L^-T * L^-1 = (L * L^T)^-1 = F^-1 . + + Thus we want to multiply by L^-T in order to sample from Gaussian with + covariance F^-1. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transpose: Bool. If true the Cholesky factor inverses are transposed + before multiplying the vecs. (Default: False) + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + assert self._compute_cholesky_inverse + + fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose) + return self._apply_transformation(vecs_and_vars, fcn) + def _instantiate_factors(self): """Instantiates FisherFactors' variables. @@ -333,9 +350,13 @@ class FisherEstimator(object): return self._made_vars def _register_matrix_functions(self): - for exp in self._exps: - for block in self.blocks: + for block in self.blocks: + for exp in self._exps: block.register_matpower(exp) + if self._compute_cholesky: + block.register_cholesky() + if self._compute_cholesky_inverse: + block.register_cholesky_inverse() def _finalize_layer_collection(self): self._layers.create_subgraph() diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py index 33c96965061..9c9fef471f8 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -25,6 +25,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'FisherEstimator', + 'make_fisher_estimator', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 00b3673a742..3a5c8eb5f96 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -83,34 +83,22 @@ def normalize_damping(damping, num_replications): def compute_pi_tracenorm(left_cov, right_cov): - """Computes the scalar constant pi for Tikhonov regularization/damping. + r"""Computes the scalar constant pi for Tikhonov regularization/damping. $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$ See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. Args: - left_cov: The left Kronecker factor "covariance". - right_cov: The right Kronecker factor "covariance". + left_cov: A LinearOperator object. The left Kronecker factor "covariance". + right_cov: A LinearOperator object. The right Kronecker factor "covariance". Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ - - def _trace(cov): - if len(cov.shape) == 1: - # Diagonal matrix. - return math_ops.reduce_sum(cov) - elif len(cov.shape) == 2: - # Full matrix. - return math_ops.trace(cov) - else: - raise ValueError( - "What's the trace of a Tensor of rank %d?" % len(cov.shape)) - # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = left_cov.trace() * int(right_cov.domain_dimension) + right_norm = right_cov.trace() * int(left_cov.domain_dimension) return math_ops.sqrt(left_norm / right_norm) @@ -188,6 +176,16 @@ class FisherBlock(object): """ pass + @abc.abstractmethod + def register_cholesky(self): + """Registers a Cholesky factor to be computed by the block.""" + pass + + @abc.abstractmethod + def register_cholesky_inverse(self): + """Registers an inverse Cholesky factor to be computed by the block.""" + pass + def register_inverse(self): """Registers a matrix inverse to be computed by the block.""" self.register_matpower(-1) @@ -228,6 +226,33 @@ class FisherBlock(object): """ return self.multiply_matpower(vector, 1) + @abc.abstractmethod + def multiply_cholesky(self, vector, transpose=False): + """Multiplies the vector by the (damped) Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor is transposed before + multiplying the vector. (Default: False) + + Returns: + The vector left-multiplied by the (damped) Cholesky-factor of the block. + """ + pass + + @abc.abstractmethod + def multiply_cholesky_inverse(self, vector, transpose=False): + """Multiplies vector by the (damped) inverse Cholesky-factor of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + transpose: Bool. If true the Cholesky factor inverse is transposed + before multiplying the vector. (Default: False) + Returns: + Vector left-multiplied by (damped) inverse Cholesky-factor of the block. + """ + pass + @abc.abstractmethod def tensors_to_compute_grads(self): """Returns the Tensor(s) with respect to which this FisherBlock needs grads. @@ -275,15 +300,32 @@ class FullFB(FisherBlock): def register_matpower(self, exp): self._factor.register_matpower(exp, self._damping_func) - def multiply_matpower(self, vector, exp): + def register_cholesky(self): + self._factor.register_cholesky(self._damping_func) + + def register_cholesky_inverse(self): + self._factor.register_cholesky_inverse(self._damping_func) + + def _multiply_matrix(self, matrix, vector, transpose=False): vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_matpower( - vector_flat, exp, self._damping_func) + out_flat = matrix.matmul(vector_flat, adjoint=transpose) return utils.column_to_tensors(vector, out_flat) + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector, transpose=transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block.""" - return self._factor.get_cov() + return self._factor.get_cov_as_linear_operator().to_dense() def tensors_to_compute_grads(self): return self._params @@ -305,7 +347,47 @@ class FullFB(FisherBlock): return math_ops.reduce_sum(self._batch_sizes) -class NaiveDiagonalFB(FisherBlock): +@six.add_metaclass(abc.ABCMeta) +class DiagonalFB(FisherBlock): + """A base class for FisherBlocks that use diagonal approximations.""" + + def register_matpower(self, exp): + # Not needed for this. Matrix powers are computed on demand in the + # diagonal case + pass + + def register_cholesky(self): + # Not needed for this. Cholesky's are computed on demand in the + # diagonal case + pass + + def register_cholesky_inverse(self): + # Not needed for this. Cholesky inverses's are computed on demand in the + # diagonal case + pass + + def _multiply_matrix(self, matrix, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = matrix.matmul(vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def multiply_matpower(self, vector, exp): + matrix = self._factor.get_matpower(exp, self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky(self, vector, transpose=False): + matrix = self._factor.get_cholesky(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def multiply_cholesky_inverse(self, vector, transpose=False): + matrix = self._factor.get_cholesky_inverse(self._damping_func) + return self._multiply_matrix(matrix, vector) + + def full_fisher_block(self): + return self._factor.get_cov_as_linear_operator().to_dense() + + +class NaiveDiagonalFB(DiagonalFB): """FisherBlock using a diagonal matrix approximation. This type of approximation is generically applicable but quite primitive. @@ -333,20 +415,6 @@ class NaiveDiagonalFB(FisherBlock): self._factor = self._layer_collection.make_or_get_factor( fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - vector_flat = utils.tensors_to_column(vector) - out_flat = self._factor.left_multiply_matpower( - vector_flat, exp, self._damping_func) - return utils.column_to_tensors(vector, out_flat) - - def full_fisher_block(self): - return self._factor.get_cov() - def tensors_to_compute_grads(self): return self._params @@ -452,7 +520,7 @@ class InputOutputMultiTower(object): return self.__outputs -class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): +class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for fully-connected (dense) layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a fully @@ -497,32 +565,8 @@ class FullyConnectedDiagonalFB(InputOutputMultiTower, FisherBlock): self._damping_func = _package_func(lambda: damping, (damping,)) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - def multiply_matpower(self, vector, exp): - """Multiplies the vector by the (damped) matrix-power of the block. - - Args: - vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape - [input_size, output_size] corresponding to layer's weights. If not, a - 2-tuple of the former and a Tensor of shape [output_size] corresponding - to the layer's bias. - exp: A scalar representing the power to raise the block before multiplying - it by the vector. - - Returns: - The vector left-multiplied by the (damped) matrix-power of the block. - """ - reshaped_vec = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_matpower( - reshaped_vec, exp, self._damping_func) - return utils.mat2d_to_layer_params(vector, reshaped_out) - - -class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): +class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB): """FisherBlock for 2-D convolutional layers using a diagonal approx. Estimates the Fisher Information matrix's diagonal entries for a convolutional @@ -621,17 +665,6 @@ class ConvDiagonalFB(InputOutputMultiTower, FisherBlock): self._num_locations) self._damping_func = _package_func(damping_func, damping_id) - def register_matpower(self, exp): - # Not needed for this. Matrix powers are computed on demand in the - # diagonal case - pass - - def multiply_matpower(self, vector, exp): - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = self._factor.left_multiply_matpower( - reshaped_vect, exp, self._damping_func) - return utils.mat2d_to_layer_params(vector, reshaped_out) - class KroneckerProductFB(FisherBlock): """A base class for blocks with separate input and output Kronecker factors. @@ -640,9 +673,6 @@ class KroneckerProductFB(FisherBlock): output factors. """ - def __init__(self, layer_collection): - super(KroneckerProductFB, self).__init__(layer_collection) - def _setup_damping(self, damping, normalization=None): """Makes functions that compute the damping values for both factors.""" def compute_damping(): @@ -651,9 +681,10 @@ class KroneckerProductFB(FisherBlock): else: maybe_normalized_damping = damping - return compute_pi_adjusted_damping(self._input_factor.get_cov(), - self._output_factor.get_cov(), - maybe_normalized_damping**0.5) + return compute_pi_adjusted_damping( + self._input_factor.get_cov_as_linear_operator(), + self._output_factor.get_cov_as_linear_operator(), + maybe_normalized_damping**0.5) if normalization is not None: damping_id = ("compute_pi_adjusted_damping", @@ -675,6 +706,14 @@ class KroneckerProductFB(FisherBlock): self._input_factor.register_matpower(exp, self._input_damping_func) self._output_factor.register_matpower(exp, self._output_damping_func) + def register_cholesky(self): + self._input_factor.register_cholesky(self._input_damping_func) + self._output_factor.register_cholesky(self._output_damping_func) + + def register_cholesky_inverse(self): + self._input_factor.register_cholesky_inverse(self._input_damping_func) + self._output_factor.register_cholesky_inverse(self._output_damping_func) + @property def _renorm_coeff(self): """Kronecker factor multiplier coefficient. @@ -687,17 +726,47 @@ class KroneckerProductFB(FisherBlock): """ return 1.0 - def multiply_matpower(self, vector, exp): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = self._output_factor.right_multiply_matpower( - reshaped_vector, exp, self._output_damping_func) - reshaped_out = self._input_factor.left_multiply_matpower( - reshaped_out, exp, self._input_damping_func) - if self._renorm_coeff != 1.0: - renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype) - reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype) + reshaped_out = right_factor.matmul_right(reshaped_vector, + adjoint=transpose_right) + reshaped_out = left_factor.matmul(reshaped_out, + adjoint=transpose_left) + if extra_scale != 1.0: + reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) + def multiply_matpower(self, vector, exp): + left_factor = self._input_factor.get_matpower( + exp, self._input_damping_func) + right_factor = self._output_factor.get_matpower( + exp, self._output_damping_func) + extra_scale = float(self._renorm_coeff)**exp + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale) + + def multiply_cholesky(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky(self._input_damping_func) + right_factor = self._output_factor.get_cholesky(self._output_damping_func) + extra_scale = float(self._renorm_coeff)**0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + + def multiply_cholesky_inverse(self, vector, transpose=False): + left_factor = self._input_factor.get_cholesky_inverse( + self._input_damping_func) + right_factor = self._output_factor.get_cholesky_inverse( + self._output_damping_func) + extra_scale = float(self._renorm_coeff)**-0.5 + return self._multiply_factored_matrix(left_factor, right_factor, vector, + extra_scale=extra_scale, + transpose_left=transpose, + transpose_right=not transpose) + def full_fisher_block(self): """Explicitly constructs the full Fisher block. @@ -706,8 +775,8 @@ class KroneckerProductFB(FisherBlock): Returns: The full Fisher block. """ - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() + left_factor = self._input_factor.get_cov_as_linear_operator().to_dense() + right_factor = self._output_factor.get_cov_as_linear_operator().to_dense() return self._renorm_coeff * utils.kronecker_product(left_factor, right_factor) @@ -796,7 +865,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB): class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): - """FisherBlock for convolutional layers using the basic KFC approx. + r"""FisherBlock for convolutional layers using the basic KFC approx. Estimates the Fisher Information matrix's blog for a convolutional layer. @@ -945,10 +1014,10 @@ class DepthwiseConvDiagonalFB(ConvDiagonalFB): self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) - def multiply_matpower(self, vector, exp): + def _multiply_matrix(self, matrix, vector): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower( - conv2d_vector, exp) + conv2d_result = super( + DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) @@ -1016,10 +1085,14 @@ class DepthwiseConvKFCBasicFB(ConvKFCBasicFB): self._filter_shape = (filter_height, filter_width, in_channels, in_channels * channel_multiplier) - def multiply_matpower(self, vector, exp): + def _multiply_factored_matrix(self, left_factor, right_factor, vector, + extra_scale=1.0, transpose_left=False, + transpose_right=False): conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector) - conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower( - conv2d_vector, exp) + conv2d_result = super( + DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix( + left_factor, right_factor, conv2d_vector, extra_scale=extra_scale, + transpose_left=transpose_left, transpose_right=transpose_right) return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result) @@ -1233,6 +1306,8 @@ class InputOutputMultiTowerMultiUse(InputOutputMultiTower): else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") + else: + inputs = tuple(inputs) # Now we perform the analogous processing for grads_list if isinstance(grads_list[0][0], (list, tuple)): @@ -1275,6 +1350,8 @@ class InputOutputMultiTowerMultiUse(InputOutputMultiTower): else: raise ValueError("Global config variable TOWER_STRATEGY must be one of " "'concat' or 'separate'.") + else: + grads_list = tuple(tuple(grads) for grads in grads_list) if self._num_uses is None: raise ValueError("You must supply a value for the num_uses argument if " @@ -1664,3 +1741,12 @@ class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse, return utils.mat2d_to_layer_params(vector, Z) # pylint: enable=invalid-name + + def multiply_cholesky(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + + def multiply_cholesky_inverse(self, vector): + raise NotImplementedError("FullyConnectedSeriesFB does not support " + "Cholesky computations.") + diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 0d40d265a17..b43232dfafa 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -24,6 +24,7 @@ import contextlib import numpy as np import six +from tensorflow.contrib.kfac.python.ops import linear_operator as lo from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops @@ -32,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -41,10 +43,14 @@ from tensorflow.python.util import nest # Whether to initialize covariance estimators at a zero matrix (or the identity # matrix). -INIT_COVARIANCES_AT_ZERO = False +INIT_COVARIANCES_AT_ZERO = True # Whether to zero-debias the moving averages. -ZERO_DEBIAS = False +ZERO_DEBIAS = True + +# Whether to initialize inverse (and other such matrices computed from the cov +# matrices) to the zero matrix (or the identity matrix). +INIT_INVERSES_AT_ZERO = True # When the number of inverses requested from a FisherFactor exceeds this value, # the inverses are computed using an eigenvalue decomposition. @@ -55,6 +61,22 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +# Used to subsample the flattened extracted image patches. The number of +# outer products per row of the covariance matrix should not exceed this +# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True. +_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1 + +# Used to subsample the inputs passed to the extract image patches. The batch +# size of number of inputs to extract image patches is multiplied by this +# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True. +_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5 + +# If True, then subsamples the tensor passed to compute the covaraince matrix. +_SUB_SAMPLE_OUTER_PRODUCTS = False + +# If True, then subsamples the tensor passed to compute the covaraince matrix. +_SUB_SAMPLE_INPUTS = False + # TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data # passed to the factors from the blocks will be concatenated across towers # (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over @@ -65,42 +87,64 @@ TOWER_STRATEGY = "concat" def set_global_constants(init_covariances_at_zero=None, zero_debias=None, + init_inverses_at_zero=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None, + max_num_outer_products_per_cov_row=None, + sub_sample_outer_products=None, + inputs_to_extract_patches_factor=None, + sub_sample_inputs=None, tower_strategy=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS + global INIT_INVERSES_AT_ZERO global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD + global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW + global _SUB_SAMPLE_OUTER_PRODUCTS + global _INPUTS_TO_EXTRACT_PATCHES_FACTOR + global _SUB_SAMPLE_INPUTS global TOWER_STRATEGY if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero if zero_debias is not None: ZERO_DEBIAS = zero_debias + if init_inverses_at_zero is not None: + INIT_INVERSES_AT_ZERO = init_inverses_at_zero if eigenvalue_decomposition_threshold is not None: EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if max_num_outer_products_per_cov_row is not None: + _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row + if sub_sample_outer_products is not None: + _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products + if inputs_to_extract_patches_factor is not None: + _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor + if sub_sample_inputs is not None: + _SUB_SAMPLE_INPUTS = sub_sample_inputs if tower_strategy is not None: TOWER_STRATEGY = tower_strategy def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument - return array_ops.diag(array_ops.ones(shape[0], dtype)) + if INIT_INVERSES_AT_ZERO: + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument if INIT_COVARIANCES_AT_ZERO: - return array_ops.diag(array_ops.zeros(shape[0], dtype)) - return array_ops.diag(array_ops.ones(shape[0], dtype)) + return array_ops.zeros(shape, dtype=dtype) + return linalg_ops.eye(num_rows=shape[0], dtype=dtype) -def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: disable=unused-argument +def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument if INIT_COVARIANCES_AT_ZERO: - return array_ops.zeros(shape, dtype) - return array_ops.ones(shape, dtype) + return array_ops.zeros(shape, dtype=dtype) + return array_ops.ones(shape, dtype=dtype) @contextlib.contextmanager @@ -227,6 +271,58 @@ def graph_func_to_string(func): return list_to_string(func.func_id) +def _subsample_for_cov_computation(array, name=None): + """Subsamples the first dimension of the array. + + `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance + matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer + products per row of the covariance matrix is greater than + `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + name: `string`, Default(None) + + Returns: + A tensor of shape `[max_samples, dim_2]`. + + Raises: + ValueError: If array's is not matrix-shaped. + ValueError: If array's batch_size cannot be inferred. + + """ + with tf_ops.name_scope(name, "subsample", [array]): + array = tf_ops.convert_to_tensor(array) + if len(array.shape) != 2: + raise ValueError("Input param array must be a matrix.") + + batch_size = array.shape.as_list()[0] + if batch_size is None: + raise ValueError("Unable to get batch_size from input param array.") + + num_cov_rows = array.shape.as_list()[-1] + max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows) + if batch_size <= max_batch_size: + return array + + return _random_tensor_gather(array, max_batch_size) + + +def _random_tensor_gather(array, max_size): + """Generates a random set of indices and gathers the value at the indcices. + + Args: + array: Tensor, of shape `[batch_size, dim_2]`. + max_size: int, Number of indices to sample. + + Returns: + A tensor of shape `[max_size, ...]`. + """ + batch_size = array.shape.as_list()[0] + indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size] + return array_ops.gather(array, indices) + + @six.add_metaclass(abc.ABCMeta) class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. @@ -314,7 +410,7 @@ class FisherFactor(object): the cov update. Returns: - Tensor of same shape as self.get_cov_var(). + Tensor of same shape as self.get_cov(). """ pass @@ -363,78 +459,43 @@ class FisherFactor(object): """Create and return update ops corresponding to registered computations.""" pass - @abc.abstractmethod def get_cov(self): - """Get full covariance matrix. - - Returns: - Tensor of shape [n, n]. Represents all parameter-parameter correlations - captured by this FisherFactor. - """ - pass - - def get_cov_var(self): - """Get variable backing this FisherFactor. - - May or may not be the same as self.get_cov() - - Returns: - Variable of shape self._cov_shape. - """ return self._cov @abc.abstractmethod - def left_multiply_matpower(self, x, exp, damping_func): - """Left multiplies 'x' by matrix power of this factor (w/ damping applied). - - This calculation is essentially: - (C + damping * I)**exp * x - where * is matrix-multiplication, ** is matrix power, I is the identity - matrix, and C is the matrix represented by this factor. - - x can represent either a matrix or a vector. For some factors, 'x' might - represent a vector but actually be stored as a 2D matrix for convenience. - - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - exp: float. The matrix exponent to use. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). - - Returns: - Tensor of same shape as 'x' representing the result of the multiplication. - """ + def get_cov_as_linear_operator(self): pass @abc.abstractmethod - def right_multiply_matpower(self, x, exp, damping_func): - """Right multiplies 'x' by matrix power of this factor (w/ damping applied). + def register_matpower(self, exp, damping_func): + pass - This calculation is essentially: - x * (C + damping * I)**exp - where * is matrix-multiplication, ** is matrix power, I is the identity - matrix, and C is the matrix represented by this factor. + @abc.abstractmethod + def register_cholesky(self, damping_func): + pass - Unlike left_multiply_matpower, x will always be a matrix. + @abc.abstractmethod + def register_cholesky_inverse(self, damping_func): + pass - Args: - x: Tensor. Represents a single vector. Shape depends on implementation. - exp: float. The matrix exponent to use. - damping_func: A function that computes a 0-D Tensor or a float which will - be the damping value used. i.e. damping = damping_func(). + @abc.abstractmethod + def get_matpower(self, exp, damping_func): + pass - Returns: - Tensor of same shape as 'x' representing the result of the multiplication. - """ + @abc.abstractmethod + def get_cholesky(self, damping_func): + pass + + @abc.abstractmethod + def get_cholesky_inverse(self, damping_func): pass -class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses explicitly. +class DenseSquareMatrixFactor(FisherFactor): + """Base class for FisherFactors that are stored as dense square matrices. - This class explicitly calculates and stores inverses of covariance matrices - provided by the underlying FisherFactor implementation. It is assumed that - vectors can be represented as 2-D matrices. + This class explicitly calculates and stores inverses of their `cov` matrices, + which must be square dense matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -453,7 +514,19 @@ class InverseProvidingFactor(FisherFactor): self._eigendecomp = None self._damping_funcs_by_id = {} # {hashable: lambda} - super(InverseProvidingFactor, self).__init__() + self._cholesky_registrations = set() # { hashable } + self._cholesky_inverse_registrations = set() # { hashable } + + self._cholesky_by_damping = {} # { hashable: variable } + self._cholesky_inverse_by_damping = {} # { hashable: variable } + + super(DenseSquareMatrixFactor, self).__init__() + + def get_cov_as_linear_operator(self): + assert self.get_cov().shape.ndims == 2 + return lo.LinearOperatorFullMatrix(self.get_cov(), + is_self_adjoint=True, + is_square=True) def _register_damping(self, damping_func): damping_id = graph_func_to_id(damping_func) @@ -478,8 +551,6 @@ class InverseProvidingFactor(FisherFactor): be the damping value used. i.e. damping = damping_func(). """ if exp == 1.0: - # We don't register these. The user shouldn't even be calling this - # function with exp = 1.0. return damping_id = self._register_damping(damping_func) @@ -487,6 +558,38 @@ class InverseProvidingFactor(FisherFactor): if (exp, damping_id) not in self._matpower_registrations: self._matpower_registrations.add((exp, damping_id)) + def register_cholesky(self, damping_func): + """Registers a Cholesky factor to be maintained and served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_registrations: + self._cholesky_registrations.add(damping_id) + + def register_cholesky_inverse(self, damping_func): + """Registers an inverse Cholesky factor to be maintained/served on demand. + + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_cholesky_inverse. + + Args: + damping_func: A function that computes a 0-D Tensor or a float which will + be the damping value used. i.e. damping = damping_func(). + """ + damping_id = self._register_damping(damping_func) + + if damping_id not in self._cholesky_inverse_registrations: + self._cholesky_inverse_registrations.add(damping_id) + def instantiate_inv_variables(self): """Makes the internal "inverse" variable(s).""" @@ -504,6 +607,32 @@ class InverseProvidingFactor(FisherFactor): assert (exp, damping_id) not in self._matpower_by_exp_and_damping self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower + for damping_id in self._cholesky_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + chol = variable_scope.get_variable( + "cholesky_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_by_damping + self._cholesky_by_damping[damping_id] = chol + + for damping_id in self._cholesky_inverse_registrations: + damping_func = self._damping_funcs_by_id[damping_id] + damping_string = graph_func_to_string(damping_func) + with variable_scope.variable_scope(self._var_scope): + cholinv = variable_scope.get_variable( + "cholesky_inverse_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + assert damping_id not in self._cholesky_inverse_by_damping + self._cholesky_inverse_by_damping[damping_id] = cholinv + def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" ops = [] @@ -521,7 +650,8 @@ class InverseProvidingFactor(FisherFactor): # We precompute these so we don't need to evaluate them multiple times (for # each matrix power that uses them) - damping_value_by_id = {damping_id: self._damping_funcs_by_id[damping_id]() + damping_value_by_id = {damping_id: math_ops.cast( + self._damping_funcs_by_id[damping_id](), self._dtype) for damping_id in self._damping_funcs_by_id} if use_eig: @@ -542,29 +672,91 @@ class InverseProvidingFactor(FisherFactor): self._matpower_by_exp_and_damping.items()): assert exp == -1 damping = damping_value_by_id[damping_id] - ops.append(matpower.assign(utils.posdef_inv(self._cov, damping))) + ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping))) + + # TODO(b/77902055): If inverses are being computed with Cholesky's + # we can share the work. Instead this code currently just computes the + # Cholesky a second time. It does at least share work between requests for + # Cholesky's and Cholesky inverses with the same damping id. + for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items(): + cholesky_ops = [] + + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + + if damping_id in self._cholesky_by_damping: + cholesky = self._cholesky_by_damping[damping_id] + cholesky_ops.append(cholesky.assign(cholesky_value)) + + identity = linalg_ops.eye(cholesky_value.shape.as_list()[0], + dtype=cholesky_value.dtype) + cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value, + identity) + cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value)) + + ops.append(control_flow_ops.group(*cholesky_ops)) + + for damping_id, cholesky in self._cholesky_by_damping.items(): + if damping_id not in self._cholesky_inverse_by_damping: + damping = damping_value_by_id[damping_id] + cholesky_value = utils.cholesky(self.get_cov(), damping) + ops.append(cholesky.assign(cholesky_value)) self._eigendecomp = False return ops def get_inverse(self, damping_func): # Just for backwards compatibility of some old code and tests - damping_id = graph_func_to_id(damping_func) - return self._matpower_by_exp_and_damping[(-1, damping_id)] + return self.get_matpower(-1, damping_func) def get_matpower(self, exp, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + if exp != 1: + damping_id = graph_func_to_id(damping_func) + matpower = self._matpower_by_exp_and_damping[(exp, damping_id)] + else: + matpower = self.get_cov() + identity = linalg_ops.eye(matpower.shape.as_list()[0], + dtype=matpower.dtype) + matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity + + assert matpower.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(matpower, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): # Note that this function returns a variable which gets updated by the # inverse ops. It may be stale / inconsistent with the latest value of # get_cov(). damping_id = graph_func_to_id(damping_func) - return self._matpower_by_exp_and_damping[(exp, damping_id)] + cholesky = self._cholesky_by_damping[damping_id] + assert cholesky.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky, + is_non_singular=True, + is_square=True) + + def get_cholesky_inverse(self, damping_func): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). + damping_id = graph_func_to_id(damping_func) + cholesky_inv = self._cholesky_inverse_by_damping[damping_id] + assert cholesky_inv.shape.ndims == 2 + return lo.LinearOperatorFullMatrix(cholesky_inv, + is_non_singular=True, + is_square=True) def get_eigendecomp(self): """Creates or retrieves eigendecomposition of self._cov.""" # Unlike get_matpower this doesn't retrieve a stored variable, but instead # always computes a fresh version from the current value of get_cov(). if not self._eigendecomp: - eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov()) # The matrix self._cov is positive semidefinite by construction, but the # numerical eigenvalues could be negative due to numerical errors, so here @@ -575,45 +767,8 @@ class InverseProvidingFactor(FisherFactor): return self._eigendecomp - def get_cov(self): - # Variable contains full covariance matrix. - return self.get_cov_var() - def left_multiply_matpower(self, x, exp, damping_func): - if isinstance(x, tf_ops.IndexedSlices): - raise ValueError("Left-multiply not yet supported for IndexedSlices.") - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - if exp == 1: - return math_ops.matmul(self.get_cov(), x) + damping_func() * x - - return math_ops.matmul(self.get_matpower(exp, damping_func), x) - - def right_multiply_matpower(self, x, exp, damping_func): - if isinstance(x, tf_ops.IndexedSlices): - if exp == 1: - n = self.get_cov().shape[0] - damped_cov = self.get_cov() + damping_func() * array_ops.eye(n) - return utils.matmul_sparse_dense(x, damped_cov) - - return utils.matmul_sparse_dense(x, self.get_matpower(exp, damping_func)) - - if x.shape.ndims != 2: - raise ValueError( - "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." - % (x,)) - - if exp == 1: - return math_ops.matmul(x, self.get_cov()) + damping_func() * x - - return math_ops.matmul(x, self.get_matpower(exp, damping_func)) - - -class FullFactor(InverseProvidingFactor): +class FullFactor(DenseSquareMatrixFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. Note that this uses the naive "square the sum estimator", and so is applicable @@ -672,42 +827,52 @@ class DiagonalFactor(FisherFactor): """ def __init__(self): - self._damping_funcs_by_id = {} # { hashable: lambda } super(DiagonalFactor, self).__init__() + def get_cov_as_linear_operator(self): + assert self._matrix_diagonal.shape.ndims == 1 + return lo.LinearOperatorDiag(self._matrix_diagonal, + is_self_adjoint=True, + is_square=True) + @property def _cov_initializer(self): return diagonal_covariance_initializer + @property + def _matrix_diagonal(self): + return array_ops.reshape(self.get_cov(), [-1]) + def make_inverse_update_ops(self): return [] def instantiate_inv_variables(self): pass - def get_cov(self): - # self.get_cov() could be any shape, but it must have one entry per - # parameter. Flatten it into a vector. - cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) - return array_ops.diag(cov_diag_vec) - - def left_multiply_matpower(self, x, exp, damping_func): - matpower = (self.get_cov_var() + damping_func())**exp - - if isinstance(x, tf_ops.IndexedSlices): - return utils.matmul_diag_sparse(array_ops.reshape(matpower, [-1]), x) - - if x.shape != matpower.shape: - raise ValueError("x (%s) and cov (%s) must have same shape." % - (x, matpower)) - return matpower * x - - def right_multiply_matpower(self, x, exp, damping_func): - raise NotImplementedError("Only left-multiply is currently supported.") - def register_matpower(self, exp, damping_func): pass + def register_cholesky(self, damping_func): + pass + + def register_cholesky_inverse(self, damping_func): + pass + + def get_matpower(self, exp, damping_func): + matpower_diagonal = (self._matrix_diagonal + + math_ops.cast(damping_func(), self._dtype))**exp + return lo.LinearOperatorDiag(matpower_diagonal, + is_non_singular=True, + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + def get_cholesky(self, damping_func): + return self.get_matpower(0.5, damping_func) + + def get_cholesky_inverse(self, damping_func): + return self.get_matpower(-0.5, damping_func) + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -1082,7 +1247,7 @@ class ConvDiagonalFactor(DiagonalFactor): return self._inputs[tower].device -class FullyConnectedKroneckerFactor(InverseProvidingFactor): +class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ @@ -1135,7 +1300,7 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): return self._tensors[0][tower].device -class ConvInputKroneckerFactor(InverseProvidingFactor): +class ConvInputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the input side of a convolutional layer. Estimates E[ a a^T ] where a is the inputs to a convolutional layer given @@ -1153,7 +1318,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): dilation_rate=None, data_format=None, extract_patches_fn=None, - has_bias=False): + has_bias=False, + sub_sample_inputs=None, + sub_sample_patches=None): """Initializes ConvInputKroneckerFactor. Args: @@ -1173,6 +1340,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches". has_bias: bool. If True, append 1 to in_channel. + sub_sample_inputs: `bool`. If True, then subsample the inputs from which + the image patches are extracted. (Default: None) + sub_sample_patches: `bool`, If `True` then subsample the extracted + patches.(Default: None) """ self._inputs = inputs self._filter_shape = filter_shape @@ -1182,7 +1353,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): self._data_format = data_format self._extract_patches_fn = extract_patches_fn self._has_bias = has_bias + if sub_sample_inputs is None: + self._sub_sample_inputs = _SUB_SAMPLE_INPUTS + else: + self._sub_sample_inputs = sub_sample_inputs + if sub_sample_patches is None: + self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS + else: + self._sub_sample_patches = sub_sample_patches super(ConvInputKroneckerFactor, self).__init__() @property @@ -1215,6 +1394,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): assert source == 0 inputs = self._inputs[tower] + if self._sub_sample_inputs: + batch_size = inputs.shape.as_list()[0] + max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR) + inputs = _random_tensor_gather(inputs, max_size) # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. @@ -1260,8 +1443,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # |Delta| = number of spatial offsets, and J = number of input maps # for convolutional layer l. patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + # We append a homogenous coordinate to patches_flat if the layer has # bias parameters. This gives us [[A_l]]_H from the paper. + if self._sub_sample_patches: + patches_flat = _subsample_for_cov_computation(patches_flat) + if self._has_bias: patches_flat = append_homog(patches_flat) # We call compute_cov without passing in a normalizer. compute_cov uses @@ -1277,7 +1464,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): return self._inputs[tower].device -class ConvOutputKroneckerFactor(InverseProvidingFactor): +class ConvOutputKroneckerFactor(DenseSquareMatrixFactor): r"""Kronecker factor for the output side of a convolutional layer. Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer @@ -1567,6 +1754,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): psi_var) in self._option1quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) invsqrtC0 = math_ops.matmul( eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) @@ -1595,6 +1783,7 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): mu_var) in self._option2quants_by_damping.items(): damping = self._damping_funcs_by_id[damping_id]() + damping = math_ops.cast(damping, self._dtype) # compute C0^(-1/2) invsqrtC0 = math_ops.matmul( diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 366e2a82d56..cbbfe7212c9 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -182,7 +182,7 @@ class LayerCollection(object): self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None - self._default_generic_approximation = APPROX_FULL_NAME + self._default_generic_approximation = APPROX_DIAGONAL_NAME self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_conv2d_approximation = APPROX_KRONECKER_NAME diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py new file mode 100644 index 00000000000..61cb955ae85 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/linear_operator.py @@ -0,0 +1,95 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SmartMatrices definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.linalg import linalg +from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.ops.linalg import linear_operator_util as lou + + +class LinearOperatorExtras(object): # pylint: disable=missing-docstring + + def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + if isinstance(x, ops.IndexedSlices): + return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -2 if adjoint else -1 + arg_dim = -1 if adjoint_arg else -2 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"): + + with self._name_scope(name, values=[x]): + + if isinstance(x, ops.IndexedSlices): + return self._matmul_right_sparse( + x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + x = ops.convert_to_tensor(x, name="x") + self._check_input_dtype(x) + + self_dim = -1 if adjoint else -2 + arg_dim = -2 if adjoint_arg else -1 + self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim]) + + return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg) + + +class LinearOperatorFullMatrix(LinearOperatorExtras, + linalg.LinearOperatorFullMatrix): + + # TODO(b/78117889) Remove this definition once core LinearOperator + # has _matmul_right. + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + return lou.matmul_with_broadcast( + x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint) + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + assert not adjoint and not adjoint_arg + return utils.matmul_sparse_dense(x, self._matrix) + + +class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring + linalg.LinearOperatorDiag): + + def _matmul_right(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + x = linalg_impl.adjoint(x) if adjoint_arg else x + return diag_mat * x + + def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False): + diag_mat = math_ops.conj(self._diag) if adjoint else self._diag + assert not adjoint_arg + return utils.matmul_diag_sparse(diag_mat, x) + + def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False): + raise NotImplementedError diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index e7d4243fc3d..42d525c2c21 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -613,19 +613,19 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def multiply_fisher(self, vector): probs = self._probs return vector * probs - probs * math_ops.reduce_sum( - vector * probs, axis=-1, keep_dims=True) + vector * probs, axis=-1, keepdims=True) def multiply_fisher_factor(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=-1, keep_dims=True) + sqrt_probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_transpose(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=-1, keep_dims=True) + probs * vector, axis=-1, keepdims=True) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py index 705a871d482..4279cb27928 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -33,7 +33,6 @@ _allowed_symbols = [ "CategoricalLogitsNegativeLogProbLoss", "OnehotCategoricalLogitsNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", - "MultiBernoulliNegativeLogProbLoss", "insert_slice_in_zeros", ] diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index f01c5a83221..b7f63d8d94a 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import warnings # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -67,7 +66,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. If damping is adapted during training then this value is used for - initializing damping varaible. + initializing damping variable. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher @@ -115,7 +114,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._estimation_mode = estimation_mode self._colocate_gradients_with_ops = colocate_gradients_with_ops - # The below paramaters are required only if damping needs to be adapated. + # The below parameters are required only if damping needs to be adapated. # These parameters can be set by calling # set_damping_adaptation_params() explicitly. self._damping_adaptation_decay = 0.95 @@ -196,7 +195,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): min_damping: `float`(Optional), Minimum value the damping parameter can take. Default value 1e-5. damping_adaptation_decay: `float`(Optional), The `damping` parameter is - multipled by the `damping_adaptation_decay` every + multiplied by the `damping_adaptation_decay` every `damping_adaptation_interval` number of iterations. Default value 0.99. damping_adaptation_interval: `int`(Optional), Number of steps in between updating the `damping` parameter. Default value 5. @@ -243,62 +242,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): def damping_adaptation_interval(self): return self._damping_adaptation_interval - @property - def cov_update_thunks(self): - self._maybe_make_and_save_everything() - return self._cov_update_thunks - - @property - def cov_update_ops(self): - self._maybe_make_and_save_everything() - return self._cov_update_ops - - @property - def cov_update_op(self): - self._maybe_make_and_save_everything() - return self._cov_update_op - - @property - def inv_update_thunks(self): - self._maybe_make_and_save_everything() - return self._inv_update_thunks - - @property - def inv_update_ops(self): - self._maybe_make_and_save_everything() - return self._inv_update_ops - - @property - def inv_update_op(self): - self._maybe_make_and_save_everything() - return self._inv_update_op - - def _maybe_make_and_save_everything(self): - if not self._fisher_est.made_vars(): - warnings.warn("These convenience properties will be depcrecated soon. " - "Please use explicit op/thunk creation methods instead " - "(e.g. make_ops_and_vars, etc).", - DeprecationWarning) - (self._cov_update_ops, self._cov_update_op, self._inv_update_ops, - self._inv_update_op, self._cov_update_thunks, - self._inv_update_thunks) = self.make_ops_and_vars() - - def make_ops_and_vars(self): - """Make ops and vars with device placement `self._placement_strategy`. - - See `FisherEstimator.make_ops_and_vars` for details. - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_op: inv_update_ops grouped into a single op. - """ - return self._fisher_est.make_ops_and_vars(scope=self.get_name()) - def make_vars_and_create_op_thunks(self): """Make vars and create op thunks. @@ -385,7 +328,6 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): Returns: An `Operation` that applies the specified gradients. """ - self._maybe_make_and_save_everything() # In Python 3, grads_and_vars can be a zip() object which can only be # iterated over once. By converting it to a list, we ensure that it can be # iterated over more than once. diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py index bf12dbaa9ad..c4454325aeb 100644 --- a/tensorflow/contrib/kfac/python/ops/placement.py +++ b/tensorflow/contrib/kfac/python/ops/placement.py @@ -21,8 +21,6 @@ from __future__ import print_function import itertools from tensorflow.python.framework import ops as tf_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import variable_scope def _make_thunk_on_device(func, device): @@ -35,7 +33,7 @@ def _make_thunk_on_device(func, device): class RoundRobinPlacementMixin(object): """Implements round robin placement strategy for ops and variables.""" - def __init__(self, cov_devices=None, inv_devices=None, *args, **kwargs): + def __init__(self, cov_devices=None, inv_devices=None, **kwargs): """Initializes the RoundRobinPlacementMixin class. Args: @@ -45,66 +43,15 @@ class RoundRobinPlacementMixin(object): inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. - *args: - **kwargs: + **kwargs: Need something here? """ - super(RoundRobinPlacementMixin, self).__init__(*args, **kwargs) + super(RoundRobinPlacementMixin, self).__init__(**kwargs) self._cov_devices = cov_devices self._inv_devices = inv_devices - def make_ops_and_vars(self, scope=None): - """Make ops and vars with a round-robin device placement strategy. - - For each factor, all of that factor's cov variables and their associated - update ops will be placed on a particular device. A new device is chosen - for each factor by cycling through list of devices in the - `self._cov_devices` attribute. If `self._cov_devices` is `None` then no - explicit device placement occurs. - - An analogous strategy is followed for inverse update ops, with the list of - devices being given by the `self._inv_devices` attribute. - - Inverse variables on the other hand are not placed on any specific device - (they will just use the current the device placement context, whatever - that happens to be). The idea is that the inverse variable belong where - they will be accessed most often, which is the device that actually applies - the preconditioner to the gradient. The user will be responsible for setting - the device context for this. - - Args: - scope: A string or None. If None it will be set to the name of this - estimator (given by the name property). All variables will be created, - and all ops will execute, inside of a variable scope of the given - name. (Default: None) - - Returns: - cov_update_ops: List of ops that compute the cov updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - cov_update_op: cov_update_ops grouped into a single op. - inv_update_ops: List of ops that compute the inv updates. Corresponds - one-to-one with the list of factors given by the "factors" property. - inv_update_op: inv_update_ops grouped into a single op. - cov_update_thunks: Thunks that make the ops in cov_update_ops. - inv_update_thunks: Thunks that make the ops in inv_update_ops. - """ - (cov_update_thunks, - inv_update_thunks) = self.make_vars_and_create_op_thunks(scope=scope) - cov_update_ops = [thunk() for thunk in cov_update_thunks] - inv_update_ops = [thunk() for thunk in inv_update_thunks] - - scope = self.name if scope is None else scope - with variable_scope.variable_scope(scope): - cov_update_op = control_flow_ops.group(cov_update_ops, - name="cov_update_op") - inv_update_op = control_flow_ops.group(inv_update_ops, - name="inv_update_op") - - return (cov_update_ops, cov_update_op, inv_update_ops, inv_update_op, - cov_update_thunks, inv_update_thunks) - def make_vars_and_create_op_thunks(self, scope=None): - """Make vars and create op thunks w/ a round-robin device placement strat. + """Make vars and create op thunks w/ a round-robin device placement start. For each factor, all of that factor's cov variables and their associated update ops will be placed on a particular device. A new device is chosen diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index b6f42815e79..144295f4c7e 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -235,6 +235,13 @@ posdef_eig_functions = { } +def cholesky(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return linalg_ops.cholesky(tensor + damping * identity) + + class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ @@ -553,13 +560,17 @@ def is_data_format_channel_last(data_format): return data_format.endswith("C") -def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name +def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name """Computes matmul(A, B) where A is sparse, B is dense. Args: A: tf.IndexedSlices with dense shape [m, n]. B: tf.Tensor with shape [n, k]. name: str. Name of op. + transpose_a: Bool. If true we transpose A before multiplying it by B. + (Default: False) + transpose_b: Bool. If true we transpose B before multiplying it by A. + (Default: False) Returns: tf.IndexedSlices resulting from matmul(A, B). @@ -573,7 +584,8 @@ def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name raise ValueError("A must represent a matrix. Found: %s." % A) if B.shape.ndims != 2: raise ValueError("B must be a matrix.") - new_values = math_ops.matmul(A.values, B) + new_values = math_ops.matmul( + A.values, B, transpose_a=transpose_a, transpose_b=transpose_b) return ops.IndexedSlices( new_values, A.indices, diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py index 0727f4cf887..39e9d65407f 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops_test.py @@ -660,7 +660,7 @@ class ReduceSumTest(Base): sum_lt = ops.reduce_sum(self.original_lt, {('channel', 'hihowareyou')}) golden_lt = core.LabeledTensor( math_ops.reduce_sum( - self.original_lt.tensor, 1, keep_dims=True), + self.original_lt.tensor, 1, keepdims=True), [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) self.assertLabeledTensorsEqual(sum_lt, golden_lt) @@ -668,7 +668,7 @@ class ReduceSumTest(Base): sum_lt = ops.reduce_sum(self.original_lt, ('channel', 'hihowareyou')) golden_lt = core.LabeledTensor( math_ops.reduce_sum( - self.original_lt.tensor, 1, keep_dims=True), + self.original_lt.tensor, 1, keepdims=True), [self.a0, ('channel', ['hihowareyou']), self.a2, self.a3]) self.assertLabeledTensorsEqual(sum_lt, golden_lt) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index d5b3b279a1b..7355a403aee 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -381,7 +381,7 @@ py_test( py_test( name = "rev_block_lib_test", - size = "small", + size = "medium", srcs = ["python/layers/rev_block_lib_test.py"], srcs_version = "PY2AND3", deps = [ diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py index bf251449820..dd2395f8c97 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import init_ops @@ -691,11 +692,12 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): index += num_val return grouped_vals + @test_util.enable_c_shapes def testEmbeddingLookupSparse(self): vocab_size = 13 batch_size = 10 param_shape = [2, 5] - expected_lookup_result_shape = [None] + param_shape + expected_lookup_result_shape = param_shape sp_ids, sp_weights, ids, weights, vals_per_batch_entry = ( self._RandomIdsAndWeights(batch_size, vocab_size)) @@ -719,7 +721,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase): None if ignore_weights else sp_weights, combiner=combiner) - self.assertEqual(embedding_sum.get_shape().as_list(), + self.assertEqual(embedding_sum.get_shape().as_list()[1:], expected_lookup_result_shape) tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 151fc7a0d73..2f3e57653c5 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1536,6 +1536,7 @@ def convolution3d_transpose( @add_arg_scope def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): """Converts a dense tensor into a sparse tensor. + An example use would be to convert dense labels to sparse ones so that they can be fed to the ctc_loss. @@ -2323,11 +2324,16 @@ def images_to_sequence(inputs, outputs_collections=None, scope=None): """Convert a batch of images into a batch of sequences. + Args: inputs: a (num_images, height, width, depth) tensor data_format: A string. `NHWC` (default) and `NCHW` are supported. outputs_collections: The collections to which the outputs are added. scope: Optional scope for name_scope. + + Raises: + ValueError: If `data_format` is not either NCHW or NHWC. + Returns: (width, num_images*height, depth) sequence tensor """ @@ -2833,6 +2839,7 @@ def sequence_to_images(inputs, outputs_collections=None, scope=None): """Convert a batch of sequences into a batch of images. + Args: inputs: (num_steps, num_batches, depth) sequence tensor height: the height of the images @@ -2840,6 +2847,7 @@ def sequence_to_images(inputs, Currently supports `'channels_first'` and `'channels_last'`. outputs_collections: The collections to which the outputs are added. scope: Optional scope for name_scope. + Returns: A tensor representing the output of the operation. """ @@ -2849,7 +2857,7 @@ def sequence_to_images(inputs, if num_batches is None: num_batches = -1 else: - num_batches = num_batches // height + num_batches //= height reshaped = array_ops.reshape(inputs, [width, num_batches, height, depth]) if output_data_format == 'channels_first': diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index b01fd5d5c95..56e9194cebb 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1333,7 +1333,7 @@ class DropoutTest(test.TestCase): with self.test_session(): images = np.random.uniform(size=(5, height, width, 3)) output = _layers.dropout(images) - self.assertEqual(output.op.name, 'Dropout/dropout/mul') + self.assertEqual(output.op.name, 'Dropout/dropout_1/mul') output.get_shape().assert_is_compatible_with( ops.convert_to_tensor(images).get_shape()) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 02d294c68f1..8ed9f446bcd 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -33,13 +33,14 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util +from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope @@ -50,6 +51,13 @@ __all__ = ["rev_block", "RevBlock", "recompute_grad"] LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") _USE_DEFAULT = "__rev_block_lib_default" +_WRONG_VARS_ERR = """\ +The variables used on recompute were different than the variables originally +used. The function wrapped with @recompute_grad likley creates its own variable +scope with a default name and has been called twice in the same enclosing scope. +To fix, ensure each call to the function happens in its own unique variable +scope. +""" def _acc_grads(*lists_of_grads): @@ -146,7 +154,7 @@ def _scope_wrap(fn, scope): @functools.wraps(fn) def wrap(*args, **kwargs): - with variable_scope.variable_scope(scope): + with variable_scope.variable_scope(scope, use_resource=True): return fn(*args, **kwargs) return wrap @@ -221,95 +229,95 @@ class RevBlock(base.Layer): "build.") self.built = True - def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): - """Custom gradient fn for a block of reversible residual layers.""" - # Inputs have passed through an Identity. Recover the original Tensors to - # be able to match up side inputs. - assert [u"Identity"] == list(set([x.op.type for x in inputs])) - inputs = [x.op.inputs[0] for x in inputs] - side_inputs = inputs[2:] - del inputs + def _make_efficient_grad_fn(self, inputs_, ys_): + def _efficient_grad_fn(*grad_ys, **kwargs): + """Custom gradient fn for a block of reversible residual layers.""" + inputs = inputs_ + ys = ys_ + variables = kwargs["variables"] + side_inputs = inputs[2:] - f_side_idxs = [None] * len(self.f_side_input) - g_side_idxs = [None] * len(self.g_side_input) - assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) - for i, t in enumerate(side_inputs): - if t in self.f_side_input: - f_side_idxs[self.f_side_input.index(t)] = i - elif t in self.g_side_input: - g_side_idxs[self.g_side_input.index(t)] = i - else: - assert False + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False - f_vars = [[] for _ in range(self.num_layers)] - g_vars = [[] for _ in range(self.num_layers)] - f_vars_idxs = [[] for _ in range(self.num_layers)] - g_vars_idxs = [[] for _ in range(self.num_layers)] + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] - for i, ref in enumerate(variables): - # Use the name to identify the layer number and function (f or g) - regex = LAYER_RE.match(ref.name) - layer_no = int(regex.group(1)) - fn_name = regex.group(2) - if fn_name == "f": - f_vars[layer_no].append(ref) - f_vars_idxs[layer_no].append(i) - else: - assert fn_name == "g" - g_vars[layer_no].append(ref) - g_vars_idxs[layer_no].append(i) + for i, ref in enumerate(variables): + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) - f_var_grads = [] - g_var_grads = [] - f_side_grads = [] - g_side_grads = [] + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] - # Reverse variable containers to go backward - f_vars.reverse() - g_vars.reverse() - f = list(self.f) - g = list(self.g) - f.reverse() - g.reverse() + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() - with variable_scope.variable_scope(self.scope_name, reuse=True): - for i in xrange(self.num_layers): - ys, grad_ys, f_ret, g_ret = _rev_layer_backward( - ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], - self.g_side_input) + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) - grad_f_vars, grad_f_side = f_ret - grad_g_vars, grad_g_side = g_ret - f_var_grads.append(grad_f_vars) - g_var_grads.append(grad_g_vars) - f_side_grads.append(grad_f_side) - g_side_grads.append(grad_g_side) + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) - # Accumulate layer gradients for f_side_input and g_side_input - acc_f_side_grads = _acc_grads(*f_side_grads) - acc_g_side_grads = _acc_grads(*g_side_grads) + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) - # Use the stored idxs to put gradients in the passed-in order. - side_input_grads = [None] * len(side_inputs) - variable_grads = [None] * len(variables) + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) - # Variable gradients were collected in reverse layer order. Reverse to match - # idxs. - f_var_grads.reverse() - g_var_grads.reverse() - for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( - zip(g_vars_idxs, g_var_grads)): - for i, grad in zip(idxs, grads): - variable_grads[i] = grad + # Variable gradients were collected in reverse layer order. Reverse to + # match idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad - for i, grad in zip(f_side_idxs, acc_f_side_grads): - side_input_grads[i] = grad - for i, grad in zip(g_side_idxs, acc_g_side_grads): - side_input_grads[i] = grad + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad - grad_x1, grad_x2 = grad_ys - return [grad_x1, grad_x2] + side_input_grads, variable_grads + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + return _efficient_grad_fn def _forward(self, x1, x2): """Run forward through the reversible layers.""" @@ -317,10 +325,6 @@ class RevBlock(base.Layer): side_inputs = [self.f_side_input, self.g_side_input] flat_side_inputs = nest.flatten(side_inputs) - custom_grad_fn = ( - self._efficient_grad_fn if self._use_efficient_backprop else None) - - @_fn_with_custom_grad(custom_grad_fn) def _forward_wrap(x1_, x2_, *flat_side_inputs): f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) return _rev_block_forward( @@ -333,7 +337,16 @@ class RevBlock(base.Layer): g_side_input=g_side, gate_outputs=self._use_efficient_backprop) - return _forward_wrap(x1, x2, *flat_side_inputs) + @custom_gradient.custom_gradient + def _forward_with_custom_grad(*args): + out = _forward_wrap(*args) # pylint: disable=no-value-for-parameter + grad_fn = self._make_efficient_grad_fn(args, out) + return out, grad_fn + + if self._use_efficient_backprop: + return _forward_with_custom_grad(x1, x2, *flat_side_inputs) + else: + return _forward_wrap(x1, x2, *flat_side_inputs) def _backward(self, y1, y2): """Run backward through the reversible layers.""" @@ -432,6 +445,10 @@ def enable_with_args(dec): def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """Decorator that recomputes the function on the backwards pass. + To use this function, you must use `ResourceVariable`s (i.e. + `variable_scope(name, use_resource=True), which are the default in Eager mode + and when running on TPU. + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -472,44 +489,55 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): if use_data_dep_ == _USE_DEFAULT: use_data_dep_ = _is_on_tpu() - cached_vs = [] - cached_arg_scope = [] - - def grad_fn(inputs, variables, outputs, output_grads): - """Recompute outputs for gradient computation.""" - del outputs - # Recompute outputs - with framework_ops.control_dependencies(output_grads): - if use_data_dep_: - inputs = _force_data_dependency(output_grads, inputs) - with contrib_framework_ops.arg_scope(cached_arg_scope[0]): - with variable_scope.variable_scope(cached_vs[0], reuse=True): - outputs = fn(*inputs) - - if not (isinstance(outputs, list) or isinstance(outputs, tuple)): - outputs = [outputs] - outputs = list(outputs) - grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) - - if tupleize_grads: - if use_data_dep_: - grads = _tuple_with_data_dep(grads) - else: - grads = control_flow_ops.tuple(grads) - - grad_inputs = grads[:len(inputs)] - grad_vars = grads[len(inputs):] - return grad_inputs, grad_vars - - @_fn_with_custom_grad(grad_fn) + @custom_gradient.custom_gradient def fn_with_recompute(*args): - cached_vs.append(variable_scope.get_variable_scope()) - # TODO(rsepassi): Rm conditional in TF 1.4 - if hasattr(contrib_framework_ops, "current_arg_scope"): - cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) - else: - cached_arg_scope.append({}) - return fn(*args) + """Wrapper for fn.""" + # Forward pass + vs = variable_scope.get_variable_scope() + arg_scope = contrib_framework_ops.current_arg_scope() + with backprop.GradientTape() as tape: + outputs = fn(*args) + original_vars = set(tape.watched_variables()) + + # Backward pass + def grad_fn(*output_grads, **kwargs): + """Recompute outputs for gradient computation.""" + variables = [] + if original_vars: + variables = kwargs["variables"] + if set(variables) != original_vars: + raise ValueError(_WRONG_VARS_ERR) + del kwargs + inputs = list(args) + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + if use_data_dep_: + inputs = _force_data_dependency(output_grads, inputs) + with contrib_framework_ops.arg_scope(arg_scope): + with variable_scope.variable_scope(vs, reuse=True): + with backprop.GradientTape() as tape: + outputs = fn(*inputs) + recompute_vars = set(tape.watched_variables()) + if original_vars != recompute_vars: + raise ValueError(_WRONG_VARS_ERR) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, + output_grads) + + if tupleize_grads: + if use_data_dep_: + grads = _tuple_with_data_dep(grads) + else: + grads = control_flow_ops.tuple(grads) + + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + return outputs, grad_fn return fn_with_recompute(*args) @@ -536,107 +564,6 @@ def _underlying_variable_ref(t): return None -def _fn_with_custom_grad(grad_fn, use_global_vars=False): - """Decorator to create a subgraph with a custom gradient function. - - The subgraph created by the decorated function is NOT put in a Defun and so - does not suffer from the limitations of the Defun (all subgraph ops on the - same device, no summaries). - - Args: - grad_fn: function with signature - (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - Decorator for function such that the gradient is defined by grad_fn. - """ - - def dec(fn): - - @functools.wraps(fn) - def wrapped(*args): - return _fn_with_custom_grad_internal( - fn, args, grad_fn, use_global_vars=use_global_vars) - - return wrapped - - return dec - - -def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): - """Create a subgraph with a custom gradient. - - Args: - fn: function that takes inputs as arguments and produces 1 or more Tensors. - inputs: list, will be passed as fn(*inputs). - grad_fn: function with signature - (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), - all of which are lists of Tensors. - use_global_vars: if True, variables will be the global variables created. - If False, will be the trainable variables. - - Returns: - fn(*inputs) - """ - vs = variable_scope.get_variable_scope() - get_vars_fn = ( - vs.global_variables if use_global_vars else vs.trainable_variables) - len_before_vars = len(get_vars_fn()) - inputs = [array_ops.identity(x) for x in inputs] - outputs = fn(*inputs) - train_vars = get_vars_fn()[len_before_vars:] - - if grad_fn is None: - return outputs - - if not (isinstance(outputs, tuple) or isinstance(outputs, list)): - outputs = [outputs] - outputs = list(outputs) - - defun_inputs = [inputs, train_vars, outputs] - - def custom_grad_fn(op, *dys): - """Custom grad fn applying grad_fn for identity Defun.""" - fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( - defun_inputs, list(op.inputs)) - fn_vars = [_underlying_variable_ref(v) for v in fn_vars] - dys = list(dys) - assert len(fn_outputs) == len(outputs) - assert len(fn_outputs) == len(dys) - - grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) - grad_outputs = [None] * len(fn_outputs) - return tuple(grad_inputs + grad_vars + grad_outputs) - - # The Defun takes as input the original inputs, the trainable variables - # created in fn, and the outputs. In the forward it passes through the - # outputs. In the backwards, it produces gradients for the original inputs - # and the trainable variables. - in_types = [t.dtype for t in inputs] - out_types = [t.dtype for t in outputs] - var_types = [t.dtype for t in train_vars] - - # Get a unique name for the Defun - with framework_ops.name_scope("identity_custom_grad") as ns: - defun_name = ns - - @function.Defun( - *(in_types + var_types + out_types), - func_name=defun_name, - python_grad_func=custom_grad_fn, - shape_func=lambda _: [t.get_shape() for t in outputs]) - def identity(*args): - _, _, outs = nest.pack_sequence_as(defun_inputs, args) - return tuple([array_ops.identity(t) for t in outs]) - - flat_inputs = nest.flatten(defun_inputs) - id_out = identity(*flat_inputs) - return id_out - - def _force_data_dependency(first_compute, then_compute): """Force all of `then_compute` to depend on all of `first_compute`. diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 392a490be15..997f53b9e1b 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -60,8 +60,8 @@ class RevBlockTest(test.TestCase): sess.run(variables.global_variables_initializer()) x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) - self.assertAllClose(x1, x1_inv) - self.assertAllClose(x2, x2_inv) + self.assertAllClose(x1, x1_inv, atol=1e-5) + self.assertAllClose(x2, x2_inv, atol=1e-5) def testBackwardForward(self): @@ -83,8 +83,8 @@ class RevBlockTest(test.TestCase): sess.run(variables.global_variables_initializer()) y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) - self.assertAllClose(y1, y1_inv) - self.assertAllClose(y2, y2_inv) + self.assertAllClose(y1, y1_inv, rtol=1e-5) + self.assertAllClose(y2, y2_inv, rtol=1e-5) def _testRevBlock(self, x=None, @@ -179,18 +179,16 @@ class RevBlockTest(test.TestCase): self._testRevBlock(f=[f1, f2, f1, f2]) - # TODO(rsepassi): Recent change to conv seems to have broken this test. Find - # out why. - def _testConvAndBatchNorm(self): + def testConvAndBatchNorm(self): x = random_ops.random_uniform( [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) def f(x): x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = layers.batch_norm(x, is_training=True) + x = layers.batch_norm(x, is_training=False) return x self._testRevBlock(x=x, f=f) @@ -278,7 +276,7 @@ class RecomputeTest(test.TestCase): ] outputs_and_vars = [] for name, wrapped_fn in names_and_fns: - with variable_scope.variable_scope(name) as vs: + with variable_scope.variable_scope(name, use_resource=True) as vs: out = math_ops.reduce_sum(wrapped_fn(x)) outputs_and_vars.append((out, vs.trainable_variables())) @@ -304,103 +302,45 @@ class RecomputeTest(test.TestCase): self.assertAllClose(current, g) current = g - def testResourceVariable(self): - @rev_block_lib.recompute_grad(tupleize_grads=True) + def testDoubleCallInSameScopeFails(self): + + @rev_block_lib.recompute_grad def layer_with_recompute(inputs): - var = variable_scope.get_variable("var", ()) - return var * inputs + return core_layers.dense(inputs, 2) - inputs = array_ops.ones((), dtypes.float32) with variable_scope.variable_scope("layer", use_resource=True): - outputs = layer_with_recompute(inputs) - loss = math_ops.square(outputs) - grads = gradients_impl.gradients(loss, variables.trainable_variables()) - self.assertEqual(1, len(grads)) - self.assertTrue(grads[0] is not None) + inputs = array_ops.ones((2, 4), dtypes.float32) + out1 = layer_with_recompute(inputs) + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) + tvars = variables.trainable_variables() + assert len(tvars) == 4 + with self.assertRaisesWithPredicateMatch( + ValueError, "called twice in the same enclosing scope"): + gradients_impl.gradients(out, [inputs] + tvars) -class FnWithCustomGradTest(test.TestCase): + def testDoubleCallInUniqueScope(self): - def testCorrectness(self): + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs): + with variable_scope.variable_scope("inner", use_resource=True): + return core_layers.dense(inputs, 2) - w = random_ops.random_uniform([6, 10]) + with variable_scope.variable_scope("layer", use_resource=True): + inputs = array_ops.ones((2, 4), dtypes.float32) - def fn(a, b, c): - return core_layers.dense( - a, - 10, - use_bias=False, - kernel_initializer=lambda shape, dtype, partition_info: w - ) + math_ops.matmul(b, c) + with variable_scope.variable_scope("layer1", use_resource=True): + out1 = layer_with_recompute(inputs) + with variable_scope.variable_scope("layer2", use_resource=True): + out2 = layer_with_recompute(inputs) + out1 + out = math_ops.reduce_sum(out2) - def grad_fn(inputs, trainable_variables, outputs, grad_outputs): - outputs = outputs[0] - grad_outputs = grad_outputs[0] - grad_inputs = gradients_impl.gradients( - outputs, inputs, grad_ys=grad_outputs) - grad_vars = gradients_impl.gradients( - outputs, trainable_variables, grad_ys=grad_outputs) - return grad_inputs, grad_vars - - custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - - out = fn(a, b, c) - custom_out = custom_fn(a, b, c) - self.assertEqual(out.get_shape().as_list(), - custom_out.get_shape().as_list()) - - loss = math_ops.reduce_mean(out) - custom_loss = math_ops.reduce_mean(custom_out) - - grads = gradients_impl.gradients( - loss, [a, b, c] + [variables.trainable_variables()[0]]) - custom_grads = gradients_impl.gradients( - custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - out_val, custom_out_val, grads_val, custom_grads_val = sess.run( - [out, custom_out, grads, custom_grads]) - self.assertAllClose(out_val, custom_out_val) - for g1, g2 in zip(grads_val, custom_grads_val): - self.assertAllClose(g1, g2) - - def testCustomGrad(self): - - def fn(a, b, c): - return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) - - def grad_fn(inputs, trainable_variables, unused_outputs, - unused_grad_outputs): - grad_inputs = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) - ] - grad_vars = [ - array_ops.ones_like(t) * (i + len(inputs) + 1.) - for i, t in enumerate(trainable_variables) - ] - return grad_inputs, grad_vars - - a = random_ops.random_uniform([11, 6]) - b = random_ops.random_uniform([11, 7]) - c = random_ops.random_uniform([7, 10]) - w = random_ops.random_uniform([6, 10]) - out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) - loss = math_ops.reduce_mean(out) - grads = gradients_impl.gradients( - loss, [a, b, c, variables.trainable_variables()[0]]) - expected_grads = [ - array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) - ] - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - g_val, eg_val = sess.run([grads, expected_grads]) - for g1, g2 in zip(g_val, eg_val): - self.assertAllClose(g1, g2) + tvars = variables.trainable_variables() + assert len(tvars) == 4 + grads = gradients_impl.gradients(out, [inputs] + tvars) + for grad in grads: + self.assertTrue(grad is not None) if __name__ == "__main__": diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py index 3e639a180ef..69bb6be8145 100644 --- a/tensorflow/contrib/layers/python/layers/target_column.py +++ b/tensorflow/contrib/layers/python/layers/target_column.py @@ -270,7 +270,7 @@ class _RegressionTargetColumn(_TargetColumn): def logits_to_predictions(self, logits, proba=False): if self.num_label_columns == 1: - return array_ops.squeeze(logits, squeeze_dims=[1]) + return array_ops.squeeze(logits, axis=[1]) return logits def get_eval_ops(self, features, logits, labels, metrics=None): @@ -418,7 +418,7 @@ def _softmax_cross_entropy_loss(logits, target): "Instead got %s." % target.dtype) # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. if len(target.get_shape()) == 2: - target = array_ops.squeeze(target, squeeze_dims=[1]) + target = array_ops.squeeze(target, axis=[1]) loss_vec = nn.sparse_softmax_cross_entropy_with_logits( labels=target, logits=logits) return loss_vec diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py index 3409860add8..645dc1291eb 100644 --- a/tensorflow/contrib/layers/python/layers/utils_test.py +++ b/tensorflow/contrib/layers/python/layers/utils_test.py @@ -294,7 +294,6 @@ class NPositiveIntegersTest(test.TestCase): self.assertEqual(utils.n_positive_integers(2, 2), (2, 2)) self.assertEqual(utils.n_positive_integers(2, (2, 3)), (2, 3)) self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) - self.assertEqual(utils.n_positive_integers(3, (2, 3, 1)), (2, 3, 1)) self.assertEqual( utils.n_positive_integers(3, tensor_shape.TensorShape([2, 3, 1])), (2, 3, 1)) diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index d665fc9335c..0fdbe8f6308 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -281,7 +281,10 @@ py_test( size = "medium", srcs = ["python/learn/estimators/estimator_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], + tags = [ + "manual", + "noasan", # times out + ], deps = [ ":learn", "//tensorflow/contrib/framework:framework_py", @@ -431,6 +434,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/learn/estimators/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = [ "noasan", # b/73741358 @@ -482,6 +486,7 @@ py_test( name = "state_saving_rnn_estimator_test", size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["noasan"], deps = [ @@ -741,7 +746,7 @@ py_test( tf_py_test( name = "graph_io_test", - size = "small", + size = "medium", srcs = ["python/learn/learn_io/graph_io_test.py"], additional_deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index d81a534b79b..9e5aaf3118d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -715,7 +715,9 @@ class EstimatorTest(test.TestCase): ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') - self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'], + # TODO(b/78461127): Please modify tests to not directly rely on names of + # checkpoints. + self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 2b4b6eff39f..339c4e0e360 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -777,7 +777,7 @@ class _RegressionHead(_SingleHead): key = prediction_key.PredictionKey.SCORES with ops.name_scope(None, "predictions", (logits,)): if self.logits_dimension == 1: - logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key) + logits = array_ops.squeeze(logits, axis=(1,), name=key) return {key: self._link_fn(logits)} def _metrics(self, eval_loss, predictions, labels, weights): @@ -974,7 +974,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None): is_squeezed_labels = False # TODO(ptucker): This will break for dynamic shapes. if len(labels.get_shape()) == 2: - labels = array_ops.squeeze(labels, squeeze_dims=(1,)) + labels = array_ops.squeeze(labels, axis=(1,)) is_squeezed_labels = True loss = nn.sparse_softmax_cross_entropy_with_logits( @@ -1862,12 +1862,12 @@ def _get_arguments(func): if hasattr(func, "__code__"): # Regular function. return tf_inspect.getargspec(func) - elif hasattr(func, "__call__"): - # Callable object. - return _get_arguments(func.__call__) elif hasattr(func, "func"): # Partial function. return _get_arguments(func.func) + elif hasattr(func, "__call__"): + # Callable object. + return _get_arguments(func.__call__) def _verify_loss_fn_args(loss_fn): diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py index b28835a8097..584556992a0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py @@ -36,7 +36,6 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import random_ops from tensorflow.python.platform import benchmark from tensorflow.python.platform import flags from tensorflow.python.platform import test diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index d3bb0fda576..0a863f0e20c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -863,6 +863,38 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self): + """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + sparse_tensor.SparseTensor( + values=[2., 3., 1.], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]), + 'country': + sparse_tensor.SparseTensor( + # 'GB' is out of the vocabulary. + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]) + }, constant_op.constant([[1], [0], [1]]) + + country = feature_column_lib.sparse_column_with_keys( + 'country', keys=['US', 'CA', 'MK', 'IT', 'CN']) + country_weighted_by_price = feature_column_lib.weighted_sparse_column( + country, 'price') + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id') + classifier = linear.LinearClassifier( + feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerCrossedFeatures(self): """Tests LinearClassifier with SDCAOptimizer and crossed features.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index 8c85c431be6..14ee2ba6094 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -299,6 +299,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig): # so instead of breaking compatibility with that assumption, we # just manually initialize this field: self._train_distribute = None + self._device_fn = None gpu_options = config_pb2.GPUOptions( per_process_gpu_memory_fraction=gpu_memory_fraction) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 3744abd860e..541da906173 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import function_utils __all__ = ["Experiment"] def _get_standardized_predicate_fn(predicate_fn): - pred_fn_args = estimator_util.fn_args(predicate_fn) + pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): @@ -468,10 +468,15 @@ class Experiment(object): on which that evaluation was based. At the beginning of evaluation, the passed `eval_results` will be None so it's expected that the predicate function handles that gracefully. - When `predicate_fn` is not specified, continuous eval will run in an - infinite loop (if `train_steps` is None). or exit once global step - reaches `train_steps`. - + Continuous eval behavior under different conditions: + * When `predicate_fn` is specified: + + if `train_steps` is None, run until `predicate_fn` returns False. + + if `train_steps` is specified, run until either global step + reaches `train_steps` or `predicate_fn` returns False. + * When `predicate_fn` is not specified: + + if `train_steps` is None, run in an infinite loop. + + if `train_steps` is specified, run until global step reaches + `train_steps`. export: Whether to export from this step. Default is 'True'. Raises: diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py index 92976d1539c..9f2cadb0174 100644 --- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py @@ -40,7 +40,7 @@ def mean_squared_error_regressor(tensor_in, labels, weights, biases, name=None): [tensor_in, labels]): predictions = nn.xw_plus_b(tensor_in, weights, biases) if len(labels.get_shape()) == 1 and len(predictions.get_shape()) == 2: - predictions = array_ops_.squeeze(predictions, squeeze_dims=[1]) + predictions = array_ops_.squeeze(predictions, axis=[1]) return predictions, losses.mean_squared_error(labels, predictions) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index c7cdb413121..f8106d1e4a7 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -343,7 +343,8 @@ def get_temp_export_dir(timestamped_export_dir): """ (dirname, basename) = os.path.split(timestamped_export_dir) temp_export_dir = os.path.join( - compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename))) + compat.as_bytes(dirname), + compat.as_bytes('temp-{}'.format(compat.as_text(basename)))) return temp_export_dir diff --git a/tensorflow/contrib/legacy_seq2seq/BUILD b/tensorflow/contrib/legacy_seq2seq/BUILD index 8c2c4fd29c0..4ce91a140f8 100644 --- a/tensorflow/contrib/legacy_seq2seq/BUILD +++ b/tensorflow/contrib/legacy_seq2seq/BUILD @@ -58,5 +58,8 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], - tags = ["noasan"], # times out b/63678675 + tags = [ + "noasan", # times out b/63678675 + "optonly", # times out (flaky) + ], ) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 8b7ff75ba5d..78b7970069f 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -42,22 +42,3 @@ cuda_py_test( "//tensorflow/python:platform_test", ], ) - -cuda_py_test( - name = "linear_operator_block_diag_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_block_diag_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], - shard_count = 5, - tags = ["noasan"], -) diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 14cc3b2b497..a262a099cf8 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -18,10 +18,14 @@ See the @{$python/contrib.linalg} guide. @@LinearOperator @@LinearOperatorBlockDiag +@@LinearOperatorCirculant +@@LinearOperatorCirculant2D +@@LinearOperatorCirculant3D @@LinearOperatorDiag @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@LinearOperatorFullMatrix +@@LinearOperatorKronecker @@LinearOperatorLowerTriangular @@LinearOperatorLowRankUpdate @@LinearOperatorComposition @@ -35,12 +39,14 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import * from tensorflow.python.ops.linalg.linear_operator import * +from tensorflow.python.ops.linalg.linear_operator_block_diag import * +from tensorflow.python.ops.linalg.linear_operator_circulant import * from tensorflow.python.ops.linalg.linear_operator_composition import * from tensorflow.python.ops.linalg.linear_operator_diag import * from tensorflow.python.ops.linalg.linear_operator_full_matrix import * from tensorflow.python.ops.linalg.linear_operator_identity import * +from tensorflow.python.ops.linalg.linear_operator_kronecker import * from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index ac50699f598..b5741967ab5 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -39,8 +39,8 @@ from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import googletest _MAX_ITERATIONS = 100 -_SHARD_NUMBERS = [None, 1, 3, 10] -_NUM_LOSS_PARTITIONS = [2, 4] +_SHARD_NUMBERS = [None, 1, 3] +_NUM_LOSS_PARTITIONS = [4] def make_example_proto(feature_dict, target, value=1.0): @@ -105,11 +105,13 @@ def make_example_dict(example_protos, example_weights): def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): random.seed(1) + sparse_features = [ SparseFeatureColumn( - [int(i / num_non_zero) for i in range(num_examples * num_non_zero)], - [int(random.random() * dim) for _ in range( - num_examples * num_non_zero)], + [i for i in range(num_examples) for _ in range(num_non_zero)], [ + i for _ in range(num_examples) + for i in random.sample(range(dim), num_non_zero) + ], [num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)]) ] examples_dict = dict( @@ -289,6 +291,34 @@ class SdcaWithLogisticLossTest(SdcaModelTest): # It would be 0.01 without shuffling and 0.02 with adaptive sampling. self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-3) + def testSparseDuplicate(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0] * 5, + 'gender': [0] * 5 + }, 0), + make_example_proto({ + 'age': [1] * 5, + 'gender': [1] * 5 + }, 1), + ] + example_weights = [1.0, 1.0] + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + train_op = lr.minimize() + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'Duplicate'): + train_op.run() + def testDistributedSimple(self): # Setup test data example_protos = [ diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index ec726bbed41..5015fb08481 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -49,6 +49,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): default_value, empty_key, num_shards=1, + checkpoint=True, name='ShardedMutableHashTable'): with ops.name_scope(name, 'sharded_mutable_hash_table') as scope: super(ShardedMutableDenseHashTable, self).__init__(key_dtype, @@ -61,6 +62,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): value_dtype=value_dtype, default_value=default_value, empty_key=empty_key, + checkpoint=checkpoint, name='%s-%d-of-%d' % (name, i + 1, num_shards))) self._table_shards = table_shards # TODO(andreasst): add a value_shape() method to LookupInterface diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 5d4572bf6c7..12039ecc6f3 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -37,18 +37,18 @@ class SDCAOptimizer(object): Example usage: ```python - real_feature_column = real_valued_column(...) - sparse_feature_column = sparse_column_with_hash_bucket(...) - sdca_optimizer = linear.SDCAOptimizer(example_id_column='example_id', - num_loss_partitions=1, - num_table_shards=1, - symmetric_l2_regularization=2.0) - classifier = tf.contrib.learn.LinearClassifier( - feature_columns=[real_feature_column, sparse_feature_column], - weight_column_name=..., - optimizer=sdca_optimizer) - classifier.fit(input_fn_train, steps=50) - classifier.evaluate(input_fn=input_fn_eval) + real_feature_column = real_valued_column(...) + sparse_feature_column = sparse_column_with_hash_bucket(...) + sdca_optimizer = linear.SDCAOptimizer(example_id_column='example_id', + num_loss_partitions=1, + num_table_shards=1, + symmetric_l2_regularization=2.0) + classifier = tf.contrib.learn.LinearClassifier( + feature_columns=[real_feature_column, sparse_feature_column], + weight_column_name=..., + optimizer=sdca_optimizer) + classifier.fit(input_fn_train, steps=50) + classifier.evaluate(input_fn=input_fn_eval) ``` Here the expectation is that the `input_fn_*` functions passed to train and @@ -198,6 +198,14 @@ class SDCAOptimizer(object): example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1]) flat_ids = array_ops.reshape(id_tensor.values, [-1]) + # Prune invalid IDs (< 0) from the flat_ids, example_ids, and + # weight_tensor. These can come from looking up an OOV entry in the + # vocabulary (default value being -1). + is_id_valid = math_ops.greater_equal(flat_ids, 0) + flat_ids = array_ops.boolean_mask(flat_ids, is_id_valid) + example_ids = array_ops.boolean_mask(example_ids, is_id_valid) + weight_tensor = array_ops.boolean_mask(weight_tensor, is_id_valid) + projection_length = math_ops.reduce_max(flat_ids) + 1 # project ids based on example ids so that we can dedup ids that # occur multiple times for a single example. diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 9c4533079c7..55b984f260e 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,6 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") -exports_files(["LICENSE"]) - exports_files(glob([ "testdata/*.bin", "testdata/*.pb", @@ -92,6 +90,8 @@ cc_library( deps = [":context"], ) +exports_files(["builtin_ops.h"]) + cc_library( name = "string", hdrs = [ @@ -112,6 +112,7 @@ cc_library( "interpreter.cc", "model.cc", "nnapi_delegate.cc", + "op_resolver.cc", "optional_debug_tools.cc", ], hdrs = [ @@ -122,6 +123,7 @@ cc_library( "interpreter.h", "model.h", "nnapi_delegate.h", + "op_resolver.h", "optional_debug_tools.h", ], copts = tflite_copts(), @@ -137,6 +139,7 @@ cc_library( "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/profiling:profiler", "//tensorflow/contrib/lite/schema:schema_fbs", ], ) @@ -223,6 +226,18 @@ cc_test( ], ) +# Test OpResolver. +cc_test( + name = "op_resolver_test", + size = "small", + srcs = ["op_resolver_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + # Test the C extension API code. cc_test( name = "context_test", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index b4504f246a0..cc8a8035d1d 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -1,4 +1,3 @@ - # Find where we're running from, so we can store generated files here. ifeq ($(origin MAKEFILE_DIR), undefined) MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) @@ -30,7 +29,7 @@ GENDIR := $(MAKEFILE_DIR)/gen/obj/ CXX := $(CC_PREFIX)gcc CXXFLAGS := --std=c++11 -O3 -DNDEBUG CC := $(CC_PREFIX)gcc -CFLAGS := -O3 -DNDEBUG +CCFLAGS := -O3 -DNDEBUG LDOPTS := LDOPTS += -L/usr/local/lib ARFLAGS := -r @@ -69,12 +68,12 @@ LIB_NAME := libtensorflow-lite.a LIB_PATH := $(LIBDIR)$(LIB_NAME) # A small example program that shows how to link against the library. -BENCHMARK_PATH := $(BINDIR)benchmark_model +MINIMAL_PATH := $(BINDIR)minimal -BENCHMARK_SRCS := \ -tensorflow/contrib/lite/tools/benchmark_model.cc -BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) +MINIMAL_SRCS := \ +tensorflow/contrib/lite/examples/minimal/minimal.cc +MINIMAL_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS)))) # What sources we want to compile, must be kept in sync with the main Bazel # build files. @@ -90,7 +89,8 @@ $(wildcard tensorflow/contrib/lite/kernels/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ $(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ -$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) +$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) \ +$(wildcard tensorflow/contrib/lite/downloads/fft2d/fftsg.c) # Remove any duplicates. CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) CORE_CC_EXCLUDE_SRCS := \ @@ -99,7 +99,7 @@ $(wildcard tensorflow/contrib/lite/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ $(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ -$(BENCHMARK_SRCS) +$(MINIMAL_SRCS) # Filter out all the excluded files. TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # File names of the intermediate files target compilation generates. @@ -118,17 +118,17 @@ $(OBJDIR)%.o: %.c $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ # The target that's compiled if there's no command-line arguments. -all: $(LIB_PATH) $(BENCHMARK_PATH) +all: $(LIB_PATH) $(MINIMAL_PATH) # Gathers together all the objects we've compiled into a single '.a' archive. $(LIB_PATH): $(LIB_OBJS) @mkdir -p $(dir $@) $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) -$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH) +$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH) @mkdir -p $(dir $@) $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \ + -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \ $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) # Gets rid of all generated files. diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md new file mode 100644 index 00000000000..8fd63d5cee7 --- /dev/null +++ b/tensorflow/contrib/lite/RELEASE.md @@ -0,0 +1,8 @@ +# Release 0.1.7 + +* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit + fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). +* To reproduce the iOS library, it's required to cherry pick git commit + f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. +* The code is based on TensorFlow 1.8.0 release candidate and it's very close + to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index b8f6b7fd59a..85216776823 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -124,19 +124,19 @@ def tf_to_tflite(name, src, options, out): out: name of the output flatbuffer file. """ - toco = "//tensorflow/contrib/lite/toco:toco" + toco_cmdline = " ".join([ + "//tensorflow/contrib/lite/toco:toco", + "--input_format=TENSORFLOW_GRAPHDEF", + "--output_format=TFLITE", + ("--input_file=$(location %s)" % src), + ("--output_file=$(location %s)" % out), + ] + options ) native.genrule( name = name, - srcs=[src, options], + srcs=[src], outs=[out], - cmd = ("$(location %s) " + - " --input_file=$(location %s) " + - " --output_file=$(location %s) " + - " --input_format=TENSORFLOW_GRAPHDEF" + - " --output_format=TFLITE" + - " `cat $(location %s)`") - % (toco, src, out, options), - tools= [toco], + cmd = toco_cmdline, + tools= ["//tensorflow/contrib/lite/toco:toco"], ) def tflite_to_json(name, src, out): diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 4910c89eaeb..8660c653ae4 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -161,6 +161,9 @@ typedef struct { typedef struct { } TfLitePadParams; +typedef struct { +} TfLitePadV2Params; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -227,6 +230,12 @@ typedef struct { TfLiteType output_type; } TfLiteArgMaxParams; +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; +} TfLiteTransposeConvParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 859bc7ab70d..7e285186f45 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -33,6 +33,7 @@ typedef enum { kTfLiteBuiltinDepthwiseConv2d = 4, kTfLiteBuiltinDequantize = 6, kTfLiteBuiltinEmbeddingLookup = 7, + kTfLiteBuiltinFloor = 8, kTfLiteBuiltinFullyConnected = 9, kTfLiteBuiltinHashtableLookup = 10, kTfLiteBuiltinL2Normalization = 11, @@ -83,6 +84,15 @@ typedef enum { kTfLiteBuiltinArgMax = 56, kTfLiteBuiltinMinimum = 57, kTfLiteBuiltinLess = 58, + kTfLiteBuiltinNeg = 59, + kTfLiteBuiltinPadv2 = 60, + kTfLiteBuiltinGreater = 61, + kTfLiteBuiltinGreaterEqual = 62, + kTfLiteBuiltinLessEqual = 63, + kTfLiteBuiltinSelect = 64, + kTfLiteBuiltinSlice = 65, + kTfLiteBuiltinSin = 66, + kTfLiteBuiltinTransposeConv = 67, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 0b38f43cd32..4eb66cc225e 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -275,7 +275,7 @@ typedef struct { typedef struct TfLiteContext { // Number of tensors in the context. - int tensors_size; + size_t tensors_size; // The execution plan contains a list of the node indices in execution // order. execution_plan->size is the current number of nodes. And, @@ -370,13 +370,21 @@ typedef struct _TfLiteRegistration { // Builtin codes. If this kernel refers to a builtin this is the code // of the builtin. This is so we can do marshaling to other frameworks like - // NN API. Note, it is the responsibility of the registration binder to - // set this properly. + // NN API. + // Note: It is the responsibility of the registration binder to set this + // properly. int32_t builtin_code; // Custom op name. If the op is a builtin, this will be null. + // Note: It is the responsibility of the registration binder to set this + // properly. // WARNING: This is an experimental interface that is subject to change. const char* custom_name; + + // The version of the op. + // Note: It is the responsibility of the registration binder to set this + // properly. + int version; } TfLiteRegistration; // WARNING: This is an experimental interface that is subject to change. @@ -397,13 +405,13 @@ typedef struct _TfLiteDelegate { // This can be null if the delegate doesn't use its own buffer. TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, int size); + void* data, size_t size); // Copy the data from raw memory to delegate buffer handle. // This can be null if the delegate doesn't use its own buffer. TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, int size); + void* data, size_t size); // Free the Delegate Buffer Handle. Note: This only frees the handle, but // this doesn't release the underlying resource (e.g. textures). The diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index a93ed201d64..436c3e1d4ca 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,12 +30,15 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once +# the archive has been propagated in mirror.bazel.build. +GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +FFT2D_URL="https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -91,6 +94,7 @@ download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 49280129971..57000072561 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -42,7 +42,6 @@ android_binary( custom_package = "org.tensorflow.lite.demo", inline_constants = 1, manifest = "AndroidManifest.xml", - manifest_merger = "android", nocompress_extensions = [ ".tflite", ], diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index 2a64c1de725..e36218e4f12 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -62,8 +62,8 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, {1, wanted_height, wanted_width, wanted_channels}, quant); ops::builtin::BuiltinOpResolver resolver; - TfLiteRegistration* resize_op = - resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + const TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR, 1); auto* params = reinterpret_cast( malloc(sizeof(TfLiteResizeBilinearParams))); params->align_corners = false; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index a91467d345f..966fcd2a31f 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -70,6 +71,22 @@ TfLiteStatus ReadLabelsFile(const string& file_name, return kTfLiteOk; } +void PrintProfilingInfo(const profiling::ProfileEvent* e, uint32_t op_index, + TfLiteRegistration registration) { + // output something like + // time (ms) , Node xxx, OpCode xxx, symblic name + // 5.352, Node 5, OpCode 4, DEPTHWISE_CONV_2D + + LOG(INFO) << std::fixed << std::setw(10) << std::setprecision(3) + << (e->end_timestamp_us - e->begin_timestamp_us) / 1000.0 + << ", Node " << std::setw(3) << std::setprecision(3) << op_index + << ", OpCode " << std::setw(3) << std::setprecision(3) + << registration.builtin_code << ", " + << EnumNameBuiltinOperator( + static_cast(registration.builtin_code)) + << "\n"; +} + void RunInference(Settings* s) { if (!s->model_name.c_str()) { LOG(ERROR) << "no model file name\n"; @@ -166,19 +183,36 @@ void RunInference(Settings* s) { exit(-1); } + profiling::Profiler* profiler = new profiling::Profiler(); + interpreter->SetProfiler(profiler); + + if (s->profiling) profiler->StartProfiling(); + struct timeval start_time, stop_time; - gettimeofday(&start_time, NULL); + gettimeofday(&start_time, nullptr); for (int i = 0; i < s->loop_count; i++) { if (interpreter->Invoke() != kTfLiteOk) { LOG(FATAL) << "Failed to invoke tflite!\n"; } } - gettimeofday(&stop_time, NULL); + gettimeofday(&stop_time, nullptr); LOG(INFO) << "invoked \n"; LOG(INFO) << "average time: " << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) << " ms \n"; + if (s->profiling) { + profiler->StopProfiling(); + auto profile_events = profiler->GetProfileEvents(); + for (int i = 0; i < profile_events.size(); i++) { + auto op_index = profile_events[i]->event_metadata; + const auto node_and_registration = + interpreter->node_and_registration(op_index); + const TfLiteRegistration registration = node_and_registration->second; + PrintProfilingInfo(profile_events[i], op_index, registration); + } + } + const int output_size = 1000; const size_t num_results = 5; const float threshold = 0.001f; @@ -217,13 +251,14 @@ void RunInference(Settings* s) { void display_usage() { LOG(INFO) << "label_image\n" - << "--accelerated, -a: [0|1], use Android NNAPI or note\n" + << "--accelerated, -a: [0|1], use Android NNAPI or not\n" << "--count, -c: loop interpreter->Invoke() for certain times\n" << "--input_mean, -b: input mean\n" << "--input_std, -s: input standard deviation\n" << "--image, -i: image_name.bmp\n" << "--labels, -l: labels for the model\n" << "--tflite_model, -m: model_name.tflite\n" + << "--profiling, -p: [0|1], profiling or not\n" << "--threads, -t: number of threads\n" << "--verbose, -v: [0|1] print more information\n" << "\n"; @@ -235,21 +270,22 @@ int Main(int argc, char** argv) { int c; while (1) { static struct option long_options[] = { - {"accelerated", required_argument, 0, 'a'}, - {"count", required_argument, 0, 'c'}, - {"verbose", required_argument, 0, 'v'}, - {"image", required_argument, 0, 'i'}, - {"labels", required_argument, 0, 'l'}, - {"tflite_model", required_argument, 0, 'm'}, - {"threads", required_argument, 0, 't'}, - {"input_mean", required_argument, 0, 'b'}, - {"input_std", required_argument, 0, 's'}, - {0, 0, 0, 0}}; + {"accelerated", required_argument, nullptr, 'a'}, + {"count", required_argument, nullptr, 'c'}, + {"verbose", required_argument, nullptr, 'v'}, + {"image", required_argument, nullptr, 'i'}, + {"labels", required_argument, nullptr, 'l'}, + {"tflite_model", required_argument, nullptr, 'm'}, + {"profiling", required_argument, nullptr, 'p'}, + {"threads", required_argument, nullptr, 't'}, + {"input_mean", required_argument, nullptr, 'b'}, + {"input_std", required_argument, nullptr, 's'}, + {nullptr, 0, nullptr, 0}}; /* getopt_long stores the option index here. */ int option_index = 0; - c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options, + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:p:s:t:v:", long_options, &option_index); /* Detect the end of the options. */ @@ -257,15 +293,14 @@ int Main(int argc, char** argv) { switch (c) { case 'a': - s.accel = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.accel = strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'b': - s.input_mean = strtod(optarg, NULL); + s.input_mean = strtod(optarg, nullptr); break; case 'c': - s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.loop_count = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'i': s.input_bmp_name = optarg; @@ -276,16 +311,20 @@ int Main(int argc, char** argv) { case 'm': s.model_name = optarg; break; + case 'p': + s.profiling = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) + break; case 's': - s.input_std = strtod(optarg, NULL); + s.input_std = strtod(optarg, nullptr); break; case 't': s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + optarg, nullptr, 10); break; case 'v': - s.verbose = strtol( // NOLINT(runtime/deprecated_fn) - optarg, (char**)NULL, 10); + s.verbose = + strtol(optarg, nullptr, 10); // NOLINT(runtime/deprecated_fn) break; case 'h': case '?': diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h index 4de32e33fb4..4b48014e1c7 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.h +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -25,6 +25,7 @@ struct Settings { bool verbose = false; bool accel = false; bool input_floating = false; + bool profiling = false; int loop_count = 1; float input_mean = 127.5f; float input_std = 127.5f; diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc new file mode 100644 index 00000000000..106e3b02705 --- /dev/null +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include + +// This is an example that is minimal to read a model +// from disk and perform inference. There is no data being loaded +// that is up to you to add as a user. +// +// NOTE: Do not add any dependencies to this that cannot be built with +// the minimal makefile. This example must remain trivial to build with +// the minimal build tool. +// +// Usage: minimal + +using namespace tflite; + +#define TFLITE_MINIMAL_CHECK(x) \ + if(!(x)) { \ + fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \ + exit(1); \ + } + + +int main(int argc, char *argv[]) { + if(argc != 2) { + fprintf(stderr, "Usage: %s \n"); + return 1; + } + const char* filename = argv[1]; + + // Load model + std::unique_ptr model + = tflite::FlatBufferModel::BuildFromFile(filename); + TFLITE_MINIMAL_CHECK(model != nullptr); + + // Build the interpreter + tflite::ops::builtin::BuiltinOpResolver resolver; + InterpreterBuilder builder(*model.get(), resolver); + std::unique_ptr interpreter; + builder(&interpreter); + TFLITE_MINIMAL_CHECK(interpreter != nullptr); + + // Allocate tensor buffers. + TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk); + + // Fill input buffers + // TODO(user): Insert code to fill input tensors + + // Run inference + TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk); + + // Read output buffers + // TODO(user): Insert getting data out code. + + return 0; +} diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index fe208e47d1a..50cc146a87e 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -29,7 +29,7 @@ interpreter->AllocateTensors(); float* input = interpreter->typed_input_tensor(0); // Fill `input`. interpreter->Invoke(); -float* output = interpreter->type_output_tensor(0); +float* output = interpreter->typed_output_tensor(0); ``` ### Data Alignment diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index d7cc854ebac..972e57f73e8 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -39,7 +39,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); - TfLiteTensor* input = GetInput(context, node, 0); + const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); int num_dims = NumDimensions(input); @@ -54,7 +54,7 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { using namespace tflite; - TfLiteTensor* input = GetInput(context, node,0); + const TfLiteTensor* input = GetInput(context, node,0); TfLiteTensor* output = GetOutput(context, node,0); float* input_data = input->data.f; diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index d8134d5a000..c1c8ef049f6 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -1,28 +1,63 @@ # List of Hosted Models -* [NASNet large](https://storage.googleapis.com/download.tensorflow.org/models/tflite/nasnet_large_2018_03_27.zip) -* [NASNet mobile](https://storage.googleapis.com/download.tensorflow.org/models/tflite/nasnet_mobile_2018_03_27.zip) -* [ResNet v2 101](https://storage.googleapis.com/download.tensorflow.org/models/tflite/resnet_v2_101_2018_03_27.zip) -* [ResNet v2 50](https://storage.googleapis.com/download.tensorflow.org/models/tflite/resnet_v2_50_2018_03_27.zip) -* [Inception ResNet v2](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_resnet_v2_2018_03_27.zip) -* [Inception v4](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v4_2018_03_27.zip) -* [Inception v3 2015](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_2015_2017_11_10.zip) -* [Inception v3 Slim 2016](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) -* [Mobilenet 0.25 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_128_float_2017_11_08.zip) -* [Mobilenet 0.25 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_160_float_2017_11_08.zip) -* [Mobilenet 0.25 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_192_float_2017_11_08.zip) -* [Mobilenet 0.25 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.25_224_float_2017_11_08.zip) -* [Mobilenet 0.50 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_128_float_2017_11_08.zip) -* [Mobilenet 0.50 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_160_float_2017_11_08.zip) -* [Mobilenet 0.50 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_192_float_2017_11_08.zip) -* [Mobilenet 0.50 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.50_224_float_2017_11_08.zip) -* [Mobilenet 0.75 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_128_float_2017_11_08.zip) -* [Mobilenet 0.75 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_160_float_2017_11_08.zip) -* [Mobilenet 0.75 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_192_float_2017_11_08.zip) -* [Mobilenet 0.75 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_0.75_224_float_2017_11_08.zip) -* [Mobilenet 1.0 128 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_128_float_2017_11_08.zip) -* [Mobilenet 1.0 160 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_160_float_2017_11_08.zip) -* [Mobilenet 1.0 192 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_192_float_2017_11_08.zip) -* [Mobilenet 1.0 224 Float](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_float_2017_11_08.zip) -* [Mobilenet 1.0 224 Quant](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) -* [Smart Reply 1.0 Android ](https://storage.googleapis.com/download.tensorflow.org/models/tflite/smartreply_1.0_2017_11_01.zip) +## Image classification (Float Models) + +Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance +------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: +DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms +SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms +NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 72.2% | 90.6% | 261 ms | 389 ms +NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.1% | 95.8% | 6697 ms | 7940 ms +ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms +ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms +Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 76.9% | 93.5% | 1433 ms | 1522 ms +Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 79.6% | 94.6% | 2986 ms | 3139 ms +Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 76.8% | 93.5% | 2731 ms | 2926 ms +Mobilenet_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.5% | 66.3% | 6.2 ms | 13.0 ms +Mobilenet_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.5% | 70.3% | 8.6 ms | 19.5 ms +Mobilenet_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.7% | 72.3% | 12.1 ms | 27.8 ms +Mobilenet_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.8% | 74.2% | 16.2 ms | 37.3 ms +Mobilenet_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.3% | 79.4% | 18.1 ms | 29.9 ms +Mobilenet_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.1% | 81.9% | 26.8 ms | 45.9 ms +Mobilenet_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.6% | 35.6 ms | 65.3 ms +Mobilenet_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.3% | 84.9% | 47.6 ms | 164.2 ms +Mobilenet_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.1% | 83.9% | 34.6 ms | 48.7 ms +Mobilenet_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.3% | 86.0% | 51.3 ms | 75.2 ms +Mobilenet_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.2% | 87.3% | 71.7 ms | 107.0 ms +Mobilenet_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.4% | 88.2% | 95.7 ms | 143.4 ms +Mobilenet_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.8% | 57.4 ms | 76.8 ms +Mobilenet_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms +Mobilenet_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.0% | 89.2% | 118.6 ms | 167.3 ms +Mobilenet_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 70.9% | 89.9% | 160.1 ms | 224.3 ms + +^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph. + +^^ The performance numbers are generated in the benchmark on Pixel-2 using +single thread large core. + +## Image classification (Quantized Models) + +Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance +------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: +Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.9% | 65.8% | 3.7 ms +Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.5% | 69.1% | 5.5 ms +Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.8% | 71.9% | 7.9 ms +Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 73.8% | 10.4 ms +Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.9% | 8.8 ms +Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 81.3% | 13.0 ms +Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.4% | 83.2% | 18.3 ms +Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 62.2% | 84.5% | 24.7 ms +Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 59.8% | 82.8% | 16.2 ms +Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 63.9% | 85.5% | 24.3 ms +Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.2% | 87.1% | 33.8 ms +Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 67.9% | 88.1% | 45.4 ms +Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 64.0% | 85.5% | 24.9 ms +Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.3% | 87.7% | 37.4 ms +Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.0% | 88.9% | 51.9 ms +Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 69.7% | 89.5% | 70.2 ms + +## Other models + +Model | TF Lite FlatBuffer +----------------------- | :----------------: +Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md index 7a3a231626d..ab507893074 100644 --- a/tensorflow/contrib/lite/g3doc/rpi.md +++ b/tensorflow/contrib/lite/g3doc/rpi.md @@ -32,7 +32,7 @@ This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc v Log in to you RPI, install the toolchain. ```bash -sudo apt-get instal build-essential +sudo apt-get install build-essential ``` First, clone this TensorFlow repository. Run this at the root of the repository: diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 203924f03d3..d8c46e63315 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -132,10 +132,7 @@ TensorFlow operation not listed above are likely unsupported. Notably, the following common ops are not supported at the moment: * [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space) -* [tf.floor](https://www.tensorflow.org/api_docs/python/tf/floor) -* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) * [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear) -* [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice) * [tf.tanh](https://www.tensorflow.org/api_docs/python/tf/tanh) ## TensorFlow Lite Operations @@ -223,6 +220,23 @@ Options { } ``` +**CONV_2D_TRANSPOSE** + +``` +Inputs { + 0: output_shape + 1: filter + 2: 4D tensor +} +Outputs { + 0: the transpose (gradient) of conv2d +} +Options { + padding: SAME|VALID + stride_w,stride_h: stride of the filter window +} +``` + **DEPTHWISE_CONV_2D** ``` @@ -254,6 +268,17 @@ Outputs { } ``` +**FLOOR** + +``` +inputs { + 0: tensor +} +outputs: { + 0: result of computing element-wise floor of the input tensor +} +``` + **FULLY_CONNECTED** ``` @@ -271,6 +296,45 @@ Options { } ``` +**GATHER** + +``` +Inputs { + 0: params tensor + 1: indices tensor + 2: axis tensor (optional) +} +Outputs { + 0: a tensor with same type as the params tensor. +} +``` + +**GREATER** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than the corresponding element of the second tensor. +} +``` + +**GREATER_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + greater than or equal to the corresponding element of the second tensor. +} +``` + **L2_NORMALIZATION** ``` @@ -315,6 +379,19 @@ Outputs { } ``` +**LESS_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is less + than or equal to the corresponding element of the second tensor. +} +``` + **LOCAL_RESPONSE_NORMALIZATION** ``` @@ -387,6 +464,17 @@ Options { } ``` +**NEG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: elementwise negation of the input tensor +} +``` + **PAD** ``` @@ -463,6 +551,19 @@ Options { } ``` +**SLICE** + +``` +Inputs { + 0: tensor + 1: 1D tensor + 2: 1D tensor +} +Outputs { + 0: slice of the input tensor of the given size from the given begin index. +} +``` + **SOFTMAX** ``` @@ -548,7 +649,7 @@ Outputs { 0: slice of the input tensor of the given size } Options { - begin_mask: mask for begin indicies + begin_mask: mask for begin indices end_mask: mask for end indices shrink_axis_mask: mask that indicates which dimensions to remove } @@ -563,7 +664,7 @@ Inputs { } Outputs { 0: k largest element along each last dimensional slice - 1: indicies of values within the last dimension of the input ensor + 1: indices of values within the last dimension of the input ensor } ``` @@ -579,6 +680,20 @@ Outputs { } ``` +**SELECT** + +``` +Inputs { + 0: tensor + 1: tensor + 2: tensor +} +Outputs { + 0: tensor that contains the elementwise values of 'tensor 1' if the + corresponding value of 'tensor 0' is true or the value of 'tensor 2' if false. +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index f2586546088..ebb0aedc200 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/interpreter.h" + #include #include #include #include + #include "tensorflow/contrib/lite/arena_planner.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/contrib/lite/profiling/profiler.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/util.h" @@ -122,7 +125,8 @@ Interpreter::~Interpreter() { for (int i = 0; i < context_.tensors_size; i++) { TfLiteTensor* tensor = &context_.tensors[i]; - if (tensor->buffer_handle != kTfLiteNullBufferHandle) { + if (tensor->buffer_handle != kTfLiteNullBufferHandle && + tensor->delegate->FreeBufferHandle != nullptr) { tensor->delegate->FreeBufferHandle(tensor->delegate, &tensor->buffer_handle); } @@ -245,11 +249,8 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( // Initialize the output tensors's delegate-related fields. for (int tensor_index : subgraph.output_tensors) { TfLiteTensor* tensor = &tensors_[tensor_index]; - TF_LITE_ENSURE_EQ(&context_, tensor->delegate, nullptr); - TF_LITE_ENSURE_EQ(&context_, tensor->buffer_handle, - kTfLiteNullBufferHandle); - // buffer_handle will be filled in delegate's `Prepare` - // function. + TF_LITE_ENSURE(&context_, tensor->delegate == nullptr || + tensor->delegate == delegate); tensor->delegate = delegate; } @@ -308,7 +309,12 @@ TfLiteStatus Interpreter::CheckTensorIndices(const char* label, for (int i = 0; i < length; i++) { int index = indices[i]; - if (index < kOptionalTensor || index >= context_.tensors_size) { + // Continue if index == kOptionalTensor before additional comparisons below, + // size_t(-1) is always >= context_tensors_size. + if (index == kOptionalTensor) { + continue; + } + if (index < 0 || static_cast(index) >= context_.tensors_size) { ReportError(&context_, "Invalid tensor index %d in %s\n", index, label); consistent_ = false; return kTfLiteError; @@ -318,7 +324,7 @@ TfLiteStatus Interpreter::CheckTensorIndices(const char* label, } TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, - int dims_size, size_t* bytes) { + size_t dims_size, size_t* bytes) { // TODO(aselle): Check for overflow here using overflow.h in TensorFlow // MultiplyWithoutOverflow. TF_LITE_ENSURE(&context_, bytes != nullptr); @@ -547,6 +553,7 @@ TfLiteStatus Interpreter::Invoke() { TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; + SCOPED_OPERATOR_PROFILE(profiler_, node_index); // TODO(ycling): This is an extra loop through inputs to check if the data // need to be copied from Delegate buffer to raw memory, which is often not @@ -570,6 +577,12 @@ TfLiteStatus Interpreter::Invoke() { } } + if (!allow_buffer_handle_output_) { + for (int tensor_index : outputs_) { + EnsureTensorDataIsReadable(tensor_index); + } + } + return status; } @@ -638,7 +651,7 @@ TfLiteStatus Interpreter::GetNodeAndRegistration( } TfLiteStatus Interpreter::SetTensorParametersReadOnly( - int tensor_index, TfLiteType type, const char* name, const int rank, + int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation) { if (state_ == kStateInvokableAndImmutable) { @@ -684,7 +697,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly( // bytes. The lifetime of buffer must be ensured to be greater or equal // to Interpreter. TfLiteStatus Interpreter::SetTensorParametersReadWrite( - int tensor_index, TfLiteType type, const char* name, const int rank, + int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization) { if (state_ == kStateInvokableAndImmutable) { ReportError( diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index df67cce9de5..7315d836068 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -20,10 +20,12 @@ limitations under the License. #include #include #include + #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/memory_planner.h" +#include "tensorflow/contrib/lite/profiling/profiler.h" namespace tflite { @@ -148,7 +150,7 @@ class Interpreter { }; TfLiteStatus SetTensorParametersReadOnly( - int tensor_index, TfLiteType type, const char* name, const int rank, + int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization, const char* buffer, size_t bytes, const Allocation* allocation = nullptr); @@ -163,7 +165,7 @@ class Interpreter { dims.data(), quantization); } TfLiteStatus SetTensorParametersReadWrite( - int tensor_index, TfLiteType type, const char* name, const int rank, + int tensor_index, TfLiteType type, const char* name, const size_t rank, const int* dims, TfLiteQuantizationParams quantization); // Functions to access tensor data @@ -187,10 +189,10 @@ class Interpreter { } // Return the number of tensors in the model. - int tensors_size() const { return context_.tensors_size; } + size_t tensors_size() const { return context_.tensors_size; } // Return the number of ops in the model. - int nodes_size() const { return nodes_and_registration_.size(); } + size_t nodes_size() const { return nodes_and_registration_.size(); } // WARNING: Experimental interface, subject to change const std::vector& execution_plan() const { return execution_plan_; } @@ -199,7 +201,7 @@ class Interpreter { // Overrides execution plan. This bounds checks indices sent in. TfLiteStatus SetExecutionPlan(const std::vector& new_plan); - // Get a tensor data structure. + // Get a mutable tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { @@ -208,9 +210,14 @@ class Interpreter { return &context_.tensors[tensor_index]; } + // Get an immutable tensor data structure. + const TfLiteTensor* tensor(int tensor_index) const { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + // Get a pointer to an operation and registration data structure if in bounds. - // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this - // read/write access to structure const std::pair* node_and_registration( int node_index) const { if (node_index >= nodes_and_registration_.size() || node_index < 0) @@ -218,7 +225,8 @@ class Interpreter { return &nodes_and_registration_[node_index]; } - // Perform a checked cast to the appropriate tensor type. + // Perform a checked cast to the appropriate tensor type (mutable pointer + // version). template T* typed_tensor(int tensor_index) { if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { @@ -229,20 +237,46 @@ class Interpreter { return nullptr; } - // Return a pointer into the data of a given input tensor. The given index - // must be between 0 and inputs().size(). + // Perform a checked cast to the appropriate tensor type (immutable pointer + // version). + template + const T* typed_tensor(int tensor_index) const { + if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a mutable pointer into the data of a given input tensor. The given + // index must be between 0 and inputs().size(). template T* typed_input_tensor(int index) { return typed_tensor(inputs_[index]); } - // Return a pointer into the data of a given output tensor. The given index - // must be between 0 and outputs().size(). + // Return an immutable pointer into the data of a given input tensor. The + // given index must be between 0 and inputs().size(). + template + const T* typed_input_tensor(int index) const { + return typed_tensor(inputs_[index]); + } + + // Return a mutable pointer into the data of a given output tensor. The given + // index must be between 0 and outputs().size(). template T* typed_output_tensor(int index) { return typed_tensor(outputs_[index]); } + // Return an immutable pointer into the data of a given output tensor. The + // given index must be between 0 and outputs().size(). + template + const T* typed_output_tensor(int index) const { + return typed_tensor(outputs_[index]); + } + // Change the dimensionality of a given tensor. Note, this is only acceptable // for tensor indices that are inputs. // Returns status of failure or success. @@ -282,6 +316,7 @@ class Interpreter { // Ensure the data in `tensor.data` is readable. In case delegate is used, // it might require to copy the data from delegate buffer to raw memory. + // WARNING: This is an experimental API and subject to change. TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { TF_LITE_ENSURE(&context_, tensor_index < tensors_size()); TfLiteTensor* tensor = &tensors_[tensor_index]; @@ -320,6 +355,10 @@ class Interpreter { TfLiteBufferHandle* buffer_handle, TfLiteDelegate** delegate); + void SetProfiler(profiling::Profiler* profiler) { profiler_ = profiler; } + + profiling::Profiler* GetProfiler() { return profiler_; } + // The default capacity of `tensors_` vector. static constexpr int kTensorsReservedCapacity = 128; // The capacity headroom of `tensors_` vector before calling ops' @@ -328,6 +367,18 @@ class Interpreter { // pointers to existing tensors. static constexpr int kTensorsCapacityHeadroom = 16; + // Set if buffer handle output is allowed. + // + // When using hardware delegation, Interpreter will make the data of output + // tensors available in `tensor->data` by default. If the application can + // consume the buffer handle directly (e.g. reading output from OpenGL + // texture), it can set this flag to false, so Interpreter won't copy the data + // from buffer handle to CPU memory. + // WARNING: This is an experimental API and subject to change. + void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { + allow_buffer_handle_output_ = allow_buffer_handle_output; + } + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -385,7 +436,7 @@ class Interpreter { // Compute the number of bytes required to represent a tensor with dimensions // specified by the array dims (of length dims_size). Returns the status code // and bytes. - TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, + TfLiteStatus BytesRequired(TfLiteType type, const int* dims, size_t dims_size, size_t* bytes); // Request an tensor be resized implementation. If the given tensor is of @@ -446,7 +497,7 @@ class Interpreter { // tensors. After calling this function, adding `kTensorsCapacityHeadroom` // more tensors won't invalidate the pointer to existing tensors. void EnsureTensorsVectorCapacity() { - const int required_capacity = tensors_size() + kTensorsCapacityHeadroom; + const size_t required_capacity = tensors_size() + kTensorsCapacityHeadroom; if (required_capacity > tensors_.capacity()) { tensors_.reserve(required_capacity); context_.tensors = tensors_.data(); @@ -518,6 +569,11 @@ class Interpreter { std::unique_ptr nnapi_delegate_; std::unique_ptr memory_planner_; + + bool allow_buffer_handle_output_ = false; + + // Profiler for this interpreter instance. + profiling::Profiler* profiler_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 131e0880798..453c1ada1cf 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -887,15 +887,15 @@ class TestDelegate : public ::testing::Test { TfLiteIntArrayFree(nodes_to_separate); return kTfLiteOk; }; - delegate_.CopyToBufferHandle = [](TfLiteDelegate* delegate, - TfLiteBufferHandle buffer_handle, - void* data, int size) -> TfLiteStatus { + delegate_.CopyToBufferHandle = + [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, + void* data, size_t size) -> TfLiteStatus { // TODO(ycling): Implement tests to test buffer copying logic. return kTfLiteOk; }; delegate_.CopyFromBufferHandle = [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, - void* data, int size) -> TfLiteStatus { + void* data, size_t size) -> TfLiteStatus { // TODO(ycling): Implement tests to test buffer copying logic. return kTfLiteOk; }; diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1dda55b8edf..593af81a18a 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -1,7 +1,9 @@ # Description: # TensorFlow Lite Java API. -package(default_visibility = ["//visibility:private"]) +package(default_visibility = [ + "//tensorflow/contrib/lite/java/ovic:__pkg__", +]) licenses(["notice"]) # Apache 2.0 @@ -46,23 +48,6 @@ android_library( ], ) -java_library( - name = "ovicbenchmarkerlib", - srcs = [ - "ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java", - "ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", - ], - javacopts = JAVACOPTS, - visibility = ["//visibility:public"], - deps = [ - ":libtensorflowlite_jni.so", - ":tensorflowlite_java", - "//tensorflow/contrib/lite/java/src/main/native", - "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", - "@org_checkerframework_qual", - ], -) - java_library( name = "tensorflowlitelib", srcs = glob( @@ -165,28 +150,6 @@ java_test( ], ) -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - ], - javacopts = JAVACOPTS, - test_class = "org.tensorflow.ovic.OvicClassifierTest", - visibility = ["//visibility:public"], - deps = [ - ":ovicbenchmarkerlib", - "@com_google_truth", - "@junit", - ], -) - filegroup( name = "libtensorflowlite_jni", srcs = select({ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml index ba63dce5d9a..95b6b7016f2 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -31,6 +31,7 @@ android:theme="@style/MaterialTheme"> diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 300786c3ca0..4f5662bc2d1 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -47,6 +47,8 @@ import android.os.HandlerThread; import android.support.annotation.NonNull; import android.support.v13.app.FragmentCompat; import android.support.v4.content.ContextCompat; +import android.text.SpannableString; +import android.text.SpannableStringBuilder; import android.util.Log; import android.util.Size; import android.view.LayoutInflater; @@ -54,6 +56,9 @@ import android.view.Surface; import android.view.TextureView; import android.view.View; import android.view.ViewGroup; +import android.widget.CompoundButton; +import android.widget.NumberPicker; +import android.widget.ToggleButton; import android.widget.TextView; import android.widget.Toast; import java.io.IOException; @@ -82,6 +87,8 @@ public class Camera2BasicFragment extends Fragment private boolean runClassifier = false; private boolean checkedPermissions = false; private TextView textView; + private ToggleButton toggle; + private NumberPicker np; private ImageClassifier classifier; /** Max preview width that is guaranteed by Camera2 API */ @@ -202,14 +209,21 @@ public class Camera2BasicFragment extends Fragment * * @param text The message to show */ - private void showToast(final String text) { + private void showToast(String s) { + SpannableStringBuilder builder = new SpannableStringBuilder(); + SpannableString str1 = new SpannableString(s); + builder.append(str1); + showToast(builder); + } + + private void showToast(SpannableStringBuilder builder) { final Activity activity = getActivity(); if (activity != null) { activity.runOnUiThread( new Runnable() { @Override public void run() { - textView.setText(text); + textView.setText(builder, TextView.BufferType.SPANNABLE); } }); } @@ -289,6 +303,24 @@ public class Camera2BasicFragment extends Fragment public void onViewCreated(final View view, Bundle savedInstanceState) { textureView = (AutoFitTextureView) view.findViewById(R.id.texture); textView = (TextView) view.findViewById(R.id.text); + toggle = (ToggleButton) view.findViewById(R.id.button); + + toggle.setOnCheckedChangeListener(new CompoundButton.OnCheckedChangeListener() { + public void onCheckedChanged(CompoundButton buttonView, boolean isChecked) { + classifier.setUseNNAPI(isChecked); + } + }); + + np = (NumberPicker) view.findViewById(R.id.np); + np.setMinValue(1); + np.setMaxValue(10); + np.setWrapSelectorWheel(true); + np.setOnValueChangedListener(new NumberPicker.OnValueChangeListener() { + @Override + public void onValueChange(NumberPicker picker, int oldVal, int newVal){ + classifier.setNumThreads(newVal); + } + }); } /** Load the model and labels. */ @@ -659,8 +691,9 @@ public class Camera2BasicFragment extends Fragment showToast("Uninitialized Classifier or invalid context."); return; } + SpannableStringBuilder textToShow = new SpannableStringBuilder(); Bitmap bitmap = textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY()); - String textToShow = classifier.classifyFrame(bitmap); + classifier.classifyFrame(bitmap, textToShow); bitmap.recycle(); showToast(textToShow); } diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index c57bb348c5b..7bb6afd9d8b 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -19,10 +19,11 @@ import android.app.Activity; import android.content.res.AssetFileDescriptor; import android.graphics.Bitmap; import android.os.SystemClock; +import android.text.SpannableString; +import android.text.SpannableStringBuilder; +import android.text.style.ForegroundColorSpan; +import android.text.style.RelativeSizeSpan; import android.util.Log; - -import org.tensorflow.lite.Interpreter; - import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; @@ -37,11 +38,15 @@ import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; /** * Classifies images with Tensorflow Lite. */ public abstract class ImageClassifier { + // Display preferences + private static final float GOOD_PROB_THRESHOLD = 0.3f; + private static final int SMALL_COLOR = 0xffddaa88; /** Tag for the {@link Log}. */ private static final String TAG = "TfLiteCameraDemo"; @@ -99,10 +104,12 @@ public abstract class ImageClassifier { } /** Classifies a frame from the preview stream. */ - String classifyFrame(Bitmap bitmap) { + void classifyFrame(Bitmap bitmap, SpannableStringBuilder builder) { + printTopKLabels(builder); + if (tflite == null) { Log.e(TAG, "Image classifier has not been initialized; Skipped."); - return "Uninitialized Classifier."; + builder.append(new SpannableString("Uninitialized Classifier.")); } convertBitmapToByteBuffer(bitmap); // Here's where the magic happens!!! @@ -115,9 +122,10 @@ public abstract class ImageClassifier { applyFilter(); // Print the results. - String textToShow = printTopKLabels(); - textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; - return textToShow; + long duration = endTime - startTime; + SpannableString span = new SpannableString(duration + " ms"); + span.setSpan(new ForegroundColorSpan(android.graphics.Color.LTGRAY), 0, span.length(), 0); + builder.append(span); } void applyFilter() { @@ -142,6 +150,16 @@ public abstract class ImageClassifier { } } + public void setUseNNAPI(Boolean nnapi) { + if (tflite != null) + tflite.setUseNNAPI(nnapi); + } + + public void setNumThreads(int num_threads) { + if (tflite != null) + tflite.setNumThreads(num_threads); + } + /** Closes tflite to release resources. */ public void close() { tflite.close(); @@ -192,7 +210,7 @@ public abstract class ImageClassifier { } /** Prints top-K labels, to be shown in UI as the results. */ - private String printTopKLabels() { + private void printTopKLabels(SpannableStringBuilder builder) { for (int i = 0; i < getNumLabels(); ++i) { sortedLabels.add( new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i))); @@ -200,13 +218,27 @@ public abstract class ImageClassifier { sortedLabels.poll(); } } - String textToShow = ""; + final int size = sortedLabels.size(); - for (int i = 0; i < size; ++i) { + for (int i = 0; i < size; i++) { Map.Entry label = sortedLabels.poll(); - textToShow = String.format("\n%s: %4.2f", label.getKey(), label.getValue()) + textToShow; + SpannableString span = + new SpannableString(String.format("%s: %4.2f\n", label.getKey(), label.getValue())); + int color; + // Make it white when probability larger than threshold. + if (label.getValue() > GOOD_PROB_THRESHOLD) { + color = android.graphics.Color.WHITE; + } else { + color = SMALL_COLOR; + } + // Make first item bigger. + if (i == size - 1) { + float sizeScale = (i == size - 1) ? 1.25f : 0.8f; + span.setSpan(new RelativeSizeSpan(sizeScale), 0, span.length(), 0); + } + span.setSpan(new ForegroundColorSpan(color), 0, span.length(), 0); + builder.insert(0, span); } - return textToShow; } /** diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png index c22509d8dfc..52cf2ab9529 100644 Binary files a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png index d68af39186c..b75f892c462 100644 Binary files a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png index 15e419b7ccd..36e14c48d14 100644 Binary files a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png index 342ce34e166..06dd2a740ec 100644 Binary files a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png new file mode 100644 index 00000000000..b94bcfc081e Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml index a84f1bbfa0c..ef8a9e08450 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -13,38 +13,55 @@ See the License for the specific language governing permissions and limitations under the License. --> - - + + + + + + + android:scaleType="centerInside" + android:src="@drawable/logo"/> - + android:textOff="@string/tflite" + android:textOn="@string/nnapi"/> + - + + + - - - diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml new file mode 100644 index 00000000000..ddb099a950c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml index 15305c436e0..e567009a424 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -14,32 +14,81 @@ limitations under the License. --> + android:layout_height="match_parent" + android:background="#bb7700"> + + + android:layout_alignParentTop="false" + android:background="#bb7700" + android:orientation="vertical" + android:weightSum="100"> - - - + + - + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml index 0a71dbd0e80..7af8f3a98c6 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -16,7 +16,7 @@ --> - TfLiteCameraDemo + TfLite Camera Demo + Threads: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml index a08ec3eb629..29a033bcd43 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml @@ -21,4 +21,6 @@ NN:On NN:Off Use NNAPI + tflite + NNAPI diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml index 3f3bdfb4948..1752b3b5f97 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -14,5 +14,10 @@ limitations under the License. --> - + diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD new file mode 100644 index 00000000000..362d93636f7 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/BUILD @@ -0,0 +1,68 @@ +# Description: +# OVIC Benchmarker Java API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") + +java_test( + name = "OvicClassifierTest", + size = "medium", + srcs = ["src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.ovic.OvicClassifierTest", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + "@com_google_truth", + "@junit", + ], +) + +java_binary( + name = "ovic_validator", + srcs = ["src/main/java/org/tensorflow/ovic/OvicValidator.java"], + data = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + main_class = "org.tensorflow.ovic.OvicValidator", + deps = [ + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java", + ], +) + +android_library( + name = "ovicbenchmarkerlib", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml", + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) + +java_library( + name = "ovicbenchmarkerlib_java", + srcs = [ + "src/main/java/org/tensorflow/ovic/OvicClassifier.java", + "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java", + ], + javacopts = JAVACOPTS, + deps = [ + "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java:tensorflowlite_java", + "//tensorflow/contrib/lite/java/src/main/native", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@org_checkerframework_qual", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/README.md b/tensorflow/contrib/lite/java/ovic/README.md index 76c33838bfe..5efa70987e4 100644 --- a/tensorflow/contrib/lite/java/ovic/README.md +++ b/tensorflow/contrib/lite/java/ovic/README.md @@ -6,7 +6,7 @@ This folder contains building code for track one of the [Low Power ImageNet Reco Follow the steps [here](https://www.tensorflow.org/mobile/tflite/demo_android) to install Tensorflow, Bazel, and the Android NDK and SDK. -## To test the benchmarker: +## Test the benchmarker: The testing utilities helps the developers (you) to make sure that your submissions in TfLite format will be processed as expected in the competition's benchmarking system. @@ -37,47 +37,122 @@ unzip -j /tmp/ovic.zip -d tensorflow/contrib/lite/java/ovic/src/testdata/ You can run test with Bazel as below. This helps to ensure that the installation is correct. ```sh -bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java:OvicClassifierTest --test_output=all +bazel test --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/ovic:OvicClassifierTest --cxxopt=-Wno-all --test_output=all ``` ### Test your submissions -Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it as below. +Once you have a submission that follows the instructions from the [competition site](https://rebootingcomputing.ieee.org/home/sitemap/14-lpirc/80-low-power-image-recognition-challenge-lpirc-2018), you can verify it in two ways: + +#### Validate using randomly generated images + +You can call the validator binary below to verify that your model fits the format requirements. This often helps you to catch size mismatches (e.g. output should be [1, 1001] instead of [1,1,1,1001]). Let say the submission file is located at `/path/to/my_model.lite`, then call: + +```sh +bazel build --cxxopt--std=c++11 //tensorflow/contrib/lite/java/ovic:ovic_validator --cxxopt=-Wno-all +bazel-bin/tensorflow/contrib/lite/java/ovic/ovic_validator /path/to/my_model.lite +``` + +Successful validation should print the following message to terminal: + +``` +Successfully validated /path/to/my_model.lite. + +``` + +#### Test that the model produces sensible outcomes + +You can go a step further to verify that the model produces results as expected. This helps you catch bugs during TOCO conversion (e.g. using the wrong mean and std values). * Move your submission to the testdata folder: -Let say the submission file is located at `/tmp/my_model.lite`, then - ```sh -cp /tmp/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ +cp /path/to/my_model.lite tensorflow/contrib/lite/java/ovic/src/testdata/ ``` * Resize the test image to the resolutions that are expected by your submission: The test images can be found at `tensorflow/contrib/lite/java/ovic/src/testdata/test_image_*.jpg`. You may reuse these images if your image resolutions are 128x128 or 224x224. -* Add your model and test image to the BUILD rule: +* Add your model and test image to the BUILD rule at `tensorflow/contrib/lite/java/ovic/src/testdata/BUILD`: ```JSON -java_test( - name = "OvicClassifierTest", - size = "medium", - srcs = ["ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java"], - data = [ - "ovic/src/testdata/float_model.lite", - "ovic/src/testdata/labels.txt", - "ovic/src/testdata/low_res_model.lite", - "ovic/src/testdata/quantized_model.lite", - "ovic/src/testdata/test_image_128.jpg", - "ovic/src/testdata/test_image_224.jpg", - "ovic/src/testdata/my_model.lite", # <--- Your submission. - "ovic/src/testdata/my_test_image.jpg", # <--- Your test image. - ], - ... +filegroup( + name = "ovic_testdata", + srcs = [ + "@tflite_ovic_testdata//:float_model.lite", + "@tflite_ovic_testdata//:low_res_model.lite", + "@tflite_ovic_testdata//:quantized_model.lite", + "@tflite_ovic_testdata//:test_image_128.jpg", + "@tflite_ovic_testdata//:test_image_224.jpg" + "my_model.lite", # <--- Your submission. + "my_test_image.jpg", # <--- Your test image. + ], + ... ``` * Modify `OvicClassifierTest.java` to test your model. -Change `TEST_IMAGE_PATH` to `testdata/my_test_image.jpg`. If your model runs inference in floating point, change `FLOAT_MODEL_PATH` to `testdata/my_model.lite`. If your model runs [quantized inference](https://www.tensorflow.org/performance/quantization), change `QUANTIZED_MODEL_PATH` to `testdata/my_model.lite`. +Change `TEST_IMAGE_PATH` to `my_test_image.jpg`. Change either `FLOAT_MODEL_PATH` or `QUANTIZED_MODEL_PATH` to `my_model.lite` depending on whether your model runs inference in float or [8-bit](https://www.tensorflow.org/performance/quantization). Now you can run the bazel tests to catch any runtime issues with the submission. + +Note: Please make sure that your submission passes the test. If a submission fails to pass the test it will not be processed by the submission server. + +## Measure on-device latency + +We provide two ways to measure the on-device latency of your submission. The first is through our competition server, which is reliable and repeatable, but is limited to a few trials per day. The second is through the benchmarker Apk, which requires a device and may not be as accurate as the server, but has a fast turn-around and no access limitations. We recommend that the participants use the benchmarker apk for early development, and reserve the competition server for evaluating promising submissions. + +### Running the benchmarker app + +Make sure that you have followed instructions in [Test your submissions](#test-your-submissions) to add your model to the testdata folder and to the corresponding build rules. + +Modify `tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java`: + +* Add your model to the benchmarker apk by changing `MODEL_PATH` and `TEST_IMAGE_PATH` below to your submission and test image. + +``` + private static final String TEST_IMAGE_PATH = "my_test_image.jpg"; + private static final String MODEL_PATH = "my_model.lite"; +``` + +* Adjust the benchmark parameters when needed: + +You can chnage the length of each experiment, and the processor affinity below. `BIG_CORE_MASK` is an integer whose binary encoding represents the set of used cores. This number is phone-specific. For example, Pixel 2 has 8 cores: the 4 little cores are represented by the 4 less significant bits, and the 4 big cores by the 4 more significant bits. Therefore a mask value of 16, or in binary `00010000`, represents using only the first big core. The mask 32, or in binary `00100000` uses the second big core and should deliver identical results as the mask 16 because the big cores are interchangeable. + +``` + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; +``` + +Note: You'll need ROOT access to the phone to change processor affinity. + +* Build and install the app. + +``` +bazel build -c opt --cxxopt=--std=c++11 --cxxopt=-Wno-all //tensorflow/contrib/lite/java/ovic/demo/app:ovic_benchmarker_binary +adb install -r bazel-bin/tensorflow/contrib/lite/java/ovic/demo/app/ovic_benchmarker_binary.apk +``` + +Start the app and click the `Start` button in dark green. The button should turn bright green, signaling that the experiment is running. The benchmarking results will be displayed after about the `WALL_TIME` you specified above. For example: + +``` +my_model.lite: Average latency=158.6ms after 20 runs. +``` + +### Sample latencies + +Note: the benchmarking results can be quite different depending on the background processes running on the phone. A few things that help stabilize the app's readings are placing the phone on a cooling plate, restarting the phone, and shutting down internet access. + +| Model | Pixel 1 latency (ms) | Pixel 2 latency (ms) | +| -------------------- |:---------------------:| --------------------:| +| float_model.lite | 120 | 155 | +| quantized_model.lite | 85 | 74 | +| low_res_model.lite | 4.2 | 4.0 | + +Since Pixel 2 has excellent support for 8-bit quantized models, we strongly recommend you to check out the [quantization training tutorial](https://www.tensorflow.org/performance/quantization). + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml new file mode 100644 index 00000000000..55f2961fd71 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/AndroidManifest.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD new file mode 100644 index 00000000000..83974f4b337 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -0,0 +1,29 @@ +# Sample app for OVIC benchmarking. +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "ovic_benchmarker_binary", + srcs = [ + "OvicBenchmarker.java", + "OvicBenchmarkerActivity.java", + ], + assets = [ + "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", + "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", + ], + assets_dir = "", + custom_package = "ovic.demo.app", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".lite", + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java similarity index 97% rename from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java rename to tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java index d0102883e6b..113ab74a20d 100644 --- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarker.java @@ -1,4 +1,4 @@ -/*Copyright 2018 Google LLC +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -package org.tensorflow.ovic; +package ovic.demo.app; import android.graphics.Bitmap; import android.os.SystemClock; @@ -22,6 +22,8 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; +import org.tensorflow.ovic.OvicClassifier; +import org.tensorflow.ovic.OvicSingleImageResult; /** * Class that benchmarks image classifier models. diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java new file mode 100644 index 00000000000..59457c308ad --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java @@ -0,0 +1,247 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package ovic.demo.app; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.content.res.AssetManager; +import android.graphics.Bitmap; +import android.graphics.BitmapFactory; +import android.os.Bundle; +import android.os.Process; +import android.os.SystemClock; +import android.util.Log; +import android.view.View; +import android.widget.TextView; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.text.DecimalFormat; +import org.tensorflow.ovic.OvicSingleImageResult; + +/** Class that benchmark image classifier models. */ +public class OvicBenchmarkerActivity extends Activity { + /** Tag for the {@link Log}. */ + private static final String TAG = "OvicBenchmarkerActivity"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + private static final String TEST_IMAGE_PATH = "test_image_224.jpg"; + private static final String MODEL_PATH = "float_model.lite"; + /** + * Each bottom press will launch a benchmarking experiment. The experiment stops when either the + * total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS, + * whichever comes first. + */ + /** Wall time for each benchmarking experiment. */ + private static final double WALL_TIME = 3000; + /** Maximum number of iterations in each benchmarking experiment. */ + private static final int MAX_ITERATIONS = 100; + /** Mask for binding to a single big core. Pixel 1 (4), Pixel 2 (16). */ + private static final int BIG_CORE_MASK = 16; + /** Amount of time in milliseconds to wait for affinity to set. */ + private static final int WAIT_TIME_FOR_AFFINITY = 1000; + + /* The model to be benchmarked. */ + private MappedByteBuffer model = null; + private InputStream labelInputStream = null; + private OvicBenchmarker benchmarker; + /** Inference result of each iteration. */ + OvicSingleImageResult iterResult = null; + + private TextView textView = null; + // private Button startButton = null; + private static final DecimalFormat df2 = new DecimalFormat(".##"); + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + + // TextView used to display the progress, for information purposes only. + textView = (TextView) findViewById(R.id.textView); + } + + private Bitmap loadTestBitmap() throws IOException { + InputStream imageStream = getAssets().open(TEST_IMAGE_PATH); + return BitmapFactory.decodeStream(imageStream); + } + + public void initializeTest() throws IOException { + Log.i(TAG, "Initializing benchmarker."); + benchmarker = new OvicBenchmarker(WALL_TIME); + AssetManager am = getAssets(); + AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH); + FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = modelInputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + labelInputStream = am.open(LABEL_PATH); + } + + public Boolean doTestIteration() throws IOException, InterruptedException { + if (benchmarker == null) { + throw new RuntimeException("Benchmarker has not been initialized."); + } + if (benchmarker.shouldStop()) { + return false; + } + if (!benchmarker.readyToTest()) { + Log.i(TAG, "getting ready to test."); + benchmarker.getReadyToTest(labelInputStream, model); + if (!benchmarker.readyToTest()) { + throw new RuntimeException("Failed to get the benchmarker ready."); + } + } + Log.i(TAG, "Going to do test iter."); + // Start testing. + Bitmap testImageBitmap = loadTestBitmap(); + iterResult = benchmarker.doTestIteration(testImageBitmap); + testImageBitmap.recycle(); + if (iterResult == null) { + throw new RuntimeException("Inference failed to produce a result."); + } + Log.i(TAG, iterResult.toString()); + return true; + } + + public void startPressed(View view) throws IOException { + Log.i(TAG, "Start pressed"); + try { + initializeTest(); + } catch (IOException e) { + Log.e(TAG, "Can't initialize benchmarker.", e); + throw e; + } + String displayText = ""; + try { + setProcessorAffinity(BIG_CORE_MASK); + } catch (IOException e) { + Log.e(TAG, e.getMessage()); + displayText = e.getMessage() + "\n"; + } + Log.i(TAG, "Successfully initialized benchmarker."); + int testIter = 0; + Boolean iterSuccess = false; + double totalLatency = 0.0f; + while (testIter < MAX_ITERATIONS) { + try { + iterSuccess = doTestIteration(); + } catch (IOException e) { + Log.e(TAG, "Error during iteration " + testIter); + throw e; + } catch (InterruptedException e) { + Log.e(TAG, "Interrupted at iteration " + testIter); + } + if (!iterSuccess) { + break; + } + testIter++; + totalLatency += (double) iterResult.latency; + } + ; + Log.i(TAG, "Benchmarking finished"); + + if (textView != null) { + if (testIter > 0) { + textView.setText( + displayText + + MODEL_PATH + + ": Average latency=" + + df2.format(totalLatency / testIter) + + "ms after " + + testIter + + " runs."); + } else { + textView.setText("Benchmarker failed to run on more than one images."); + } + } + } + + private static void setProcessorAffinity(int mask) throws IOException { + int myPid = Process.myPid(); + Log.i(TAG, String.format("Setting processor affinity to 0x%02x", mask)); + + String command = String.format("taskset -a -p %x %d", mask, myPid); + try { + Runtime.getRuntime().exec(command).waitFor(); + } catch (InterruptedException e) { + throw new IOException("Interrupted: " + e); + } + + // Make sure set took effect - try for a second to confirm the change took. If not then fail. + long startTimeMs = SystemClock.elapsedRealtime(); + while (true) { + int readBackMask = readCpusAllowedMask(); + if (readBackMask == mask) { + Log.i(TAG, String.format("Successfully set affinity to 0x%02x", mask)); + return; + } + if (SystemClock.elapsedRealtime() > startTimeMs + WAIT_TIME_FOR_AFFINITY) { + throw new IOException( + String.format( + "Core-binding failed: affinity set to 0x%02x but read back as 0x%02x\n" + + "please root device.", + mask, readBackMask)); + } + + try { + Thread.sleep(50); + } catch (InterruptedException e) { + // Ignore sleep interrupted, will sleep again and compare is final cross-check. + } + } + } + + public static int readCpusAllowedMask() throws IOException { + // Determine how many CPUs there are total + final String pathname = "/proc/self/status"; + final String resultPrefix = "Cpus_allowed:"; + File file = new File(pathname); + String line = ""; + String allowedCPU = ""; + Integer allowedMask = null; + BufferedReader bufReader = null; + try { + bufReader = new BufferedReader(new FileReader(file)); + while ((line = bufReader.readLine()) != null) { + if (line.startsWith(resultPrefix)) { + allowedMask = Integer.valueOf(line.substring(resultPrefix.length()).trim(), 16); + allowedCPU = bufReader.readLine(); + break; + } + } + } catch (RuntimeException e) { + throw new IOException( + "Invalid number in " + pathname + " line: \"" + line + "\": " + e.getMessage()); + } finally { + if (bufReader != null) { + bufReader.close(); + } + } + if (allowedMask == null) { + throw new IOException(pathname + " missing " + resultPrefix + " line"); + } + Log.i(TAG, allowedCPU); + return allowedMask; + } +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle new file mode 100644 index 00000000000..c5d19bad89a --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.ovicbenchmarker" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "lite", "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 00000000000..715d1b6d69c Binary files /dev/null and b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 00000000000..9beff0885fd Binary files /dev/null and b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml new file mode 100644 index 00000000000..93f5c6a016b --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/drawable/start_button_color.xml @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml new file mode 100644 index 00000000000..e9d83bae543 --- /dev/null +++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml @@ -0,0 +1,54 @@ + + + + + + +